import logging
import ray

from agents.ppo.trainer import PPOTrainer
from agents.league.config import LEAGUE_DEFAULT_CONFIG
from worker.worker_set import WorkerSet
from trainer.trainer import Trainer
from utils.annotations import override
from utils.debug import update_global_seed_if_necessary
from utils.from_config import from_config
from utils.typing import (
    TrainerConfigDict,
    PartialTrainerConfigDict,
    ResultDict,
)

logger = logging.getLogger(__name__)


class LeagueTrainer(PPOTrainer):
    _allow_unknown_subkeys = PPOTrainer._allow_unknown_subkeys + [
        "league_config",
    ]
    _override_all_subkeys_if_type_changes = (
        PPOTrainer._override_all_subkeys_if_type_changes + ["league_config"]
    )

    @classmethod
    @override(PPOTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return LEAGUE_DEFAULT_CONFIG

    @override(PPOTrainer)
    def setup(self, config: PartialTrainerConfigDict):
        # Setup our config: Merge the user-supplied config.
        self.config = self.merge_trainer_configs(
            self.get_default_config(), config, self._allow_unknown_configs
        )
        # Validate the framework settings in config.
        self.validate_framework()
        # Setup the "env creator" callable.
        self.setup_env_creator()
        # Set Trainer's seed.
        update_global_seed_if_necessary(config.get("seed"))

        # Create the League object, and build the multi-agent config.
        self.league = from_config(
            self.config["league_config"], trainer=self, trainer_config=self.config
        )

        self.validate_config(self.config)
        self.callbacks = self.config["callbacks"]()

        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info(
                "Current log_level is {}. For more information, "
                "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                "-vv flags.".format(log_level)
            )
        if self.config.get("log_level"):
            logging.getLogger("rllib").setLevel(self.config["log_level"])

        # Create local replay buffer if necessary.
        self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
            self.config
        )

        self.workers = None
        self.train_exec_impl = None

        # Create rollout workers for collecting samples for training.
        self.workers = WorkerSet(
            env_creator=self.env_creator,
            validate_env=self.validate_env,
            policy_class=self.get_default_policy_class(self.config),
            trainer_config=self.config,
            num_workers=self.config["num_workers"],
            local_worker=True,
            logdir=self.logdir,
        )

        # Function defining one single training iteration's behavior.
        # LocalIterator-creating "trainer plan".
        # Only call this once here to create `self.train_exec_impl`,
        # which is a ray.util.iter.LocalIterator that will be `next`'d
        # on each training iteration.
        self.train_exec_impl = self.execution_plan(
            self.workers, self.config, **self._kwargs_for_execution_plan()
        )

        # Evaluation WorkerSet setup.
        self.setup_eval_workers()

    @override(Trainer)
    def step_attempt(self) -> ResultDict:
        """Attempts a single training step, including evaluation, if required.

        Returns:
            The results dict with stats/infos on sampling, training,
            and - if required - evaluation.
        """
        step_results = Trainer.step_attempt(self)
        step_results = self.league.update_league(result=step_results)

        return step_results

    def __getstate__(self) -> dict:
        state = Trainer.__getstate__(self)
        if hasattr(self, "league") and self.config["league_config"].get("coordinator", True):
            state["league"] = ray.get(self.league.coordinator.save.remote())
        return state

    def __setstate__(self, state: dict):
        Trainer.__setstate__(self, state)
        if hasattr(self, "league") and "league" in state and self.config["league_config"].get("coordinator", True):
            self.league.coordinator.restore.remote(state["league"])
