from typing import Optional, Sequence, cast

import torch
import torch.nn as nn
from tensordict import TensorDictBase, unravel_key_list
from tensordict.base import NO_DEFAULT
from tensordict.nn import TensorDictModuleBase
from torchrl.data import UnboundedContinuous, UnboundedDiscrete
from torchrl.envs import TensorDictPrimer
from transformers import Mamba2Config, Mamba2Model

from src.async_rl.module.utils import Key, append_to_last


class Mamba2Module(TensorDictModuleBase):
	"""An embedder for an Mamba2 module.

	This class is a wrapper around an Mamba2 module that allows it to be
	used as a `tensordict` module. Most importantly, it can be used in
	policy and value networks in TorchRL.

	The module is not in recurrent mode by default, which means that it
	expects input tensors corresponding to single time steps. This is
	the behavior that should be used when the module is part of a policy
	network that is being tested.

	When the module is in recurrent mode, it expects input tensors
	corresponding to full sequences. This is the behavior that should be
	used when the module is part of a network that is being trained.

	The `set_recurrent_mode` method can be used to switch between these
	two modes, while maintaining the same module instance, meaning that
	weights are shared. This enables using the same parameters for both
	training and testing.

	Args:
		input_size (int): The dimension of the input tensor.
		hidden_size (int): The dimension of the hidden state.
		num_layers (int, optional): The number of layers in the Mamba2
			network.
		head_dim (int, optional): The dimension of the heads in the
		Mamba2 network. Defaults to 8.
		in_key (str, optional): The key for the input tensor. If not
			provided, `in_keys` must be provided. Defaults to `None`.
		in_keys (Sequence[str], optional): The keys for the input tensor
			and the hidden state. If not provided, `in_key` must be
			provided. Defaults to `None`.
		out_key (str, optional): The key for the output tensor. If not
			provided, `out_keys` must be provided. Defaults to `None`.
		out_keys (Sequence[str], optional): The keys for the output
			tensor and the next hidden state. If not provided, `out_key`
			must be provided. Defaults to `None`.
		device (torch.device, optional): The device on which to place
			the Mamba2 network. Defaults to `None`.
		mamba2 (Mamba2Network, optional): An Mamba2 network to use
			instead of creating a new one. Defaults to `None`.
	"""

	def __init__(
		self,
		input_size: int,
		hidden_size: int,
		num_layers: int = 1,
		head_dim: int = 8,
		*,
		in_key: Optional[Key] = None,
		in_keys: Optional[Sequence[Key]] = None,
		out_key: Optional[Key] = None,
		out_keys: Optional[Sequence[Key]] = None,
		device: Optional[torch.device] = None,
		mamba2: Optional[Mamba2Model] = None,
	) -> None:
		super().__init__()

		self.input_size = input_size
		config = Mamba2Config(
			hidden_size=hidden_size,
			head_dim=head_dim,
			num_hidden_layers=num_layers,
			vocab_size=hidden_size,
		)
		n_heads = (config.expand * config.hidden_size) // config.head_dim
		config.num_heads = n_heads
		ssm: Mamba2Model = Mamba2Model(config) if mamba2 is None else mamba2
		ssm.embeddings = nn.Linear(input_size, hidden_size)
		self.ssm = ssm.to(device)  # type: ignore
		self.cache = None

		if not (in_key is None) ^ (in_keys is None):
			raise ValueError(
				"Exactly one of in_key and in_keys must be provided"
			)
		elif in_key:
			in_keys = [in_key, "recurrent_state", "is_init"]

		if not (out_key is None) ^ (out_keys is None):
			raise ValueError(
				"Exactly one of out_key and out_keys must be provided"
			)
		elif out_key:
			out_keys = [out_key, ("next", "recurrent_state")]

		in_keys = unravel_key_list(in_keys)
		out_keys = unravel_key_list(out_keys)

		if len(in_keys) != 3:
			raise ValueError(
				"Expected in_keys to have length 3, got {}".format(
					len(in_keys)
				)
			)
		if len(out_keys) != 2:
			raise ValueError(
				"Expected out_keys to have length 2, got {}".format(
					len(out_keys)
				)
			)

		self.in_keys = in_keys
		self.out_keys = out_keys

		self.in_conv_states = append_to_last(in_keys[1], "conv_states")
		self.in_ssm_states = append_to_last(in_keys[1], "ssm_states")
		self.in_cache_position = append_to_last(in_keys[1], "cache_position")
		self.out_conv_states = append_to_last(out_keys[1], "conv_states")
		self.out_ssm_states = append_to_last(out_keys[1], "ssm_states")
		self.out_cache_position = append_to_last(out_keys[1], "cache_position")

		self._reccurrent_mode = False

	def forward(self, td: TensorDictBase) -> TensorDictBase:
		"""Forward pass of the module.

		The behavior of the forward pass depends on whether the module
		is in recurrent mode, together with the expected shape and keys
		of the input tensor.

		Args:
			td (TensorDictBase): The input tensor dictionary.

		Returns:
			TensorDictBase: The output tensor dictionary.
		"""
		return (
			self._forward_recurrent(td)
			if self._reccurrent_mode
			else self._forward_non_recurrent(td)
		)

	def set_recurrent_mode(self, mode: bool = True) -> "Mamba2Module":
		"""Either sets or unsets the module in recurrent mode.

		Args:
			mode (bool, optional): Whether to set the module in
				recurrent mode. Defaults to `True`.

		Returns:
			Mamba2Module: A module with the requested recurrent mode,
				sharing the same weights as the original module.
		"""
		if mode == self._reccurrent_mode:
			return self
		config = self.ssm.config
		to_return = Mamba2Module(
			input_size=self.input_size,
			hidden_size=config.hidden_size,
			num_layers=config.num_hidden_layers,
			head_dim=config.head_dim,
			in_keys=self.in_keys,
			out_keys=cast(Sequence[Key], self.out_keys),
			mamba2=self.ssm,
		)
		to_return._reccurrent_mode = mode
		return to_return

	def make_tensordict_primer(self) -> TensorDictPrimer:
		"""Makes a tensordict primer for the environment.

		Returns:
			TensorDictPrimer: A primer for the environment which
				populates the tensordict with the expected keys and
				shapes.
		"""
		config: Mamba2Config = cast(Mamba2Config, self.ssm.config)
		conv_states_shape = torch.Size(
			[
				config.num_hidden_layers,
				config.expand * config.hidden_size
				+ 2 * config.n_groups * config.state_size,
				config.conv_kernel,
			]
		)
		ssm_states_shape = torch.Size(
			[
				config.num_hidden_layers,
				config.num_heads,
				config.head_dim,
				config.state_size,
			]
		)
		return TensorDictPrimer(
			{
				append_to_last(
					self.in_keys[1], "conv_states"
				): UnboundedContinuous(conv_states_shape),  # type: ignore
				append_to_last(
					self.in_keys[1], "ssm_states"
				): UnboundedContinuous(ssm_states_shape),  # type: ignore
				append_to_last(
					self.in_keys[1], "cache_position"
				): UnboundedDiscrete(torch.Size([1])),  # type: ignore
			}
		)

	def _forward_recurrent(self, td: TensorDictBase) -> TensorDictBase:
		input_ = td.get(self.in_keys[0], NO_DEFAULT)
		output = self.ssm(input_, return_dict=False)[0]
		td.set(self.out_keys[0], output)
		return td

	def _forward_non_recurrent(self, td: TensorDictBase) -> TensorDictBase:
		input_ = td.get(self.in_keys[0], NO_DEFAULT)
		conv_states = td.get(self.in_conv_states, None)
		ssm_states = td.get(self.in_ssm_states, None)
		cache_position = td.get(self.in_cache_position, None)

		input_ = input_[torch.newaxis, torch.newaxis]
		if self.cache is not None:
			self.cache.conv_states = conv_states[:, torch.newaxis]
			self.cache.ssm_states = ssm_states[:, torch.newaxis]

		output, cache = self.ssm(
			input_,
			cache_params=self.cache,
			cache_position=cache_position,
			return_dict=False,
		)
		self.cache = cache

		output = output[0, 0]
		conv_states = cache.conv_states[0]
		ssm_states = cache.ssm_states[0]

		td.set(self.out_keys[0], output)
		td.set(self.out_conv_states, conv_states)
		td.set(self.out_ssm_states, ssm_states)
		td.set(self.out_cache_position, cache_position + 1)

		return td

	@property
	def recurrent_mode(self) -> bool:
		"""Whether the module is in recurrent mode."""
		return self._reccurrent_mode

	@recurrent_mode.setter
	def recurrent_mode(self, mode: bool) -> None:
		raise RuntimeError(
			"recurrent_mode cannot be changed in-place. Call `module.set"
		)
