import logging
from typing import Optional, Type

from agents.dqn.simple_q.policy import SimpleQTorchPolicy
from agents.dqn.simple_q.config import SIMPLE_Q_DEFAULT_CONFIG
from trainer.trainer import Trainer
from trainer.concurrency_ops import Concurrently
from trainer.metric_ops import StandardMetricsReporting
from trainer.replay_ops import Replay, StoreToReplayBuffer
from trainer.rollout_ops import ParallelRollouts
from trainer.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
    UpdateTargetNetwork
from policy.policy import Policy
from utils.annotations import override
from utils.typing import TrainerConfigDict

logger = logging.getLogger(__name__)


class SimpleQTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return SIMPLE_Q_DEFAULT_CONFIG

    @override(Trainer)
    def validate_config(self, config: TrainerConfigDict) -> None:
        """Checks and updates the config based on settings.
        """
        super().validate_config(config)

        if config["exploration_config"]["type"] == "ParameterNoise":
            if config["batch_mode"] != "complete_episodes":
                logger.warning(
                    "ParameterNoise Exploration requires `batch_mode` to be "
                    "'complete_episodes'. Setting batch_mode="
                    "complete_episodes.")
                config["batch_mode"] = "complete_episodes"
            if config.get("noisy", False):
                raise ValueError(
                    "ParameterNoise Exploration and `noisy` network cannot be"
                    " used at the same time!")

        if config.get("prioritized_replay"):
            if config["multiagent"]["replay_mode"] == "lockstep":
                raise ValueError("Prioritized replay is not supported when "
                                 "replay_mode=lockstep.")
            elif config.get("replay_sequence_length", 0) > 1:
                raise ValueError("Prioritized replay is not supported when "
                                 "replay_sequence_length > 1.")
        else:
            if config.get("worker_side_prioritization"):
                raise ValueError(
                    "Worker side prioritization is not supported when "
                    "prioritized_replay=False.")

        # Multi-agent mode and multi-GPU optimizer.
        if config["multiagent"]["policies"] and \
                not config["simple_optimizer"]:
            logger.info(
                "In multi-agent mode, policies will be optimized sequentially"
                " by the multi-GPU optimizer. Consider setting "
                "`simple_optimizer=True` if this doesn't work for you.")

    @override(Trainer)
    def get_default_policy_class(
            self, config: TrainerConfigDict) -> Optional[Type[Policy]]:
        return SimpleQTorchPolicy

    @staticmethod
    @override(Trainer)
    def execution_plan(workers, config, **kwargs):
        assert "local_replay_buffer" in kwargs, (
            "GenericOffPolicy execution plan requires a local replay buffer.")

        local_replay_buffer = kwargs["local_replay_buffer"]

        rollouts = ParallelRollouts(workers, mode="bulk_sync")

        # (1) Generate rollouts and store them in our local replay buffer.
        store_op = rollouts.for_each(
            StoreToReplayBuffer(local_buffer=local_replay_buffer))

        if config["simple_optimizer"]:
            train_step_op = TrainOneStep(workers)
        else:
            train_step_op = MultiGPUTrainOneStep(
                workers=workers,
                sgd_minibatch_size=config["train_batch_size"],
                num_sgd_iter=1,
                num_gpus=config["num_gpus"],
                _fake_gpus=config["_fake_gpus"])

        # (2) Read and train on experiences from the replay buffer.
        replay_op = Replay(local_buffer=local_replay_buffer) \
            .for_each(train_step_op) \
            .for_each(UpdateTargetNetwork(
                workers, config["target_network_update_freq"]))

        # Alternate deterministically between (1) and (2).
        train_op = Concurrently(
            [store_op, replay_op], mode="round_robin", output_indexes=[1])

        return StandardMetricsReporting(train_op, workers, config)
