import gym
import copy
import logging
import importlib.util
from types import FunctionType
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

import ray
from ray.actor import ActorHandle
from worker.rollout_worker import RolloutWorker
from env.base_env import BaseEnv
from env.env_context import EnvContext
from worker.offline import (
    NoopOutput,
    JsonReader,
    MixedInput,
    JsonWriter,
    ShuffledInput,
    D4RLReader,
)
from policy.policy import Policy, PolicySpec
from utils.annotations import DeveloperAPI
from utils.from_config import from_config
from utils.typing import EnvType, PolicyID, TrainerConfigDict
from ray.tune.registry import registry_contains_input, registry_get_input

logger = logging.getLogger(__name__)

# Generic type var for foreach_* methods.
T = TypeVar("T")


@DeveloperAPI
class WorkerSet:
    """Set of RolloutWorkers with n @ray.remote workers and one local worker.

    Where n may be 0.
    """

    def __init__(
        self,
        *,
        env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
        validate_env: Optional[Callable[[EnvType], None]] = None,
        policy_class: Optional[Type[Policy]] = None,
        trainer_config: Optional[TrainerConfigDict] = None,
        num_workers: int = 0,
        local_worker: bool = True,
        logdir: Optional[str] = None,
    ):
        """Initializes a WorkerSet instance.

        Args:
            env_creator: Function that returns env given env config.
            validate_env: Optional callable to validate the generated
                environment (only on worker=0). The env to be checked is the one returned from
                the env creator.
            policy_class: An optional Policy class. The Policy class to use for creating the policies
                of the workers. If None, PolicySpecs can be
                generated automatically by using the Trainer's default class
                of via a given multi-agent policy config dict.
            trainer_config: Optional dict that extends the common config of
                the Trainer class.
            num_workers: Number of remote rollout workers to create. 0 for local only.
            local_worker: Whether to create a local (non @ray.remote) worker
                in the returned set as well (default: True). If `num_workers`
                is 0, always create a local worker.
            logdir: Optional logging directory for workers.
        """

        if not trainer_config:
            from trainer.trainer import COMMON_CONFIG

            trainer_config = COMMON_CONFIG

        self._env_creator = env_creator
        self._policy_class = policy_class
        self._remote_config = trainer_config
        self._logdir = logdir

        # Force a local worker if num_workers == 0 (no remote workers).
        # Otherwise, this WorkerSet would be empty.
        self._local_worker = None
        if num_workers == 0:
            local_worker = True

        self._local_config = copy.deepcopy(trainer_config)

        # Create a number of @ray.remote workers.
        self._remote_workers = []
        self.add_workers(num_workers)

        # Create a local worker, if needed.
        # If num_workers > 0 and we don't have an env on the local worker,
        # get the observation- and action spaces for each policy from
        # the first remote worker (which does have an env).
        if (
            local_worker
            and self._remote_workers
            and not trainer_config.get("create_env_on_driver")
            and (
                not trainer_config.get("observation_space")
                or not trainer_config.get("action_space")
            )
        ):
            remote_spaces = ray.get(
                self.remote_workers()[0].foreach_policy.remote(
                    lambda p, pid: (pid, p.observation_space, p.action_space)
                )
            )
            spaces = {
                e[0]: (getattr(e[1], "original_space", e[1]), e[2])
                for e in remote_spaces
            }
            # Try to add the actual env's obs/action spaces.
            try:
                env_spaces = ray.get(
                    self.remote_workers()[0].foreach_env.remote(
                        lambda env: (env.observation_space, env.action_space)
                    )
                )[0]
                spaces["__env__"] = env_spaces
            except Exception:
                pass

            logger.info(
                "Inferred observation/action spaces from remote "
                f"worker (local worker has no env): {spaces}"
            )
        else:
            spaces = None

        if local_worker:
            self._local_worker = self._make_worker(
                cls=RolloutWorker,
                env_creator=env_creator,
                validate_env=validate_env,
                policy_cls=self._policy_class,
                worker_index=0,
                num_workers=num_workers,
                config=self._local_config,
                spaces=spaces,
            )

    def local_worker(self) -> RolloutWorker:
        """Returns the local rollout worker."""
        return self._local_worker

    def remote_workers(self) -> List[ActorHandle]:
        """Returns a list of remote rollout workers."""
        return self._remote_workers

    def sync_weights(
        self,
        policies: Optional[List[PolicyID]] = None,
        from_worker: Optional[RolloutWorker] = None,
    ) -> None:
        """Syncs model weights from the local worker to all remote workers.

        Args:
            policies: Optional list of PolicyIDs to sync weights for.
                If None (default), sync weights to/from all policies.
            from_worker: Optional RolloutWorker instance to sync from.
                If None (default), sync from this WorkerSet's local worker.
        """
        if self.local_worker() is None and from_worker is None:
            raise TypeError(
                "No `local_worker` in WorkerSet, must provide `from_worker` "
                "arg in `sync_weights()`!"
            )

        # Only sync if we have remote workers or `from_worker` is provided.
        if self.remote_workers() or from_worker is not None:
            weights = (from_worker or self.local_worker()).get_weights(policies)
            weights_ref = ray.put(weights)
            # Sync to all remote workers in this WorkerSet.
            for to_worker in self.remote_workers():
                to_worker.set_weights.remote(weights_ref)

            # If from_worker is provided, also sync to this WorkerSet's local
            # worker.
            if from_worker is not None and self.local_worker() is not None:
                self.local_worker().set_weights(weights)

    def add_workers(
        self, num_workers: int, config: Optional[TrainerConfigDict] = None
    ) -> None:
        """Creates and adds a number of remote workers to this worker set.

        Can be called several times on the same WorkerSet to add more
        RolloutWorkers to the set.

        Args:
            num_workers: The number of remote Workers to add to this
                WorkerSet.
            config: Optional dict that extends the common config of
                the Trainer class.
        """
        remote_args = {
            "num_cpus": self._remote_config["num_cpus_per_worker"],
            "num_gpus": self._remote_config["num_gpus_per_worker"],
            "resources": self._remote_config["custom_resources_per_worker"],
        }
        cls = RolloutWorker.as_remote(**remote_args).remote
        self._remote_workers.extend(
            [
                self._make_worker(
                    cls=cls,
                    env_creator=self._env_creator,
                    validate_env=None,
                    policy_cls=self._policy_class,
                    worker_index=i + 1,
                    num_workers=num_workers,
                    config=config or self._remote_config,
                )
                for i in range(num_workers)
            ]
        )

    def reset(self, new_remote_workers: List[ActorHandle]) -> None:
        """Hard overrides the remote workers in this set with the given one.

        Args:
            new_remote_workers: A list of new RolloutWorkers
                (as `ActorHandles`) to use as remote workers.
        """
        self._remote_workers = new_remote_workers

    def stop(self) -> None:
        """Calls `stop` on all rollout workers (including the local one)."""
        try:
            self.local_worker().stop()
            tids = [w.stop.remote() for w in self.remote_workers()]
            ray.get(tids)
        except Exception:
            logger.exception("Failed to stop workers!")
        finally:
            for w in self.remote_workers():
                w.__ray_terminate__.remote()

    @DeveloperAPI
    def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
        """Calls the given function with each worker instance as arg.

        Args:
            func: The function to call for each worker (as only arg).

        Returns:
             The list of return values of all calls to `func([worker])`.
        """
        local_result = []
        if self.local_worker() is not None:
            local_result = [func(self.local_worker())]
        remote_results = ray.get([w.apply.remote(func) for w in self.remote_workers()])
        return local_result + remote_results

    @DeveloperAPI
    def foreach_worker_with_index(
        self, func: Callable[[RolloutWorker, int], T]
    ) -> List[T]:
        """Calls `func` with each worker instance and worker idx as args.

        The index will be passed as the second arg to the given function.

        Args:
            func: The function to call for each worker and its index
                (as args). The local worker has index 0, all remote workers
                have indices > 0.

        Returns:
             The list of return values of all calls to `func([worker, idx])`.
                The first entry in this list are the results of the local
                worker, followed by all remote workers' results.
        """
        local_result = []
        # Local worker: Index=0.
        if self.local_worker() is not None:
            local_result = [func(self.local_worker(), 0)]
        # Remote workers: Index > 0.
        remote_results = ray.get(
            [w.apply.remote(func, i + 1) for i, w in enumerate(self.remote_workers())]
        )
        return local_result + remote_results

    @DeveloperAPI
    def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
        """Calls `func` with each worker's (policy, PolicyID) tuple.

        Note that in the multi-agent case, each worker may have more than one
        policy.

        Args:
            func: A function - taking a Policy and its ID - that is
                called on all workers' Policies.

        Returns:
            The list of return values of func over all workers' policies. The
                length of this list is:
                (num_workers + 1 (local-worker)) *
                [num policies in the multi-agent config dict].
                The local workers' results are first, followed by all remote
                workers' results
        """
        results = []
        if self.local_worker() is not None:
            results = self.local_worker().foreach_policy(func)
        ray_gets = []
        for worker in self.remote_workers():
            ray_gets.append(worker.apply.remote(lambda w: w.foreach_policy(func)))
        remote_results = ray.get(ray_gets)
        for r in remote_results:
            results.extend(r)
        return results

    @DeveloperAPI
    def trainable_policies(self) -> List[PolicyID]:
        """Returns the list of trainable policy ids."""
        if self.local_worker() is not None:
            return self.local_worker().policies_to_train
        else:
            raise NotImplementedError

    @DeveloperAPI
    def foreach_trainable_policy(
        self, func: Callable[[Policy, PolicyID], T]
    ) -> List[T]:
        """Apply `func` to all workers' Policies iff in `policies_to_train`.

        Args:
            func: A function - taking a Policy and its ID - that is
                called on all workers' Policies in `worker.policies_to_train`.

        Returns:
            List[any]: The list of n return values of all
                `func([trainable policy], [ID])`-calls.
        """
        results = []
        if self.local_worker() is not None:
            results = self.local_worker().foreach_trainable_policy(func)
        ray_gets = []
        for worker in self.remote_workers():
            ray_gets.append(
                worker.apply.remote(lambda w: w.foreach_trainable_policy(func))
            )
        remote_results = ray.get(ray_gets)
        for r in remote_results:
            results.extend(r)
        return results

    @DeveloperAPI
    def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
        """Calls `func` with all workers' sub-environments as args.

        An "underlying sub environment" is a single clone of an env within
        a vectorized environment.
        `func` takes a single underlying sub environment as arg, e.g. a
        gym.Env object.

        Args:
            func: A function - taking an EnvType (normally a gym.Env object)
                as arg and returning a list of lists of return values, one
                value per underlying sub-environment per each worker.

        Returns:
            The list (workers) of lists (sub environments) of results.
        """
        local_results = []
        if self.local_worker() is not None:
            local_results = [self.local_worker().foreach_env(func)]
        ray_gets = []
        for worker in self.remote_workers():
            ray_gets.append(worker.foreach_env.remote(func))
        return local_results + ray.get(ray_gets)

    @DeveloperAPI
    def foreach_env_with_context(
        self, func: Callable[[BaseEnv, EnvContext], List[T]]
    ) -> List[List[T]]:
        """Calls `func` with all workers' sub-environments and env_ctx as args.

        An "underlying sub environment" is a single clone of an env within
        a vectorized environment.
        `func` takes a single underlying sub environment and the env_context
        as args.

        Args:
            func: A function - taking a BaseEnv object and an EnvContext as
                arg - and returning a list of lists of return values over envs
                of the worker.

        Returns:
            The list (1 item per workers) of lists (1 item per sub-environment)
                of results.
        """
        local_results = []
        if self.local_worker() is not None:
            local_results = [self.local_worker().foreach_env_with_context(func)]
        ray_gets = []
        for worker in self.remote_workers():
            ray_gets.append(worker.foreach_env_with_context.remote(func))
        return local_results + ray.get(ray_gets)

    def _make_worker(
        self,
        *,
        cls: Callable,
        env_creator: Callable[[EnvContext], EnvType],
        validate_env: Optional[Callable[[EnvType], None]],
        policy_cls: Type[Policy],
        worker_index: int,
        num_workers: int,
        config: TrainerConfigDict,
        spaces: Optional[
            Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]
        ] = None,
    ) -> Union[RolloutWorker, ActorHandle]:
        def valid_module(class_path):
            if isinstance(class_path, str) and "." in class_path:
                module_path, class_name = class_path.rsplit(".", 1)
                try:
                    spec = importlib.util.find_spec(module_path)
                    if spec is not None:
                        return True
                except (ModuleNotFoundError, ValueError):
                    print(
                        f"module {module_path} not found while trying to get "
                        f"input {class_path}"
                    )
            return False

        if isinstance(config["input"], FunctionType):
            input_creator = config["input"]
        elif config["input"] == "sampler":
            input_creator = lambda ioctx: ioctx.default_sampler_input()
        elif isinstance(config["input"], dict):
            input_creator = lambda ioctx: ShuffledInput(
                MixedInput(config["input"], ioctx), config["shuffle_buffer_size"]
            )
        elif isinstance(config["input"], str) and registry_contains_input(
            config["input"]
        ):
            input_creator = registry_get_input(config["input"])
        elif "d4rl" in config["input"]:
            env_name = config["input"].split(".")[-1]
            input_creator = lambda ioctx: D4RLReader(env_name, ioctx)
        elif valid_module(config["input"]):
            input_creator = lambda ioctx: ShuffledInput(
                from_config(config["input"], ioctx=ioctx)
            )
        else:
            input_creator = lambda ioctx: ShuffledInput(
                JsonReader(config["input"], ioctx), config["shuffle_buffer_size"]
            )

        if isinstance(config["output"], FunctionType):
            output_creator = config["output"]
        elif config["output"] is None:
            output_creator = lambda ioctx: NoopOutput()
        elif config["output"] == "logdir":
            output_creator = lambda ioctx: JsonWriter(
                ioctx.log_dir,
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"],
            )
        else:
            output_creator = lambda ioctx: JsonWriter(
                config["output"],
                ioctx,
                max_file_size=config["output_max_file_size"],
                compress_columns=config["output_compress_columns"],
            )

        if config["input"] == "sampler":
            input_evaluation = []
        else:
            input_evaluation = config["input_evaluation"]

        # Assert everything is correct in "multiagent" config dict (if given).
        ma_policies = config["multiagent"]["policies"]
        if ma_policies:
            for pid, policy_spec in ma_policies.copy().items():
                assert isinstance(policy_spec, (PolicySpec, list, tuple))
                # Class is None -> Use `policy_cls`.
                if policy_spec.policy_class is None:
                    ma_policies[pid] = ma_policies[pid]._replace(
                        policy_class=policy_cls
                    )
            policies = ma_policies

        # Create a policy_spec (MultiAgentPolicyConfigDict),
        # even if no "multiagent" setup given by user.
        else:
            policies = policy_cls

        worker = cls(
            env_creator=env_creator,
            validate_env=validate_env,
            policy_spec=policies,
            policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
            policies_to_train=config["multiagent"]["policies_to_train"],
            rollout_fragment_length=config["rollout_fragment_length"],
            count_steps_by=config["multiagent"]["count_steps_by"],
            batch_mode=config["batch_mode"],
            episode_horizon=config["horizon"],
            preprocessor_pref=config["preprocessor_pref"],
            sample_async=config["sample_async"],
            compress_observations=config["compress_observations"],
            num_envs=config["num_envs_per_worker"],
            observation_fn=config["multiagent"]["observation_fn"],
            observation_filter=config["observation_filter"],
            clip_rewards=config["clip_rewards"],
            normalize_actions=config["normalize_actions"],
            clip_actions=config["clip_actions"],
            env_config=config["env_config"],
            policy_config=config,
            worker_index=worker_index,
            num_workers=num_workers,
            record_env=config["record_env"],
            log_dir=self._logdir,
            log_level=config["log_level"],
            callbacks=config["callbacks"],
            input_creator=input_creator,
            input_evaluation=input_evaluation,
            output_creator=output_creator,
            remote_worker_envs=config["remote_worker_envs"],
            remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
            soft_horizon=config["soft_horizon"],
            no_done_at_end=config["no_done_at_end"],
            seed=(config["seed"] + worker_index)
            if config["seed"] is not None
            else None,
            spaces=spaces,
        )

        return worker
