from typing import List, Optional

import torch
import torch.nn as nn
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModuleBase


class SignalPredLoss(nn.Module):
	"""Module for computing the loss for the signal prediction task.

	This module computes the loss of an individual signal prediction
	task. To compute a global loss, see :class:`SignalsPredLoss`.
	"""

	def __init__(self) -> None:
		super(SignalPredLoss, self).__init__()

	def forward(self, obs: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
		"""Computes the loss for the signal prediction task.

		Args:
			obs (torch.Tensor): The observation tensor. The last element
				of each observation is the time since the signal was
				last recorded.
			pred (torch.Tensor): The prediction tensor.

		Returns:
			torch.Tensor: The loss tensor.

		Raises:
			ValueError: If the indices for missing signals are invalid,
				caused by the observation tensor having more than three
				dimensions.
		"""
		indices = torch.nonzero(obs[..., -1] == 0, as_tuple=True)
		if len(indices) == 1:
			return torch.nn.functional.mse_loss(
				obs[indices[0], :-1], pred[indices[0]]
			)
		elif len(indices) == 2:
			return torch.nn.functional.mse_loss(
				obs[indices[0], indices[1], :-1], pred[indices[0], indices[1]]
			)
		raise ValueError("Invalid indices for missing signals.")


class SignalsPredLoss(TensorDictModuleBase):
	"""Module for computing the loss for the signals prediction task.

	This module is intended to be attached to a TorchRL loss module. It
	computes the MSE loss for predicted and observed signals in an
	environment with asynchronous signals.
	"""

	def __init__(
		self,
		obs_keys: List[str],
		pred_prefix: Optional[str] = "pred",
	) -> None:
		super(SignalsPredLoss, self).__init__()

		self.pred_prefix = pred_prefix
		self.obs_keys = obs_keys
		self.in_keys = list(obs_keys) + [
			f"{pred_prefix}_{obs}" for obs in obs_keys
		]  # type: ignore
		self.out_keys = ["missing_signals_loss"]

		self.loss_nn = SignalPredLoss()

	def forward(self, td: TensorDictBase) -> TensorDictBase:
		"""Computes the loss for the signal prediction task.

		Args:
			td (TensorDictBase): The input tensor dictionary. It should
				contain the observations and predictions for the
				signals.

		Returns:
			TensorDictBase: The input tensor dictionary with the loss
				computed and stored under the key
				`"missing_signals_loss"`.
		"""
		loss = torch.zeros(td.batch_size, device=td.device)

		for obs_key in self.obs_keys:
			curr_loss = self.loss_nn(
				td.get(obs_key),
				td.get(f"{self.pred_prefix}_{obs_key}"),
			)
			loss += curr_loss

		td.set("missing_signals_loss", loss)

		return td
