from typing import Literal, Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torchrl.modules as trlm
from tensordict.nn import (
	InteractionType,
	TensorDictModule,
	TensorDictModuleBase,
	TensorDictSequential,
)
from torchrl.data import TensorSpec
from torchrl.objectives import SACLoss

from src.async_rl.module.signals_pred import (
	MissingSignalFillerTensorDict,
	SignalsPredLoss,
)
from src.async_rl.module.utils import Key
from src.async_rl.utils import to_number


class SACLoss4Mamba(SACLoss):
	"""SAC loss for Mamba networks.

	Mamba is not compatible with the original SAC loss due to the
	incompatibility of the `vmap` method. This class is a subclass of
	the SAC loss that is compatible with Mamba2 networks.
	"""

	def _make_vmap(self) -> None:
		def ensemble_qnetwork(
			inputs: torch.Tensor,
			params_list: nn.ParameterList,
			_randomness: None = None,
		) -> torch.Tensor:
			outputs = []
			for params in params_list:
				with params.to_module(self.qvalue_network):
					output = self.qvalue_network(inputs)
				outputs.append(output)
			return torch.stack(outputs, dim=0)

		self._vmap_qnetworkN0 = ensemble_qnetwork


def _make_rnn_module(
	type: Literal["lstm", "gru", "s4", "mamba2"],
	input_size: int,
	in_key: Optional[Key] = None,
	in_keys: Optional[Sequence[Key]] = None,
	out_key: Optional[Key] = None,
	out_keys: Optional[Sequence[Key]] = None,
) -> TensorDictModuleBase:
	if type == "lstm":
		return trlm.LSTMModule(
			input_size=input_size,
			hidden_size=128,
			num_layers=1,
			python_based=True,
			in_key=in_key,
			out_key=out_key,
			in_keys=in_keys,
			out_keys=out_keys,
		)
	elif type == "gru":
		return trlm.GRUModule(
			input_size=input_size,
			hidden_size=128,
			num_layers=1,
			python_based=True,
			in_key=in_key,
			out_key=out_key,
			in_keys=in_keys,
			out_keys=out_keys,
		)
	elif type == "s4":
		from src.async_rl.module.recurrent import S4Module

		return S4Module(
			input_size=input_size,
			hidden_size=128,
			l_max=4,
			in_key=in_key,
			in_keys=in_keys,
			out_key=out_key,
			out_keys=out_keys,
		)
	elif type == "mamba2":
		from src.async_rl.module.recurrent import Mamba2Module

		return Mamba2Module(
			input_size=input_size,
			hidden_size=128,
			in_key=in_key,
			in_keys=in_keys,
			out_key=out_key,
			out_keys=out_keys,
		)
	else:
		raise ValueError(f"Invalid RNN type: {type}")


def _make_pred_signals_modules(
	signals_keys: Sequence[str], signals_dims: Sequence[int]
) -> Tuple[TensorDictModuleBase, TensorDictModuleBase, TensorDictModuleBase]:
	filler = MissingSignalFillerTensorDict(signals_keys, aggregate=True)  # type: ignore

	extract_key = []
	for key, dim in zip(signals_keys, signals_dims, strict=True):
		extract_key.append(
			TensorDictModule(
				nn.Linear(128, dim),
				in_keys="embed",
				out_keys=("next", f"pred_{key}"),
			)
		)
	extract_pred_module = TensorDictSequential(*extract_key)

	pred_loss = SignalsPredLoss(signals_keys)  # type: ignore

	return filler, extract_pred_module, pred_loss


def build_rec_sac_module(
	n_obs: int,
	action_spec: TensorSpec,
	gamma: float,
	dropout: float = 0.0,
	rnn_type: Literal["gru", "lstm", "s4"] = "gru",
	obs_emb_size: int = 32,
	depth: int = 2,
	hidden_size: int = 256,
	min_alpha: Optional[float] = None,
	max_alpha: Optional[float] = None,
	pred_signals: bool = False,
	signals_keys: Optional[Sequence[str]] = None,
	signals_dims: Optional[Sequence[int]] = None,
) -> Tuple[
	SACLoss,
	TensorDictModule,
	Tuple[TensorDictModuleBase, TensorDictModuleBase],
]:
	"""Builds a recurrent SAC module.

	Args:
		n_obs (int): Dimension of the observation space.
		action_spec (TensorSpec): Specification of the action space.
		gamma (float): Discount factor.
		dropout (float, optional): MLPs dropout. Defaults to `0.0`.
		rnn_type (Literal["gru", "lstm", "s4"], optional): Type of RNN
			to use. Defaults to `gru`.
		obs_emb_size (int, optional): Size of the observation embedding.
			Defaults to `32`.
		depth (int, optional): Depth of the MLPs. Defaults to `2`.
		hidden_size (int, optional): Size of the hidden layers. Defaults
			to `256`.
		min_alpha (Optional[float], optional): Minimum value of the
			temperature. Defaults to `None`.
		max_alpha (Optional[float], optional): Maximum value of the
			temperature. Defaults to `None`.
		pred_signals (bool, optional): Whether to have the RNN predict
			signals to then fill the missing observations for the next
			time step. Defaults to `False`.
		signals_keys (Optional[Sequence[str]], optional): Keys of the
			signals to predict. Only used if `pred_signals` is `True`.
			Defaults to `None`.
		signals_dims (Optional[Sequence[int]], optional): Dimension of
			the signals to predict. Only used if `pred_signals` is
			`True`. Defaults to `None`.

	Returns:
		Tuple[SACLoss, TensorDictModule, Tuple[TensorDictModuleBase,
			TensorDictModuleBase]]: Tuple containing the SAC loss, the
			actor module, and a tuple containing the actor and Q-value
			RNNs.
	"""
	n_act = action_spec.shape[-1]

	if pred_signals:
		assert (
			signals_keys is not None and signals_dims is not None
		), "Signals keys and dims must be provided if predicting signals."
		filler, pred_extractor, pred_signals_loss = _make_pred_signals_modules(
			signals_keys, signals_dims
		)

	feature_actor_pre = TensorDictModule(
		nn.Linear(n_obs, obs_emb_size),
		in_keys="observation",
		out_keys="obs_emb",
	)
	feature_actor_post = TensorDictModule(
		nn.Linear(n_obs, obs_emb_size),
		in_keys="observation",
		out_keys="obs_emb_post",
	)
	rnn_actor = _make_rnn_module(
		rnn_type,
		obs_emb_size,
		in_key="obs_emb",
		out_key="embed",
	)

	actor_nn = TensorDictModule(
		trlm.MLP(
			num_cells=[hidden_size] * depth,
			activation_class=nn.ReLU,
			out_features=2 * n_act,
			in_features=128 + obs_emb_size,
			dropout=dropout,
		),
		in_keys=["embed", "obs_emb_post"],
		out_keys="loc_scale",
	)
	actor_extractor = TensorDictModule(
		trlm.distributions.NormalParamExtractor(
			scale_mapping="biased_softplus_1.0",
			scale_lb=to_number(0.1),
		),
		in_keys="loc_scale",
		out_keys=["loc", "scale"],
	)
	actor_for_loss = trlm.tensordict_module.actors.ProbabilisticActor(
		module=TensorDictSequential(
			feature_actor_pre,
			feature_actor_post,
			rnn_actor,
			actor_nn,
			actor_extractor,
		)
		if not pred_signals
		else TensorDictSequential(
			filler,  # type: ignore
			feature_actor_pre,
			feature_actor_post,
			rnn_actor,
			pred_extractor,  # type: ignore
			actor_nn,
			actor_extractor,
		),
		spec=action_spec,
		in_keys=["loc", "scale"],
		distribution_class=trlm.distributions.TanhNormal,
		distribution_kwargs={
			"low": action_spec.space.low,  # type: ignore
			"high": action_spec.space.high,  # type: ignore
			"tahn_loc": False,
		},
		default_interaction_type=InteractionType.RANDOM,
		return_log_prob=False,
	)
	actor = trlm.tensordict_module.actors.ProbabilisticActor(
		module=TensorDictSequential(
			feature_actor_pre,
			feature_actor_post,
			rnn_actor.set_recurrent_mode(True),
			actor_nn,
			actor_extractor,
		)
		if not pred_signals
		else TensorDictSequential(
			filler,  # type: ignore
			feature_actor_pre,
			feature_actor_post,
			rnn_actor.set_recurrent_mode(True),
			pred_extractor,  # type: ignore
			pred_signals_loss,  # type: ignore
			actor_nn,
			actor_extractor,
		),
		spec=action_spec,
		in_keys=["loc", "scale"],
		distribution_class=trlm.distributions.TanhNormal,
		distribution_kwargs={
			"low": action_spec.space.low,  # type: ignore
			"high": action_spec.space.high,  # type: ignore
			"tahn_loc": False,
		},
		default_interaction_type=InteractionType.RANDOM,
		return_log_prob=True,
	)

	feature_qvalue_pre = TensorDictModule(
		trlm.MLP(
			depth=0,
			activation_class=nn.ReLU,
			out_features=obs_emb_size,
			in_features=n_obs + n_act,
		),
		in_keys=["observation", "action"],
		out_keys="qv_obs_emb",
	)
	feature_qvalue_post = TensorDictModule(
		trlm.MLP(
			depth=0,
			activation_class=nn.ReLU,
			out_features=obs_emb_size,
			in_features=n_obs + n_act,
		),
		in_keys=["observation", "action"],
		out_keys="qv_obs_emb_post",
	)
	rnn_qvalue = _make_rnn_module(
		rnn_type,
		obs_emb_size,
		in_keys=["qv_obs_emb", "qv_recurrent_state", "is_init"],
		out_keys=["embed_qvalue", ("next", "qv_recurrent_state")],
	)
	qvalue_nn = TensorDictModule(
		trlm.MLP(
			num_cells=[hidden_size] * depth,
			activation_class=nn.ReLU,
			out_features=1,
			in_features=128 + obs_emb_size,
			dropout=dropout,
		),
		in_keys=["embed_qvalue", "qv_obs_emb_post"],
		out_keys=["embed_qvalue"],
	)

	qvalue_in_keys = []
	if pred_signals:
		for key in signals_keys:  # type: ignore
			qvalue_in_keys.append(key)
		for key in signals_keys:  # type: ignore
			qvalue_in_keys.append(f"pred_{key}")
		qvalue_in_keys.append("action_history")
	else:
		qvalue_in_keys.insert(0, "observation")
	qvalue_in_keys.extend(["action", "qv_recurrent_state", "is_init"])

	qvalue = trlm.tensordict_module.actors.ValueOperator(
		module=TensorDictSequential(
			feature_qvalue_pre,
			feature_qvalue_post,
			rnn_qvalue.set_recurrent_mode(True),
			qvalue_nn,
			selected_out_keys=["embed_qvalue"],
		)
		if not pred_signals
		else TensorDictSequential(
			filler,  # type: ignore
			feature_qvalue_pre,
			feature_qvalue_post,
			rnn_qvalue.set_recurrent_mode(True),
			qvalue_nn,
			selected_out_keys=["embed_qvalue", "observation"],
		),
		in_keys=qvalue_in_keys,
	)

	sac_loss_cls = SACLoss4Mamba if rnn_type == "mamba2" else SACLoss
	loss = sac_loss_cls(
		actor_network=actor,
		qvalue_network=qvalue,
		num_qvalue_nets=2,
		loss_function="l2",
		delay_actor=False,
		delay_qvalue=True,
		separate_losses=True,
		alpha_init=1.0,
		min_alpha=min_alpha,  # type: ignore
		max_alpha=max_alpha,  # type: ignore
	)
	loss.make_value_estimator(gamma=gamma)

	return loss, actor_for_loss, (rnn_actor, rnn_qvalue)
