import logging
import platform
from typing import List, Dict, Any

import ray
from worker.worker_set import WorkerSet
from trainer.common import (
    AGENT_STEPS_SAMPLED_COUNTER,
    STEPS_SAMPLED_COUNTER,
    _get_shared_metrics,
)
from trainer.replay_ops import MixInReplay
from trainer.rollout_ops import ParallelRollouts, ConcatBatches
from policy.sample_batch import MultiAgentBatch
from utils.actors import create_colocated
from utils.typing import SampleBatchType, ModelWeights
from ray.util.iter import (
    ParallelIterator,
    ParallelIteratorWorker,
    from_actors,
    LocalIterator,
)

logger = logging.getLogger(__name__)


@ray.remote(num_cpus=0)
class Aggregator(ParallelIteratorWorker):
    """An aggregation worker used by gather_experiences_tree_aggregation().

    Each of these actors is a shard of a parallel iterator that consumes
    batches from RolloutWorker actors, and emits batches of size
    train_batch_size. This allows expensive decompression / concatenation
    work to be offloaded to these actors instead of run in the learner.
    """

    def __init__(
        self, config: Dict, rollout_group: "ParallelIterator[SampleBatchType]"
    ):
        self.weights = None
        self.global_vars = None

        def generator():
            it = rollout_group.gather_async(
                num_async=config["max_sample_requests_in_flight_per_worker"]
            )

            # Update the rollout worker with our latest policy weights.
            def update_worker(item):
                worker, batch = item
                if self.weights:
                    worker.set_weights.remote(self.weights, self.global_vars)
                return batch

            # Augment with replay and concat to desired train batch size.
            it = (
                it.zip_with_source_actor()
                .for_each(update_worker)
                .for_each(lambda batch: batch.decompress_if_needed())
                .for_each(
                    MixInReplay(
                        num_slots=config["replay_buffer_num_slots"],
                        replay_proportion=config["replay_proportion"],
                    )
                )
                .flatten()
                .combine(
                    ConcatBatches(
                        min_batch_size=config["train_batch_size"],
                        count_steps_by=config["multiagent"]["count_steps_by"],
                    )
                )
            )

            for train_batch in it:
                yield train_batch

        super().__init__(generator, repeat=False)

    def get_host(self) -> str:
        return platform.node()

    def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None:
        self.weights = weights
        self.global_vars = global_vars


def gather_experiences_tree_aggregation(
    workers: WorkerSet, config: Dict
) -> "LocalIterator[Any]":
    """Tree aggregation version of gather_experiences_directly()."""

    rollouts = ParallelRollouts(workers, mode="raw")

    # Divide up the workers between aggregators.
    worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
    i = 0
    for worker_idx in range(len(workers.remote_workers())):
        worker_assignments[i].append(worker_idx)
        i += 1
        i %= len(worker_assignments)
    logger.info("Worker assignments: {}".format(worker_assignments))

    # Create parallel iterators that represent each aggregation group.
    rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
        rollouts.select_shards(assigned) for assigned in worker_assignments
    ]

    # This spawns |num_aggregation_workers| intermediate actors that aggregate
    # experiences in parallel. We force colocation on the same node to maximize
    # data bandwidth between them and the driver.
    train_batches = from_actors(
        [create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups]
    )

    # TODO(ekl) properly account for replay.
    def record_steps_sampled(batch):
        metrics = _get_shared_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
        if isinstance(batch, MultiAgentBatch):
            metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps()
        else:
            metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
        return batch

    return train_batches.gather_async().for_each(record_steps_sampled)
