from algorithms.abstract import ZeroWorker
from algorithms.bots import BellmanBot, RandomBot
from algorithms.mu_zero import MuZeroEvaluator, MuZeroBot, MuZeroGymnasium, MuZeroDynamicsLab
import ray
from algorithms.utils.types import SpielGame, DynamicsTestLog
from algorithms.utils.params import Params


def make_muzero_worker(id_: int,
                       g: SpielGame,
                       p: Params,
                       num_cpus: int = 1, num_gpus: float = 0.2):

    @ray.remote(num_cpus=num_cpus, num_gpus=num_gpus)
    class MuZeroWorker(ZeroWorker):
        def __init__(self, worker_id: int, game: SpielGame, params: Params):
            print("This actor is allowed to use GPUs {}.".format(ray.get_gpu_ids()))
            evaluator = MuZeroEvaluator(params, worker_id)
            zero_bot = MuZeroBot(game, evaluator, params)
            bellman_bot = BellmanBot(game, params)
            random_bot = RandomBot(game, params)
            gymnasium = MuZeroGymnasium(game, bellman_bot, params)
            bots = zero_bot, bellman_bot, random_bot
            ZeroWorker.__init__(self, worker_id, game, evaluator, bots, gymnasium, params)
            self._dynamics_lab = MuZeroDynamicsLab(game, evaluator, params)

        def eval_dynamics(self, num_iterations: int) -> DynamicsTestLog:
            log = DynamicsTestLog(choice_top_acc=[],
                                  choice_pass_action=[],
                                  choice_single_acc=[],
                                  tau_acc=[],
                                  chance_pass_action=[],
                                  choice_strict_acc=[])
            for i in range(num_iterations):
                self._dynamics_lab.run_dynamics_test_single(log)
            return log

        @staticmethod
        def kill():
            ray.actor.exit_actor()

    return MuZeroWorker.remote(id_, g, p)



