# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import itertools
from dataclasses import dataclass

import torch
from tensordict import set_lazy_legacy
from tensordict.nn import InteractionType
from torch import nn
from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
    NoisyLinear,
    SafeModule,
    SafeProbabilisticModule,
    SafeProbabilisticTensorDictSequential,
    SafeSequential,
)
from torchrl.modules.distributions import (
    Delta,
    OneHotCategorical,
    TanhDelta,
    TanhNormal,
)
from torchrl.modules.models.model_based import (
    DreamerActor,
    ObsDecoder,
    ObsEncoder,
    RSSMPosterior,
    RSSMPrior,
    RSSMRollout,
)
from torchrl.modules.models.models import DuelingCnnDQNet, DuelingMlpDQNet, MLP
from torchrl.modules.tensordict_module import (
    Actor,
    DistributionalQValueActor,
    QValueActor,
)
from torchrl.modules.tensordict_module.world_models import WorldModelWrapper
from torchrl.trainers.helpers import transformed_env_constructor

DISTRIBUTIONS = {
    "delta": Delta,
    "tanh-normal": TanhNormal,
    "categorical": OneHotCategorical,
    "tanh-delta": TanhDelta,
}

ACTIVATIONS = {
    "elu": nn.ELU,
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
}


def make_dqn_actor(
    proof_environment: EnvBase, cfg: DictConfig, device: torch.device  # noqa: F821
) -> Actor:
    """DQN constructor helper function.

    Args:
        proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec.
        cfg (DictConfig): contains arguments of the DQN script
        device (torch.device): device on which the model must be cast

    Returns:
         A DQN policy operator.

    Examples:
        >>> from torchrl.trainers.helpers.models import make_dqn_actor, DiscreteModelConfig
        >>> from torchrl.trainers.helpers.envs import EnvConfig
        >>> from torchrl.envs.libs.gym import GymEnv
        >>> from torchrl.envs.transforms import ToTensorImage, TransformedEnv
        >>> import hydra
        >>> from hydra.core.config_store import ConfigStore
        >>> import dataclasses
        >>> proof_environment = TransformedEnv(GymEnv("ALE/Pong-v5",
        ...    pixels_only=True), ToTensorImage())
        >>> device = torch.device("cpu")
        >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in
        ...                    (DiscreteModelConfig, EnvConfig)
        ...                   for config_field in dataclasses.fields(config_cls)]
        >>> Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
        >>> cs = ConfigStore.instance()
        >>> cs.store(name="config", node=Config)
        >>> with initialize(config_path=None):
        >>>     cfg = compose(config_name="config")
        >>> actor = make_dqn_actor(proof_environment, cfg, device)
        >>> td = proof_environment.reset()
        >>> print(actor(td))
        TensorDict(
            fields={
                done: Tensor(torch.Size([1]), dtype=torch.bool),
                pixels: Tensor(torch.Size([3, 210, 160]), dtype=torch.float32),
                action: Tensor(torch.Size([6]), dtype=torch.int64),
                action_value: Tensor(torch.Size([6]), dtype=torch.float32),
                chosen_action_value: Tensor(torch.Size([1]), dtype=torch.float32)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)


    """
    env_specs = proof_environment.specs

    atoms = cfg.atoms if cfg.distributional else None
    linear_layer_class = torch.nn.Linear if not cfg.noisy else NoisyLinear

    action_spec = env_specs["input_spec", "full_action_spec", "action"]
    if action_spec.domain != "discrete":
        raise ValueError(
            f"env {proof_environment} has an action domain "
            f"{action_spec.domain} which is incompatible with "
            f"DQN. Make sure your environment has a discrete "
            f"domain."
        )

    if cfg.from_pixels:
        net_class = DuelingCnnDQNet
        default_net_kwargs = {
            "cnn_kwargs": {
                "bias_last_layer": True,
                "depth": None,
                "num_cells": [32, 64, 64],
                "kernel_sizes": [8, 4, 3],
                "strides": [4, 2, 1],
            },
            "mlp_kwargs": {"num_cells": 512, "layer_class": linear_layer_class},
        }
        in_key = "pixels"

    else:
        net_class = DuelingMlpDQNet
        default_net_kwargs = {
            "mlp_kwargs_feature": {},  # see class for details
            "mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class},
        }
        # automatically infer in key
        (in_key,) = itertools.islice(
            env_specs["output_spec", "full_observation_spec"], 1
        )

    actor_class = QValueActor
    actor_kwargs = {}

    if isinstance(action_spec, Categorical):
        # if action spec is modeled as categorical variable, we still need to have features equal
        # to the number of possible choices and also set categorical behavioral for actors.
        actor_kwargs.update({"action_space": "categorical"})
        out_features = env_specs["input_spec", "full_action_spec", "action"].space.n
    else:
        out_features = action_spec.shape[0]

    if cfg.distributional:
        if not atoms:
            raise RuntimeError(
                "Expected atoms to be a positive integer, " f"got {atoms}"
            )
        vmin = -3
        vmax = 3

        out_features = (atoms, out_features)
        support = torch.linspace(vmin, vmax, atoms)
        actor_class = DistributionalQValueActor
        actor_kwargs.update({"support": support})
        default_net_kwargs.update({"out_features_value": (atoms, 1)})

    net = net_class(
        out_features=out_features,
        **default_net_kwargs,
    )

    model = actor_class(
        module=net,
        spec=Composite(action=action_spec),
        in_keys=[in_key],
        safe=True,
        **actor_kwargs,
    ).to(device)

    # init
    with torch.no_grad():
        td = proof_environment.fake_tensordict()
        td = td.unsqueeze(-1)
        model(td.to(device))
    return model


@set_lazy_legacy(False)
def make_dreamer(
    cfg: DictConfig,  # noqa: F821
    proof_environment: EnvBase = None,
    device: DEVICE_TYPING = "cpu",
    action_key: str = "action",
    value_key: str = "state_value",
    use_decoder_in_env: bool = False,
    obs_norm_state_dict=None,
) -> nn.ModuleList:
    """Create Dreamer components.

    Args:
        cfg (DictConfig): Config object.
        proof_environment (EnvBase): Environment to initialize the model.
        device (DEVICE_TYPING, optional): Device to use.
            Defaults to "cpu".
        action_key (str, optional): Key to use for the action.
            Defaults to "action".
        value_key (str, optional): Key to use for the value.
            Defaults to "state_value".
        use_decoder_in_env (bool, optional): Whether to use the decoder in the model based dreamer env.
            Defaults to `False`.
        obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform used
            when proof_environment is missing. Defaults to None.

    Returns:
        nn.TensorDictModel: Dreamer World model.
        nn.TensorDictModel: Dreamer Model based environment.
        nn.TensorDictModel: Dreamer Actor the world model space.
        nn.TensorDictModel: Dreamer Value model.
        nn.TensorDictModel: Dreamer Actor for the real world space.

    """
    proof_env_is_none = proof_environment is None
    if proof_env_is_none:
        proof_environment = transformed_env_constructor(
            cfg=cfg, use_env_creator=False, obs_norm_state_dict=obs_norm_state_dict
        )()

    # Modules
    obs_encoder = ObsEncoder()
    obs_decoder = ObsDecoder()

    rssm_prior = RSSMPrior(
        hidden_dim=cfg.rssm_hidden_dim,
        rnn_hidden_dim=cfg.rssm_hidden_dim,
        state_dim=cfg.state_dim,
        action_spec=proof_environment.action_spec,
    )
    rssm_posterior = RSSMPosterior(
        hidden_dim=cfg.rssm_hidden_dim, state_dim=cfg.state_dim
    )
    reward_module = MLP(
        out_features=1, depth=2, num_cells=cfg.mlp_num_units, activation_class=nn.ELU
    )

    world_model = _dreamer_make_world_model(
        obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
    ).to(device)
    with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
        tensordict = proof_environment.fake_tensordict().unsqueeze(-1)
        tensordict = tensordict.to(device)
        world_model(tensordict)

    model_based_env = _dreamer_make_mbenv(
        reward_module,
        rssm_prior,
        obs_decoder,
        proof_environment,
        use_decoder_in_env,
        cfg.state_dim,
        cfg.rssm_hidden_dim,
    )
    model_based_env = model_based_env.to(device)

    actor_simulator, actor_realworld = _dreamer_make_actors(
        obs_encoder,
        rssm_prior,
        rssm_posterior,
        cfg.mlp_num_units,
        action_key,
        proof_environment,
    )
    actor_simulator = actor_simulator.to(device)

    value_model = _dreamer_make_value_model(cfg.mlp_num_units, value_key)
    value_model = value_model.to(device)
    with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
        tensordict = model_based_env.fake_tensordict().unsqueeze(-1)
        tensordict = tensordict.to(device)
        tensordict = actor_simulator(tensordict)
        value_model(tensordict)

    actor_realworld = actor_realworld.to(device)
    if proof_env_is_none:
        proof_environment.close()
        torch.cuda.empty_cache()
        del proof_environment

    del tensordict
    return world_model, model_based_env, actor_simulator, value_model, actor_realworld


def _dreamer_make_world_model(
    obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
):
    # World Model and reward model
    rssm_rollout = RSSMRollout(
        SafeModule(
            rssm_prior,
            in_keys=["state", "belief", "action"],
            out_keys=[
                ("next", "prior_mean"),
                ("next", "prior_std"),
                "_",
                ("next", "belief"),
            ],
        ),
        SafeModule(
            rssm_posterior,
            in_keys=[("next", "belief"), ("next", "encoded_latents")],
            out_keys=[
                ("next", "posterior_mean"),
                ("next", "posterior_std"),
                ("next", "state"),
            ],
        ),
    )

    transition_model = SafeSequential(
        SafeModule(
            obs_encoder,
            in_keys=[("next", "pixels")],
            out_keys=[("next", "encoded_latents")],
        ),
        rssm_rollout,
        SafeModule(
            obs_decoder,
            in_keys=[("next", "state"), ("next", "belief")],
            out_keys=[("next", "reco_pixels")],
        ),
    )
    reward_model = SafeModule(
        reward_module,
        in_keys=[("next", "state"), ("next", "belief")],
        out_keys=[("next", "reward")],
    )
    world_model = WorldModelWrapper(
        transition_model,
        reward_model,
    )
    return world_model


def _dreamer_make_actors(
    obs_encoder,
    rssm_prior,
    rssm_posterior,
    mlp_num_units,
    action_key,
    proof_environment,
):
    actor_module = DreamerActor(
        out_features=proof_environment.action_spec.shape[0],
        depth=3,
        num_cells=mlp_num_units,
        activation_class=nn.ELU,
    )
    actor_simulator = _dreamer_make_actor_sim(
        action_key, proof_environment, actor_module
    )
    actor_realworld = _dreamer_make_actor_real(
        obs_encoder,
        rssm_prior,
        rssm_posterior,
        actor_module,
        action_key,
        proof_environment,
    )
    return actor_simulator, actor_realworld


def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
    actor_simulator = SafeProbabilisticTensorDictSequential(
        SafeModule(
            actor_module,
            in_keys=["state", "belief"],
            out_keys=["loc", "scale"],
            spec=Composite(
                **{
                    "loc": Unbounded(
                        proof_environment.action_spec.shape,
                        device=proof_environment.action_spec.device,
                    ),
                    "scale": Unbounded(
                        proof_environment.action_spec.shape,
                        device=proof_environment.action_spec.device,
                    ),
                }
            ),
        ),
        SafeProbabilisticModule(
            in_keys=["loc", "scale"],
            out_keys=[action_key],
            default_interaction_type=InteractionType.RANDOM,
            distribution_class=TanhNormal,
            distribution_kwargs={"tanh_loc": True},
            spec=Composite(**{action_key: proof_environment.action_spec}),
        ),
    )
    return actor_simulator


def _dreamer_make_actor_real(
    obs_encoder, rssm_prior, rssm_posterior, actor_module, action_key, proof_environment
):
    # actor for real world: interacts with states ~ posterior
    # Out actor differs from the original paper where first they compute prior and posterior and then act on it
    # but we found that this approach worked better.
    actor_realworld = SafeSequential(
        SafeModule(
            obs_encoder,
            in_keys=["pixels"],
            out_keys=["encoded_latents"],
        ),
        SafeModule(
            rssm_posterior,
            in_keys=["belief", "encoded_latents"],
            out_keys=[
                "_",
                "_",
                "state",
            ],
        ),
        SafeProbabilisticTensorDictSequential(
            SafeModule(
                actor_module,
                in_keys=["state", "belief"],
                out_keys=["loc", "scale"],
                spec=Composite(
                    **{
                        "loc": Unbounded(
                            proof_environment.action_spec.shape,
                        ),
                        "scale": Unbounded(
                            proof_environment.action_spec.shape,
                        ),
                    }
                ),
            ),
            SafeProbabilisticModule(
                in_keys=["loc", "scale"],
                out_keys=[action_key],
                default_interaction_type=InteractionType.DETERMINISTIC,
                distribution_class=TanhNormal,
                distribution_kwargs={"tanh_loc": True},
                spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}),
            ),
        ),
        SafeModule(
            rssm_prior,
            in_keys=["state", "belief", action_key],
            out_keys=[
                "_",
                "_",
                "_",  # we don't need the prior state
                ("next", "belief"),
            ],
        ),
    )
    return actor_realworld


def _dreamer_make_value_model(mlp_num_units, value_key):
    # actor for simulator: interacts with states ~ prior
    value_model = SafeModule(
        MLP(
            out_features=1,
            depth=3,
            num_cells=mlp_num_units,
            activation_class=nn.ELU,
        ),
        in_keys=["state", "belief"],
        out_keys=[value_key],
    )
    return value_model


def _dreamer_make_mbenv(
    reward_module,
    rssm_prior,
    obs_decoder,
    proof_environment,
    use_decoder_in_env,
    state_dim,
    rssm_hidden_dim,
):
    # MB environment
    if use_decoder_in_env:
        mb_env_obs_decoder = SafeModule(
            obs_decoder,
            in_keys=[("next", "state"), ("next", "belief")],
            out_keys=[("next", "reco_pixels")],
        )
    else:
        mb_env_obs_decoder = None

    transition_model = SafeSequential(
        SafeModule(
            rssm_prior,
            in_keys=["state", "belief", "action"],
            out_keys=[
                "_",
                "_",
                "state",
                "belief",
            ],
        ),
    )
    reward_model = SafeModule(
        reward_module,
        in_keys=["state", "belief"],
        out_keys=["reward"],
    )
    model_based_env = DreamerEnv(
        world_model=WorldModelWrapper(
            transition_model,
            reward_model,
        ),
        prior_shape=torch.Size([state_dim]),
        belief_shape=torch.Size([rssm_hidden_dim]),
        obs_decoder=mb_env_obs_decoder,
    )

    model_based_env.set_specs_from_env(proof_environment)
    model_based_env = TransformedEnv(model_based_env)
    default_dict = {
        "state": Unbounded(state_dim),
        "belief": Unbounded(rssm_hidden_dim),
        # "action": proof_environment.action_spec,
    }
    model_based_env.append_transform(
        TensorDictPrimer(random=False, default_value=0, **default_dict)
    )
    return model_based_env


@dataclass
class DreamerConfig:
    """Dreamer model config struct."""

    batch_length: int = 50
    state_dim: int = 30
    rssm_hidden_dim: int = 200
    mlp_num_units: int = 400
    grad_clip: int = 100
    world_model_lr: float = 6e-4
    actor_value_lr: float = 8e-5
    imagination_horizon: int = 15
    model_device: str = ""
    # Decay of the reward moving averaging
    exploration: str = "additive_gaussian"
    # One of "additive_gaussian", "ou_exploration" or ""


@dataclass
class REDQModelConfig:
    """REDQ model config struct."""

    annealing_frames: int = 1000000
    # float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.
    noisy: bool = False
    # whether to use NoisyLinearLayers in the value network.
    ou_exploration: bool = False
    # wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for
    # efficient entropy-based exploration, this should be left for experimentation only.
    ou_sigma: float = 0.2
    # Ornstein-Uhlenbeck sigma
    ou_theta: float = 0.15
    # Aimed at superseding --ou_exploration.
    distributional: bool = False
    # whether a distributional loss should be used (TODO: not implemented yet).
    atoms: int = 51
    # number of atoms used for the distributional loss (TODO)
    gSDE: bool = False
    # if True, exploration is achieved using the gSDE technique.
    tanh_loc: bool = False
    # if True, uses a Tanh-Normal transform for the policy location of the form
    # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions)
    default_policy_scale: float = 1.0
    # Default policy scale parameter
    distribution: str = "tanh_normal"
    # if True, uses a Tanh-Normal-Tanh distribution for the policy
    actor_cells: int = 256
    # cells of the actor
    qvalue_cells: int = 256
    # cells of the qvalue net
    scale_lb: float = 0.1
    # min value of scale
    value_cells: int = 256
    # cells of the value net
    activation: str = "tanh"
    # activation function, either relu or elu or tanh, Default=tanh


@dataclass
class ContinuousModelConfig:
    """Continuous control model config struct."""

    annealing_frames: int = 1000000
    # float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.
    noisy: bool = False
    # whether to use NoisyLinearLayers in the value network.
    ou_exploration: bool = False
    # wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for
    # efficient entropy-based exploration, this should be left for experimentation only.
    ou_sigma: float = 0.2
    # Ornstein-Uhlenbeck sigma
    ou_theta: float = 0.15
    # Aimed at superseding --ou_exploration.
    distributional: bool = False
    # whether a distributional loss should be used (TODO: not implemented yet).
    atoms: int = 51
    # number of atoms used for the distributional loss (TODO)
    gSDE: bool = False
    # if True, exploration is achieved using the gSDE technique.
    tanh_loc: bool = False
    # if True, uses a Tanh-Normal transform for the policy location of the form
    # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions)
    default_policy_scale: float = 1.0
    # Default policy scale parameter
    distribution: str = "tanh_normal"
    # if True, uses a Tanh-Normal-Tanh distribution for the policy
    lstm: bool = False
    # if True, uses an LSTM for the policy.
    shared_mapping: bool = False
    # if True, the first layers of the actor-critic are shared.
    actor_cells: int = 256
    # cells of the actor
    qvalue_cells: int = 256
    # cells of the qvalue net
    scale_lb: float = 0.1
    # min value of scale
    value_cells: int = 256
    # cells of the value net
    activation: str = "tanh"
    # activation function, either relu or elu or tanh, Default=tanh


@dataclass
class DiscreteModelConfig:
    """Discrete model config struct."""

    annealing_frames: int = 1000000
    # Number of frames used for annealing of the EGreedy exploration. Default=1e6.
    noisy: bool = False
    # whether to use NoisyLinearLayers in the value network
    distributional: bool = False
    # whether a distributional loss should be used.
    atoms: int = 51
    # number of atoms used for the distributional loss
