from typing import Type

from agents.qmix.config import QMIX_DEFAULT_CONFIG
from agents.dqn.simple_q.trainer import SimpleQTrainer
from agents.qmix.policy import QMixTorchPolicy
from worker.worker_set import WorkerSet
from trainer.concurrency_ops import Concurrently
from trainer.metric_ops import StandardMetricsReporting
from trainer.replay_ops import SimpleReplayBuffer, Replay, \
    StoreToReplayBuffer
from trainer.rollout_ops import ParallelRollouts, ConcatBatches
from trainer.train_ops import TrainOneStep, UpdateTargetNetwork
from policy.policy import Policy
from utils.annotations import override
from utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator


class QMixTrainer(SimpleQTrainer):
    @classmethod
    @override(SimpleQTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return QMIX_DEFAULT_CONFIG

    @override(SimpleQTrainer)
    def get_default_policy_class(self,
                                 config: TrainerConfigDict) -> Type[Policy]:
        return QMixTorchPolicy

    @staticmethod
    @override(SimpleQTrainer)
    def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
                       **kwargs) -> LocalIterator[dict]:
        assert len(kwargs) == 0, (
            "QMIX execution_plan does NOT take any additional parameters")

        rollouts = ParallelRollouts(workers, mode="bulk_sync")
        replay_buffer = SimpleReplayBuffer(config["buffer_size"])

        store_op = rollouts \
            .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

        train_op = Replay(local_buffer=replay_buffer) \
            .combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"]
            )) \
            .for_each(TrainOneStep(workers)) \
            .for_each(UpdateTargetNetwork(
                workers, config["target_network_update_freq"]))

        merged_op = Concurrently(
            [store_op, train_op], mode="round_robin", output_indexes=[1])

        return StandardMetricsReporting(merged_op, workers, config)
