import logging
from typing import Optional, Type

import ray
from agents.impala.config import IMPALA_DEFAULT_CONFIG
from trainer.trainer import Trainer
from trainer.learner_thread import LearnerThread
from trainer.multi_gpu_learner_thread import MultiGPULearnerThread
from trainer.tree_agg import gather_experiences_tree_aggregation
from trainer.common import (STEPS_TRAINED_COUNTER,
                            STEPS_TRAINED_THIS_ITER_COUNTER,
                            _get_global_vars, _get_shared_metrics)
from trainer.replay_ops import MixInReplay
from trainer.rollout_ops import ParallelRollouts, ConcatBatches
from trainer.concurrency_ops import Concurrently, Enqueue, Dequeue
from trainer.metric_ops import StandardMetricsReporting
from policy.policy import Policy
from utils.annotations import override
from utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from utils.typing import PartialTrainerConfigDict, TrainerConfigDict
from ray.tune.utils.placement_groups import PlacementGroupFactory

logger = logging.getLogger(__name__)


def make_learner_thread(local_worker, config):
    if not config["simple_optimizer"]:
        logger.info(
            "Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks".
                format(config["num_gpus"], config["num_multi_gpu_tower_stacks"]))
        num_stacks = config["num_multi_gpu_tower_stacks"]
        buffer_size = config["minibatch_buffer_size"]
        if num_stacks < buffer_size:
            logger.warning(
                "In multi-GPU mode you should have at least as many "
                "multi-GPU tower stacks (to load data into on one device) as "
                "you have stack-index slots in the buffer! You have "
                f"configured {num_stacks} stacks and a buffer of size "
                f"{buffer_size}. Setting "
                f"`minibatch_buffer_size={num_stacks}`.")
            config["minibatch_buffer_size"] = num_stacks

        learner_thread = MultiGPULearnerThread(
            local_worker,
            num_gpus=config["num_gpus"],
            lr=config["lr"],
            train_batch_size=config["train_batch_size"],
            num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"],
            num_sgd_iter=config["num_sgd_iter"],
            learner_queue_size=config["learner_queue_size"],
            learner_queue_timeout=config["learner_queue_timeout"])
    else:
        learner_thread = LearnerThread(
            local_worker,
            minibatch_buffer_size=config["minibatch_buffer_size"],
            num_sgd_iter=config["num_sgd_iter"],
            learner_queue_size=config["learner_queue_size"],
            learner_queue_timeout=config["learner_queue_timeout"])
    return learner_thread


def gather_experiences_directly(workers, config):
    rollouts = ParallelRollouts(
        workers,
        mode="async",
        num_async=config["max_sample_requests_in_flight_per_worker"])

    # Augment with replay and concat to desired train batch size.
    train_batches = rollouts \
        .for_each(lambda batch: batch.decompress_if_needed()) \
        .for_each(MixInReplay(
        num_slots=config["replay_buffer_num_slots"],
        replay_proportion=config["replay_proportion"])) \
        .flatten() \
        .combine(
        ConcatBatches(
            min_batch_size=config["train_batch_size"],
            count_steps_by=config["multiagent"]["count_steps_by"],
        ))

    return train_batches


# Update worker weights as they finish generating experiences.
class BroadcastUpdateLearnerWeights:
    def __init__(self, learner_thread, workers, broadcast_interval):
        self.learner_thread = learner_thread
        self.steps_since_broadcast = 0
        self.broadcast_interval = broadcast_interval
        self.workers = workers
        self.weights = workers.local_worker().get_weights()

    def __call__(self, item):
        actor, batch = item
        self.steps_since_broadcast += 1
        if (self.steps_since_broadcast >= self.broadcast_interval
                and self.learner_thread.weights_updated):
            self.weights = ray.put(self.workers.local_worker().get_weights())
            self.steps_since_broadcast = 0
            self.learner_thread.weights_updated = False
            # Update metrics.
            metrics = _get_shared_metrics()
            metrics.counters["num_weight_broadcasts"] += 1
        actor.set_weights.remote(self.weights, _get_global_vars())
        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())


class ImpalaTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return IMPALA_DEFAULT_CONFIG

    @override(Trainer)
    def get_default_policy_class(self, config: PartialTrainerConfigDict) -> \
            Optional[Type[Policy]]:
        if config["vtrace"]:
            from agents.impala.vtrace_policy import \
                VTraceTorchPolicy
            return VTraceTorchPolicy
        else:
            from agents.a3c.policy import \
                A3CTorchPolicy
            return A3CTorchPolicy

    @override(Trainer)
    def validate_config(self, config):
        # Call the super class' validation method first.
        super().validate_config(config)

        # Check the IMPALA specific config.

        if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
            deprecation_warning(
                "num_data_loader_buffers",
                "num_multi_gpu_tower_stacks",
                error=False)
            config["num_multi_gpu_tower_stacks"] = \
                config["num_data_loader_buffers"]

        if config["entropy_coeff"] < 0.0:
            raise ValueError("`entropy_coeff` must be >= 0.0!")

        # Check whether worker to aggregation-worker ratio makes sense.
        if config["num_aggregation_workers"] > config["num_workers"]:
            raise ValueError(
                "`num_aggregation_workers` must be smaller than or equal "
                "`num_workers`! Aggregation makes no sense otherwise.")
        elif config["num_aggregation_workers"] > \
                config["num_workers"] / 2:
            logger.warning(
                "`num_aggregation_workers` should be significantly smaller "
                "than `num_workers`! Try setting it to 0.5*`num_workers` or "
                "less.")

        # If two separate optimizers/loss terms used for tf, must also set
        # `_tf_policy_handles_more_than_one_loss` to True.
        if config["_separate_vf_optimizer"] is True:
            # Only supported to tf so far.
            # TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
            #  models to return separate sets of weights in order to create
            #  the different torch optimizers).
            raise ValueError(
                "`_separate_vf_optimizer` only supported to tf so far!")

    @staticmethod
    @override(Trainer)
    def execution_plan(workers, config, **kwargs):
        assert len(kwargs) == 0, (
            "IMPALA execution_plan does NOT take any additional parameters")

        if config["num_aggregation_workers"] > 0:
            train_batches = gather_experiences_tree_aggregation(
                workers, config)
        else:
            train_batches = gather_experiences_directly(workers, config)

        # Start the learner thread.
        learner_thread = make_learner_thread(workers.local_worker(), config)
        learner_thread.start()

        # This sub-flow sends experiences to the learner.
        enqueue_op = train_batches \
            .for_each(Enqueue(learner_thread.inqueue))
        # Only need to update workers if there are remote workers.
        if workers.remote_workers():
            enqueue_op = enqueue_op.zip_with_source_actor() \
                .for_each(BroadcastUpdateLearnerWeights(
                learner_thread, workers,
                broadcast_interval=config["broadcast_interval"]))

        def record_steps_trained(item):
            count, fetches = item
            metrics = _get_shared_metrics()
            # Manually update the steps trained counter since the learner
            # thread is executing outside the pipeline.
            metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
            metrics.counters[STEPS_TRAINED_COUNTER] += count
            return item

        # This sub-flow updates the steps trained counter based on learner
        # output.
        dequeue_op = Dequeue(
            learner_thread.outqueue, check=learner_thread.is_alive) \
            .for_each(record_steps_trained)

        merged_op = Concurrently(
            [enqueue_op, dequeue_op], mode="async", output_indexes=[1])

        # Callback for APPO to use to update KL, target network periodically.
        # The input to the callback is the learner fetches dict.
        if config["after_train_step"]:
            merged_op = merged_op.for_each(lambda t: t[1]).for_each(
                config["after_train_step"](workers, config))

        return StandardMetricsReporting(merged_op, workers, config) \
            .for_each(learner_thread.add_learner_metrics)

    @classmethod
    @override(Trainer)
    def default_resource_request(cls, config):
        cf = dict(cls.get_default_config(), **config)

        eval_config = cf["evaluation_config"]

        # Return PlacementGroupFactory containing all needed resources
        # (already properly defined as device bundles).
        return PlacementGroupFactory(
            bundles=[{
                # Driver + Aggregation Workers:
                # Force to be on same node to maximize data bandwidth
                # between aggregation workers and the learner (driver).
                # Aggregation workers tree-aggregate experiences collected
                # from RolloutWorkers (n rollout workers map to m
                # aggregation workers, where m < n) and always use 1 CPU
                # each.
                "CPU": cf["num_cpus_for_driver"] +
                       cf["num_aggregation_workers"],
                "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
            }] + [
                        {
                            # RolloutWorkers.
                            "CPU": cf["num_cpus_per_worker"],
                            "GPU": cf["num_gpus_per_worker"],
                        } for _ in range(cf["num_workers"])
                    ] + ([
                             {
                                 # Evaluation (remote) workers.
                                 # Note: The local eval worker is located on the driver
                                 # CPU.
                                 "CPU": eval_config.get("num_cpus_per_worker",
                                                        cf["num_cpus_per_worker"]),
                                 "GPU": eval_config.get("num_gpus_per_worker",
                                                        cf["num_gpus_per_worker"]),
                             } for _ in range(cf["evaluation_num_workers"])
                         ] if cf["evaluation_interval"] else []),
            strategy=config.get("placement_strategy", "PACK"))
