import torch
import open_spiel.python.rl_agent as rl_agent
from algorithms.escher.escher import ESCHERSolver
from torch.distributions import Categorical
import numpy as np


class EscherRLAgent(rl_agent.AbstractAgent):
    def __init__(self, policy_network, player_id, n_obs, n_actions):
        self.model = policy_network.model
        self.model.to("cpu")
        self.player_id = player_id
        self.n_obs = n_obs
        self.n_actions = n_actions

    def step(self, time_step, is_evaluation=False):
        obs = time_step.observations["info_state"][self.player_id]
        legal_actions = time_step.observations["legal_actions"][self.player_id]
        legal_actions = np.array(legal_actions)
        legal_actions_mask = torch.zeros((self.n_actions,), dtype=torch.bool)
        legal_actions_mask[legal_actions] = True
        logits = self.model(torch.tensor(np.array(obs), dtype=torch.float))
        logits = torch.where(legal_actions_mask.bool(), logits, -1e6)
        dist = Categorical(logits=logits)
        action = dist.sample().detach().numpy()
        probs = np.zeros(self.n_actions)
        probs[action] = 1.0
        return rl_agent.StepOutput(action=action, probs=probs)

    def get_model(self):
        def model(info_state):
            action_values = self.model(torch.tensor(info_state, dtype=torch.float32))
            return action_values

        return model


class RunEscher:
    def __init__(self, args, game, expl_callback):
        assert False, "You should be using escher_parallel"
        self.args = args
        self.game = game
        self.expl_callback = expl_callback
        self.deep_cfr_solver = ESCHERSolver(
            self.game,
            num_traversals=args.algorithm.num_traversals,
            num_val_fn_traversals=args.num_val_fn_traversals,
            num_iterations=args.iters,
            regret_network_train_steps=args.regret_train_steps,
            policy_network_train_steps=args.policy_net_train_steps,
            batch_size_regret=args.batch_size_regret,
            value_network_train_steps=args.val_train_steps,
            batch_size_value=args.batch_size_val,
            train_device=args.train_device,
            num_random_games=args.num_random_games,
            max_steps=args.max_steps,
            use_wandb=args.wandb,
            eval_every=args.eval_every,
        )

    def wrap_rl_agent(self, save_path=None):
        n_obs = self.game.observation_tensor_size()
        n_actions = self.game.num_distinct_actions()
        agents = [
            EscherRLAgent(self.deep_cfr_solver._policy_network, id, n_obs, n_actions)
            for id in range(self.game.num_players())
        ]
        if save_path:
            import pickle

            with open(save_path, "wb") as f:
                pickle.dump(agents, f)
        return agents

    def get_model(self):
        def model_0(info_state):
            action_values = self.deep_cfr_solver._policy_network(
                torch.tensor(info_state, dtype=torch.float32)
            )
            action_probs = torch.softmax(action_values, dim=1)
            return action_probs

        def model_1(info_state):
            action_values = self.deep_cfr_solver._policy_network(
                torch.tensor(info_state, dtype=torch.float32)
            )
            action_probs = torch.softmax(action_values, dim=1)
            return action_probs

        return [model_0, model_1]

    def run(self):
        # regret, pol_loss, convs, nodes = self.deep_cfr_solver.solve(save_path_convs=self.args.experiment_dir)

        (
            regret_losses,
            value_losses,
            timestr,
            avg_reward,
            avg_win_against_random,
            algo_start_time,
        ) = self.deep_cfr_solver.pre_iteration()
        expl_check_step_count = 0
        while True:
            start_nodes = self.deep_cfr_solver._nodes_visited
            (
                regret_losses,
                value_losses,
                timestr,
                avg_reward,
                avg_win_against_random,
                algo_start_time,
            ) = self.deep_cfr_solver.iteration(
                regret_losses,
                value_losses,
                timestr,
                avg_reward,
                avg_win_against_random,
                algo_start_time,
            )

            added_nodes = self.deep_cfr_solver._nodes_visited - start_nodes
            expl_check_step_count += added_nodes

            if expl_check_step_count > self.args.compute_exploitability_every:
                self.deep_cfr_solver._reinitialize_policy_network()
                self.deep_cfr_solver.learn_average_policy_network()

                expl_agents = self.wrap_rl_agent()
                models = [ag.get_model() for ag in expl_agents]
                self.expl_callback(
                    models[0], models[1], self.deep_cfr_solver._nodes_visited
                )
                expl_check_step_count = 0

            if self.deep_cfr_solver._nodes_visited >= self.args.max_steps:
                break

        last_iterate_policy = self.deep_cfr_solver.save_policy_network(
            self.args.experiment_dir
        )
        return last_iterate_policy
