import json
import logging
import os

import numpy as np
from ray.rllib.utils import try_import_torch
from termcolor import colored

from grl.algos.p2sro.p2sro_manager.logger import SimpleP2SROManagerLogger
from grl.rl_apps.scenarios.psro_scenario import PSROScenario
from grl.utils.common import ensure_dir
from grl.utils.strategy_spec import StrategySpec

torch, _ = try_import_torch()


logger = logging.getLogger(__name__)


class ApproxExploitabilityP2SROManagerLogger(SimpleP2SROManagerLogger):

    def __init__(self, p2sro_manger, log_dir: str, scenario: PSROScenario):
        super(ApproxExploitabilityP2SROManagerLogger, self).__init__(p2sro_manger=p2sro_manger, log_dir=log_dir)

        self._scenario = scenario

        self._exploitability_per_generation = []
        self._total_steps_per_generation = []
        self._total_episodes_per_generation = []
        self._num_policies_per_generation = []
        self._payoff_table_checkpoint_nums = []
        self._payoff_table_checkpoint_paths = []
        self._policy_nums_checkpoint_paths = []

        self._exploitability_stats_save_path = os.path.join(log_dir, "approx_exploitability_stats.json")
        ensure_dir(self._exploitability_stats_save_path)

    def on_active_policy_moved_to_fixed(self, player: int, policy_num: int, fixed_policy_spec: StrategySpec):
        current_checkpoint_num = self.get_current_checkpoint_num()

        super(ApproxExploitabilityP2SROManagerLogger, self).on_active_policy_moved_to_fixed(
            player=player, policy_num=policy_num, fixed_policy_spec=fixed_policy_spec
        )

        data = self._manager.get_copy_of_latest_data()
        latest_payoff_table, active_policy_nums_per_player, fixed_policy_nums_per_player = data

        if len(fixed_policy_nums_per_player[0]) < 1 or len(fixed_policy_nums_per_player[1]) < 1:
            return
        if not np.array_equal(fixed_policy_nums_per_player[0], fixed_policy_nums_per_player[1]):
            return

        n_policies = len(fixed_policy_nums_per_player[0])
        latest_policy_index = max(fixed_policy_nums_per_player[0])

        policy_spec_added_this_gen = [latest_payoff_table.get_spec_for_player_and_pure_strat_index(
            player=p, pure_strat_index=n_policies-1) for p in range(2)]

        exploitability_this_gen = np.mean([policy_spec_added_this_gen[p].metadata["average_br_reward"] for p in range(2)])

        logger.info(f"{n_policies} policies, {exploitability_this_gen} approx exploitability")

        latest_policy_steps = sum(policy_spec_added_this_gen[p].metadata["timesteps_training_br"] for p in range(2))
        latest_policy_episodes = sum(policy_spec_added_this_gen[p].metadata["episodes_training_br"] for p in range(2))

        if latest_policy_index > 0:
            total_steps_this_generation = latest_policy_steps + self._total_steps_per_generation[latest_policy_index - 1]
            total_episodes_this_generation = latest_policy_episodes + self._total_episodes_per_generation[latest_policy_index - 1]
        else:
            total_steps_this_generation = latest_policy_steps
            total_episodes_this_generation = latest_policy_episodes

        self._exploitability_per_generation.append(exploitability_this_gen)
        self._total_steps_per_generation.append(total_steps_this_generation)
        self._total_episodes_per_generation.append(total_episodes_this_generation)
        self._num_policies_per_generation.append(n_policies)
        self._payoff_table_checkpoint_nums.append(current_checkpoint_num)
        self._payoff_table_checkpoint_paths.append(self.get_latest_numbered_payoff_table_checkpoint_path())
        self._policy_nums_checkpoint_paths.append(self.get_latest_numbered_policy_nums_path())

        stats_out = {'num_policies': self._num_policies_per_generation,
                     'approx_exploitability': self._exploitability_per_generation,
                     'timesteps': self._total_steps_per_generation,
                     'episodes': self._total_episodes_per_generation,
                     'payoff_table_checkpoint_num': self._payoff_table_checkpoint_nums,
                     'payoff_table_checkpoint_path': self._payoff_table_checkpoint_paths,
                     'policy_nums_checkpoint_path': self._policy_nums_checkpoint_paths,
                     }

        with open(self._exploitability_stats_save_path, "+w") as json_file:
            json.dump(stats_out, json_file)
        logger.info(colored(f"(Graph this in a notebook) "
                            f"Saved approx exploitability stats to {self._exploitability_stats_save_path}", "green"))
