from typing import Optional, Sequence, Tuple, cast

import torch
import torch.nn as nn
from tensordict import TensorDictBase, unravel_key_list
from tensordict.base import NO_DEFAULT
from tensordict.nn import TensorDictModuleBase
from torchrl.envs import TensorDictPrimer

from src.async_rl.module.utils import Key, append_to_last, complex_spec
from src.async_rl.ssms import S4Block


class S4Network(nn.Module):
	"""An S4 network.

	Args:
		input_size (int): The dimension of the input tensor.
		hidden_size (int): The dimension of the hidden state.
		d_state (int, optional): The dimension of the SSM kernel.
			Defaults to `64`.
		num_layers (int, optional): The number of layers in the S4
			network. Defaults to `1`.
		l_max (int, optional): The maximum length of the sequence. It
			must be an even number. Defaults to `4`.
	"""

	def __init__(
		self,
		input_size: int,
		hidden_size: int,
		d_state: int = 64,
		num_layers: int = 1,
		l_max: int = 4,
	) -> None:
		super(S4Network, self).__init__()
		assert l_max % 2 == 0, "l_max must be an even number."

		self.first_linear = nn.Linear(input_size, hidden_size)
		self.s4s = nn.ModuleList(
			[
				S4Block(
					hidden_size,
					transposed=False,
					d_state=d_state,
					l_max=l_max,
				)
				for _ in range(num_layers)
			]
		)

		mock_tensor = torch.zeros(1, l_max if l_max else 1, hidden_size)
		for s4 in self.s4s:
			_, _ = s4(mock_tensor)

	def forward(self, x: torch.Tensor) -> torch.Tensor:
		"""Forward pass of the S4 network.

		Args:
			x (torch.Tensor): The input tensor representing a sequence.

		Returns:
			torch.Tensor: The output tensor.
		"""
		x = self.first_linear(x)
		for s4 in self.s4s:
			x, _ = s4(x)
		return x

	def step(
		self, x: torch.Tensor, state: torch.Tensor
	) -> Tuple[torch.Tensor, torch.Tensor]:
		"""A single step of the S4 network.

		Args:
			x (torch.Tensor): The input tensor representing a single
				time step.
			state (torch.Tensor): The hidden state of the previous time
				step.

		Returns:
			Tuple[torch.Tensor, torch.Tensor]: The output tensor and the
				next hidden state.
		"""
		x = self.first_linear(x)
		states = []
		for index, s4 in enumerate(self.s4s):
			x, curr_state = s4.step(x, state[:, index])
			states.append(curr_state)
		return x, torch.cat(states, dim=1)

	def setup_for_step(self) -> None:
		"""Sets up S4 blocks for processing individual time steps."""
		for s4 in self.s4s:
			s4.setup_step()
			s4.default_state(1)


class S4Module(TensorDictModuleBase):
	"""An embedder for an S4 module.

	This class is a wrapper around an S4 module that allows it to be
	used as a `tensordict` module. Most importantly, it can be used in
	policy and value networks in TorchRL.

	The module is not in recurrent mode by default, which means that it
	expects input tensors corresponding to single time steps. This is
	the behavior that should be used when the module is part of a policy
	network that is being tested.

	When the module is in recurrent mode, it expects input tensors
	corresponding to full sequences. This is the behavior that should be
	used when the module is part of a network that is being trained.

	The `set_recurrent_mode` method can be used to switch between these
	two modes, while maintaining the same module instance, meaning that
	weights are shared. This enables using the same parameters for both
	training and testing.

	Args:
		input_size (int): The dimension of the input tensor.
		hidden_size (int): The dimension of the hidden state.
		d_state (int, optional): The dimension of the SSM kernel.
			Defaults to `64`.
		num_layers (int, optional): The number of layers in the S4
			network.
		l_max (int, optional): The maximum length of the sequence. It
			must be an even number. Defaults to `4`.
		in_key (str, optional): The key for the input tensor. If not
			provided, `in_keys` must be provided. Defaults to `None`.
		in_keys (Sequence[str], optional): The keys for the input tensor
			and the hidden state. If not provided, `in_key` must be
			provided. Defaults to `None`.
		out_key (str, optional): The key for the output tensor. If not
			provided, `out_keys` must be provided. Defaults to `None`.
		out_keys (Sequence[str], optional): The keys for the output
			tensor and the next hidden state. If not provided, `out_key`
			must be provided. Defaults to `None`.
		device (torch.device, optional): The device on which to place
			the S4 network. Defaults to `None`.
		s4 (S4Network, optional): An S4 network to use instead of
			creating a new one. Defaults to `None`.
	"""

	def __init__(
		self,
		input_size: int,
		hidden_size: int,
		d_state: int = 64,
		num_layers: int = 1,
		l_max: int = 4,
		*,
		in_key: Optional[Key] = None,
		in_keys: Optional[Sequence[Key]] = None,
		out_key: Optional[Key] = None,
		out_keys: Optional[Sequence[Key]] = None,
		device: Optional[torch.device] = None,
		s4: Optional[S4Network] = None,
	) -> None:
		super().__init__()

		ssm = (
			S4Network(
				input_size,
				hidden_size,
				d_state=d_state,
				num_layers=num_layers,
				l_max=l_max,
			).to(device)
			if s4 is None
			else s4.to(device)
		)
		self.ssm = ssm
		self.input_dim = input_size
		self.hidden_size = hidden_size
		self.num_layers = num_layers
		self.d_state = d_state

		if not (in_key is None) ^ (in_keys is None):
			raise ValueError(
				"Exactly one of in_key and in_keys must be provided"
			)
		elif in_key:
			in_keys = [in_key, "recurrent_state", "is_init"]

		if not (out_key is None) ^ (out_keys is None):
			raise ValueError(
				"Exactly one of out_key and out_keys must be provided"
			)
		elif out_key:
			out_keys = [out_key, ("next", "recurrent_state")]

		in_keys = unravel_key_list(in_keys)
		out_keys = unravel_key_list(out_keys)

		if len(in_keys) != 3:
			raise ValueError(
				"Expected in_keys to have length 3, got {}".format(
					len(in_keys)
				)
			)
		if len(out_keys) != 2:
			raise ValueError(
				"Expected out_keys to have length 2, got {}".format(
					len(out_keys)
				)
			)

		self.in_keys = in_keys
		self.out_keys = out_keys

		self.in_real = append_to_last(in_keys[1], "real")
		self.in_imag = append_to_last(in_keys[1], "imag")
		self.out_real = append_to_last(out_keys[1], "real")
		self.out_imag = append_to_last(out_keys[1], "imag")

		self._reccurrent_mode = False

	def forward(self, td: TensorDictBase) -> TensorDictBase:
		"""Forward pass of the module.

		The behavior of the forward pass depends on whether the module
		is in recurrent mode, together with the expected shape and keys
		of the input tensor.

		Args:
			td (TensorDictBase): The input tensor dictionary.

		Returns:
			TensorDictBase: The output tensor dictionary.
		"""
		return (
			self._forward_recurrent(td)
			if self._reccurrent_mode
			else self._forward_non_recurrent(td)
		)

	def set_recurrent_mode(self, mode: bool = True) -> "S4Module":
		"""Either sets or unsets the module in recurrent mode.

		Args:
			mode (bool, optional): Whether to set the module in
				recurrent mode. Defaults to `True`.

		Returns:
			S4Module: A module with the requested recurrent mode,
				sharing the same weights as the original module.
		"""
		if mode == self._reccurrent_mode:
			return self
		to_return = S4Module(
			input_size=self.input_dim,
			hidden_size=self.hidden_size,
			d_state=self.d_state,
			num_layers=self.num_layers,
			in_keys=self.in_keys,
			out_keys=cast(Sequence[Key], self.out_keys),
			s4=self.ssm,
		)
		to_return._reccurrent_mode = mode
		return to_return

	def make_tensordict_primer(self) -> TensorDictPrimer:
		"""Makes a tensordict primer for the environment.

		Returns:
			TensorDictPrimer: A primer for the environment which
				populates the tensordict with the expected keys and
				shapes.
		"""
		return TensorDictPrimer(
			complex_spec(
				shape=torch.Size(
					[self.num_layers, self.hidden_size, self.d_state]
				),
				key=self.in_keys[1],
			)
		)

	def setup_for_step(self) -> None:
		"""Sets up S4 network for processing individual time steps."""
		self.ssm.setup_for_step()

	def _forward_recurrent(self, td: TensorDictBase) -> TensorDictBase:
		input_ = td.get(self.in_keys[0], NO_DEFAULT)
		output = self.ssm(input_)
		td.set(self.out_keys[0], output)
		return td

	def _forward_non_recurrent(self, td: TensorDictBase) -> TensorDictBase:
		squeeze_out = False

		input_ = td.get(self.in_keys[0], NO_DEFAULT)
		state_real = td.get(self.in_real, None)
		state_imag = td.get(self.in_imag, None)
		state = torch.view_as_complex(
			torch.stack([state_real, state_imag], dim=-1)
		)

		if len(input_.size()) == 1:
			squeeze_out = True
			input_ = input_.unsqueeze(0)
		while len(state.size()) < 4:
			state = state.unsqueeze(0)
		output, state = self.ssm.step(input_, state)

		if squeeze_out:
			output = output.squeeze(0)
		while len(state.size()) > 3:
			state = state.squeeze(0)

		state_real, state_imag = torch.unbind(
			torch.view_as_real(state), dim=-1
		)
		td.set(self.out_keys[0], output)
		td.set(self.out_real, state_real)
		td.set(self.out_imag, state_imag)

		return td

	@property
	def recurrent_mode(self) -> bool:
		"""Whether the module is in recurrent mode."""
		return self._reccurrent_mode

	@recurrent_mode.setter
	def recurrent_mode(self, mode: bool) -> None:
		raise RuntimeError(
			"recurrent_mode cannot be changed in-place. Call"
			+ " `module.set_recurrent_mode` instead."
		)
