import math
from typing import Any, Dict, List, Optional, Type

import ray
from ray.actor import ActorHandle
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict


class DistributedLearners:
    """Container class for n learning @ray.remote-turned policies.
    The container contains n "learner shards", each one consisting of one
    multi-agent replay buffer and m policy actors that share this replay
    buffer.
    """

    def __init__(
        self,
        *,
        config,
        max_num_policies_to_train: int,
        replay_actor_class: Type[ActorHandle],
        replay_actor_args: List[Any],
        num_learner_shards: Optional[int] = None,
    ):
        """Initializes a DistributedLearners instance.
        Args:
            config: The Trainer's config dict.
            max_num_policies_to_train: Maximum number of policies that will ever be
                trainable. For these policies, we'll have to create remote
                policy actors, distributed across n "learner shards".
            num_learner_shards: Optional number of "learner shards" to reserve.
                Each one consists of one multi-agent replay actor and
                m policy actors that share this replay buffer. If None,
                will infer this number automatically from the number of GPUs
                and the max. number of learning policies.
            replay_actor_class: The class to use to produce one multi-agent
                replay buffer on each learner shard (shared by all policy actors
                on that shard).
            replay_actor_args: The args to pass to the remote replay buffer
                actor's constructor.
        """
        self.config = config
        self.num_gpus = self.config["num_gpus"]
        self.max_num_policies_to_train = max_num_policies_to_train
        self.replay_actor_class = replay_actor_class
        self.replay_actor_args = replay_actor_args

        # Auto-num-learner-shard detection:
        # Examples:
        # 4 GPUs + max. 10 policies to train -> 4 shards (0.4 GPU/pol).
        # 8 GPUs + max. 3 policies to train -> 3 shards (2 GPUs/pol).
        # 8 GPUs + max. 2 policies to train -> 2 shards (4 GPUs/pol).
        # 2 GPUs + max. 5 policies to train -> 2 shards (0.4 GPUs/pol).
        if num_learner_shards is None:
            self.num_learner_shards = min(
                self.num_gpus or self.max_num_policies_to_train,
                self.max_num_policies_to_train,
            )
        #
        else:
            self.num_learner_shards = num_learner_shards

        self.num_gpus_per_shard = self.num_gpus // self.num_learner_shards
        if self.num_gpus_per_shard == 0:
            self.num_gpus_per_shard = self.num_gpus / self.num_learner_shards

        num_policies_per_shard = (
            self.max_num_policies_to_train / self.num_learner_shards
        )
        self.num_gpus_per_policy = self.num_gpus_per_shard / num_policies_per_shard
        self.num_policies_per_shard = math.ceil(num_policies_per_shard)

        self.shards = [
            _Shard(
                config=self.config,
                max_num_policies=self.num_policies_per_shard,
                num_gpus_per_policy=self.num_gpus_per_policy,
                replay_actor_class=self.replay_actor_class,
                replay_actor_args=self.replay_actor_args,
            )
            for _ in range(self.num_learner_shards)
        ]

    def add_policy(self, policy_id, policy_spec):
        # Find first empty slot.
        for shard in self.shards:
            if shard.max_num_policies > len(shard.policy_actors):
                pol_actor = shard.add_policy(policy_id, policy_spec)
                return pol_actor
        raise RuntimeError("All shards are full!")

    def remove_policy(self):
        raise NotImplementedError

    def get_policy_actor(self, policy_id):
        for shard in self.shards:
            if policy_id in shard.policy_actors:
                return shard.policy_actors[policy_id]
        raise None

    def get_replay_and_policy_actors(self, policy_id):
        for shard in self.shards:
            if policy_id in shard.policy_actors:
                return shard.replay_actor, shard.policy_actors[policy_id]
        return None, None

    def get_policy_id(self, policy_actor):
        for shard in self.shards:
            for pid, act in shard.policy_actors.items():
                if act == policy_actor:
                    return pid
        raise None

    def get_replay_actors(self):
        return [shard.replay_actor for shard in self.shards]

    def stop(self) -> None:
        """Terminates all ray actors."""
        for shard in self.shards:
            shard.stop()

    def __len__(self):
        """Returns the number of all Policy actors in all our shards."""
        return sum(len(s) for s in self.shards)

    def __iter__(self):
        def _gen():
            for shard in self.shards:
                for pid, policy_actor in shard.policy_actors.items():
                    yield pid, policy_actor, shard.replay_actor

        return _gen()


class _Shard:
    def __init__(
        self,
        config,
        max_num_policies,
        num_gpus_per_policy,
        replay_actor_class,
        replay_actor_args,
    ):
        self.config = config
        self.has_replay_buffer = False
        self.max_num_policies = max_num_policies
        self.num_gpus_per_policy = num_gpus_per_policy
        self.replay_actor_class = replay_actor_class
        self.replay_actor_args = replay_actor_args

        self.replay_actor: Optional[ActorHandle] = None
        self.policy_actors: Dict[str, ActorHandle] = {}

    def add_policy(self, policy_id: PolicyID, policy_spec: PolicySpec):
        # Merge the policies config overrides with the main config.
        # Also, adjust `num_gpus` (to indicate an individual policy's
        # num_gpus, not the total number of GPUs).
        cfg = Trainer.merge_trainer_configs(
            self.config,
            dict(policy_spec.config, **{"num_gpus": self.num_gpus_per_policy}),
        )

        # Need to create the replay actor first. Then add the first policy.
        if self.replay_actor is None:
            return self._add_replay_buffer_and_policy(policy_id, policy_spec, cfg)

        # Replay actor already exists -> Just add a new policy here.

        assert len(self.policy_actors) < self.max_num_policies

        colocated = create_colocated_actors(
            actor_specs=[
                (
                    ray.remote(
                        num_cpus=1,
                        num_gpus=self.num_gpus_per_policy
                        if not cfg["_fake_gpus"]
                        else 0,
                    )(policy_spec.policy_class),
                    # Policy c'tor args.
                    (policy_spec.observation_space, policy_spec.action_space, cfg),
                    # Policy c'tor kwargs={}.
                    {},
                    # Count=1,
                    1,
                )
            ],
            # Force co-locate on the already existing replay actor's node.
            node=ray.get(self.replay_actor.get_host.remote()),
        )

        self.policy_actors[policy_id] = colocated[0][0]

        return self.policy_actors[policy_id]

    def _add_replay_buffer_and_policy(
        self,
        policy_id: PolicyID,
        policy_spec: PolicySpec,
        config: TrainerConfigDict,
    ):
        assert self.replay_actor is None
        assert len(self.policy_actors) == 0

        colocated = create_colocated_actors(
            actor_specs=[
                (self.replay_actor_class, self.replay_actor_args, {}, 1),
            ]
            + [
                (
                    ray.remote(
                        num_cpus=1,
                        num_gpus=self.num_gpus_per_policy
                        if not config["_fake_gpus"]
                        else 0,
                    )(policy_spec.policy_class),
                    # Policy c'tor args.
                    (policy_spec.observation_space, policy_spec.action_space, config),
                    # Policy c'tor kwargs={}.
                    {},
                    # Count=1,
                    1,
                )
            ],
            node=None,
        )  # None

        self.replay_actor = colocated[0][0]
        self.policy_actors[policy_id] = colocated[1][0]
        self.has_replay_buffer = True

        return self.policy_actors[policy_id]

    def stop(self):
        """Terminates all ray actors (replay and n policy actors)."""
        # Terminate the replay actor.
        self.replay_actor.__ray_terminate__.remote()

        # Terminate all policy actors.
        for pid, policy_actor in self.policy_actors.items():
            policy_actor.__ray_terminate__.remote()

    def __len__(self):
        """Returns the number of Policy actors in this shard."""
        return len(self.policy_actors)