python/microsoft/0xDeCA10B/simulation/decai/simulation/simulate.py

simulate.py
import json
import logging
import os
import random
import time
from dataclasses import asdict, dataclass
from functools import partial
from itertools import cycle
from logging import Logger
from platform import uname
from queue import PriorityQueue
from threading import Thread
from typing import List

import numpy as np
from bokeh import colors
from bokeh.document import Document
from bokeh.io import export_png
from bokeh.models import AdaptiveTicker, ColumnDataSource, FuncTickFormatter, PrintfTickFormatter
from bokeh.plotting import curdoc, figure
from injector import inject
from tornado import gen
from tqdm import tqdm

from decai.simulation.contract.balances import Balances
from decai.simulation.contract.collab_trainer import CollaborativeTrainer
from decai.simulation.contract.incentive.prediction_market import MarketPhase, PredictionMarket
from decai.simulation.contract.objects import Address, Msg, RejectException, TimeMock
from decai.simulation.data.data_loader import DataLoader
from decai.simulation.data.featuremapping.feature_index_mapper import FeatureIndexMapper


@dataclass
class Agent:
    """
    A user to run in the simulator.
    """
    address: Address
    start_balance: float
    mean_deposit: float
    stdev_deposit: float
    mean_update_wait_s: float
    stdev_update_wait_time: float = 1
    pay_to_call: float = 0
    good: bool = True
    prob_mistake: float = 0
    calls_model: bool = False

    def __post_init__(self):
        assert self.start_balance > self.mean_deposit

    def __lt__(self, other):
        return self.address  <  other.address

    def get_next_deposit(self) -> int:
        while True:
            result = int(random.normalvariate(self.mean_deposit, self.stdev_deposit))
            if result > 0:
                return result

    def get_next_wait_s(self) -> int:
        while True:
            result = int(random.normalvariate(self.mean_update_wait_s, self.stdev_update_wait_time))
            if result >= 1:
                return result


class Simulator(object):
    """
    A simulator for Decentralized & Collaborative AI.
    """

    @inject
    def __init__(self,
                 balances: Balances,
                 data_loader: DataLoader,
                 decai: CollaborativeTrainer,
                 feature_index_mapper: FeatureIndexMapper,
                 logger: Logger,
                 time_method: TimeMock,
                 ):

        self._balances = balances
        self._data_loader = data_loader
        self._decai = decai
        self._feature_index_mapper = feature_index_mapper
        self._logger = logger
        self._time = time_method
        self._warned_about_saving_plot = False

    def save_plot_image(self, plot, plot_save_path):
        try:
            export_png(plot, filename=plot_save_path)
        except Exception as e:
            if self._warned_about_saving_plot:
                return
            show_error_details = True
            message = "Could not save picture of the plot."
            try:
                # Check if in WSL.
                show_error_details = not ('microsoft' in uname().release.lower())
            except:
                pass
            if show_error_details:
                self._logger.exception(message, exc_info=e)
            else:
                self._logger.warning(f"{message} %s", e)
            self._warned_about_saving_plot = True

    def simulate(self,
                 agents: List[Agent],
                 baseline_accuracy: float = None,
                 init_train_data_portion: float = 0.1,
                 pm_test_sets: list = None,
                 accuracy_plot_wait_s=2E5,
                 train_size: int = None, test_size: int = None,
                 filename_indicator: str = None
                 ):
        """
        Run a simulation.

        :param agents: The agents that will interact with the data.
        :param baseline_accuracy: The baseline accuracy of the model.
            Usually the accuracy on a hidden test set when the model is trained with all data.
        :param init_train_data_portion: The portion of the data to initially use for training. Must be [0,1].
        :param pm_test_sets: The test sets for the prediction market incentive mechanism.
        :param accuracy_plot_wait_s: The amount of time to wait in seconds between plotting the accuracy.
        :param train_size: The amount of training data to use.
        :param test_size: The amount of test data to use.
        :param filename_indicator: Path of the filename to create for the run.
        """

        assert 0  < = init_train_data_portion  < = 1

        # Data to save.
        save_data = dict(agents=[asdict(a) for a in agents],
                         baselineAccuracy=baseline_accuracy,
                         initTrainDataPortion=init_train_data_portion,
                         accuracies=[],
                         balances=[],
                         )
        time_for_filenames = int(time.time())
        save_path = f'saved_runs/{time_for_filenames}-{filename_indicator}-simulation_data.json'
        model_save_path = f'saved_runs/{time_for_filenames}-{filename_indicator}-model.json'
        plot_save_path = f'saved_runs/{time_for_filenames}-{filename_indicator}.png'
        self._logger.info("Saving run info to \"%s\".", save_path)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        # Set up plots.
        doc: Document = curdoc()
        doc.title = "DeCAI Simulation"

        plot = figure(title="Balances & Accuracy on Hidden Test Set",
                      )
        plot.width = 800
        plot.height = 600

        plot.xaxis.axis_label = "Time (days)"
        plot.yaxis.axis_label = "Percent"
        plot.title.text_font_size = '20pt'
        plot.xaxis.major_label_text_font_size = '20pt'
        plot.xaxis.axis_label_text_font_size = '20pt'
        plot.yaxis.major_label_text_font_size = '20pt'
        plot.yaxis.axis_label_text_font_size = '20pt'

        plot.xaxis[0].ticker = AdaptiveTicker(base=5 * 24 * 60 * 60)
        plot.xgrid[0].ticker = AdaptiveTicker(base=24 * 60 * 60)

        balance_plot_sources_per_agent = dict()
        good_colors = cycle([
            colors.named.green,
            colors.named.lawngreen,
            colors.named.darkgreen,
            colors.named.limegreen,
        ])
        bad_colors = cycle([
            colors.named.red,
            colors.named.darkred,
        ])
        for agent in agents:
            source = ColumnDataSource(dict(t=[], b=[]))
            assert agent.address not in balance_plot_sources_per_agent
            balance_plot_sources_per_agent[agent.address] = source
            if agent.calls_model:
                color = 'blue'
                line_dash = 'dashdot'
            elif agent.good:
                color = next(good_colors)
                line_dash = 'dotted'
            else:
                color = next(bad_colors)
                line_dash = 'dashed'
            plot.line(x='t', y='b',
                      line_dash=line_dash,
                      line_width=2,
                      source=source,
                      color=color,
                      legend=f"{agent.address} Balance")

        plot.legend.location = 'top_left'
        plot.legend.label_text_font_size = '12pt'

        # JavaScript code.
        plot.xaxis[0].formatter = FuncTickFormatter(code="""
        return (tick / 86400).toFixed(0);
        """)
        plot.yaxis[0].formatter = PrintfTickFormatter(format="%0.1f%%")

        acc_source = ColumnDataSource(dict(t=[], a=[]))
        if baseline_accuracy is not None:
            plot.ray(x=[0], y=[baseline_accuracy * 100], length=0, angle=0, line_width=2,
                     legend=f"Accuracy when trained with all data: {baseline_accuracy * 100:0.1f}%")
        plot.line(x='t', y='a',
                  line_dash='solid',
                  line_width=2,
                  source=acc_source,
                  color='black',
                  legend="Current Accuracy")

        @gen.coroutine
        def plot_cb(agent: Agent, t, b):
            source = balance_plot_sources_per_agent[agent.address]
            source.stream(dict(t=[t], b=[b * 100 / agent.start_balance]))
            save_data['balances'].append(dict(t=t, a=agent.address, b=b))

        @gen.coroutine
        def plot_accuracy_cb(t, a):
            acc_source.stream(dict(t=[t], a=[a * 100]))
            save_data['accuracies'].append(dict(t=t, accuracy=a))

        continuous_evaluation = not isinstance(self._decai.im, PredictionMarket)

        def task():
            (x_train, y_train), (x_test, y_test) = \
                self._data_loader.load_data(train_size=train_size, test_size=test_size)
            classifications = self._data_loader.classifications()
            x_train, x_test, feature_index_mapping = self._feature_index_mapper.map(x_train, x_test)
            x_train_len = x_train.shape[0]
            init_idx = int(x_train_len * init_train_data_portion)
            self._logger.info("Initializing model with %d out of %d samples.",
                              init_idx, x_train_len)
            x_init_data, y_init_data = x_train[:init_idx], y_train[:init_idx]
            x_remaining, y_remaining = x_train[init_idx:], y_train[init_idx:]

            save_model = isinstance(self._decai.im, PredictionMarket) and self._decai.im.reset_model_during_reward_phase
            self._decai.model.init_model(x_init_data, y_init_data, save_model)

            if self._logger.isEnabledFor(logging.DEBUG):
                s = self._decai.model.evaluate(x_init_data, y_init_data)
                self._logger.debug("Initial training data evaluation: %s", s)
                if len(x_remaining) > 0:
                    s = self._decai.model.evaluate(x_remaining, y_remaining)
                    self._logger.debug("Remaining training data evaluation: %s", s)
                else:
                    self._logger.debug("There is no more remaining data to evaluate.")

            self._logger.info("Evaluating initial model.")
            accuracy = self._decai.model.log_evaluation_details(x_test, y_test)
            self._logger.info("Initial test set accuracy: %0.2f%%", accuracy * 100)
            t = self._time()
            doc.add_next_tick_callback(
                partial(plot_accuracy_cb, t=t, a=accuracy))

            q = PriorityQueue()
            random.shuffle(agents)
            for agent in agents:
                self._balances.initialize(agent.address, agent.start_balance)
                q.put((self._time() + agent.get_next_wait_s(), agent))
                doc.add_next_tick_callback(
                    partial(plot_cb, agent=agent, t=t, b=agent.start_balance))

            unclaimed_data = []
            next_data_index = 0
            next_accuracy_plot_time = 1E4
            desc = "Processing agent requests"
            current_time = 0
            with tqdm(desc=desc,
                      unit_scale=True, mininterval=2, unit=" requests",
                      total=len(x_remaining),
                      ) as pbar:
                while not q.empty():
                    # For now assume sending a transaction (editing) is free (no gas)
                    # since it should be relatively cheaper than the deposit required to add data.
                    # It may not be cheaper than calling `report`.

                    if next_data_index >= len(x_remaining):
                        if not continuous_evaluation or len(unclaimed_data) == 0:
                            break

                    current_time, agent = q.get()
                    update_balance_plot = False
                    if current_time > next_accuracy_plot_time:
                        self._logger.debug("Evaluating.")
                        next_accuracy_plot_time += accuracy_plot_wait_s
                        accuracy = self._decai.model.evaluate(x_test, y_test)
                        doc.add_next_tick_callback(
                            partial(plot_accuracy_cb, t=current_time, a=accuracy))

                        if continuous_evaluation:
                            self._logger.debug("Unclaimed data: %d", len(unclaimed_data))
                            pbar.set_description(f"{desc} ({len(unclaimed_data)} unclaimed)")

                        with open(save_path, 'w') as f:
                            json.dump(save_data, f, separators=(',', ':'))
                        self._decai.model.export(model_save_path, classifications,
                                                 feature_index_mapping=feature_index_mapping)

                        if os.path.exists(plot_save_path):
                            os.remove(plot_save_path)
                        self.save_plot_image(plot, plot_save_path)

                    self._time.set_time(current_time)

                    balance = self._balances[agent.address]
                    if balance > 0 and next_data_index  <  len(x_remaining):
                        # Pick data.
                        x, y = x_remaining[next_data_index], y_remaining[next_data_index]

                        if agent.calls_model:
                            # Only call the model if it's good.
                            if random.random()  <  accuracy:
                                update_balance_plot = True
                                self._decai.predict(Msg(agent.address, agent.pay_to_call), x)
                        else:
                            if not agent.good:
                                y = 1 - y
                            if agent.prob_mistake > 0 and random.random()  <  agent.prob_mistake:
                                y = 1 - y

                            # Bad agents always contribute.
                            # Good agents will only work if the model is doing well.
                            # Add a bit of chance they will contribute since 0.85 accuracy is okay.
                            if not agent.good or random.random()  <  accuracy + 0.15:
                                value = agent.get_next_deposit()
                                if value > balance:
                                    value = balance
                                msg = Msg(agent.address, value)
                                try:
                                    self._decai.add_data(msg, x, y)
                                    # Don't need to plot every time. Plot less as we get more data.
                                    update_balance_plot = next_data_index / len(x_remaining) + 0.1  <  random.random()
                                    balance = self._balances[agent.address]
                                    if continuous_evaluation:
                                        unclaimed_data.append((current_time, agent, x, y))
                                    next_data_index += 1
                                    pbar.update()
                                except RejectException:
                                    # Probably failed because they didn't pay enough which is okay.
                                    # Or if not enough time has passed since data was attempted to be added
                                    # which is okay too because a real contract would reject this
                                    # because the smallest unit of time we can use is 1s.
                                    if self._logger.isEnabledFor(logging.DEBUG):
                                        self._logger.exception("Error adding data.")

                    if balance > 0:
                        q.put((current_time + agent.get_next_wait_s(), agent))

                    claimed_indices = []
                    for i in range(len(unclaimed_data)):
                        added_time, adding_agent, x, classification = unclaimed_data[i]
                        if current_time - added_time  <  self._decai.im.refund_time_s:
                            break
                        if next_data_index >= len(x_remaining) \
                                and current_time - added_time  <  self._decai.im.any_address_claim_wait_time_s:
                            break
                        balance = self._balances[agent.address]
                        msg = Msg(agent.address, balance)

                        if current_time - added_time > self._decai.im.any_address_claim_wait_time_s:
                            # Attempt to take the entire deposit.
                            try:
                                self._decai.report(msg, x, classification, added_time, adding_agent.address)
                                update_balance_plot = True
                            except RejectException:
                                if self._logger.isEnabledFor(logging.DEBUG):
                                    self._logger.exception("Error taking reward.")
                        elif adding_agent.address == agent.address:
                            try:
                                self._decai.refund(msg, x, classification, added_time)
                                update_balance_plot = True
                            except RejectException:
                                if self._logger.isEnabledFor(logging.DEBUG):
                                    self._logger.exception("Error getting refund.")
                        else:
                            try:
                                self._decai.report(msg, x, classification, added_time, adding_agent.address)
                                update_balance_plot = True
                            except RejectException:
                                if self._logger.isEnabledFor(logging.DEBUG):
                                    self._logger.exception("Error taking reward.")

                        stored_data = self._decai.data_handler.get_data(x, classification,
                                                                        added_time, adding_agent.address)
                        if stored_data.claimable_amount  < = 0:
                            claimed_indices.append(i)

                    for i in claimed_indices[::-1]:
                        unclaimed_data.pop(i)

                    if update_balance_plot:
                        balance = self._balances[agent.address]
                        doc.add_next_tick_callback(
                            partial(plot_cb, agent=agent, t=current_time, b=balance))

            self._logger.info("Done going through data.")
            if continuous_evaluation:
                pbar.set_description(f"{desc} ({len(unclaimed_data)} unclaimed)")

            if isinstance(self._decai.im, PredictionMarket):
                self._time.add_time(agents[0].get_next_wait_s())
                self._decai.im.end_market()
                for i, test_set_portion in enumerate(pm_test_sets):
                    if i != self._decai.im.test_reveal_index:
                        self._decai.im.verify_next_test_set(test_set_portion)
                with tqdm(desc="Processing contributions",
                          unit_scale=True, mininterval=2, unit=" contributions",
                          total=self._decai.im.get_num_contributions_in_market(),
                          ) as pbar:
                    finished_first_round_of_rewards = False
                    while self._decai.im.remaining_bounty_rounds > 0:
                        self._time.add_time(agents[0].get_next_wait_s())
                        self._decai.im.process_contribution()
                        pbar.update()

                        if not finished_first_round_of_rewards:
                            accuracy = self._decai.im.prev_acc
                            # If we plot too often then we end up with a blob instead of a line.
                            if random.random()  <  0.1:
                                doc.add_next_tick_callback(
                                    partial(plot_accuracy_cb, t=self._time(), a=accuracy))

                        if self._decai.im.state == MarketPhase.REWARD_RESTART:
                            finished_first_round_of_rewards = True
                            if self._decai.im.reset_model_during_reward_phase:
                                # Update the accuracy after resetting all data.
                                accuracy = self._decai.im.prev_acc
                            else:
                                # Use the accuracy after training with all data.
                                pass
                            doc.add_next_tick_callback(
                                partial(plot_accuracy_cb, t=self._time(), a=accuracy))
                            pbar.total += self._decai.im.get_num_contributions_in_market()
                            self._time.add_time(self._time() * 0.001)

                            for agent in agents:
                                balance = self._balances[agent.address]
                                market_bal = self._decai.im._market_balances[agent.address]
                                self._logger.debug("\"%s\" market balance: %0.2f   Balance: %0.2f",
                                                   agent.address, market_bal, balance)
                                doc.add_next_tick_callback(
                                    partial(plot_cb, agent=agent, t=self._time(), b=max(balance + market_bal, 0)))

                self._time.add_time(self._time() * 0.02)
                for agent in agents:
                    msg = Msg(agent.address, 0)
                    # Find data submitted by them.
                    data = None
                    for key, stored_data in self._decai.data_handler:
                        if stored_data.sender == agent.address:
                            data = key[0]
                            break
                    if data is not None:
                        self._decai.refund(msg, np.array(data), stored_data.classification, stored_data.time)
                        balance = self._balances[agent.address]
                        doc.add_next_tick_callback(
                            partial(plot_cb, agent=agent, t=self._time(), b=balance))
                        self._logger.info("Balance for \"%s\": %.2f (%+.2f%%)",
                                          agent.address, balance,
                                          (balance - agent.start_balance) / agent.start_balance * 100)
                    else:
                        self._logger.warning("No data submitted by \"%s\" was found."
                                             "\nWill not update it's balance.", agent.address)

                self._logger.info("Done issuing rewards.")

            accuracy = self._decai.model.log_evaluation_details(x_test, y_test)
            doc.add_next_tick_callback(
                partial(plot_accuracy_cb, t=current_time + 100, a=accuracy))

            with open(save_path, 'w') as f:
                json.dump(save_data, f, separators=(',', ':'))
            self._decai.model.export(model_save_path, classifications, feature_index_mapping=feature_index_mapping)

            if os.path.exists(plot_save_path):
                os.remove(plot_save_path)
            self.save_plot_image(plot, plot_save_path)

        doc.add_root(plot)
        thread = Thread(target=task)
        thread.start()