import itertools
import random
import ray
import numpy as np

from expground.types import (
    AgentID,
    Union,
    Dict,
    LambdaType,
    EnvDescription,
    RolloutConfig,
    TrainingConfig,
)
from expground.logger import Log
from expground.utils.logging import write_to_tensorboard
from expground.utils.stoppers import DEFAULT_STOP_CONDITIONS, get_stopper

from .do import DOLearner, pack_action_to_policy


class BDOLearner(DOLearner):
    def __init__(
        self,
        experiment: str,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        ray_mode: bool = False,
        seed: int = None,
        evaluation_worker_num: int = 0,
        agent_mapping: LambdaType = None,
        **kwargs
    ) -> None:
        super(BDOLearner, self).__init__(
            experiment,
            env_description,
            rollout_config,
            training_config,
            ray_mode=ray_mode,
            seed=seed,
            evaluation_worker_num=evaluation_worker_num,
            agent_mapping=agent_mapping,
            **kwargs
        )

        self.max_bi_step = 10

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict = None,
    ) -> Dict:
        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()

        # random init sub sets, actually policy index
        brs = {}
        for agent in self.agents:
            _agent = self.agent_mapping(agent)
            sset = self.full_policy_set[_agent]
            brs[agent] = [random.choice(sset)]  # [self.sub_policy_set[_agent][0]]
            self.populations[_agent]["policy_0"] = pack_action_to_policy(
                self.full_policy_set[_agent],
                self.agent_interfaces[_agent].observation_space,
                is_fixed=True,
                distribution=np.array(
                    [float(e == brs[agent][0]) for e in sset], dtype=np.float32
                ),
            )

        self.payoff_manager.expand({agent: ["policy_0"] for agent in self.agents})

        while not stopper.is_terminal():
            Log.info("Global iteration on: %s", stopper.counter)
            results = self.run_simulation()
            Log.debug("simulation result report: {}".format(results))
            self.payoff_manager.update_payoff_and_simulation_status(results)
            # policy support for each agent
            policy_sets = {
                agent: list(self.populations[self.agent_mapping(agent)].keys())
                for agent in self.agents
            }

            utilities = self.payoff_manager.get_sub_matrix(policy_sets)
            equilibrium = self.meta_solver.solve(utilities)
            meta_strategies = {
                k: dict(zip(policy_sets[k], v)) for k, v in equilibrium.items()
            }

            base_nash_conv = self.compute_nash_conv(meta_strategies)
            # Log.info("\t* computed meta strategies are: {}".format(meta_strategies))
            Log.info("\t* computed NashConv={}".format(base_nash_conv))

            bi_step = 0
            brs = self.compute_best_responses(sampler_config, meta_strategies)
            # fixed_policy_mapping = {k: list(v.keys()) for k, v in meta_strategies.items()}

            policy_sets_list = [policy_sets[e] for e in self.agents]
            while bi_step < self.max_bi_step:
                # Log.info("\t* computed meta strategies are: {}".format(meta_strategies))

                simulations = []
                for i, agent in enumerate(self.agents):
                    # extract policy keys
                    tmp = (
                        policy_sets_list[:i] + [brs[agent]] + policy_sets_list[i + 1 :]
                    )
                    sub_tasks = list(
                        map(
                            lambda x: dict(zip(self.agents, x)), itertools.product(*tmp)
                        )
                    )
                    simulations.append(sub_tasks)
                results = self.run_simulation(simulations=simulations, keep_dim=True)
                results = dict(zip(self.agents, results))
                results = self.payoff_manager.dict_to_matrix(results)

                # gen partial matrix here
                sorted_agent_mapping: Dict[
                    AgentID, int
                ] = self.payoff_manager.agent_axes_mapping
                partial_matrix = {}
                for agent, idx in sorted_agent_mapping.items():
                    partial_matrix[agent] = results[agent].copy()
                    partial_matrix[agent] = np.expand_dims(
                        partial_matrix[agent], axis=idx
                    )

                matrix = {}
                for agent, sub_utilities in utilities.items():
                    matrix[agent] = sub_utilities - partial_matrix[agent]

                equilibrium = self.meta_solver.solve(matrix)

                meta_strategies = {
                    k: dict(zip(policy_sets[k], v)) for k, v in equilibrium.items()
                }

                optim_nash_conv = self.compute_nash_conv(meta_strategies)
                _ = self.compute_best_responses(
                    sampler_config, meta_strategies, existing_brs=brs
                )
                if (
                    stopper.counter == 0
                    or optim_nash_conv.nash_conv < base_nash_conv.nash_conv
                ):
                    break
                bi_step += 1

            Log.info(
                "\t* final nash_conv: {}\nmeta={}".format(
                    optim_nash_conv.nash_conv, equilibrium
                )
            )

            write_to_tensorboard(
                self.summary_writer,
                info={
                    "NashConv/base": base_nash_conv.nash_conv,
                    "NashConv/optim": optim_nash_conv.nash_conv,
                },
                global_step=stopper.counter,
                prefix="",
            )

            self.payoff_manager.expand(brs)

            stopper.step(
                None,
                None,
                time_step=self.total_timesteps,
                episode_th=self.total_episodes,
            )
