"""
This file gives an implementation of Double Oracle (https://dl.acm.org/doi/10.5555/3041838.3041906).
It is a tabular case of the known PSRO algorithm.

The DOLearner accepts only 2-player matrix games. A policy individual is actually a mixture over a player's actions.
"""

import random
import time
import copy
import numpy as np

from collections import defaultdict, namedtuple
from expground.common.policy_pool import PolicyPool
from expground.envs import agent_interface

from expground.logger import Log
from expground.types import (
    AgentID,
    EnvDescription,
    LambdaType,
    Dict,
    PolicyID,
    RolloutConfig,
    TrainingConfig,
    Union,
    List,
    Sequence,
    Tuple,
    Any,
)


from expground.utils import rollout
from expground.utils.logging import write_to_tensorboard
from expground.gt.payoff_matrix import PayoffMatrix
from expground.utils.rollout import Evaluator
from expground.utils.sampler import SamplerInterface, get_sampler
from expground.utils.stoppers import DEFAULT_STOP_CONDITIONS, get_stopper

from expground.common.exploitability import measure_exploitability
from expground.gt.meta_solver import MetaSolver

from expground.envs.agent_interface import AgentInterface
from expground.envs.utils import (
    from_payoff_tables,
    from_simple_description,
    from_game_type,
)

from expground.algorithms.base_policy import pack_action_to_policy, Policy
from expground.learner.base_learner import Learner

BestResponseInfo = namedtuple("BestResponseInfo", "br, agent, data")


class DOLearner(Learner):

    NAME = "TabularDO"

    def __init__(
        self,
        experiment: str,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        ray_mode: bool = False,
        seed: int = None,
        evaluation_worker_num: int = 0,
        agent_mapping: LambdaType = None,
        **kwargs,
    ) -> None:
        """Initialize a Double Oracle learner.

        Args:
            experiment (str): Experiment tag
            summary_writer (tf.SummaryWriter, optional): External tensorflow summary writer. Defaults to None.
            seed (int, optional): Random seed. Defaults to None.
            evaluation_worker_num (int, optional): Ray worker num for evaluation, enabled when ray_mode is on. Defaults to 0.
            ray_mode (bool, optional): Enable ray mode or not. Defaults to False.
            agent_mapping (LambdaType, optional): Agent mapping closure. Defaults to ....
        """
        super(DOLearner, self).__init__(
            experiment or "DOLearner_{}".format(time.time()),
            seed=seed,
            evaluation_worker_num=evaluation_worker_num,
            ray_mode=ray_mode,
            agent_mapping=agent_mapping or (lambda x: x),
        )

        env_config = env_description["config"]
        self.observation_spaces = env_config["observation_spaces"]
        self.action_spaces = env_config["action_spaces"]

        self.env_description = env_description
        self.rollout_config = rollout_config
        self.training_config = training_config

        self.total_timesteps = 0
        self.total_episodes = 0

        self.register_ego_agents(env_config["possible_agents"])
        self.register_env_agents(env_config["possible_agents"])

        self.payoff_manager = PayoffMatrix(self.agents)
        self.meta_solver = MetaSolver.from_type("fictitious_play")

        # init agent interfaces for mapped agents, i.e., some agents will share a same agent interface,
        # and the sub policy set.
        self.agent_interfaces: Dict[AgentID, AgentInterface] = {}
        self.full_policy_set: Dict[AgentID, List] = {}
        self.sub_policy_set: Dict[AgentID, List] = {}
        self.populations: Dict[AgentID, Dict[PolicyID, Policy]] = {}
        self._init_agent_interfaces()

        base_env = self.env_description["creator"](
            **self.env_description["config"],
        )
        base_env.seed(1)
        base_env.reset()
        self.env_payoff_matrix = base_env.unwrapped.payoff_matrix

    def _init_agent_interfaces(self):
        for agent in self.agents:
            _agent = self.agent_mapping(agent)
            if self.agent_interfaces.get(_agent) is None:
                self.agent_interfaces[_agent] = AgentInterface(
                    _agent,
                    policy=None,
                    observation_space=self.observation_spaces[agent],
                    action_space=self.action_spaces[agent],
                )
                self.full_policy_set[_agent] = self.env_description["config"][
                    "full_policy_set"
                ]
                self.populations[_agent] = {}
                self.sub_policy_set[_agent] = []
            else:
                # check space compabability
                assert (
                    self.observation_spaces[agent].shape
                    == self.agent_interfaces[_agent].observation_space.shape
                )
                assert (
                    self.action_spaces[agent].shape
                    == self.agent_interfaces[_agent].action_space.shape
                )
        Log.debug("init agent interfaces done")

    def _compute_best_response(
        self,
        main_agent: AgentID,
        meta_strategies: Dict[AgentID, Dict[PolicyID, float]],
        sampler_config: SamplerInterface,
        existing_brs: Dict[AgentID, List[PolicyID]] = None,
    ) -> BestResponseInfo:
        """Run best response computation, and return a named tuple which describes the corresponding training info.

        Args:
            main_agent (AgentID): Trainable agent id, mapped from and environment agent.
            meta_strategies (Dict[str, Dict[PolicyID, float]]): Meta strategies, maps from environment agent ids to a dict of policy distribution.
            sample_config (Dict[str, Any]): Sample configuration.

        Returns:
            BestResponseInfo: A tuple.
        """

        # parse meta strategies, then maps policy ids to policies
        opponent_meta_strategies = {
            k: list(map(lambda x: self.populations[k][x], v.keys()))
            for k, v in meta_strategies.items()
            if self.agent_mapping(k) != main_agent
        }
        # parse meta strategies, then depart distribution from it
        distribution = {
            k: np.asarray(list(v.values()), dtype=np.float32)
            for k, v in meta_strategies.items()
        }

        existing_brs = existing_brs or {}

        # pack opponent meta strategies to agent interfaces
        agent_interfaces = {}
        main_env_agents = []

        for agent in self.agents:
            _agent = self.agent_mapping(agent)
            if _agent == main_agent:
                # br not share
                if existing_brs.get(agent, None) is not None:
                    policy = self.populations[_agent][existing_brs[agent][0]]
                else:
                    policy = pack_action_to_policy(
                        self.full_policy_set[_agent],
                        self.agent_interfaces[_agent].observation_space,
                        is_fixed=False,
                    )
                self.agent_interfaces[_agent].policy = policy
                self.agent_interfaces[_agent].is_active = True
                main_env_agents.append(agent)
            else:
                self.agent_interfaces[_agent].policy = pack_action_to_policy(
                    opponent_meta_strategies[agent],
                    self.agent_interfaces[_agent].observation_space,
                    is_fixed=True,
                    distribution=distribution[agent],
                )
                self.agent_interfaces[_agent].is_active = False
            # copy agent interface
            agent_interfaces[agent] = self.agent_interfaces[_agent]

        sampler = get_sampler(main_env_agents, sampler_config)
        generator = rollout.sequential_rollout(
            sampler=sampler,
            agent_interfaces=agent_interfaces,
            env_description=self.env_description,
            fragment_length=self.rollout_config.fragment_length,
            max_step=self.rollout_config.max_step,
            agent_filter=main_env_agents,
            train_every=1,
        )

        epoch_training_statistic = []

        try:
            while True:
                info = next(generator)
                # Log.info("[training] {}".format(info))
                # do training
                agent_batches = sampler.sample(
                    batch_size=self.training_config.hyper_params["batch_size"],
                    agent_filter=main_env_agents,
                )
                # merge to one batch or random sample one agent
                selected_batch = agent_batches[
                    random.choice(list(agent_batches.keys()))
                ]
                tmp = self.agent_interfaces[main_agent].policy.optimize(selected_batch)
                epoch_training_statistic.append(tmp)
        except StopIteration as e:
            info = e.value
            self.total_timesteps += info["total_timesteps"]
            self.total_episodes += info["num_episode"]

        return BestResponseInfo(
            self.agent_interfaces[main_agent].policy,
            main_agent,
            epoch_training_statistic,
        )

    def compute_best_responses(
        self,
        sampler_config,
        meta_strategies: Dict[AgentID, Dict[str, float]],
        existing_brs: Dict[AgentID, List[PolicyID]] = None,
    ) -> Dict[AgentID, List[PolicyID]]:
        """Compute the exact best responses to current meta strategies, and update populations if id of BRs not exist in current pool.

        Args:
            sampler_config ([type]): [description]
            meta_strategies (Dict[AgentID, Dict[str, float]]): A dict of meta strategies.
            existing_brs (Dict[AgentID, PolicyID]): A dict of existing BRs, if your wanna train existing brs, fill it.

        Returns:
            Dict[AgentID, List[PolicyID]]: A dict of best responses.
        """

        tasks = []
        cache = []
        brs = {}

        for agent in self.ego_agents:
            _agent = self.agent_mapping(agent)
            if _agent in cache:
                continue
            cache.append(_agent)
            tasks.append(
                self._compute_best_response(
                    _agent, meta_strategies, sampler_config, existing_brs
                )
            )

        if existing_brs is None:
            status = defaultdict(lambda: False)
            for res in tasks:
                _agent = self.agent_mapping(res.agent)
                if status[_agent]:
                    continue
                pid = f"policy_{len(self.populations[_agent])}"
                self.populations[_agent][pid] = res.br
                # import pdb; pdb.set_trace()
                status[_agent] = True
                brs[_agent] = [pid]
            # add br for other agents
            for agent in self.agents:
                _agent = self.agent_mapping(agent)
                brs[agent] = brs[_agent]
        else:
            brs = existing_brs
        return brs

    def _assign_policy_to_agents(
        self, policy_mapping: Dict[AgentID, PolicyID]
    ) -> Dict[AgentID, AgentInterface]:
        """Assign a policy mapping to agent interfaces, then return a dict of new agent interfaces.

        Args:
            policy_mapping (Dict[AgentID, PolicyID]): A dict of policy mapping, maps from environment agent id to policy id.

        Returns:
            Dict[AgentID, AgentInterface]: A dict of agent interfaces.
        """

        agent_interfaces = {}
        for agent, pid in policy_mapping.items():
            _agent = self.agent_mapping(agent)
            agent_interfaces[agent] = self.agent_interfaces[_agent].copy()
            # hard reset # if agent_interfaces[agent].policy is None:
            agent_interfaces[_agent].policy = self.populations[_agent][pid]
            agent_interfaces[agent].is_fixed = True
        return agent_interfaces

    def run_simulation(
        self, simulations: List[Dict[AgentID, PolicyID]] = None, keep_dim: bool = False
    ) -> Sequence[Tuple[Dict, Dict]]:
        """Run simulations for expanded policy combinations.

        Returns:
            Sequence[Tuple[Dict, Dict]]: A sequence of tuple (policy combination, evaluation result).
        """

        simulations = simulations or self.payoff_manager.gen_simulations(split=False)
        results = []
        # parse simulations
        for task in simulations:
            # set agent interface with policy
            task = [task] if not isinstance(task, (list, tuple)) else task
            tmp = Evaluator.run(
                task,
                max_step=self.rollout_config.max_step,
                fragment_length=self.rollout_config.fragment_length,
                agent_interfaces=self._assign_policy_to_agents(task[0]),
                rollout_caller=rollout.sequential_rollout,
                env_desc=self.env_description,
                seed=1,
            )
            if keep_dim:
                results.append(tmp)
            else:
                results.extend(tmp)
        return results

    def compute_nash_conv(
        self, meta_strategies: Dict[AgentID, Dict[PolicyID, float]]
    ) -> "NashConv":
        """Extract environment matrix to form a spiel game, then evaluate it.

        Args:
            meta_strategies (Dict[AgentID, Dict[PolicyID, float]]): A dict of agent policy mixture.

        Returns:
            NashConv: A NashConv instance.
        """

        # convert game to spiel version
        row = self.env_payoff_matrix["player_0"]
        column = self.env_payoff_matrix["player_1"]
        dim = self.env_description["config"]["scenario_config"]["dim"]
        game = from_game_type(dim, dim, row, column)

        nash_conv = measure_exploitability(
            game,
            populations={
                agent: self.populations[self.agent_mapping(agent)]
                for agent in self.agents
            },
            policy_mixture_dict=meta_strategies,
        )

        return nash_conv

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict = None,
    ) -> Dict:
        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()

        # random init sub sets
        brs = {}
        for agent in self.agents:
            _agent = self.agent_mapping(agent)
            sset = self.full_policy_set[_agent]
            brs[agent] = [random.choice(sset)]
            # pack actions as policies, mapping populations
            self.populations[_agent]["policy_0"] = pack_action_to_policy(
                self.full_policy_set[_agent],
                self.agent_interfaces[_agent].observation_space,
                is_fixed=True,
                distribution=np.array(
                    [float(e == brs[agent][0]) for e in sset], dtype=np.float32
                ),
            )

        self.payoff_manager.expand({agent: ["policy_0"] for agent in self.agents})

        while not stopper.is_terminal():
            Log.info("Global iteration on: %s", stopper.counter)
            results = self.run_simulation()
            Log.debug("simulation result report: {}".format(results))
            self.payoff_manager.update_payoff_and_simulation_status(results)
            policy_sets = {
                agent: list(self.populations[self.agent_mapping(agent)].keys())
                for agent in self.agents
            }
            utilities = self.payoff_manager.get_sub_matrix(policy_sets)
            equilibrium = self.meta_solver.solve(utilities)
            meta_strategies = {
                k: dict(zip(policy_sets[k], v)) for k, v in equilibrium.items()
            }
            nash_conv = self.compute_nash_conv(meta_strategies)
            # Log.info("\t* computed meta strategies are: {}".format(meta_strategies))
            Log.info("\t* computed NashConv={}".format(nash_conv))
            write_to_tensorboard(
                self.summary_writer,
                info={"NashConv/base": nash_conv.nash_conv},
                global_step=stopper.counter,
                prefix="",
            )
            # here is the learned br policies
            brs = self.compute_best_responses(sampler_config, meta_strategies)
            # compared to the exact br to compute nash conv
            self.payoff_manager.expand(brs)
            stopper.step(
                None,
                None,
                time_step=self.total_timesteps,
                episode_th=self.total_episodes,
            )

    def load(self, **kwargs) -> None:
        Log.warning("Not implemented model load yet.")

    def save(self, **kwargs) -> None:
        Log.warning("Not implemented model saving yet.")
