from expground.common.exploitability import measure_exploitability
from expground.types import (
    Union,
    LambdaType,
    Dict,
    PolicyConfig,
    EnvDescription,
    TrainingConfig,
    RolloutConfig,
    Tuple,
    AgentID,
    PolicyID,
)
from expground.utils.logging import write_to_tensorboard
from expground.utils.stoppers import get_stopper, DEFAULT_STOP_CONDITIONS
from expground.algorithms.base_policy import Policy

from .psro import PSROLearner


class TTSLearner(PSROLearner):
    def __init__(
        self,
        meta_solver: str,
        policy_config: PolicyConfig,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = False,
        train_every: int = 1,
        experiment: str = None,
        ray_mode: bool = False,
        seed: int = None,
        mixed_at_every_step: bool = False,
        independent_learning: bool = True,
        evaluation_worker_num: int = 0,
        distribution_training_kwargs: Dict = None,
        centralized_critic_config: Dict = None,
        rectifier_type: int = 0,
    ):
        super(TTSLearner, self).__init__(
            meta_solver,
            policy_config,
            env_description,
            rollout_config,
            training_config,
            loss_func,
            learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            use_learnable_dist=True,
            experiment=experiment,
            ray_mode=ray_mode,
            seed=seed,
            mixed_at_every_step=mixed_at_every_step,
            independent_learning=independent_learning,
            evaluation_worker_num=evaluation_worker_num,
            distribution_training_kwargs=distribution_training_kwargs,
            centralized_critic_config=centralized_critic_config,
            rectifier_type=rectifier_type,
        )

        assert self._use_learnable_dist

    def _run_decentralized_meta_solver(
        self,
    ) -> Tuple[
        Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]]
    ]:
        """Run two-timescale learning to find best response and equilibrium

        Raises:
            NotImplementedError: [description]

        Returns:
            Tuple[Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]]]: [description]
        """
        raise NotImplementedError

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict,
    ):
        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()

        while not stopper.is_terminal():

            equilibrium, fixed_policies = self._run_decentralized_meta_solver()
            nash_conv = measure_exploitability(
                self._env_description["config"]["env_id"], fixed_policies, equilibrium
            )
            write_to_tensorboard(
                self.summary_writer,
                info={"NashConv": nash_conv.nash_conv},
                global_step=stopper.counter,
                prefix="",
            )

            self._sync_up()
            stopper.step(None, None, None, None)
