import gym
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
    TYPE_CHECKING, Union

from src.rllib.models.catalog import ModelCatalog
from src.rllib.models.jax.jax_modelv2 import JAXModelV2
from src.rllib.models.modelv2 import ModelV2
from src.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from src.rllib.models.torch.torch_modelv2 import TorchModelV2
from src.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.policy.torch_policy import TorchPolicy
from src.rllib.utils import add_mixins, force_list, NullContextManager
from src.rllib.utils.annotations import override, DeveloperAPI
from src.rllib.utils.framework import try_import_torch, try_import_jax
from src.rllib.utils.torch_ops import convert_to_non_torch_type
from src.rllib.utils.typing import ModelGradients, TensorType, \
    TrainerConfigDict

if TYPE_CHECKING:
    from src.rllib.evaluation import MultiAgentEpisode  # noqa

jax, _ = try_import_jax()
torch, _ = try_import_torch()


# TODO: (sven) Unify this with `build_tf_policy` as well.
@DeveloperAPI
def build_policy_class(
        name: str,
        framework: str,
        *,
        loss_fn: Optional[Callable[[
            Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
        ], Union[TensorType, List[TensorType]]]],
        get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
        stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
            str, TensorType]]] = None,
        postprocess_fn: Optional[Callable[[
            Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], Optional[
                "MultiAgentEpisode"]
        ], SampleBatch]] = None,
        extra_action_out_fn: Optional[Callable[[
            Policy, Dict[str, TensorType], List[TensorType], ModelV2,
            TorchDistributionWrapper
        ], Dict[str, TensorType]]] = None,
        extra_grad_process_fn: Optional[Callable[[
            Policy, "torch.optim.Optimizer", TensorType
        ], Dict[str, TensorType]]] = None,
        # TODO: (sven) Replace "fetches" with "process".
        extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[
            str, TensorType]]] = None,
        optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
                                        "torch.optim.Optimizer"]] = None,
        validate_spaces: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        before_loss_init: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], None]] = None,
        after_init: Optional[Callable[
            [Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
        _after_loss_init: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], None]] = None,
        action_sampler_fn: Optional[Callable[[TensorType, List[
            TensorType]], Tuple[TensorType, TensorType]]] = None,
        action_distribution_fn: Optional[Callable[[
            Policy, ModelV2, TensorType, TensorType, TensorType
        ], Tuple[TensorType, type, List[TensorType]]]] = None,
        make_model: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], ModelV2]] = None,
        make_model_and_action_dist: Optional[Callable[[
            Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
        ], Tuple[ModelV2, Type[TorchDistributionWrapper]]]] = None,
        compute_gradients_fn: Optional[Callable[[Policy, SampleBatch], Tuple[
            ModelGradients, dict]]] = None,
        apply_gradients_fn: Optional[Callable[
            [Policy, "torch.optim.Optimizer"], None]] = None,
        mixins: Optional[List[type]] = None,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
    """Helper function for creating a new Policy class at runtime.

    Supports frameworks JAX and PyTorch.

    Args:
        name (str): name of the policy (e.g., "PPOTorchPolicy")
        framework (str): Either "jax" or "torch".
        loss_fn (Optional[Callable[[Policy, ModelV2,
            Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
            List[TensorType]]]]): Callable that returns a loss tensor.
        get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
            Optional callable that returns the default config to merge with any
            overrides. If None, uses only(!) the user-provided
            PartialTrainerConfigDict as dict for this Policy.
        postprocess_fn (Optional[Callable[[Policy, SampleBatch,
            Optional[Dict[Any, SampleBatch]], Optional["MultiAgentEpisode"]],
            SampleBatch]]): Optional callable for post-processing experience
            batches (called after the super's `postprocess_trajectory` method).
        stats_fn (Optional[Callable[[Policy, SampleBatch],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            values given the policy and training batch. If None,
            will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
            used for logging (e.g. in TensorBoard).
        extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
            List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
            TensorType]]]): Optional callable that returns a dict of extra
            values to include in experiences. If None, no extra computations
            will be performed.
        extra_grad_process_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
            Optional callable that is called after gradients are computed and
            returns a processing info dict. If None, will call the
            `TorchPolicy.extra_grad_process()` method instead.
        # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
        extra_learn_fetches_fn (Optional[Callable[[Policy],
            Dict[str, TensorType]]]): Optional callable that returns a dict of
            extra tensors from the policy after loss evaluation. If None,
            will call the `TorchPolicy.extra_compute_grad_fetches()` method
            instead.
        optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
            "torch.optim.Optimizer"]]): Optional callable that returns a
            torch optimizer given the policy and config. If None, will call
            the `TorchPolicy.optimizer()` method instead (which returns a
            torch Adam optimizer).
        validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable that takes the
            Policy, observation_space, action_space, and config to check for
            correctness. If None, no spaces checking will be done.
        before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): Optional callable to run at the
            beginning of `Policy.__init__` that takes the same arguments as
            the Policy constructor. If None, this step will be skipped.
        before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run prior to loss init. If None, this step will be skipped.
        after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
            TrainerConfigDict], None]]): DEPRECATED: Use `before_loss_init`
            instead.
        _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
            run after the loss init. If None, this step will be skipped.
            This will be deprecated at some point and renamed into `after_init`
            to match `build_tf_policy()` behavior.
        action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
            Tuple[TensorType, TensorType]]]): Optional callable returning a
            sampled action and its log-likelihood given some (obs and state)
            inputs. If None, will either use `action_distribution_fn` or
            compute actions by calling self.model, then sampling from the
            so parameterized action distribution.
        action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
            TensorType, TensorType], Tuple[TensorType,
            Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
            that takes the Policy, Model, the observation batch, an
            explore-flag, a timestep, and an is_training flag and returns a
            tuple of a) distribution inputs (parameters), b) a dist-class to
            generate an action distribution object from, and c) internal-state
            outputs (empty list if not applicable). If None, will either use
            `action_sampler_fn` or compute actions by calling self.model,
            then sampling from the parameterized action distribution.
        make_model (Optional[Callable[[Policy, gym.spaces.Space,
            gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
            that takes the same arguments as Policy.__init__ and returns a
            model instance. The distribution class will be determined
            automatically. Note: Only one of `make_model` or
            `make_model_and_action_dist` should be provided. If both are None,
            a default Model will be created.
        make_model_and_action_dist (Optional[Callable[[Policy,
            gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
            Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
            callable that takes the same arguments as Policy.__init__ and
            returns a tuple of model instance and torch action distribution
            class.
            Note: Only one of `make_model` or `make_model_and_action_dist`
            should be provided. If both are None, a default Model will be
            created.
        compute_gradients_fn (Optional[Callable[
            [Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional
            callable that the sampled batch an computes the gradients w.r.
            to the loss function.
            If None, will call the `TorchPolicy.compute_gradients()` method
            instead.
        apply_gradients_fn (Optional[Callable[[Policy,
            "torch.optim.Optimizer"], None]]): Optional callable that
            takes a grads list and applies these to the Model's parameters.
            If None, will call the `TorchPolicy.apply_gradients()` method
            instead.
        mixins (Optional[List[type]]): Optional list of any class mixins for
            the returned policy class. These mixins will be applied in order
            and will have higher precedence than the TorchPolicy class.
        get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
            Optional callable that returns the divisibility requirement for
            sample batches. If None, will assume a value of 1.

    Returns:
        Type[TorchPolicy]: TorchPolicy child class constructed from the
            specified args.
    """

    original_kwargs = locals().copy()
    parent_cls = TorchPolicy
    base = add_mixins(parent_cls, mixins)

    class policy_cls(base):
        def __init__(self, obs_space, action_space, config):
            # Set up the config from possible default-config fn and given
            # config arg.
            if get_default_config:
                config = dict(get_default_config(), **config)
            self.config = config

            # Set the DL framework for this Policy.
            self.framework = self.config["framework"] = framework

            # Validate observation- and action-spaces.
            if validate_spaces:
                validate_spaces(self, obs_space, action_space, self.config)

            # Do some pre-initialization steps.
            if before_init:
                before_init(self, obs_space, action_space, self.config)

            # Model is customized (use default action dist class).
            if make_model:
                assert make_model_and_action_dist is None, \
                    "Either `make_model` or `make_model_and_action_dist`" \
                    " must be None!"
                self.model = make_model(self, obs_space, action_space, config)
                dist_class, _ = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework=framework)
            # Model and action dist class are customized.
            elif make_model_and_action_dist:
                self.model, dist_class = make_model_and_action_dist(
                    self, obs_space, action_space, config)
            # Use default model and default action dist.
            else:
                dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"], framework=framework)
                self.model = ModelCatalog.get_model_v2(
                    obs_space=obs_space,
                    action_space=action_space,
                    num_outputs=logit_dim,
                    model_config=self.config["model"],
                    framework=framework)

            # Make sure, we passed in a correct Model factory.
            model_cls = TorchModelV2 if framework == "torch" else JAXModelV2
            assert isinstance(self.model, model_cls), \
                "ERROR: Generated Model must be a TorchModelV2 object!"

            # Call the framework-specific Policy constructor.
            self.parent_cls = parent_cls
            self.parent_cls.__init__(
                self,
                observation_space=obs_space,
                action_space=action_space,
                config=config,
                model=self.model,
                loss=loss_fn,
                action_distribution_class=dist_class,
                action_sampler_fn=action_sampler_fn,
                action_distribution_fn=action_distribution_fn,
                max_seq_len=config["model"]["max_seq_len"],
                get_batch_divisibility_req=get_batch_divisibility_req,
            )

            # Merge Model's view requirements into Policy's.
            self.view_requirements.update(self.model.view_requirements)

            _before_loss_init = before_loss_init or after_init
            if _before_loss_init:
                _before_loss_init(self, self.observation_space,
                                  self.action_space, config)

            # Perform test runs through postprocessing- and loss functions.
            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )

            if _after_loss_init:
                _after_loss_init(self, obs_space, action_space, config)

            # Got to reset global_timestep again after this fake run-through.
            self.global_timestep = 0

        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak
            # in torch (issue #6962).
            with self._no_grad_context():
                # Call super's postprocess_trajectory first.
                sample_batch = super().postprocess_trajectory(
                    sample_batch, other_agent_batches, episode)
                if postprocess_fn:
                    return postprocess_fn(self, sample_batch,
                                          other_agent_batches, episode)

                return sample_batch

        @override(parent_cls)
        def extra_grad_process(self, optimizer, loss):
            """Called after optimizer.zero_grad() and loss.backward() calls.

            Allows for gradient processing before optimizer.step() is called.
            E.g. for gradient clipping.
            """
            if extra_grad_process_fn:
                return extra_grad_process_fn(self, optimizer, loss)
            else:
                return parent_cls.extra_grad_process(self, optimizer, loss)

        @override(parent_cls)
        def extra_compute_grad_fetches(self):
            if extra_learn_fetches_fn:
                fetches = convert_to_non_torch_type(
                    extra_learn_fetches_fn(self))
                # Auto-add empty learner stats dict if needed.
                return dict({LEARNER_STATS_KEY: {}}, **fetches)
            else:
                return parent_cls.extra_compute_grad_fetches(self)

        @override(parent_cls)
        def compute_gradients(self, batch):
            if compute_gradients_fn:
                return compute_gradients_fn(self, batch)
            else:
                return parent_cls.compute_gradients(self, batch)

        @override(parent_cls)
        def apply_gradients(self, gradients):
            if apply_gradients_fn:
                apply_gradients_fn(self, gradients)
            else:
                parent_cls.apply_gradients(self, gradients)

        @override(parent_cls)
        def extra_action_out(self, input_dict, state_batches, model,
                             action_dist):
            with self._no_grad_context():
                if extra_action_out_fn:
                    stats_dict = extra_action_out_fn(
                        self, input_dict, state_batches, model, action_dist)
                else:
                    stats_dict = parent_cls.extra_action_out(
                        self, input_dict, state_batches, model, action_dist)
                return self._convert_to_non_torch_type(stats_dict)

        @override(parent_cls)
        def optimizer(self):
            if optimizer_fn:
                optimizers = optimizer_fn(self, self.config)
            else:
                optimizers = parent_cls.optimizer(self)
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)
            return optimizers

        @override(parent_cls)
        def extra_grad_info(self, train_batch):
            with self._no_grad_context():
                if stats_fn:
                    stats_dict = stats_fn(self, train_batch)
                else:
                    stats_dict = self.parent_cls.extra_grad_info(
                        self, train_batch)
                return self._convert_to_non_torch_type(stats_dict)

        def _no_grad_context(self):
            if self.framework == "torch":
                return torch.no_grad()
            return NullContextManager()

        def _convert_to_non_torch_type(self, data):
            if self.framework == "torch":
                return convert_to_non_torch_type(data)
            return data

    def with_updates(**overrides):
        """Creates a Torch|JAXPolicy cls based on settings of another one.

        Keyword Args:
            **overrides: The settings (passed into `build_torch_policy`) that
                should be different from the class that this method is called
                on.

        Returns:
            type: A new Torch|JAXPolicy sub-class.

        Examples:
        >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
        ..    name="MySpecialDQNPolicyClass",
        ..    loss_function=[some_new_loss_function],
        .. )
        """
        return build_policy_class(**dict(original_kwargs, **overrides))

    policy_cls.with_updates = staticmethod(with_updates)
    policy_cls.__name__ = name
    policy_cls.__qualname__ = name
    return policy_cls
