import logging
import ray
import gym
from typing import Type

from agents.ppo.trainer import PPOTrainer
from agents.curriculum.config import PPO_CURRICULUM_DEFAULT_CONFIG, PPO_HRL_CURRICULUM_DEFAULT_CONFIG
from agents.ppo.communication.policy import PPOInvariantComPolicy
from agents.ppo.hierarchical.trainer import PPOHRLTrainer
from policy.policy import Policy, PolicySpec
from worker.worker_set import WorkerSet
from trainer.trainer import Trainer
from utils.annotations import override
from utils.debug import update_global_seed_if_necessary
from utils.from_config import from_config
from utils.typing import (
    TrainerConfigDict,
    PartialTrainerConfigDict,
    ResultDict,
)
from ray.util.iter import LocalIterator

logger = logging.getLogger(__name__)


class PPOCurriculumTrainer(PPOTrainer):
    _allow_unknown_subkeys = PPOTrainer._allow_unknown_subkeys + [
        "teacher_config",
    ]
    _override_all_subkeys_if_type_changes = (
        PPOTrainer._override_all_subkeys_if_type_changes + ["teacher_config"]
    )

    @classmethod
    @override(PPOTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return PPO_CURRICULUM_DEFAULT_CONFIG

    @override(PPOTrainer)
    def setup(self, config: PartialTrainerConfigDict):
        # Setup our config: Merge the user-supplied config.
        self.config = self.merge_trainer_configs(
            self.get_default_config(), config, self._allow_unknown_configs
        )
        # Validate the framework settings in config.
        self.validate_framework()
        # Setup the "env creator" callable.
        self.setup_env_creator()
        # Set Trainer's seed.
        update_global_seed_if_necessary(config.get("seed"))

        # Create the Teacher object.
        self.teacher = from_config(
            self.config["teacher_config"], trainer=self, trainer_config=self.config
        )

        self.validate_config(self.config)
        self.callbacks = self.config["callbacks"]()

        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info(
                "Current log_level is {}. For more information, "
                "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                "-vv flags.".format(log_level)
            )
        if self.config.get("log_level"):
            logging.getLogger("rllib").setLevel(self.config["log_level"])

        # Create local replay buffer if necessary.
        self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
            self.config
        )

        self.workers = None
        self.train_exec_impl = None

        # Create rollout workers for collecting samples for training.
        self.workers = WorkerSet(
            env_creator=self.env_creator,
            validate_env=self.validate_env,
            policy_class=self.get_default_policy_class(self.config),
            trainer_config=self.config,
            num_workers=self.config["num_workers"],
            local_worker=True,
            logdir=self.logdir,
        )

        # Function defining one single training iteration's behavior.
        # LocalIterator-creating "execution plan".
        # Only call this once here to create `self.train_exec_impl`,
        # which is a ray.util.iter.LocalIterator that will be `next`'d
        # on each training iteration.
        self.train_exec_impl = self.execution_plan(
            self.workers, self.config, **self._kwargs_for_execution_plan()
        )

        # Evaluation WorkerSet setup.
        self.setup_eval_workers()

    @override(Trainer)
    def step_attempt(self) -> ResultDict:
        """Attempts a single training step, including worker, if required.

        Returns:
            The results dict with stats/infos on sampling, training,
            and - if required - worker.
        """
        step_results = Trainer.step_attempt(self)
        step_results = self.teacher.update_curriculum(result=step_results)

        return step_results

    def __getstate__(self) -> dict:
        state = Trainer.__getstate__(self)
        if hasattr(self, "teacher") and self.config["teacher_config"].get("task_generator", True):
            name = ray.get(self.teacher.task_generator.get_name.remote())
            if name != "uniform":
                state["teacher"] = ray.get(self.teacher.task_generator.save.remote())
        return state

    def __setstate__(self, state: dict):
        Trainer.__setstate__(self, state)
        if hasattr(self, "teacher") and "teacher" in state and self.config["teacher_config"].get("task_generator", True):
            name = ray.get(self.teacher.task_generator.get_name.remote())
            if name != "uniform":
                self.teacher.task_generator.restore.remote(state["teacher"])


class PPOComCurriculumTrainer(PPOCurriculumTrainer):
    @override(PPOTrainer)
    def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
        return PPOInvariantComPolicy


class PPOHRLCurriculumTrainer(PPOCurriculumTrainer, PPOHRLTrainer):
    _allow_unknown_subkeys = PPOTrainer._allow_unknown_subkeys + [
        "teacher_config", "high_level_policy_config", "low_level_policy_config",
    ]
    _override_all_subkeys_if_type_changes = (
        PPOTrainer._override_all_subkeys_if_type_changes + ["teacher_config", "high_level_policy_config", "low_level_policy_config"]
    )

    @classmethod
    @override(PPOCurriculumTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return PPO_HRL_CURRICULUM_DEFAULT_CONFIG

    @override(PPOCurriculumTrainer)
    def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
        return None

    @override(PPOCurriculumTrainer)
    def setup(self, config: PartialTrainerConfigDict):
        # Setup our config: Merge the user-supplied config.
        self.config = self.merge_trainer_configs(
            self.get_default_config(), config, self._allow_unknown_configs
        )
        # Validate the framework settings in config.
        self.validate_framework()
        # Setup the "env creator" callable.
        self.setup_env_creator()
        # Set Trainer's seed.
        update_global_seed_if_necessary(config.get("seed"))

        # Create the Teacher object.
        self.teacher = from_config(
            self.config["teacher_config"], trainer=self, trainer_config=self.config
        )

        # Build the multi-agent config.
        with self.env_creator(self.config["env_config"]) as temp_env:
            if isinstance(temp_env.high_level_action_space, gym.spaces.Tuple):
                from agents.ppo.communication import PPOInvariantComPolicy
                high_level_policy_cls = PPOInvariantComPolicy
            else:
                from agents.ppo.policy import PPOTorchPolicy
                high_level_policy_cls = PPOTorchPolicy
            self.config["multiagent"]["policies"]["high_level_policy"] = PolicySpec(
                policy_class=high_level_policy_cls,
                observation_space=temp_env.high_level_observation_space,
                action_space=temp_env.high_level_action_space,
                config=self.config["high_level_policy_config"],
            )

            low_level_policy_cls = self.config["multiagent"]["policies"]["low_level_policy"].policy_class
            self.config["multiagent"]["policies"]["low_level_policy"] = PolicySpec(
                policy_class=low_level_policy_cls,
                observation_space=temp_env.low_level_observation_space,
                action_space=temp_env.low_level_action_space,
                config=self.config["low_level_policy_config"],
            )

        self.validate_config(self.config)
        self.callbacks = self.config["callbacks"]()

        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info(
                "Current log_level is {}. For more information, "
                "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                "-vv flags.".format(log_level)
            )
        if self.config.get("log_level"):
            logging.getLogger("rllib").setLevel(self.config["log_level"])

        # Create local replay buffer if necessary.
        self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
            self.config
        )

        self.workers = None
        self.train_exec_impl = None

        # Create rollout workers for collecting samples for training.
        self.workers = WorkerSet(
            env_creator=self.env_creator,
            validate_env=self.validate_env,
            policy_class=self.get_default_policy_class(self.config),
            trainer_config=self.config,
            num_workers=self.config["num_workers"],
            local_worker=True,
            logdir=self.logdir,
        )

        # Function defining one single training iteration's behavior.
        # LocalIterator-creating "trainer plan".
        # Only call this once here to create `self.train_exec_impl`,
        # which is a ray.util.iter.LocalIterator that will be `next`'d
        # on each training iteration.
        self.train_exec_impl = self.execution_plan(
            self.workers, self.config, **self._kwargs_for_execution_plan()
        )

        # Evaluation WorkerSet setup.
        self.setup_eval_workers()

    @staticmethod
    @override(PPOTrainer)
    def execution_plan(
        workers: WorkerSet, config: TrainerConfigDict, **kwargs
    ) -> LocalIterator[dict]:
        return PPOHRLTrainer.execution_plan(workers, config, **kwargs)
