import torch
from torch import nn

from norse.torch.utils.state import _is_module_stateful

class Lift(torch.nn.Module):
    """
        Adopted from Norse Library (https://github.com/norse/norse)

        and

        https://github.com/Intelligent-Computing-Lab-Yale/NDA_SNN/blob/main/models/VGG_models.py

       Lift applies a given torch.nn.Module over
       a temporal sequence. In other words this module
       applies the given torch.nn.Module N times, where N
       is the outer dimension in the provided tensor.

    Parameters:
        module: Module to apply

    Examples:

        >>> batch_size = 16
        >>> seq_length = 1000
        >>> in_channels = 64
        >>> out_channels = 32
        >>> conv2d = Lift(torch.nn.Conv2d(in_channels, out_channels, 5, 1))
        >>> data = torch.randn(seq_length, batch_size, 20, 30)
        >>> output = conv2d(data)


        >>> data = torch.randn(seq_length, batch_size, in_channels, 20, 30)
        >>> module = torch.nn.Sequential(
        >>>     Lift(torch.nn.Conv2d(in_channels, out_channels, 5, 1)),
        >>>     LIF(),
        >>> )
        >>> output, _ = module(data)
    """

    def __init__(self, module: torch.nn.Module, return_state: bool = False):
        super(Lift, self).__init__()
        self.is_stateful = _is_module_stateful(module)
        self.lifted_module = module
        self.return_state = return_state

    def forward(
        self, x: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        """Apply the module over the input along the 0-th (time) dimension
        and accumulate the outputs in an output tensor.

        Parameters:
            x : torch.Tensor or tuple[torch.Tensor, torch.Tensor]

        Note:
            If the input is a tuple of two tensors, the second tuple entry will be ignored.
        """
        state = None
        if isinstance(x, tuple):
            x, state = x

        if self.is_stateful:
            return stateful_forward(self.lifted_module, x, state, self.return_state)
        return nonstateful_forward(self.lifted_module, x)

def stateful_forward(
    model: nn.Module, x: torch.Tensor, state: torch.Tensor, return_state: bool = False
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
    """Apply the module over the input along the 0-th (time) dimension
    and accumulate the outputs in an output tensor.

    Parameters:
        x : torch.Tensor
            - Shape (T, B, ...)
        state : torch.Tensor
    """
    T = x.shape[0]
    outputs = []
    for ts in range(T):
        out, state = model(x[ts], state)
        outputs.append(out)

    outputs = torch.stack(outputs)
    if return_state:
        return outputs, state
    return outputs

def nonstateful_forward(
    model: nn.Module, x: torch.Tensor
) -> torch.Tensor:
    """Apply the module over the input along the 0-th (time) dimension
    and accumulate the outputs in an output tensor.

    Parameters:
        x : torch.Tensor
            - Shape (T, B, ...)
    """
    T, B, *spatial_dims = x.shape
    out: torch.Tensor = model(x.reshape(T * B, *spatial_dims))
    _, *spatial_dims = out.shape
    out = out.view(T, B, *spatial_dims).contiguous()
    return out
