import json
import logging
import os
from typing import Dict

import numpy as np
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy import Policy
from ray.rllib.utils import try_import_torch
from ray.rllib.utils.typing import TensorType
from termcolor import colored

from grl.algos.p2sro.p2sro_manager.logger import SimpleP2SROManagerLogger
from grl.algos.p2sro.p2sro_manager.utils import get_latest_metanash_strategies
from grl.envs.oshi_zumo_multi_agent_env import OshiZumoMultiAgentEnv
from grl.envs.poker_multi_agent_env import PokerMultiAgentEnv
from grl.rl_apps.psro.poker_utils import psro_measure_exploitability_nonlstm
from grl.rl_apps.scenarios.psro_scenario import PSROScenario
from grl.rllib_tools.modified_policies.simple_q_torch_policy import SimpleQTorchPolicyPatched
from grl.rllib_tools.policy_checkpoints import load_pure_strat
from grl.utils.common import ensure_dir
from grl.utils.strategy_spec import StrategySpec

torch, _ = try_import_torch()


logger = logging.getLogger(__name__)


class ExploitabilityP2SROManagerLogger(SimpleP2SROManagerLogger):

    def __init__(self, p2sro_manger, log_dir: str, scenario: PSROScenario):
        super(ExploitabilityP2SROManagerLogger, self).__init__(p2sro_manger=p2sro_manger, log_dir=log_dir)

        self._scenario = scenario
        if not issubclass(scenario.env_class, (PokerMultiAgentEnv, OshiZumoMultiAgentEnv)):
            raise ValueError(f"ExploitabilityP2SROManagerLogger is only meant to be used with PokerMultiAgentEnv or OshiZumoMultiAgentEnv,"
                             f"not {scenario.env_class}")
        if not scenario.calc_exploitability_for_openspiel_env:
            raise ValueError(f"Only use ExploitabilityP2SROManagerLogger if "
                             f"scenario.calc_exploitability_for_openspiel_env is True.")

        self._exploitability_per_generation = []
        self._total_steps_per_generation = []
        self._total_episodes_per_generation = []
        self._num_policies_per_generation = []
        self._payoff_table_checkpoint_nums = []
        self._payoff_table_checkpoint_paths = []
        self._policy_nums_checkpoint_paths = []

        self._exploitability_stats_save_path = os.path.join(log_dir, "exploitability_stats.json")
        ensure_dir(self._exploitability_stats_save_path)

    def on_active_policy_moved_to_fixed(self, player: int, policy_num: int, fixed_policy_spec: StrategySpec):
        current_checkpoint_num = self.get_current_checkpoint_num()

        super(ExploitabilityP2SROManagerLogger, self).on_active_policy_moved_to_fixed(
            player=player, policy_num=policy_num, fixed_policy_spec=fixed_policy_spec
        )

        data = self._manager.get_copy_of_latest_data()
        latest_payoff_table, active_policy_nums_per_player, fixed_policy_nums_per_player = data

        if len(fixed_policy_nums_per_player[0]) < 1 or len(fixed_policy_nums_per_player[1]) < 1:
            return
        if not np.array_equal(fixed_policy_nums_per_player[0], fixed_policy_nums_per_player[1]):
            return

        n_policies = len(fixed_policy_nums_per_player[0])
        latest_policy_index = max(fixed_policy_nums_per_player[0])

        env = self._scenario.env_class(self._scenario.env_config)

        def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
                                action_dist: ActionDistribution) -> Dict[str, TensorType]:
            action = action_dist.deterministic_sample()
            action_probs = torch.zeros_like(policy.q_values).long()
            action_probs[0][action[0]] = 1.0
            return {"q_values": policy.q_values, "action_probs": action_probs}

        if self._scenario.policy_classes["eval"] != SimpleQTorchPolicyPatched:
            raise NotImplementedError(f"This method isn't verified to work with policy classes other than "
                                      f"SimpleQTorchPolicyPatched. "
                                      f"You're using {self._scenario.policy_classes['eval']}")

        policy_class = self._scenario.policy_classes["eval"].with_updates(
            extra_action_out_fn=extra_action_out_fn
        )

        trainer_config = self._scenario.get_trainer_config(env)
        trainer_config["explore"] = False
        policies = [policy_class(env.observation_space, env.action_space, with_common_config(trainer_config))
                    for _ in range(2)]

        metanash_probs_0 = get_latest_metanash_strategies(payoff_table=latest_payoff_table,
                                                          as_player=1,
                                                          as_policy_num=n_policies,
                                                          fictitious_play_iters=2000,
                                                          mix_with_uniform_dist_coeff=0.0,
                                                          print_matrix=False)[0].probabilities_for_each_strategy()

        if self._scenario.single_agent_symmetric_game:
            metanash_probs_1 = metanash_probs_0
        else:
            metanash_probs_1 = get_latest_metanash_strategies(payoff_table=latest_payoff_table,
                                                              as_player=0,
                                                              as_policy_num=n_policies,
                                                              fictitious_play_iters=2000,
                                                              mix_with_uniform_dist_coeff=0.0,
                                                              print_matrix=False)[1].probabilities_for_each_strategy()

        policy_specs_0 = latest_payoff_table.get_ordered_spec_list_for_player(player=0)[:n_policies]
        policy_specs_1 = latest_payoff_table.get_ordered_spec_list_for_player(player=1)[:n_policies]

        assert len(metanash_probs_1) == len(policy_specs_1), f"len(metanash_probs_1): {len(metanash_probs_1)}, len(policy_specs_1): {len(policy_specs_1)}"
        assert len(metanash_probs_0) == len(policy_specs_0)
        assert len(policy_specs_0) == len(policy_specs_1)

        br_checkpoint_paths = []
        metanash_weights = []

        for spec_0, prob_0, spec_1, prob_1 in zip(policy_specs_0, metanash_probs_0, policy_specs_1,
                                                  metanash_probs_1):
            br_checkpoint_paths.append((spec_0.metadata["checkpoint_path"], spec_1.metadata["checkpoint_path"]))
            metanash_weights.append((prob_0, prob_1))

        exploitability_this_gen = psro_measure_exploitability_nonlstm(
            br_checkpoint_path_tuple_list=br_checkpoint_paths,
            metanash_weights=metanash_weights,
            set_policy_weights_fn=load_pure_strat,
            rllib_policies=policies,
            poker_game_version=env.game_version,
            open_spiel_env_config=env.open_spiel_env_config
        )

        logger.info(f"{n_policies} policies, {exploitability_this_gen} exploitability")

        policy_spec_added_this_gen = [latest_payoff_table.get_spec_for_player_and_pure_strat_index(
            player=p, pure_strat_index=n_policies-1) for p in range(2)]

        latest_policy_steps = sum(policy_spec_added_this_gen[p].metadata["timesteps_training_br"] for p in range(2))
        latest_policy_episodes = sum(policy_spec_added_this_gen[p].metadata["episodes_training_br"] for p in range(2))

        if latest_policy_index > 0:
            total_steps_this_generation = latest_policy_steps + self._total_steps_per_generation[latest_policy_index - 1]
            total_episodes_this_generation = latest_policy_episodes + self._total_episodes_per_generation[latest_policy_index - 1]
        else:
            total_steps_this_generation = latest_policy_steps
            total_episodes_this_generation = latest_policy_episodes

        self._exploitability_per_generation.append(exploitability_this_gen)
        self._total_steps_per_generation.append(total_steps_this_generation)
        self._total_episodes_per_generation.append(total_episodes_this_generation)
        self._num_policies_per_generation.append(n_policies)
        self._payoff_table_checkpoint_nums.append(current_checkpoint_num)
        self._payoff_table_checkpoint_paths.append(self.get_latest_numbered_payoff_table_checkpoint_path())
        self._policy_nums_checkpoint_paths.append(self.get_latest_numbered_policy_nums_path())

        del policies[1]
        del policies[0]

        stats_out = {'num_policies': self._num_policies_per_generation,
                     'exploitability': self._exploitability_per_generation,
                     'timesteps': self._total_steps_per_generation,
                     'episodes': self._total_episodes_per_generation,
                     'payoff_table_checkpoint_num': self._payoff_table_checkpoint_nums,
                     'payoff_table_checkpoint_path': self._payoff_table_checkpoint_paths,
                     'policy_nums_checkpoint_path': self._policy_nums_checkpoint_paths,
                     }

        with open(self._exploitability_stats_save_path, "+w") as json_file:
            json.dump(stats_out, json_file)
        logger.info(colored(f"(Graph this in a notebook) "
                            f"Saved exploitability stats to {self._exploitability_stats_save_path}", "green"))
