# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass
from warnings import warn

import torch
from tensordict.nn import TensorDictModule, TensorDictModuleWrapper
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.collectors.collectors import DataCollectorBase
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.envs.common import EnvBase
from torchrl.envs.utils import ExplorationType
from torchrl.modules import reset_noise
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.record.loggers import Logger
from torchrl.trainers.trainers import (
    BatchSubSampler,
    ClearCudaCache,
    CountFramesLog,
    LogScalar,
    LogValidationReward,
    ReplayBufferTrainer,
    RewardNormalizer,
    SelectKeys,
    Trainer,
    UpdateWeights,
)

OPTIMIZERS = {
    "adam": optim.Adam,
    "sgd": optim.SGD,
    "adamax": optim.Adamax,
}


@dataclass
class TrainerConfig:
    """Trainer config struct."""

    optim_steps_per_batch: int = 500
    # Number of optimization steps in between two collection of data. See frames_per_batch below.
    optimizer: str = "adam"
    # Optimizer to be used.
    lr_scheduler: str = "cosine"
    # LR scheduler.
    selected_keys: list | None = None
    # a list of strings that indicate the data that should be kept from the data collector. Since storing and
    # retrieving information from the replay buffer does not come for free, limiting the amount of data
    # passed to it can improve the algorithm performance.
    batch_size: int = 256
    # batch size of the TensorDict retrieved from the replay buffer. Default=256.
    log_interval: int = 10000
    # logging interval, in terms of optimization steps. Default=10000.
    lr: float = 3e-4
    # Learning rate used for the optimizer. Default=3e-4.
    weight_decay: float = 0.0
    # Weight-decay to be used with the optimizer. Default=0.0.
    clip_norm: float = 1000.0
    # value at which the total gradient norm / single derivative should be clipped. Default=1000.0
    clip_grad_norm: bool = False
    # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold.
    normalize_rewards_online: bool = False
    # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module.
    normalize_rewards_online_scale: float = 1.0
    # Final scale of the normalized rewards.
    normalize_rewards_online_decay: float = 0.9999
    # Decay of the reward moving averaging
    sub_traj_len: int = -1
    # length of the trajectories that sub-samples must have in online settings.


def make_trainer(
    collector: DataCollectorBase,
    loss_module: LossModule,
    recorder: EnvBase | None = None,
    target_net_updater: TargetNetUpdater | None = None,
    policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None,
    replay_buffer: ReplayBuffer | None = None,
    logger: Logger | None = None,
    cfg: DictConfig = None,  # noqa: F821
) -> Trainer:
    """Creates a Trainer instance given its constituents.

    Args:
        collector (DataCollectorBase): A data collector to be used to collect data.
        loss_module (LossModule): A TorchRL loss module
        recorder (EnvBase, optional): a recorder environment. If None, the trainer will train the policy without
            testing it.
        target_net_updater (TargetNetUpdater, optional): A target network update object.
        policy_exploration (TDModule or TensorDictModuleWrapper, optional): a policy to be used for recording and exploration
            updates (should be synced with the learnt policy).
        replay_buffer (ReplayBuffer, optional): a replay buffer to be used to collect data.
        logger (Logger, optional): a Logger to be used for logging.
        cfg (DictConfig, optional): a DictConfig containing the arguments of the script. If None, the default
            arguments are used.

    Returns:
        A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided.

    Examples:
        >>> import torch
        >>> import tempfile
        >>> from torchrl.trainers.loggers import TensorboardLogger
        >>> from torchrl.trainers import Trainer
        >>> from torchrl.envs import EnvCreator
        >>> from torchrl.collectors.collectors import SyncDataCollector
        >>> from torchrl.data import TensorDictReplayBuffer
        >>> from torchrl.envs.libs.gym import GymEnv
        >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper
        >>> from torchrl.objectives.common import LossModule
        >>> from torchrl.objectives.utils import TargetNetUpdater
        >>> from torchrl.objectives import DDPGLoss
        >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
        >>> env_proof = env_maker()
        >>> obs_spec = env_proof.observation_spec
        >>> action_spec = env_proof.action_spec
        >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1])
        >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1)  # for the purpose of testing
        >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"])
        >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"])
        >>> collector = SyncDataCollector(env_maker, policy, total_frames=100)
        >>> loss_module = DDPGLoss(policy, value, gamma=0.99)
        >>> recorder = env_proof
        >>> target_net_updater = None
        >>> policy_exploration = EGreedyWrapper(policy)
        >>> replay_buffer = TensorDictReplayBuffer()
        >>> dir = tempfile.gettempdir()
        >>> logger = TensorboardLogger(exp_name=dir)
        >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration,
        ...    replay_buffer, logger)
        >>> print(trainer)

    """
    if cfg is None:
        warn(
            "Getting default cfg for the trainer. "
            "This should be only used for debugging."
        )
        cfg = TrainerConfig()
        cfg.frame_skip = 1
        cfg.total_frames = 1000
        cfg.record_frames = 10
        cfg.record_interval = 10

    optimizer_kwargs = {} if cfg.optimizer != "adam" else {"betas": (0.0, 0.9)}
    optimizer = OPTIMIZERS[cfg.optimizer](
        loss_module.parameters(),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        **optimizer_kwargs,
    )
    device = next(loss_module.parameters()).device
    if cfg.lr_scheduler == "cosine":
        optim_scheduler = CosineAnnealingLR(
            optimizer,
            T_max=int(
                cfg.total_frames / cfg.frames_per_batch * cfg.optim_steps_per_batch
            ),
        )
    elif cfg.lr_scheduler == "":
        optim_scheduler = None
    else:
        raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}")

    if VERBOSE:
        torchrl_logger.info(
            f"collector = {collector}; \n"
            f"loss_module = {loss_module}; \n"
            f"recorder = {recorder}; \n"
            f"target_net_updater = {target_net_updater}; \n"
            f"policy_exploration = {policy_exploration}; \n"
            f"replay_buffer = {replay_buffer}; \n"
            f"logger = {logger}; \n"
            f"cfg = {cfg}; \n"
        )

    if logger is not None:
        # log hyperparams
        logger.log_hparams(cfg)

    trainer = Trainer(
        collector=collector,
        frame_skip=cfg.frame_skip,
        total_frames=cfg.total_frames * cfg.frame_skip,
        loss_module=loss_module,
        optimizer=optimizer,
        logger=logger,
        optim_steps_per_batch=cfg.optim_steps_per_batch,
        clip_grad_norm=cfg.clip_grad_norm,
        clip_norm=cfg.clip_norm,
    )

    if torch.cuda.device_count() > 0:
        trainer.register_op("pre_optim_steps", ClearCudaCache(1))

    if hasattr(cfg, "noisy") and cfg.noisy:
        trainer.register_op("pre_optim_steps", lambda: loss_module.apply(reset_noise))

    if cfg.selected_keys:
        trainer.register_op("batch_process", SelectKeys(cfg.selected_keys))
    trainer.register_op("batch_process", lambda batch: batch.cpu())

    if replay_buffer is not None:
        # replay buffer is used 2 or 3 times: to register data, to sample
        # data and to update priorities
        rb_trainer = ReplayBufferTrainer(
            replay_buffer,
            cfg.batch_size,
            flatten_tensordicts=False,
            memmap=False,
            device=device,
        )

        trainer.register_op("batch_process", rb_trainer.extend)
        trainer.register_op("process_optim_batch", rb_trainer.sample)
        trainer.register_op("post_loss", rb_trainer.update_priority)
    else:
        # trainer.register_op("batch_process", mask_batch)
        trainer.register_op(
            "process_optim_batch",
            BatchSubSampler(batch_size=cfg.batch_size, sub_traj_len=cfg.sub_traj_len),
        )
        trainer.register_op("process_optim_batch", lambda batch: batch.to(device))

    if optim_scheduler is not None:
        trainer.register_op("post_optim", optim_scheduler.step)

    if target_net_updater is not None:
        trainer.register_op("post_optim", target_net_updater.step)

    if cfg.normalize_rewards_online:
        # if used the running statistics of the rewards are computed and the
        # rewards used for training will be normalized based on these.
        reward_normalizer = RewardNormalizer(
            scale=cfg.normalize_rewards_online_scale,
            decay=cfg.normalize_rewards_online_decay,
        )
        trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
        trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)

    if policy_exploration is not None and hasattr(policy_exploration, "step"):
        trainer.register_op(
            "post_steps", policy_exploration.step, frames=cfg.frames_per_batch
        )

    trainer.register_op(
        "post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]}
    )

    if recorder is not None:
        # create recorder object
        recorder_obj = LogValidationReward(
            record_frames=cfg.record_frames,
            frame_skip=cfg.frame_skip,
            policy_exploration=policy_exploration,
            environment=recorder,
            record_interval=cfg.record_interval,
            log_keys=cfg.recorder_log_keys,
        )
        # register recorder
        trainer.register_op(
            "post_steps_log",
            recorder_obj,
        )
        # call recorder - could be removed
        recorder_obj(None)
        # create explorative recorder - could be optional
        recorder_obj_explore = LogValidationReward(
            record_frames=cfg.record_frames,
            frame_skip=cfg.frame_skip,
            policy_exploration=policy_exploration,
            environment=recorder,
            record_interval=cfg.record_interval,
            exploration_type=ExplorationType.RANDOM,
            suffix="exploration",
            out_keys={("next", "reward"): "r_evaluation_exploration"},
        )
        # register recorder
        trainer.register_op(
            "post_steps_log",
            recorder_obj_explore,
        )
        # call recorder - could be removed
        recorder_obj_explore(None)

    trainer.register_op(
        "post_steps", UpdateWeights(collector, update_weights_interval=1)
    )

    trainer.register_op("pre_steps_log", LogScalar())
    trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))

    return trainer
