import time
import itertools
import ray

from collections import OrderedDict
from scipy import stats
import scipy

from expground import settings
from expground import logger
from expground.types import AgentID, PolicyID, Union, Dict, LambdaType, List
from expground.logger import Log
from expground.utils.stoppers import DEFAULT_STOP_CONDITIONS, get_stopper
from expground.gt.payoff_server import Identifier
from expground.utils.logging import write_to_tensorboard
from expground.common.exploitability import measure_exploitability
from expground.algorithms.base_policy import Policy

from .independent import IndependentLearner


def gen_policy_mappings(
    active_population, full_population
) -> List[Dict[AgentID, PolicyID]]:
    # filter full to others
    others = OrderedDict(
        **{k: v for k, v in full_population.items() if k not in active_population}
    )
    actve_population = OrderedDict(active_population.items())
    agents = list(others.keys()) + list(active_population.keys())
    values = list(others.values()) + list(active_population.values())
    res = []
    for tup in itertools.product(*values):
        res.append(dict(zip(agents, tup)))
    return res


class SubOracle(IndependentLearner):
    def __str__(self) -> str:
        identify = "&".join(self._ego_agents)
        return "SubOracle_{}".format(identify)

    def learn(
        self, sampler_config: Union[Dict, LambdaType], stop_conditions: Dict = None
    ):
        """Perform best response learning and opponent meta-strategy learning.

        Note:
            Please make sure that the agents load with meta strategies are not belong to `ego_agents`.

        Args:
            sampler_config (Union[Dict, LambdaType]): The configuration for building sampler for each agent. If it is a lambda,
                then generates configuration by agent id for each agent; if it is a dict, will be shared to all agents.
            stop_conditions (Dict, optional): The stop conditions, for stopping control. Defaults to None.. Defaults to None.
        """

        assert "inner_stop_conditions" in stop_conditions
        inner_stop_conditions = stop_conditions.pop("inner_stop_conditions")
        eps = 0.1

        n_tries = 10
        payoff_server = None

        while n_tries:
            try:
                payoff_server = ray.get_actor(settings.PAYOFF_SERVER_ACTOR)
                break
            except ValueError:
                Log.warning("Payoff server is not ready yet, try again")
                time.sleep(1)
        if payoff_server is None:
            raise ConnectionError("No available payoff server.")

        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()
        stop_conditions["inner_stop_conditions"] = inner_stop_conditions

        population: Dict[AgentID, List[PolicyID]] = {}
        population_ins: Dict[AgentID, Dict[PolicyID, Policy]] = {}
        active_population: Dict[AgentID, List[PolicyID]] = {}
        for agent in self.agents:
            ppid = self.agent_mapping(agent)
            population[agent] = self._policies[ppid].get_fixed_policy_ids()
            population_ins[agent] = self._policies[ppid].get_fixed_policies()
            if agent in self.ego_agents:
                # add active policy into its population
                active_pids = self._policies[ppid].get_active_policy_ids()
                population[agent].extend(active_pids)
                population_ins[agent].update(self._policies[ppid].get_active_policies())
                active_population[agent] = active_pids

        # run simulation
        policy_mappings: List[Dict[AgentID, PolicyID]] = gen_policy_mappings(
            active_population, population
        )
        policy_backups = None
        early_stop_n = 0
        brimprovement = [10000.0]
        improvement_idx = []

        for agent in self.ego_agents:
            idx = int(agent.split("_")[-1])
            improvement_idx.append(idx)
        while not stopper.is_terminal():
            super(SubOracle, self).learn(
                sampler_config=sampler_config, stop_conditions=inner_stop_conditions
            )
            results = self.evaluation(
                policy_mappings=policy_mappings,
                max_step=self._rollout_config.max_step,
                fragment_length=min(20, self._rollout_config.num_simulation)
                * self._rollout_config.max_step,
            )

            # update payoff table
            # identifier = Identifier(combinations=policy_mappings, key="SubOracle_{}".format(id(self)))
            ray.get(payoff_server.update_item.remote(results))

            # then compute new NE and update dist
            population_dist = ray.get(
                payoff_server.get_equilibrium.remote(population=population)
            )
            self.set_behavior_dist(population_dist)

            # compute nash conv
            nash_conv = measure_exploitability(
                self._env_desc["config"]["env_id"], population_ins, population_dist
            )

            agg_improvement = sum(
                [nash_conv.player_improvements[k] for k in improvement_idx]
            )
            Log.info(
                "\t+ step={} {} AGG_BRImprovement={}".format(
                    stopper.counter, str(self), agg_improvement
                )
            )
            if agg_improvement < min(brimprovement):
                brimprovement.append(agg_improvement)
                # reset early_stop_n
                early_stop_n = 0
                policy_backups = ray.put(self.get_ego_active_policies())
            else:
                # increase
                early_stop_n += 1

            Log.info(
                "\t\t+ update dist for {} as: {}".format(str(self), population_dist)
            )

            stopper.step(None, None, None, None)
            if early_stop_n >= 2:
                stopper.stop_all()
        assert policy_backups is not None
        self.set_ego_active_policies(ray.get(policy_backups))

        return {"AGGBRImprovement": brimprovement[-1]}

    def set_ego_active_policies(
        self, agent_policy_dict: Dict[AgentID, Dict[PolicyID, Policy]]
    ):
        for agent, policy_dict in agent_policy_dict.items():
            ppid = self.agent_mapping(agent)
            self._policies[ppid].update_pool(policy_dict)
