import math

from re import U
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from einops import rearrange, repeat

from src.models.ssm.ssm_kernel import SSKernel
from src.models.ssm.components import LinearActivation, Activation

try:
    from src.ops.fftconv import fftconv_func
except ImportError:
    fftconv_func = None

try:
    from src.ops.fused_dense import FusedDenseTD
except ImportError:
    FusedDenseTD = None


class S4(nn.Module):
    def __init__(
            self,
            d_model,
            d_state=64,
            l_max=None,
            channels=1,
            bidirectional=False,
            # Arguments for position-wise feedforward components
            activation='gelu',
            postact='glu',
            initializer=None,
            dropout=0.0, dropout_cls=nn.Dropout1d,
            bottleneck=None,
            gate=None,
            linear=False,
            use_fast_fftconv=False,
            add_uy=False,
            num_linear=1,
            conv_u=False,
            conv_y=False,
            conv_kernel_size=3,
            mimo=False,
            mimo_D = False,
            mimo_channels=2,
            y_bias=False,
            softmax_channels=False,
            tie_softmax_proj=False,
            tied_softmax_size=1,
            # SSM Kernel arguments
            **kernel_args,
        ):
        """
        d_state: the dimension of the state, also denoted by N
        l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel
        channels: can be interpreted as a number of "heads"; the SSM is a xmap from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models
        bidirectional: if True, convolution kernel will be two-sided

        Position-wise feedforward components:
        --------------------
        activation: activation in between SS and FF
        postact: activation after FF
        initializer: initializer on FF
        dropout: standard dropout argument.

        Other arguments:
        --------------------
        gate: add gated activation (GSS)
        bottleneck: reduce SSM dimension (GSS)
        linear: Remove pointwise components so that the entire module is a linear SSM

        See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"

        Other options are all experimental and should not need to be configured
        """

        super().__init__()
        self.d_model = d_model
        self.H = d_model
        self.N = d_state
        self.L = l_max
        self.bidirectional = bidirectional
        self.channels = channels
        self.linear = linear
        self.use_fast_fftconv = use_fast_fftconv
        if self.use_fast_fftconv:
            assert not linear
            assert channels == 1
            assert not mimo
            assert activation == 'gelu'
            assert dropout_cls is nn.Dropout1d
            assert fftconv_func is not None, 'Need to install fftconv'
        self.add_uy = add_uy
        self.num_linear = num_linear
        self.conv_u = conv_u
        self.conv_y = conv_y
        self.conv_kernel_size = conv_kernel_size
        self.mimo = mimo
        self.mimo_D = mimo_D
        self.mimo_channels = mimo_channels
        self.y_bias = y_bias
        self.softmax_channels = softmax_channels
        self.tie_softmax_proj = tie_softmax_proj
        self.tied_softmax_size = tied_softmax_size

        self.gate = gate
        self.bottleneck = bottleneck

        if bottleneck is not None:
            self.H = self.H // bottleneck
            self.input_linear = LinearActivation(
                self.d_model,
                self.H,
                initializer=initializer,
                # TODO: do we want activation here? GSS paper has gelu activation but from Anthropic's blogpost
                # we just want input projection to be an nn.Linear
                # activation=activation,
                activation=None,
                activate=True,
            )

        if self.mimo:
            self.H = self.H // self.mimo_channels

        if gate is not None:
            self.input_gate = LinearActivation(
                self.d_model,
                self.d_model * gate,
                initializer=initializer,
                # activation=activation,
                activation=None,
                activate=True,
            )
            self.output_gate = LinearActivation(
                self.d_model * gate,
                self.d_model,
                initializer=initializer,
                activation=None,
                activate=False,
            )

        if self.mimo and self.mimo_D:
            self.D = nn.Parameter(torch.randn(channels, self.H, self.mimo_channels, self.mimo_channels))
        else:
            self.D = nn.Parameter(torch.randn(channels, self.H))

        if self.add_uy:
            self.D_uy = nn.Parameter(torch.randn(channels, self.H))

        if self.bidirectional:
            channels *= 2

        if self.conv_u:
            self.ConvU = nn.Conv1d(self.H, self.H, self.conv_kernel_size, padding=self.conv_kernel_size - 1, groups=self.H)
        if self.conv_y:
            self.ConvY = nn.Conv1d(self.H, self.H, self.conv_kernel_size, padding=self.conv_kernel_size - 1, groups=self.H)

        # SSM Kernel
        self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=channels, mimo=mimo, mimo_channels=mimo_channels, **kernel_args)

        if self.softmax_channels:
            in_dim = self.H
            out_dim = self.H * self.channels
            if self.tie_softmax_proj:
                in_dim = tied_softmax_size
                out_dim = self.channels * tied_softmax_size
            self.u_C_proj = LinearActivation(
                in_dim, out_dim,
                initializer=initializer,
                activation=postact,
                activate=True,
            )
            self.C_softmax = nn.Softmax(dim=-1)

        # Pointwise
        if not self.linear:
            self.activation = Activation(activation)
            self.dropout = dropout_cls(dropout)
        # position-wise output transform to mix features
        if not self.linear and gate is None:
            num_channels = self.channels if not self.softmax_channels else 1
            if self.num_linear == 1:
                self.output_linear = LinearActivation(
                    self.H*num_channels*(self.mimo_channels if self.mimo else 1),
                    self.d_model*(1 if self.gate is None else self.gate),
                    initializer=initializer,
                    activation=postact,
                    activate=True,
                )
            else:
                self.output_linear = nn.Sequential(*[
                    LinearActivation(
                        self.H*num_channels*(self.mimo_channels if self.mimo else 1),
                        self.d_model*(1 if self.gate is None else self.gate),
                        initializer=initializer,
                        activation=postact,
                        activate=True,
                    )
                    for i in range(self.num_linear)
                ])

    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B L H)
        state: (H N) never needed unless you know what you're doing

        Returns: same shape as u
        """
        L = u.size(-2)
        if self.gate is not None:
            v = self.input_gate(u)
        if self.bottleneck is not None:
            u = self.input_linear(u)

        # these have to go before the L and H are switched
        if self.softmax_channels:
            if self.tie_softmax_proj:
                u_C_projection = self.u_C_proj(rearrange(u, 'b l (h s) -> b l h s', s=self.tied_softmax_size))
                u_C_projection = rearrange(u_C_projection, 'b l h (s c) -> b l (h s) c', c=self.channels)
                u_C_projection = self.C_softmax(u_C_projection)
            else:
                u_C_projection = self.u_C_proj(u)
                u_C_projection = rearrange(u_C_projection, 'b l (h c) -> b l h c',c=self.channels)
                u_C_projection = self.C_softmax(u_C_projection)
        if self.mimo:
            u = rearrange(u, 'b l (h m) -> b h m l', m = self.mimo_channels)
        else:
            u = rearrange(u, 'b l h -> b h l')
        if self.conv_u:
            u = self.ConvU(u)[..., :L]

        # Mask out padding tokens
        # TODO handle option for mask - instead of lengths, which assumes suffix padding
        if isinstance(lengths, int):
            if lengths != L:
                lengths = torch.tensor(lengths, dtype=torch.long, device=u.device)
            else:
                lengths = None
        if lengths is not None:
            assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)]
            mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.)
            u = u * mask
        # Compute SS Kernel
        L_kernel = L if self.L is None else min(L, round(self.L / rate))
        k, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)

        # Convolution
        if self.bidirectional:
            k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
            k = F.pad(k0, (0, L)) \
                    + F.pad(k1.flip(-1), (L, 0)) \

        if not self.use_fast_fftconv:
            k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L) or (C H M M L)
            u_f = torch.fft.rfft(u.to(k.dtype), n=L_kernel+L) # (B H L) or (B H M L)
            if self.mimo:
                y_f = torch.einsum('bhnl,chmnl->bchml', u_f, k_f)
            else:
                y_f = torch.einsum('bhl,chl->bchl', u_f, k_f)
            y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L) or (B C H M L)

            if self.mimo:
                if self.mimo_D:
                    y = y + torch.einsum('bhml,chnm->bchnl', u, self.D)
                else:
                    y = y + torch.einsum('bhml,ch->bchml', u, self.D)
            else:
                # Compute D term in state space equation - essentially a skip connection
                y = y + torch.einsum('bhl,ch->bchl', u, self.D)

            if self.add_uy:
                assert not self.mimo
                uy = y * torch.einsum('bhl,ch->bchl', u, self.D_uy)
                y = y + uy

            # Compute state update
            if state is not None:
                assert not self.bidirectional, "Bidirectional not supported with state forwarding"
                y = y + k_state #
                next_state = self.kernel.forward_state(u, state)
            else:
                next_state = None

            # if using c C's, project U to r x L, softmax over r, and project channels down to one channel
            if self.softmax_channels:
                if self.mimo:
                    y = torch.einsum('bchml,blhc->bhml', y, u_C_projection).unsqueeze(1)
                else:
                    y = torch.einsum('bchl,blhc->bhl', y, u_C_projection).unsqueeze(1)

            # Reshape to flatten channels
            if self.mimo:
                y = rearrange(y, '... c h m l -> ... (c h m) l')
            else:
                y = rearrange(y, '... c h l -> ... (c h) l')

            if not self.linear:
                y = self.dropout(self.activation(y))
        else:
            assert not self.mimo, "MIMO not supported with fast FFTConv"
            assert state is None, "fast fftconv doesn't support state yet"
            batch_size, H, L = u.shape
            dropout_mask = (None if not self.training
                            else F.dropout(torch.ones(batch_size, H, device=u.device), self.dropout.p))
            y = fftconv_func(u, rearrange(k, '1 h l -> h l'), rearrange(self.D, '1 h -> h'),
                             # dropout_mask, torch.is_autocast_enabled(), True)
                             # Only set output_hbl_format if we're not gating
                             dropout_mask, True, torch.is_autocast_enabled(), self.gate is None)
            next_state = None

        if self.conv_y:
            y = self.ConvY(y)[..., :L]

        y = rearrange(y, 'b h l -> b l h')

        if self.gate is not None:
            # Do v * y instead of y * v to preserve the strides of v (contiguous in the H dimension)
            y = self.output_gate(v * y)
        else:
            if not self.linear:
                y = self.output_linear(y)

        # return y, next_state
        # We don't support returning a tuple for now (to be compatible with attention)
        return y

    def setup_step(self, **kwargs):
        self.kernel._setup_step(**kwargs)

    def step(self, u, state):
        """ Step one time step as a recurrent model. Intended to be used during validation.

        u: (B H)
        state: (B H N)
        Returns: output (B H), state (B H N)
        """
        assert not self.training

        y, next_state = self.kernel.step(u, state) # (B C H)
        y = y + u.unsqueeze(-2) * self.D
        y = rearrange(y, 'b c h -> b (c h)')
        y = self.activation(y)
        if self.transposed:
            y = self.output_linear(y.unsqueeze(-1)).squeeze(-1)
        else:
            y = self.output_linear(y)
        return y, next_state

    def default_state(self, *batch_shape, device=None):
        # kernel is not a SequenceModule so it doesn't need to adhere to same interface
        # the kernel will know the device of its own parameters
        return self.kernel.default_state(*batch_shape)

    @property
    def d_state(self):
        return self.H * self.N

    @property
    def d_output(self):
        return self.d_model

    @property
    def state_to_tensor(self):
        return lambda state: rearrange('... h n -> ... (h n)', state)


class S4DoubleGate(nn.Module):
    def __init__(
            self,
            d_model,
            d_state=64,
            l_max=None,
            head_dim=1,
            # Arguments for position-wise feedforward components
            activation='gelu', # not used
            dropout=0.0, dropout_cls=nn.Dropout1d,
            comp_k=False,
            comp_k_dstate=64,
            use_fast_fftconv=False,
            fused_bias_fc=False,
            k_activation='gelu',
            q_activation='gelu',
            layer_norm=False,
            # SSM Kernel arguments
            **kernel_args,
        ):
        """
        d_state: the dimension of the state, also denoted by N
        l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel

        Position-wise feedforward components:
        --------------------
        activation: activation in between SS and FF
        dropout: standard dropout argument.

        Other arguments:
        --------------------
        gate: add gated activation (GSS)
        bottleneck: reduce SSM dimension (GSS)

        See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"

        Other options are all experimental and should not need to be configured
        """

        super().__init__()
        self.d_model = d_model
        self.head_dim = head_dim
        assert d_model % head_dim == 0
        self.H = d_model // head_dim
        self.N = d_state
        self.L = l_max
        self.comp_k = comp_k
        self.use_fast_fftconv = use_fast_fftconv
        if self.use_fast_fftconv:
            assert dropout_cls is nn.Dropout1d
            assert dropout == 0.0
            assert fftconv_func is not None, 'Need to install fftconv'

        if fused_bias_fc and FusedDenseTD is None:
            raise ImportError('fused_dense is not installed')
        linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD

        self.input_projection = linear_cls(self.d_model, self.d_model * 3)

        if self.comp_k:
            self.comp_k_kernel = SSKernel(self.d_model, N=comp_k_dstate, L=self.L, channels=1,
                                          # mode='shift', lr=kernel_args.get('lr', None))
                                          mode='shift', lr=0.0)
                                          # **kernel_args)
            self.comp_k_D = nn.Parameter(torch.randn(self.d_model))
        else:
            self.ConvQ = nn.Conv1d(self.d_model, self.d_model, 3, padding=2, groups=self.d_model)
            self.ConvK = nn.Conv1d(self.d_model, self.d_model, 3, padding=2, groups=self.d_model)

        # SSM Kernel
        self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=1, **kernel_args)
        self.D = nn.Parameter(torch.randn(self.H))

        self.layer_norm = layer_norm
        if self.layer_norm:
            self.ln = nn.LayerNorm(self.d_model)

        # Pointwise
        self.activation = Activation(activation)
        self.dropout = dropout_cls(dropout)
        # position-wise output transform to mix features
        # Don't use FusedDense since the layout is H first
        self.output_linear = nn.Linear(self.d_model, self.d_model)

        activation_map = {
            'gelu': F.gelu,
            'sigmoid': F.sigmoid,
            'relu': F.relu,
            'identity': nn.Identity(),
        }

        self.k_activation = activation_map[k_activation]
        self.q_activation = activation_map[q_activation]


    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B L H)
        state: (H N) never needed unless you know what you're doing

        Returns: same shape as u
        """
        assert state is None
        L = u.size(-2)

        # Compute SS Kernel
        L_kernel = L if self.L is None else min(L, round(self.L / rate))
        ssm_kernel, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
        ssm_kernel = rearrange(ssm_kernel, '1 h l -> h l')

        q, k, v = self.input_projection(u).chunk(3, dim=-1)
        q, k, v = [rearrange(x, 'b l h -> b h l') for x in [q, k, v]]

        if not self.comp_k:
            q = self.ConvQ(q)[..., :L]
            k = self.ConvK(k)[..., :L]
        else:
            comp_k_kernel, _ = self.comp_k_kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
            comp_k_kernel = rearrange(comp_k_kernel, '1 h l -> h l')
            if not self.use_fast_fftconv:
                fft_size = L_kernel + L
                comp_k_kernel_f = torch.fft.rfft(comp_k_kernel, n=fft_size) # (H 2L)
                k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L)
                comp_k_out = torch.fft.irfft(comp_k_kernel_f * k_f, n=fft_size)[..., :L]
                k = comp_k_out + rearrange(self.comp_k_D, 'h -> h 1') * k
            else:
                dropout_mask = None
                # No GeLU after the SSM
                k = fftconv_func(k, comp_k_kernel, self.comp_k_D, dropout_mask, False)

        # k = F.elu(k) + 1
        k = self.k_activation(k)
        if self.layer_norm:
            k = rearrange(self.ln(rearrange(k, 'b h l -> b l h')), 'b l h -> b h l')
        # kv = k * v
        kv = (rearrange(k, 'b (d1 h) l -> b d1 1 h l', d1=self.head_dim)
              * rearrange(v, 'b (d2 h) l -> b 1 d2 h l', d2=self.head_dim))
        kv = rearrange(kv, 'b d1 d2 h l -> (b d1 d2) h l')

        if not self.use_fast_fftconv:
            fft_size = L_kernel + L
            ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # (H 2L)
            kv_f = torch.fft.rfft(kv.to(ssm_kernel.dtype), n=fft_size) # (B H 2L)
            numerator = torch.fft.irfft(ssm_kernel_f * kv_f, n=fft_size)[..., :L]
            y = numerator + rearrange(self.D, 'h -> h 1') * kv
        else:
            dropout_mask = None
            # No GeLU after the SSM
            y = fftconv_func(kv, ssm_kernel, self.D,
                             # should we set output_hbl_layout = True?
                             dropout_mask, False, torch.is_autocast_enabled(), False)

        # y = self.q_activation(q) * self.y_activation(y)
        # y = rearrange(y, 'b h l -> b l h')
        # q = rearrange(self.q_activation(q), 'b (d1 h) l -> b d1 h l', d1=self.head_dim)
        q = rearrange(self.q_activation(q), 'b (d1 h) l -> b d1 1 h l', d1=self.head_dim)
        y = rearrange(y, '(b d1 d2) h l -> b d1 d2 h l', d1=self.head_dim, d2=self.head_dim)


        # einsum is way slower than multiply and then sum.
        # y = torch.einsum('bdhl,bdehl->behl', q, y)
        # y = (q * y).sum(dim=1)
        # JIT seems to give a bit of speedup, eventually we should write this in CUDA
        y = mul_sum(q, y)
        y = rearrange(y, 'b d h l -> b l (d h)')

        next_state = None

        y = self.output_linear(y)

        # return y, next_state
        # We don't support returning a tuple for now (to be compatible with attention)
        return y


class S4GSS(nn.Module):
    def __init__(
            self,
            d_model,
            d_state=512,
            l_max=None,
            f_dim=4096,
            ssm_dim=None,
            use_fast_fftconv=False,
            dropout=0.0,  # Just to absorb the kwarg
            **kernel_args,
        ):
        """
        d_state: the dimension of the state, also denoted by N
        l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel

        Position-wise feedforward components:
        --------------------
        activation: activation in between SS and FF
        dropout: standard dropout argument.

        Other arguments:
        --------------------
        gate: add gated activation (GSS)
        bottleneck: reduce SSM dimension (GSS)

        See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"

        Other options are all experimental and should not need to be configured
        """

        super().__init__()
        self.d_model = d_model
        ssm_dim = ssm_dim or self.d_model // 4
        self.ssm_dim = ssm_dim
        self.H = ssm_dim
        self.N = d_state
        self.L = l_max
        self.f_dim = f_dim
        self.use_fast_fftconv = use_fast_fftconv
        if self.use_fast_fftconv:
            assert fftconv_func is not None, 'Need to install fftconv'

        linear_cls = nn.Linear

        self.input_projection_dense = linear_cls(self.d_model, f_dim)
        self.input_projection_ssm = linear_cls(self.d_model, ssm_dim)
        self.post_ssm_projection = linear_cls(ssm_dim, f_dim)

        # SSM Kernel
        self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=1, **kernel_args)
        self.D = nn.Parameter(torch.randn(self.H))
        self.norm = nn.LayerNorm(self.ssm_dim)

        # position-wise output transform to mix features
        # Don't use FusedDense since the layout is H first
        self.output_linear = nn.Linear(f_dim, self.d_model)

        self.gelu = F.gelu


    def forward(self, x, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B L H)
        state: (H N) never needed unless you know what you're doing

        Returns: same shape as u
        """
        assert state is None

        u = self.norm(self.gelu(self.input_projection_ssm(x)))
        v = self.gelu(self.input_projection_dense(x))
        L = u.size(-2)

        # Compute SS Kernel
        L_kernel = L if self.L is None else min(L, round(self.L / rate))
        ssm_kernel, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
        ssm_kernel = rearrange(ssm_kernel, '1 h l -> h l')

        u = rearrange(u, 'b l h -> b h l')

        if not self.use_fast_fftconv:
            fft_size = L_kernel + L
            ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # (H 2L)
            u_f = torch.fft.rfft(u.to(ssm_kernel.dtype), n=fft_size) # (B H 2L)
            numerator = torch.fft.irfft(ssm_kernel_f * u_f, n=fft_size)[..., :L]
            y = numerator + rearrange(self.D, 'h -> h 1') * u
        else:
            dropout_mask = None
            # No GeLU after the SSM
            y = fftconv_func(u, ssm_kernel, self.D,
                             dropout_mask, False, torch.is_autocast_enabled(), output_hbl_layout=True)
        
        y = rearrange(y, 'b h l -> b l h')
        y = self.post_ssm_projection(y)

        y = y * v

        y = self.output_linear(y)

        # return y, next_state
        # We don't support returning a tuple for now (to be compatible with attention)
        return y

@torch.jit.script
def mul_sum(q, y):
    return (q * y).sum(dim=1)


class H3(nn.Module):

    def __init__(
            self,
            d_model,
            d_state=64,
            l_max=None,
            head_dim=1,
            ssm_k=False,
            ssm_k_dstate=64,
            bidirectional=False,
            use_fast_fftconv=False,
            fused_bias_fc=False,
            dropout=0.0,   # Just to absorb the kwarg
            # SSM Kernel arguments
            **kernel_args,
        ):
        """
        d_state: the dimension of the state, also denoted by N
        l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel

        See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr"

        Other options are all experimental and should not need to be configured
        """

        super().__init__()
        self.d_model = d_model
        self.head_dim = head_dim
        assert d_model % head_dim == 0
        self.H = d_model // head_dim
        self.N = d_state
        self.L = l_max
        self.ssm_k = ssm_k
        self.bidirectional = bidirectional
        nchannels = 1 if not self.bidirectional else 2
        self.use_fast_fftconv = use_fast_fftconv
        if self.use_fast_fftconv:
            assert fftconv_func is not None, 'Need to install fftconv'

        if fused_bias_fc and FusedDenseTD is None:
            raise ImportError('fused_dense is not installed')
        # linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
        linear_cls = nn.Linear

        self.q_proj = linear_cls(self.d_model, self.d_model)
        self.k_proj = linear_cls(self.d_model, self.d_model)
        self.v_proj = linear_cls(self.d_model, self.d_model)

        if self.ssm_k:
            self.ssm_k_kernel = SSKernel(self.d_model, N=ssm_k_dstate, L=self.L, channels=nchannels,
                                         mode='shift', lr=kernel_args.get('lr', None))
            self.ssm_k_D = nn.Parameter(torch.randn(self.d_model))
        else:
            self.ConvQ = nn.Conv1d(self.d_model, self.d_model, 3, padding=2, groups=self.d_model)
            self.ConvK = nn.Conv1d(self.d_model, self.d_model, 3, padding=2, groups=self.d_model)

        # SSM Kernel
        self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=nchannels, **kernel_args)
        self.D = nn.Parameter(torch.randn(self.H))

        # Pointwise
        # position-wise output transform to mix features
        # Don't use FusedDense since the layout is H first
        self.output_linear = nn.Linear(self.d_model, self.d_model)

    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B L H)
        state: (H N) never needed unless you know what you're doing

        Returns: same shape as u
        """
        assert state is None
        L_og = u.size(-2)
        if self.use_fast_fftconv and L_og % 2 != 0:
            u = F.pad(u, (0, 0, 0, 1))
        L = u.size(-2)

        # Compute SS Kernel
        L_kernel = L if self.L is None else min(L, round(self.L / rate))
        ssm_kernel, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
        if not self.bidirectional:
            ssm_kernel, ssm_kernel_rev = rearrange(ssm_kernel, '1 h l -> h l'), None
        else:
            assert ssm_kernel.shape[0] == 2
            ssm_kernel, ssm_kernel_rev = ssm_kernel.unbind(0)

        # q = self.q_proj(u)
        # k = self.k_proj(u)
        # v = self.v_proj(u)
        # q, k, v = [rearrange(x, 'b l h -> b h l') for x in [q, k, v]]

        u = rearrange(u, 'b l h -> (b l) h')
        # We want q, k, v to be in fp16/bf16 if running under AMP. What's the right way to get the
        # dtype when running under AMP?
        q = self.q_proj.weight @ u.T
        q = q + self.q_proj.bias.to(q.dtype).unsqueeze(-1)
        k = self.k_proj.weight @ u.T + self.k_proj.bias.to(q.dtype).unsqueeze(-1)
        v = self.v_proj.weight @ u.T + self.v_proj.bias.to(q.dtype).unsqueeze(-1)
        q, k, v = [rearrange(x, 'h (b l) -> b h l', l=L) for x in [q, k, v]]

        if not self.ssm_k:
            q = self.ConvQ(q)[..., :L]
            k = self.ConvK(k)[..., :L]
        else:
            ssm_k_kernel, _ = self.ssm_k_kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
            if not self.bidirectional:
                ssm_k_kernel, ssm_k_kernel_rev = rearrange(ssm_k_kernel, '1 h l -> h l'), None
            else:
                assert ssm_k_kernel.shape[0] == 2
                ssm_k_kernel, ssm_k_kernel_rev = ssm_k_kernel.unbind(0)
            if not self.use_fast_fftconv:
                fft_size = L_kernel + L
                ssm_k_kernel_f = torch.fft.rfft(ssm_k_kernel, n=fft_size) # (H 2L)
                if ssm_k_kernel_rev is not None:
                    ssm_k_kernel_rev_f = torch.fft.rfft(ssm_k_kernel_rev, n=fft_size) # (H 2L)
                    ssm_k_kernel_f = ssm_k_kernel_f + ssm_k_kernel_rev_f.conj()
                k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L)
                comp_k_out = torch.fft.irfft(ssm_k_kernel_f * k_f, n=fft_size)[..., :L]
                k = comp_k_out + rearrange(self.ssm_k_D, 'h -> h 1') * k
            else:
                dropout_mask = None
                # No GeLU after the SSM
                # We want output_hbl=True so that k has the same layout as q and v for the next
                # fftconv
                k = fftconv_func(k, ssm_k_kernel, self.ssm_k_D, dropout_mask, False,
                                 False, True, k_rev=ssm_k_kernel_rev)
                # This line below looks like it doesn't do anything, but it gets the stride right
                # for the case batch_size=1. In that case k has stride (L, L, 1), but q and v has
                # stride (H * L, L, 1). The two strides are equivalent because batch_size=1, but
                # the C++ code doesn't like that.
                k = rearrange(rearrange(k, 'b h l -> h b l'), 'h b l -> b h l')

        if not self.use_fast_fftconv:
            fft_size = L_kernel + L
            # kv = k * v
            kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim)
                    * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=self.head_dim))  # b d1 d2 h l
            kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size
            ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size)  # h L+1
            if ssm_kernel_rev is not None:
                ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size)  # h L+1
                ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj()
            y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :L]  # b d1 d2 h l
            y = y + kv * self.D.unsqueeze(-1)  # b d1 d2 h l
            q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim)
            # einsum is way slower than multiply and then sum.
            if self.head_dim > 1:
                y = mul_sum(y, q)
                y = rearrange(y, 'b d h l -> b (d h) l')
            else:
                y = rearrange(y * q, 'b 1 1 h l -> b h l')
        else:
            dropout_mask = None
            # No GeLU after the SSM
            # Set output_hbl_layout=True since we'll be doing a matmul right after
            y = fftconv_func(k, ssm_kernel, self.D,
                             dropout_mask, False, torch.is_autocast_enabled(), True,
                             v, self.head_dim, q, k_rev=ssm_kernel_rev)

        y = rearrange(y, 'b h l -> b l h')

        next_state = None

        y = self.output_linear(y)
        if L_og < L:
            y = y[:, :L_og, :]

        # return y, next_state
        # We don't support returning a tuple for now (to be compatible with attention)
        return y

    def setup_step(self, **kwargs):
        self.kernel._setup_step(**kwargs)
        # if self.ssm_k:
        #     self.ssm_k_kernel.setup_step(**kwargs)

    def step(self, u, state):
        """ Step one time step as a recurrent model. Intended to be used during validation.

        u: (B H)
        state: (B H N)
        Returns: output (B H), state (B H N)
        """
        assert not self.training
        assert self.ssm_k

        q = self.q_proj(u)
        k = self.k_proj(u)
        v = self.v_proj(u)
        # kp, next_state = self.ssm_k_kernel.step(k, state) # (B C H)
        kp, next_state = self.kernel.step(k, state) # (B C H)
        kp = rearrange(kp, 'b 1 h -> b h')
        kp = kp + k * self.ssm_k_D
        kv = kp * v
        y, next_state = self.kernel.step(kv, state)
        y = rearrange(y, 'b 1 h -> b h')
        y = y + kv * self.D
        y = y * q
        y = self.output_linear(y)
        return y, next_state

    def default_state(self, *batch_shape, device=None):
        # kernel is not a SequenceModule so it doesn't need to adhere to same interface
        # the kernel will know the device of its own parameters
        return self.kernel.default_state(*batch_shape)

    @property
    def d_state(self):
        return self.H * self.N

    @property
    def d_output(self):
        return self.d_model

    @property
    def state_to_tensor(self):
        return lambda state: rearrange('... h n -> ... (h n)', state)


def test_state(random_init=False, **kwargs):
    # B = 1
    # H = 64
    # N = 64
    # L = 1024
    B = 2
    H = 3
    N = 4
    L = 8
    s4 = S4(H, d_state=N, l_max=L, **kwargs)
    s4.to(device)
    s4.eval()
    # for module in s4.modules():
        # if hasattr(module, 'setup_step'): module.setup_step()
    s4.setup_step()

    u = torch.ones(B, H, L).to(device)
    initial_state = s4.default_state(B)
    if random_init:
        if initial_state.size(-1) == N:
            initial_state = initial_state[..., :N//2]
            initial_state = torch.randn_like(initial_state)
            initial_state = torch.cat([initial_state, initial_state.conj()], dim=-1)
        else:
            initial_state = torch.randn_like(initial_state)

    state = initial_state.clone()
    y, final_state = s4(u, state=state)
    print("output:\n", y, y.shape)
    print("final state:\n", final_state, final_state.shape)

    # Use Stepping
    s4.setup_step()
    state = initial_state.clone()
    ys = []
    for u_ in torch.unbind(u, dim=-1):
        y_, state = s4.step(u_, state=state)
        ys.append(y_)
    ys = torch.stack(ys, dim=-1)
    print("step outputs:\n", ys)
    print("step final state:\n", state)

    # Use Chunking

    chunks = 4
    state = initial_state.clone()
    ys = []
    for u_ in u.chunk(chunks, dim=-1):
        y_, state = s4(u_, state=state)
        ys.append(y_)
    ys = torch.cat(ys, dim=-1)
    print("chunk outputs:\n", ys)
    print("chunk final state:\n", state)
    print("chunk output error:")
    utils.compare_outputs(y, ys)
    print("chunk final state error:")
    utils.compare_outputs(final_state, state)


if __name__ == '__main__':
    from benchmark import utils
    torch.manual_seed(42)

    device = 'cuda' # 'cpu'
    device = torch.device(device)

    # test_state(random_init=True, mode='nplr', measure='legt', rank=2, channels=2)
    # test_state(random_init=False, mode='diag', measure='legs', rank=1)
    test_state(random_init=True, mode='diag', measure='legs', rank=1, disc='zoh', channels=3)
