import logging
from typing import Type

import gym

from agents.ppo.trainer import UpdateKL, PPOTrainer
from agents.ppo.hierarchical.config import PPO_HRL_DEFAULT_CONFIG
from worker.worker_set import WorkerSet
from trainer.rollout_ops import (
    ParallelRollouts,
    ConcatBatches,
    StandardizeFields,
    SelectExperiences,
)
from trainer.train_ops import TrainOneStep
from trainer.metric_ops import StandardMetricsReporting
from trainer.concurrency_ops import Concurrently
from trainer.common import _get_shared_metrics
from policy.policy import Policy, PolicySpec
from utils.annotations import override
from utils.typing import TrainerConfigDict, PartialTrainerConfigDict
from utils.debug import update_global_seed_if_necessary
from ray.util.iter import LocalIterator

logger = logging.getLogger(__name__)


class PPOHRLTrainer(PPOTrainer):
    _allow_unknown_subkeys = PPOTrainer._allow_unknown_subkeys + [
        "high_level_policy_config", "low_level_policy_config",
    ]
    _override_all_subkeys_if_type_changes = (
        PPOTrainer._override_all_subkeys_if_type_changes + ["high_level_policy_config", "low_level_policy_config"]
    )

    @classmethod
    @override(PPOTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return PPO_HRL_DEFAULT_CONFIG

    @override(PPOTrainer)
    def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
        return None

    @override(PPOTrainer)
    def setup(self, config: PartialTrainerConfigDict):
        # Setup our config: Merge the user-supplied config.
        self.config = self.merge_trainer_configs(
            self.get_default_config(), config, self._allow_unknown_configs
        )
        # Validate the framework settings in config.
        self.validate_framework()
        # Setup the "env creator" callable.
        self.setup_env_creator()
        # Set Trainer's seed.
        update_global_seed_if_necessary(config.get("seed"))

        # Build the multi-agent config.
        with self.env_creator(self.config["env_config"]) as temp_env:
            if isinstance(temp_env.high_level_observation_space, gym.spaces.Tuple):
                from agents.ppo.communication import PPOComPolicy
                high_level_policy_cls = PPOComPolicy
            else:
                from agents.ppo.policy import PPOTorchPolicy
                high_level_policy_cls = PPOTorchPolicy
            self.config["multiagent"]["policies"]["high_level_policy"] = PolicySpec(
                policy_class=high_level_policy_cls,
                observation_space=temp_env.high_level_observation_space,
                action_space=temp_env.high_level_action_space,
                config=self.config["high_level_policy_config"],
            )

            low_level_policy_cls = self.config["multiagent"]["policies"]["low_level_policy"].policy_class
            self.config["multiagent"]["policies"]["low_level_policy"] = PolicySpec(
                policy_class=low_level_policy_cls,
                observation_space=temp_env.low_level_observation_space,
                action_space=temp_env.low_level_action_space,
                config=self.config["low_level_policy_config"],
            )

        self.validate_config(self.config)
        self.callbacks = self.config["callbacks"]()

        log_level = self.config.get("log_level")
        if log_level in ["WARN", "ERROR"]:
            logger.info(
                "Current log_level is {}. For more information, "
                "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
                "-vv flags.".format(log_level)
            )
        if self.config.get("log_level"):
            logging.getLogger("rllib").setLevel(self.config["log_level"])

        # Create local replay buffer if necessary.
        self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
            self.config
        )

        self.workers = None
        self.train_exec_impl = None

        # Create rollout workers for collecting samples for training.
        self.workers = WorkerSet(
            env_creator=self.env_creator,
            validate_env=self.validate_env,
            policy_class=self.get_default_policy_class(self.config),
            trainer_config=self.config,
            num_workers=self.config["num_workers"],
            local_worker=True,
            logdir=self.logdir,
        )

        # Function defining one single training iteration's behavior.
        # LocalIterator-creating "trainer plan".
        # Only call this once here to create `self.train_exec_impl`,
        # which is a ray.util.iter.LocalIterator that will be `next`'d
        # on each training iteration.
        self.train_exec_impl = self.execution_plan(
            self.workers, self.config, **self._kwargs_for_execution_plan()
        )

        # Evaluation WorkerSet setup.
        self.setup_eval_workers()

    @staticmethod
    @override(PPOTrainer)
    def execution_plan(
        workers: WorkerSet, config: TrainerConfigDict, **kwargs
    ) -> LocalIterator[dict]:
        """Execution plan of the HRL algorithm. Defines the distributed dataflow.
        Args:
            workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
                of the Trainer.
            config (TrainerConfigDict): The trainer's configuration dict.
        Returns:
            LocalIterator[dict]: A local iterator over training metrics.
        """

        def add_high_level_metrics(batch):
            print("High-level policy learning on samples from",
                  batch.policy_batches.keys(), "env steps", batch.env_steps(),
                  "agent steps", batch.agent_steps())
            metrics = _get_shared_metrics()
            metrics.counters["num_high_level_steps"] += batch.agent_steps()
            return batch

        def add_low_level_metrics(batch):
            print("Low-level policy learning on samples from",
                  batch.policy_batches.keys(), "env steps", batch.env_steps(),
                  "agent steps", batch.agent_steps())
            metrics = _get_shared_metrics()
            metrics.counters["num_low_level_steps_of_all_agents"] += batch.agent_steps()
            return batch

        # Generate common experiences.
        rollouts = ParallelRollouts(workers, mode="bulk_sync")
        high_level_rollouts, low_level_rollouts = rollouts.duplicate(n=2)

        # High-level PPO sub-flow.
        high_level_train_op = high_level_rollouts.for_each(SelectExperiences(["high_level_policy"])) \
            .combine(ConcatBatches(
            min_batch_size=config["multiagent"]["policies"]["high_level_policy"].config["train_batch_size"],
            count_steps_by="env_steps")) \
            .for_each(add_high_level_metrics) \
            .for_each(StandardizeFields(["advantages"])) \
            .for_each(TrainOneStep(
            workers,
            policies=["high_level_policy"],
            num_sgd_iter=config["multiagent"]["policies"]["high_level_policy"].config["num_sgd_iter"],
            sgd_minibatch_size=config["multiagent"]["policies"]["high_level_policy"].config["sgd_minibatch_size"])) \
            .for_each(lambda t: t[1]) \
            .for_each(UpdateKL(workers))

        # Low-level PPO sub-flow.
        low_level_train_op = low_level_rollouts.for_each(SelectExperiences(["low_level_policy"])) \
            .combine(ConcatBatches(
            min_batch_size=config["multiagent"]["policies"]["low_level_policy"].config["train_batch_size"],
            count_steps_by="env_steps")) \
            .for_each(add_low_level_metrics) \
            .for_each(StandardizeFields(["advantages"])) \
            .for_each(TrainOneStep(
            workers,
            policies=["low_level_policy"],
            num_sgd_iter=config["multiagent"]["policies"]["low_level_policy"].config["num_sgd_iter"],
            sgd_minibatch_size=config["multiagent"]["policies"]["low_level_policy"].config["sgd_minibatch_size"])) \
            .for_each(lambda t: t[1]) \
            .for_each(UpdateKL(workers))

        # Combined training flow
        train_op = Concurrently([low_level_train_op, high_level_train_op], mode="async", output_indexes=[1])

        # Warn about bad reward scales and return training metrics.
        return StandardMetricsReporting(train_op, workers, config)
