from abc import ABCMeta
import logging
import ray
import numpy as np

from trainer.trainer import Trainer
from policy.policy import PolicySpec
from agents.league.coordinator import Coordinator, AsymmetricCoordinator, PopulationCoordinator
from utils.annotations import override
from utils.typing import TrainerConfigDict, ResultDict

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s",
    handlers=[logging.FileHandler("league.log"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


class League(metaclass=ABCMeta):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        **kwargs,
    ):
        """Initializes a League instance.

        Args:
            trainer: The Trainer object by which this league builder is used.
                Trainer calls `build_league()` after each training step.
            trainer_config: The (not yet validated) config dict to be
                used on the Trainer. Child classes of `League`
                should preprocess this to add e.g. multiagent settings
                to this config.
        """
        self.trainer = trainer
        self.config = trainer_config

    def update_league(self, result: ResultDict) -> None:
        """Method containing league-building logic. Called after train step.

        Args:
            result: The most recent result dict with all necessary stats in
                it (e.g. episode rewards) to perform league building
                operations.
        """
        raise NotImplementedError


class SelfPlayLeague(League):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        match_func: str = "FSP",
        win_rate_threshold: float = 0.8,
        iter_threshold: int = 1000,
        newest_prob: float = 0.7,
        max_league_size: int = 20,
        coordinator: bool = True,
    ):
        """Initializes a SelfPlayLeague instance.

        Args:
            trainer: The Trainer object by which this league builder is used.
                Trainer calls `build_league()` after each training step.
            trainer_config: The (not yet validated) config dict to be
                used on the Trainer. Child classes of `League`
                should preprocess this to add e.g. multi-agent settings
                to this config.
            match_func: The match function.
            win_rate_threshold: The win-rate to be achieved
                for a learning policy to get snapshot'd (forked into `self` +
                a new learning or non-learning copy of `self`).
        """
        super().__init__(
            trainer,
            trainer_config,
        )
        if coordinator:
            self.coordinator = Coordinator.options(name="coordinator").remote(
                match_func,
                win_rate_threshold,
                iter_threshold,
                newest_prob,
                max_league_size,
                seed=self.config["seed"],
            )

        self.match_func = match_func
        self.win_rate_threshold = win_rate_threshold
        self.iter_threshold = iter_threshold
        self.max_league_size = max_league_size

        # Modify the trainer's multi-agent config.
        self.config["multiagent"]["policies"] = {
            "main": PolicySpec(),
            "main_v0": PolicySpec(),
        }
        self.config["multiagent"]["policies_to_train"] = ["main"]
        self.config["multiagent"]["policy_mapping_fn"] = ray.get(
            self.coordinator.get_new_policy_mapping_fn.remote()
        )
        logger.info(f"League config: {self.config['multiagent']}")

    @override(League)
    def update_league(self, result: ResultDict) -> dict:
        logger.info(f"League after iter {self.trainer.iteration}:")

        main_win_rates = ray.get(self.coordinator.get_win_rates.remote())
        logger.info(f"\tmain_win_rates: {main_win_rates}")

        for policy_id in ["main"]:
            # TODO: checkpoint condition
            if self.match_func in ["FSP", "PFSP", "ALP"]:
                snapshot_condition = self.trainer.iteration % self.iter_threshold == self.iter_threshold - 1
            else:
                raise ValueError(f"Unsupported match function: {self.match_func}")

            if snapshot_condition:
                new_pol_id = ray.get(self.coordinator.add_policy.remote())

                # Add and set the weights of the new policy.
                self.trainer.add_policy(
                    policy_id=new_pol_id,
                    policy_cls=type(self.trainer.get_policy(policy_id)),
                    policy_state=self.trainer.get_policy(policy_id).get_state(),
                )

        result["version"] = ray.get(self.coordinator.league_version.remote())

        return result


class AsymmetricSelfPlayLeague(League):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        match_func: str = "FSP",
        win_rate_threshold: float = 0.8,
        iter_threshold: int = 1000,
        newest_prob: float = 0.7,
        max_league_size: int = 20,
        coordinator: bool = True,
    ):
        """Initializes a AsymmetricSelfPlayLeague instance.

        Args:
            trainer: The Trainer object by which this league builder is used.
                Trainer calls `build_league()` after each training step.
            trainer_config: The (not yet validated) config dict to be
                used on the Trainer. Child classes of `League`
                should preprocess this to add e.g. multi-agent settings
                to this config.
            match_func: The match function.
            win_rate_threshold: The win-rate to be achieved
                for a learning policy to get snapshot'd (forked into `self` +
                a new learning or non-learning copy of `self`).
        """
        super().__init__(
            trainer,
            trainer_config,
        )
        if coordinator:
            self.coordinator = AsymmetricCoordinator.options(name="coordinator").remote(
                match_func=match_func,
                win_rate_threshold=win_rate_threshold,
                iter_threshold=iter_threshold,
                newest_prob=newest_prob,
                max_league_size=max_league_size,
                seed=self.config["seed"],
            )

        self.match_func = match_func
        self.win_rate_threshold = win_rate_threshold
        self.iter_threshold = iter_threshold
        self.max_league_size = max_league_size

        # Modify the trainer's multi-agent config.
        temp_env = self.trainer.env_creator(self.config["env_config"])
        if isinstance(temp_env.observation_space, dict):
            self.left_obs_space = temp_env.observation_space.get("left", None)
            self.left_act_space = temp_env.action_space.get("left", None)
            self.right_obs_space = temp_env.observation_space.get("right", None)
            self.right_act_space = temp_env.action_space.get("right", None)
        else:
            self.left_obs_space = None
            self.left_act_space = None
            self.right_obs_space = None
            self.right_act_space = None
        temp_env.close()

        self.config["multiagent"]["policies"] = {
            "main_left": PolicySpec(observation_space=self.left_obs_space, action_space=self.left_act_space),
            "main_right": PolicySpec(observation_space=self.right_obs_space, action_space=self.right_act_space),
            "left_v0": PolicySpec(observation_space=self.left_obs_space, action_space=self.left_act_space),
            "right_v0": PolicySpec(observation_space=self.right_obs_space, action_space=self.right_act_space),
        }
        self.config["multiagent"]["policies_to_train"] = [
            "main_left",
            "main_right",
        ]
        self.config["multiagent"]["policy_mapping_fn"] = ray.get(
            self.coordinator.get_new_policy_mapping_fn.remote()
        )
        logger.info(f"League config: {self.config['multiagent']}")

    @override(League)
    def update_league(self, result: ResultDict) -> dict:
        """Snapshot current main policy if needed, and add custom results."""
        logger.info(f"League after iter {self.trainer.iteration}:")
        left_win_rates = ray.get(self.coordinator.get_win_rates.remote("main_left"))
        logger.info(f"\tmain_left_win_rates: {left_win_rates}")
        right_win_rates = ray.get(self.coordinator.get_win_rates.remote("main_right"))
        logger.info(f"\tmain_right_win_rates: {right_win_rates}")
        result["main_left_win_mean"] = np.mean(list(left_win_rates.values()))
        result["main_right_win_mean"] = np.mean(list(right_win_rates.values()))
        training_ratio = ray.get(self.coordinator.update_training_ratio.remote())
        logger.info(f"\ttraining_ratio: {training_ratio}")
        result["main_left_training_ratio"] = training_ratio[0]
        result["main_right_training_ratio"] = training_ratio[1]

        result["left_version"], result["right_version"] = ray.get(
            self.coordinator.league_version.remote()
        )

        for policy_id in ["main_left", "main_right"]:
            # TODO: checkpoint condition
            if self.match_func in ["FSP", "PFSP", "ALP", "ALP-GMM"]:
                snapshot_condition = self.trainer.iteration % self.iter_threshold == self.iter_threshold - 1
                remove_condition = self.trainer.iteration % self.iter_threshold == self.iter_threshold - 2
            # elif self.match_func == "PFSP":
            #     if policy_id == "main_left":
            #         snapshot_condition = (
            #             result["left_mean_win_rate"] >= self.win_rate_threshold
            #         )
            #     else:
            #         snapshot_condition = (
            #             result["right_mean_win_rate"] >= self.win_rate_threshold
            #         )
            else:
                raise ValueError(f"Unsupported match function: {self.match_func}")

            if snapshot_condition:
                if policy_id == "main_left":
                    new_pol_id = ray.get(self.coordinator.add_policy.remote("left"))
                    obs_space = self.left_obs_space
                    act_space = self.left_act_space
                else:
                    new_pol_id = ray.get(self.coordinator.add_policy.remote("right"))
                    obs_space = self.right_obs_space
                    act_space = self.right_act_space

                # Add and set the weights of the new policy.
                self.trainer.add_policy(
                    policy_id=new_pol_id,
                    policy_cls=type(self.trainer.get_policy(policy_id)),
                    observation_space=obs_space,
                    action_space=act_space,
                    policy_state=self.trainer.get_policy(policy_id).get_state(),
                    evaluation_workers=False,
                )

            if remove_condition:
                policy_to_remove = ray.get(self.coordinator.get_policy_to_remove.remote())
                policy_mapping = ray.get(self.coordinator.get_new_policy_mapping_fn.remote())
                if policy_to_remove:
                    self.trainer.remove_policy(
                        policy_id=policy_to_remove,
                        policy_mapping_fn=policy_mapping,
                        evaluation_workers=False,
                    )

        return result


class PopulationLeague(League):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        match_func: str = "FSP",
        population_size: int = 4,
        coordinator: bool = True,
    ):
        """Initializes a PopulationLeague instance."""
        super().__init__(
            trainer,
            trainer_config,
        )
        if coordinator:
            self.coordinator = PopulationCoordinator.options(name="coordinator").remote(
                match_func=match_func,
                population_size=population_size,
                seed=self.config["seed"],
            )

        self.match_func = match_func
        self.population_size = population_size

        # Modify the trainer's multi-agent config.
        temp_env = self.trainer.env_creator(self.config["env_config"])
        if isinstance(temp_env.observation_space, dict):
            self.left_obs_space = temp_env.observation_space.get("left", None)
            self.left_act_space = temp_env.action_space.get("left", None)
            self.right_obs_space = temp_env.observation_space.get("right", None)
            self.right_act_space = temp_env.action_space.get("right", None)
        else:
            self.left_obs_space = None
            self.left_act_space = None
            self.right_obs_space = None
            self.right_act_space = None
        temp_env.close()

        self.config["multiagent"]["policies"] = dict(
            **{
                f"left_{i}": PolicySpec(observation_space=self.left_obs_space, action_space=self.left_act_space)
                for i in range(self.population_size)
            },
            **{
                f"right_{i}": PolicySpec(observation_space=self.right_obs_space, action_space=self.right_act_space)
                for i in range(self.population_size)
            },
        )
        self.config["multiagent"]["policies_to_train"] = \
            [f"left_{i}" for i in range(self.population_size)] + \
            [f"left_{i}" for i in range(self.population_size)]
        self.config["multiagent"]["policy_mapping_fn"] = ray.get(
            self.coordinator.get_new_policy_mapping_fn.remote()
        )
        logger.info(f"League config: {self.config['multiagent']}")

    @override(League)
    def update_league(self, result: ResultDict) -> dict:
        """Snapshot current main policy if needed, and add custom results."""
        logger.info(f"League after iter {self.trainer.iteration}:")
        for i in range(self.population_size):
            left_win_rates = ray.get(self.coordinator.get_win_rates.remote(f"left_{i}"))
            right_win_rates = ray.get(self.coordinator.get_win_rates.remote(f"right_{i}"))
            logger.info(f"\tleft_{i}_win_rates: {left_win_rates}")
            logger.info(f"\tright_{i}_win_rates: {right_win_rates}")
            result[f"left_{i}_win_mean"] = np.mean(list(left_win_rates.values()))
            result[f"right_{i}_win_mean"] = np.mean(list(right_win_rates.values()))
        
        training_ratio, left_training_ratio, right_training_ratio = ray.get(self.coordinator.update_training_ratio.remote())
        logger.info(f"\ttraining_ratio: {training_ratio}")
        logger.info(f"\tleft_training_ratio: {left_training_ratio}")
        logger.info(f"\tright_training_ratio: {right_training_ratio}")
        result["training_ratio"] = training_ratio
        result["left_training_ratio"] = left_training_ratio
        result["right_training_ratio"] = right_training_ratio

        return result
