"""
Minmax learner perform general minmax optimization for a bunch of agents.
Note this learner is suit for pure competitive games, like zero-sum games.

Descriptions
============
Minimax optimization is an online optimization method proposed by 

ego agent: train policy support
other agent: train policy distribution
"""

import numpy as np

from expground.types import (
    LambdaType,
    Dict,
    Any,
    RolloutConfig,
    TrainingConfig,
    Sequence,
    AgentID,
    PolicyConfig,
    List,
)
from expground.logger import monitor, log
from expground.utils.sampler import SamplerInterface
from expground.algorithms.base_trainer import Trainer
from expground.learner.independent import IndependentLearner
from expground.envs.agent_interface import AgentInterface
from expground.utils.data import EpisodeKeys


# configuration should be in
class MinMaxLearner(IndependentLearner):
    def __init__(
        self,
        policy_config: PolicyConfig,
        env_description: Dict[str, Any],
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = True,
        train_every: int = 1,
        ego_agents: Sequence[AgentID] = None,
        enable_policy_pool: bool = False,
        experiment: str = None,
        seed: int = None,
        agent_mapping: LambdaType = ...,
        mini_epoch: int = 5,
        **kwargs
    ):
        super(MinMaxLearner, self).__init__(
            policy_config,
            env_description,
            rollout_config,
            training_config,
            loss_func,
            learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            ego_agents=ego_agents,
            enable_policy_pool=enable_policy_pool,
            experiment=experiment,
            seed=seed,
            agent_mapping=agent_mapping,
            mini_epoch=mini_epoch,
            **kwargs
        )

    def min_opt(self, agents: List[AgentID], batch: Dict[str, np.ndarray]):
        raise NotImplementedError

    def max_opt(self, agents: List[AgentID], batch: Dict[str, np.ndarray]):
        raise NotImplementedError

    def minimax(self, batches: Dict[AgentID, Dict[str, np.ndarray]]):
        """Perform minimax optimization here.

        Args:
            batches (Dict[AgentID, Dict[str, np.ndarray]]): A dict of agent batches.
        """

        # update value function
        # update policy
        for i, agent in enumerate(self.agents):
            other_agents = self.agents[:i] + self.agents[i + 1 :]
            batch = batches[agent]
            policy = self._policies[agent]
            state_value = policy.state_value_function(batch)
            batch[EpisodeKeys.STATE_VALUE.value] = state_value
            # other agents do minimization first
            self.min_opt(other_agents, batch)
            # then ego agent do maximization again
            self.max_opt([agent], batch)

    @monitor(enable_returns=False, enable_time=True)
    def train(
        self,
        env_desc: Dict[str, Any],
        agent_interfaces: Dict[AgentID, AgentInterface],
        trainer: Dict[AgentID, Trainer],
        sampler: SamplerInterface,
        rollout_config: RolloutConfig,
        behavior_policies=None,
    ) -> Sequence[Dict[AgentID, Dict]]:
        """Run minimax training here.

        Args:
            env_desc (Dict[str, Any]): A dict of environment description
            agent_interfaces (Dict[AgentID, AgentInterface]): A dict of environment agent interfaces.
            trainer (Dict[AgentID, Trainer]): A dict of training, mapping from ego agents to trainer.
            sampler (SamplerInterface): An instance of sampler, inherits from `SamplerInterface`.
            rollout_config (RolloutConfig): An instance of `RolloutConfig
            behavior_policies ([type], optional): A dict of mapping environment agents to policy id. Defaults to None. Generally speaking,
                the `behavior_policies` only active agents which do not link to ego agent ids.

        Returns:
            Sequence[Dict[AgentID, Dict]]: A sequence of training results, the length indicates the training epochs. Each item in this sequence
                is a dict mapping from ego agents to a dict of statistics.
        """

        # collect trajectory
        generator = rollout_config.caller(
            sampler=sampler,
            agent_policy_mapping=behavior_policies,  # agent_mapping,
            agent_interfaces=agent_interfaces,
            env_description=env_desc,
            fragment_length=rollout_config.fragment_length,
            max_step=rollout_config.max_step,
            episodic=self._episode_training,
            train_every=self._train_every,
            evaluate=False,
        )
        last_learner_state = self.get_learner_state()
        try:
            start_timesteps = self.get_learner_state().total_timesteps
            while True:
                info = next(generator)
                start_timesteps += info["timesteps"]
        except StopIteration as e:
            info = e.value
            total_timesteps = (
                last_learner_state.total_timesteps + info["total_timesteps"]
            )
            total_episodes = last_learner_state.total_episodes + info["num_episode"]
            self.update_learner_state(
                total_episodes=total_episodes, total_timesteps=total_timesteps
            )
        if sampler.is_ready() and self._episode_training:
            for _ in range(self._mini_epoch):
                self.minimax(
                    sampler.sample(
                        batch_size=self._training_config.hyper_params["batch_size"]
                    )
                )
