from typing import List, Optional

import torch
import torch.nn as nn
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModuleBase


class MissingSignalFiller(nn.Module):
	"""Module to fill a missing signal with a prediction, if needed.

	This module is used to possibly fill a missing signaled observation
	with a prediction, depending on whether the signal has been recorded
	in the last time step.
	"""

	def __init__(self, signal_key: str) -> None:
		super(MissingSignalFiller, self).__init__()

		self.signal_key = signal_key

	def forward(self, obs: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
		"""Fills the missing signal with the prediction.

		Args:
			obs (torch.Tensor): The observation tensor. The last element
				of each observation is the time since the signal was
				last recorded, meaning that if it is greater than `0`,
				then the signal should be filled with the prediction.
			pred (torch.Tensor): The prediction tensor.

		Returns:
			torch.Tensor: The filled observation tensor.
		"""
		mask = obs[..., -1] > 0
		return torch.concat(
			[
				torch.where(mask.unsqueeze(-1), pred, obs[..., :-1]),
				obs[..., -1:],
			],
			dim=-1,
		)


class MissingSignalFillerTensorDict(TensorDictModuleBase):
	"""Module to fill missing signals in a `TensorDict`.

	This module is used to possibly fill missing signaled observations
	with predictions, depending on whether the signals have been
	recorded in the last time step.

	This module handles multiple signals at once, each stored in an
	input `TensorDictBase` with the key being the signal name.

	Args:
		obs_keys (List[str]): The keys of the observations corresponding
			to signals in the input `TensorDictBase`.
		pred_prefix (Optional[str], optional): The prefix for the keys
			corresponding to the predictions. Defaults to `"pred"`.
		aggregate (bool, optional): Whether to aggregate the filled
			observations into a single tensor under the key
			`"observation"`. Defaults to `True`.
		action_history (bool, optional): Whether to aggregate
			observation should also include the tensor at key
			`"action_history"`. Defaults to `True`. Only relevant if
			`aggregate` is `True`.
	"""

	def __init__(
		self,
		obs_keys: List[str],
		pred_prefix: Optional[str] = "pred",
		aggregate: bool = True,
		action_history: bool = True,
	) -> None:
		super().__init__()

		self.pred_prefix = pred_prefix
		self.obs_keys = obs_keys
		self.in_keys = list(obs_keys) + [
			f"{pred_prefix}_{obs}" for obs in obs_keys
		]  # type: ignore
		self.out_keys = ["observation"] if aggregate else list(obs_keys)

		self.fillers = nn.ModuleDict(
			{obs_key: MissingSignalFiller(obs_key) for obs_key in obs_keys}
		)

		self.aggregate = aggregate
		if self.aggregate:
			self.in_keys.append("action_history")
		self.action_history = action_history

	def forward(self, td: TensorDictBase) -> TensorDictBase:
		"""Forward pass of the module.

		Args:
			td (TensorDictBase): The input tensor dictionary.

		Returns:
			TensorDictBase: The output tensor dictionary.
		"""
		to_aggregate = []

		for obs_key in self.obs_keys:
			filled = self.fillers[obs_key](
				td.get(obs_key), td.get(f"{self.pred_prefix}_{obs_key}")
			)
			if self.aggregate:
				to_aggregate.append(filled)
			else:
				td.set(obs_key, filled)

		if self.aggregate:
			if self.action_history:
				to_aggregate.append(td.get("action_history"))
			td.set("observation", torch.concat(to_aggregate, dim=-1))

		return td
