from typing import Tuple

import torch
import torch.nn as nn
from torchrl.data import Composite, UnboundedContinuous
from torchrl.envs.transforms import Compose, Transform
from torchrl.modules.utils import get_primers_from_module

type Key = str | Tuple[str, str]


def append_to_last(key: Key, suffix: str) -> Key:
	"""Append a suffix to the last element of a key.

	This function appends a suffix to the last element of a key, taking
	into account that the key can be a string or a tuple of strings.

	Args:
		key (Key): The key to append the suffix to.
		suffix (str): The suffix to append to the last element of the
			key.

	Returns:
		Key: The key with the suffix appended to the last element.

	Examples:
		>>> append_to_last("key", "suffix")
		'key_suffix'
		>>> append_to_last(("key1", "key2"), "suffix")
		('key1', 'key2_suffix')
	"""
	if isinstance(key, str):
		return f"{key}_{suffix}"
	else:
		return (key[0], f"{key[1]}_{suffix}")


def complex_spec(shape: torch.Size, key: Key) -> Composite:
	"""Create a composite spec for a complex number.

	This function creates a composite spec for storing complex numbers.
	The composite spec contains two unbounded continuous tensors, one
	for the real part and one for the imaginary part.

	Args:
		shape (torch.Size): The shape of the spec.
		key (Key): The key to use for the composite spec. The real part
			will be stored in a tensor whose key is the key with the
			suffix "_real", and the imaginary part will be stored in a
			tensor whose key is the key with the suffix "_imag".

	Returns:
		Composite: The composite spec for the complex number.
	"""
	return Composite(
		{
			append_to_last(key, "real"): UnboundedContinuous(shape),  # type: ignore
			append_to_last(key, "imag"): UnboundedContinuous(shape),  # type: ignore
		}
	)


def get_primers_robust(module: nn.Module) -> Transform:
	"""Get the primers from a module, or return an empty transform.

	This is a wrapper around the `get_primers_from_module` function
	from TorchRL that makes it more robust by returning an empty
	transform if the primers are not found.

	Args:
		module (nn.Module): The module to get the primers from.

	Returns:
		Transform: The primers from the module, or an empty transform if
			the primers are not found.
	"""
	primers = get_primers_from_module(module)
	if primers is None:
		return Compose()
	return primers
