from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torchrl.modules as trlm
from tensordict import TensorDictBase
from tensordict.nn import InteractionType, TensorDictModule
from torch.optim.optimizer import Optimizer
from torchrl.data import TensorSpec
from torchrl.objectives import SACLoss

from src.async_rl.module.trainer import RLTrainer, RLTrainerHook


def build_sac_module(
	n_obs: int,
	action_spec: TensorSpec,
	gamma: float,
	hidden_size: int = 256,
	num_layers: int = 2,
) -> Tuple[SACLoss, TensorDictModule]:
	"""Builds a module for reinforcement learning with SAC loss.

	Args:
		n_obs (int): Dimension of the observation space.
		action_spec (TensorSpec): Specification of the action space.
		gamma (float): Discount factor.
		hidden_size (int, optional): Number of units in the hidden
			layers. Defaults to `256`.
		num_layers (int, optional): Number of hidden layers. Defaults to
			`2`.

	Returns:
		Tuple[SACLoss, TensorDictModule]: Tuple containing the SAC loss
			and the actor module.
	"""
	n_act = action_spec.shape[-1]

	actor_nn = trlm.MLP(
		num_cells=[hidden_size] * num_layers,
		activation_class=nn.ReLU,
		out_features=2 * n_act,
		in_features=n_obs,
	)
	actor_extractor = trlm.distributions.NormalParamExtractor(
		scale_mapping="biased_softplus_1.0",
		scale_lb=0.1,  # type: ignore
	)
	actor_module = trlm.SafeModule(
		nn.Sequential(actor_nn, actor_extractor),
		in_keys=["observation"],
		out_keys=["loc", "scale"],
	)
	actor = trlm.tensordict_module.actors.ProbabilisticActor(
		module=actor_module,
		spec=action_spec,
		in_keys=["loc", "scale"],  # type: ignore
		distribution_class=trlm.distributions.TanhNormal,
		distribution_kwargs={
			"low": action_spec.space.low,  # type: ignore
			"high": action_spec.space.high,  # type: ignore
			"tahn_loc": False,
		},
		default_interaction_type=InteractionType.RANDOM,
		return_log_prob=False,
	)

	qvalue_nn = trlm.MLP(
		num_cells=[hidden_size] * num_layers,
		out_features=1,
		in_features=n_obs + n_act,
		activation_class=nn.ReLU,
	)
	qvalue = trlm.tensordict_module.actors.ValueOperator(
		module=qvalue_nn,
		in_keys=["action", "observation"],
	)
	loss = SACLoss(
		actor_network=actor,
		qvalue_network=qvalue,
		num_qvalue_nets=2,
		loss_function="l2",
		delay_actor=False,
		delay_qvalue=True,
		separate_losses=True,
		alpha_init=1.0,
	)
	loss.make_value_estimator(gamma=gamma)

	return loss, actor


class SACLossHook(RLTrainerHook):
	"""Trainer hook for the SAC algorithm.

	Args:
		loss (SACLoss): SAC loss module.
		opt_actor (Optimizer): Optimizer for the actor network.
		opt_qvalue (Optimizer): Optimizer for the Q-value network.
		opt_alpha (Optimizer): Optimizer for alpha parameter (automatic
			entropy tuning is assumed).
		gradient_clipping (Optional[float], optional): Gradient clipping
			value. Defaults to `None`, which means no clipping.
		pred_missing_signals (bool, optional): Whether the loss should
			also optimize for the prediction of missing signals in an
			asynchronous environment. Defaults to `False`.
	"""

	def __init__(
		self,
		loss: SACLoss,
		opt_actor: Optimizer,
		opt_qvalue: Optimizer,
		opt_alpha: Optimizer,
		gradient_clipping: Optional[float] = None,
		pred_missing_signals: bool = False,
	) -> None:
		self.loss = loss
		self.opt_actor = opt_actor
		self.opt_qvalue = opt_qvalue
		self.opt_alpha = opt_alpha

		self.gradient_clipping = gradient_clipping
		self.pred_missing_signals = pred_missing_signals

	def register(self, trainer: RLTrainer, name: str = "SACLossHook") -> None:
		"""Registers the SAC loss hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				"SACLossHook".
		"""
		trainer.register_hook("optim_step", self._optimizer_step)
		modules = {
			"sac_loss": self.loss,
			"qvalue_optimizer": self.opt_qvalue,
			"actor_optimizer": self.opt_actor,
			"alpha_optimizer": self.opt_alpha,
		}
		trainer.register_modules(modules)

	def _optimizer_step(
		self, batch: TensorDictBase
	) -> Tuple[TensorDictBase, Dict[str, torch.Tensor]]:
		qvalue_loss, _ = self.loss._qvalue_v2_loss(batch)
		qvalue_loss = qvalue_loss.mean()

		self.opt_qvalue.zero_grad()
		qvalue_loss.backward()
		if self.gradient_clipping is not None:
			for param_group in self.opt_qvalue.param_groups:
				torch.nn.utils.clip_grad_norm_(
					param_group["params"], self.gradient_clipping
				)
		self.opt_qvalue.step()

		actor_loss, metadata = self.loss._actor_loss(batch)
		actor_loss = actor_loss.mean()

		if self.pred_missing_signals:
			actor_loss += torch.mean(batch["missing_signals_loss"])

		self.opt_actor.zero_grad()
		actor_loss.backward()
		if self.gradient_clipping is not None:
			for param_group in self.opt_actor.param_groups:
				torch.nn.utils.clip_grad_norm_(
					param_group["params"], self.gradient_clipping
				)
		self.opt_actor.step()

		alpha_loss = self.loss._alpha_loss(metadata["log_prob"])
		alpha_loss = alpha_loss.mean()

		self.opt_alpha.zero_grad()
		alpha_loss.backward()
		if self.gradient_clipping is not None:
			for param_group in self.opt_alpha.param_groups:
				torch.nn.utils.clip_grad_norm_(
					param_group["params"], self.gradient_clipping
				)
		self.opt_alpha.step()

		losses = {
			"loss_qvalue": qvalue_loss,
			"loss_actor": actor_loss,
			"loss_alpha": alpha_loss,
		}
		return batch, losses
