"""SSM convolution kernels.

SSKernelNPLR is the S4 kernel, implementing the 'normal plus low-rank' algorithm from the original S4 paper. This stores parameters A, B, C, dt, and calling it creates the SSM convolution kernel bar{K}.

A much simpler version SSKernelSlow is included for illustration purposes: it has the same output, but uses the naive algorithm which is much slower. This module is meant for testing and exposition, to understand what the State Space Kernel actually does.

SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A.

SSKernel wraps these with common options and handles the initialization.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from einops import rearrange, repeat
from opt_einsum import contract

from src.models.ssm import hippo
from src.models.ssm import dplr
from src.ops.krylov import krylov, power
from src.ops.companion_krylov import companion_krylov, companion_from_p

from src.utils.utils import get_logger

logger = get_logger(__name__)

from src.ops.vandermonde import log_vandermonde_naive

try:
    import pykeops
    from src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose

    has_pykeops = True
    logger.info("Pykeops installation found.")
except ImportError:
    has_pykeops = False
    from src.ops.vandermonde import log_vandermonde_naive as log_vandermonde
    from src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose
    logger.warning(
        "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency."
    )
from src.ops.vandermonde import log_vandermonde_fast


_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
_c2r = torch.view_as_real
_r2c = torch.view_as_complex

if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
    _resolve_conj = lambda x: x.conj().resolve_conj()
else:
    _resolve_conj = lambda x: x.conj()


class OptimModule(nn.Module):
    """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class SSKernelSlow(OptimModule):
    """Slow version of SSKernel function for illustration and benchmarking.
    - Caches discretized matrices A^(dt), B^(dt)
    - Computes K_L(A^dt, B^dt, C)
    Usage:
    ```
    krylov = SSKernelSlow(L, A, B, C, log_dt)()
    ```
    Result is expected to be equal to SSKernelNPLR(L, w, P, B, C, log_dt, P)() if A = w - PP^*
    """

    @staticmethod
    def bilinear(dt, A, B=None, mimo=False):
        """
        dt: (...) timescales
        A: (... N N)
        B: (... N)
        """
        N = A.shape[-1]
        I = torch.eye(N).to(A)
        A_backwards = I - dt[:, None, None] / 2 * A
        A_forwards = I + dt[:, None, None] / 2 * A

        if B is None:
            dB = None
        else:
            if mimo:
                dB = dt[..., None, None] * torch.linalg.solve(
                    A_backwards.unsqueeze(1), B.unsqueeze(-1)
                ).squeeze(-1) # (... N)
            else:
                dB = dt[..., None] * torch.linalg.solve(
                    A_backwards, B.unsqueeze(-1)
                ).squeeze(-1) # (... N)

        dA = torch.linalg.solve(A_backwards, A_forwards)  # (... N N)
        return dA, dB


    def __init__(self, A, B, C, log_dt, L=None, lr=None, comp=False, mimo=False, mimo_channels=2):
        super().__init__()
        self.comp = comp
        self.L = L
        self.N = A.size(-1)
        self.H = log_dt.size(-1)
        self.mimo = mimo
        self.mimo_channels = mimo_channels

        if self.mimo:
            C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.mimo_channels, self.N))) # (C, H, M, N)
        else:
            C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N)

        # Register parameters
        if lr is None or isinstance(lr, float): lr_dict = {}
        else: lr_dict, lr = lr, None
        self.register("log_dt", log_dt, lr_dict.get('dt', lr))
        self.register("A", _c2r(A), lr_dict.get('A', lr))
        self.register("B", _c2r(B), lr_dict.get('B', lr))
        # NOTE leaving in complex form for convenience, which means it currently won't work with DDP and might have incorrect param count
        # This class shouldn't be used for anything other than testing and simple ablations, so this is fine
        # self.register("C", C.conj().resolve_conj(), True, None, wd=None)
        self.C = nn.Parameter(_c2r(_resolve_conj(C)))

        # Cache if nothing is trained
        self.trainable = lr_dict.get('dt', lr) > 0.0 or lr_dict.get('A', lr) > 0.0 or lr_dict.get('B', lr) > 0.0
        self.K = None # Compute in forward pass since that ensures correct device

    def forward(self, state=None, rate=1.0, L=None):
        if L is None: L = self.L
        # This class shouldn't support the more advanced sampling and variable length functionalities, since it's just for testing
        # But the code from NPLR could be pasted here if desired
        assert rate == 1.0 and L is not None

        A = _r2c(self.A)
        # if self.comp:
        #     A = torch.diag_embed(A.new_ones(A.shape[:-1]+(A.shape[-1]-1,)), -1) + A[:, None, :] # (H N N)

        if self.trainable or self.K is None:
            # Need to calculate dA, dB
            if self.comp:
                dA = A.new_zeros((self.H, self.N, self.N))
                dA[:, 1:, :-1] = torch.eye(self.N-1, dtype=A.dtype, device=A.device)
                A = A / torch.linalg.norm(A,ord=1,dim=-1,keepdims=True)
                dA[:, :, -1] = A
                # dA = torch.diag_embed(A.new_ones(A.shape[:-1]+(A.shape[-1]-1,)), -1) + A[:, None, :] # (H N N)
                dB = _r2c(self.B).expand((self.H, self.mimo_channels, self.N) if self.mimo else (self.H, self.N))
                dA = dA.real + 0j
                dB = dB.real + 0j
            else:
                dA, dB = SSKernelSlow.bilinear(torch.exp(self.log_dt), A, _r2c(self.B), mimo=self.mimo)

        if self.trainable:
            k = krylov(L, dA, dB, _r2c(self.C), mimo=self.mimo)  # (H L) or (C H M L) if mimo
        else:
            if self.K is None:
                self.K = krylov(L, dA, dB, mimo=self.mimo) # (H N L)
            k = contract('hnl,chn->chl', self.K[..., :L], _r2c(self.C))
        k = k.float()

        if state is not None:
            state = state.to(A)
            state = contract("... n m, ... m -> ... n", A, state)
            k_state = krylov(L, A, state.unsqueeze(-3), _r2c(self.C))
            k_state = k_state.float()
        else:
            k_state = None
        return k, k_state
        # return k.to(torch.float)

    def default_state(self, *batch_shape):
        state = torch.zeros(*batch_shape, self.H, self.N, dtype=self.C.dtype, device=self.C.device)
        return state

    def _setup_state(self):
        A, B = _r2c(self.A), _r2c(self.B)
        dA, dB = SSKernelSlow.bilinear(torch.exp(self.log_dt), A, B)
        return dA, dB

    def _setup_step(self):
        self.dA, self.dB = self._setup_state()
        self.dC = _r2c(self.C)

    def step(self, u, state):
        next_state = contract("h m n, b h n -> b h m", self.dA, state) \
                + contract("h n, b h -> b h n", self.dB, u)
        y = contract("c h n, b h n -> b c h", self.dC, next_state)
        return y.real, next_state


class SSKernelDiag(OptimModule):
    """Version using (complex) diagonal state matrix (S4D)"""

    def __init__(
        self,
        A, B, C, log_dt,
        L=None,
        disc='bilinear',
        real_type='exp',
        lr=None,
        bandlimit=None,
        mimo=False,
        mimo_channels=2,
    ):

        super().__init__()

        self.L = L
        self.disc = disc
        self.bandlimit = bandlimit
        self.real_type = real_type
        self.mimo = mimo
        self.mimo_channels = mimo_channels

        # Rank of low-rank correction
        assert A.size(-1) == C.size(-1)
        self.H = log_dt.size(-1)
        self.N = A.size(-1)
        assert A.size(-2) == B.size(-3 if self.mimo else -2) # Number of independent SSMs trained
        assert self.H % A.size(-2) == 0
        self.n_ssm = A.size(-2)
        self.repeat = self.H // A.size(0)

        self.channels = C.shape[0]
        self.C = nn.Parameter(_c2r(_resolve_conj(C)))

        # Register parameters
        if lr is None or isinstance(lr, float): lr_dict = {}
        else: lr_dict, lr = lr, None

        self.register("log_dt", log_dt, lr_dict.get('dt', lr))
        # self.register("A", _c2r(A), lr_dict.get('A', lr))
        self.register("B", _c2r(B), lr_dict.get('B', lr))
        self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr))
        self.register("A_imag", A.imag, lr_dict.get('A', lr))

    def _A_init(self, A_real):
        A_real = torch.clamp(A_real, max=-1e-4)
        if self.real_type == 'none':
            return -A_real
        elif self.real_type == 'exp':
            return torch.log(-A_real) # Some of the HiPPO methods have real part 0
        elif self.real_type == 'relu':
            return -A_real
        elif self.real_type == 'sigmoid':
            return torch.logit(-A_real)
        elif self.real_type == 'softplus':
            return torch.log(torch.exp(-A_real)-1)
        else: raise NotImplementedError

    def _A(self):
        # Get the internal A (diagonal) parameter
        if self.real_type == 'none':
            A_real = -self.inv_A_real
        elif self.real_type == 'exp':
            A_real = -torch.exp(self.inv_A_real)
        elif self.real_type == 'relu':
            # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it
            A_real = -F.relu(self.inv_A_real)-1e-4
        elif self.real_type == 'sigmoid':
            A_real = -F.sigmoid(self.inv_A_real)
        elif self.real_type == 'softplus':
            A_real = -F.softplus(self.inv_A_real)
        else: raise NotImplementedError
        A = A_real + 1j * self.A_imag
        return A

    def forward(self, L, state=None, rate=1.0, u=None):
        """
        state: (B, H, N) initial state
        rate: sampling rate factor
        L: target length

        returns:
        (C, H, L) convolution kernel (generally C=1)
        (B, H, L) output from initial state
        """
        dt = torch.exp(self.log_dt) * rate # (H)
        C = _r2c(self.C) # (C H N)
        A = self._A() # (H N)

        B = _r2c(self.B)
        if self.mimo:
            B = repeat(B, 't m n -> 1 (v t) m n', v=self.repeat)
        else:
            B = repeat(B, 't n -> 1 (v t) n', v=self.repeat)

        if self.bandlimit is not None:
            freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N)
            mask = torch.where(freqs < self.bandlimit * .5, 1, 0)
            C = C * mask

        # Incorporate dt into A
        A = repeat(A, 't n -> (v t) n', v=self.repeat)
        dtA = A * dt.unsqueeze(-1)  # (H N)


        # Augment B with state
        if state is not None:
            s = state / dt.unsqueeze(-1)
            if self.disc == 'bilinear':
                s = s * (1. + dtA/2)
            elif self.disc == 'zoh':
                s = s * dtA * dtA.exp() / (dtA.exp() - 1.)
            B = torch.cat([s, B], dim=-3) # (1+B H N)

        if not self.mimo:
            C = (B[:, None, :, :] * C).view(-1, self.H, self.N)
        if self.disc == 'zoh':
            assert not self.mimo
            # Power up
            C = C * (torch.exp(dtA)-1.) / A
            if log_vandermonde_fast is not None and C.shape[0] == 1:
                K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L)
            else:
                K = log_vandermonde(C, dtA, L) # (H L)
        elif self.disc == 'bilinear':
            if self.mimo:
                C = C * ((1. - dtA/2).reciprocal() * dt.unsqueeze(-1)).unsqueeze(1) # or * dtA / A
            else:
                C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A
            dA = (1. + dtA/2) / (1. - dtA/2)
            if self.mimo:
                K = log_vandermonde_naive(rearrange(C, 'c h m d->m d h c').contiguous(), dA.log(), L, conj=False)
            else:
                if log_vandermonde_fast is not None:
                    dA_log = dA.log()
                    dA_log_repeat = repeat(dA_log, 'h d -> (c h) d', c=C.shape[0])
                    K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log_repeat, L),
                                  '(c h) d -> c h d', c=C.shape[0])
                else:
                    K = log_vandermonde(C, dA.log(), L)
        elif self.disc == 'dss':
            assert not self.mimo
            # Implementation from DSS meant for case when real eigenvalues can be positive
            P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L]
            A_gt_0 = A.real > 0                                      # [N]
            if A_gt_0.any():
                with torch.no_grad():
                    P_max = dtA * (A_gt_0 * (L-1))                   # [H N]
                P = P - P_max.unsqueeze(-1)                          # [H N L]
            S = P.exp()                                              # [H N L]

            dtA_neg = dtA * (1 - 2*A_gt_0)                           # [H N]
            num = dtA_neg.exp() - 1                                  # [H N]
            den = (dtA_neg * L).exp() - 1                            # [H N]

            # Inline reciprocal function for DSS logic
            x = den * A
            x_conj = _resolve_conj(x)
            r = x_conj / (x*x_conj + 1e-7)

            C = C * num * r             # [C H N]
            K = contract('chn,hnl->chl', C, S).float()
        else: assert False, f"{self.disc} not supported"
        if self.mimo:
            K = rearrange(K, 'm d (c h) l -> m d c h l', c = self.channels)
            K = torch.einsum('m d c h l, b h n d -> c h m n l', K, B)
            return 2*K.real, None # need to fix this
        else:
            K = K.view(-1, self.channels, self.H, L) # (1+B C H L)
            if state is not None:
                K_state = K[:-1, ...] # (B C H L)
            else:
                K_state = None
        
        K = K[-1, ...] # (C H L)
        return K, K_state

    def _setup_step(self):
        # These methods are organized like this to be compatible with the NPLR kernel interface
        dt = torch.exp(self.log_dt) # (H)
        B = _r2c(self.B) # (H N)
        C = _r2c(self.C) # (C H N)
        self.dC = C
        A = self._A() # (H N)

        # Incorporate dt into A
        dtA = A * dt.unsqueeze(-1)  # (H N)
        if self.disc == 'zoh':
            self.dA = torch.exp(dtA) # (H N)
            self.dB = B * (torch.exp(dtA)-1.) / A # (C H N)
        elif self.disc == 'bilinear':
            self.dA = (1. + dtA/2) / (1. - dtA/2)
            self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A


    def default_state(self, *batch_shape):
        C = _r2c(self.C)
        state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device)
        return state

    def step(self, u, state):
        next_state = contract("h n, b h n -> b h n", self.dA, state) \
                + contract("h n, b h -> b h n", self.dB, u)
        y = contract("c h n, b h n -> b c h", self.dC, next_state)
        return 2*y.real, next_state

    def forward_state(self, u, state):
        self._setup_step()
        AL = self.dA ** u.size(-1)
        u = u.flip(-1).to(self.dA).contiguous() # (B H L)
        v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
        next_state = AL * state + v
        return next_state


class SSKernelFromTwoKernels(nn.Module):
    def __init__(self, kernel1, kernel2):
        super().__init__()
        self.kernel1 = kernel1
        self.kernel2 = kernel2

    def forward(self, state=None, L=None, rate=None):
        k1, k1_state = self.kernel1(state=state, L=L, rate=rate)
        k2, k2_state = self.kernel2(state=state, L=L, rate=rate)

        k = torch.cat([k1, k2], dim=1)
        if k1_state is not None and k2_state is not None:
            state = torch.cat([k1_state, k2_state], dim=1) # is this correct?
        else:
            state = None

        return k, state


class SSKernelComp(OptimModule):

    def __init__(self, A, B, C, L=None, fast_companion=False, lr=None):
        """
        A: (H, d), real. The vector represented by A is the last column of the companion matrix.
        B: (H, d), real
        C: (H, d), real. If fast_companion, we actually store C_tilde = C (I - A^L).
        """
        super().__init__()
        self.L = L
        self.N = A.size(-1)
        self.H = B.shape[0]
        self.fast_companion = fast_companion
        if self.fast_companion:
            assert self.L is not None

        # Register parameters
        if lr is None or isinstance(lr, float): lr_dict = {}
        else: lr_dict, lr = lr, None
        self.register("A", A, lr_dict.get('A', lr))
        self.register("B", B, lr_dict.get('B', lr))
        self.C = nn.Parameter(C)

    def forward(self, state=None, rate=1.0, L=None):
        assert state is None, 'state is not supported yet'
        if self.fast_companion:
            assert L is None or L == self.L
        if L is None:
            L = self.L
        # This class doesn't support variable length functionalities, since it's a discrete SSM
        assert rate == 1.0 and L is not None

        # A = self.A
        # A = A / (torch.linalg.norm(A, ord=1, dim=-1, keepdims=True) + 1e-4)
        A = torch.renorm(self.A, p=1, dim=-1, maxnorm=1.0)
        if not self.fast_companion:
            k = krylov(L, companion_from_p(A), self.B, self.C)  # (H L)
        else:
            k = companion_krylov(L, A, self.B, c_tilde=self.C)

        k = k.float()  # Otherwise it could be dtype half
        k_state = None
        return k, k_state


class SSKernelShift(OptimModule):

    def __init__(self, B, C, L=None, lr=None):
        """
        B: (H, d), real
        C: (H, d), real
        """
        super().__init__()
        self.L = L
        self.N = B.size(-1)
        self.H = B.shape[0]

        # Register parameters
        if lr is None or isinstance(lr, float): lr_dict = {}
        else: lr_dict, lr = lr, None
        self.register("B", B, lr_dict.get('B', lr))
        self.C = nn.Parameter(C)

    def forward(self, state=None, rate=1.0, L=None):
        assert state is None, 'state is not supported yet'
        if L is None:
            L = self.L
        # This class doesn't support variable length functionalities, since it's a discrete SSM
        assert rate == 1.0 and L is not None

        B_f = torch.fft.rfft(self.B, n=2*self.N)
        C_f = torch.fft.rfft(self.C, n=2*self.N)
        k = torch.fft.irfft(B_f.conj() * C_f, n=2*self.N)[..., :min(self.N, L)]
        # If self.N < L, need to pad with zeros to reach length L
        if self.N < L:
            k = F.pad(k, (0, L - self.N))
        k = k.float()  # Otherwise it could be dtype half
        k_state = None
        return k, k_state


class SSKernel(nn.Module):
    """Wrapper around SSKernel parameterizations.

    The SSKernel is expected to support the interface
    forward()
    default_state()
    _setup_step()
    step()
    """

    def __init__(
        self,
        H,
        N=64,
        L=None,
        measure="legs",
        rank=1,
        channels=1,
        dt_min=0.001,
        dt_max=0.1,
        deterministic=False,
        lr=None,
        mode="nplr",
        n_ssm=None,
        verbose=False,
        measure_args={},
        mimo=False,
        mimo_channels=2,
        **kernel_args,
    ):
        """State Space Kernel which computes the convolution kernel $\\bar{K}$

        H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config.
        N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much.
        L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known.
        measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin)
        rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt"
        channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead
        dt_min, dt_max: min and max values for the step size dt (\Delta)
        mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing
        n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H
        lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
        """
        super().__init__()
        self.N = N
        self.H = H
        dtype, cdtype = torch.float, torch.cfloat
        self.channels = channels
        self.n_ssm = n_ssm if n_ssm is not None else H
        self.mode = mode
        self.verbose = verbose
        self.kernel_args = kernel_args

        # Generate dt
        if deterministic:
            log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H))
        else:
            log_dt = torch.rand(self.H, dtype=dtype) * (
                math.log(dt_max) - math.log(dt_min)
            ) + math.log(dt_min)

        # Compute the preprocessed representation
        if mode == "real":  # For testing and ablation purposes
            # Generate A, B
            A, B = hippo.transition(measure, self.N)
            A = torch.as_tensor(A, dtype=dtype)
            B = torch.as_tensor(B, dtype=dtype)[:, 0]

            # Generate C
            if deterministic:
                C = torch.zeros(channels, self.H, self.N, dtype=dtype)
                C[..., :1] = 1.0
            else:
                C = torch.randn(channels, self.H, self.N, dtype=dtype)

            self.kernel = SSKernelSlow(
                A, B, C, log_dt, L=L,
                lr=lr,
            )
        else:
            w, P, B, V = dplr.combination(measure, self.N, rank, self.n_ssm, **measure_args)

            # Broadcast C to have H channels
            if deterministic:
                C = torch.zeros(channels, self.H, self.N, dtype=cdtype)
                C[:, :, :1] = 1.
                C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C
            else:
                C = torch.randn(channels, self.H, self.N//2, dtype=cdtype)

            # Broadcast other parameters to have n_ssm copies
            assert self.n_ssm % B.size(-2) == 0 \
                    and self.n_ssm % P.size(-2) == 0 \
                    and self.n_ssm % w.size(-2) == 0
            # Broadcast tensors to n_ssm copies
            # These will be the parameters, so make sure tensors are materialized and contiguous
            B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous()
            P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous()
            w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous()
            C = C.contiguous()

            if mimo:
                if mode == "diag":
                    if not measure.startswith("diag"):
                        logger.warning("Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.")
                    B = repeat(B, 't n -> t v n', v=mimo_channels)
                    C = repeat(C, 'c h n -> c h v n', v=mimo_channels)
                    self.kernel = SSKernelDiag(
                        w, B, C, log_dt, L=L,
                        lr=lr, mimo=mimo, mimo_channels=mimo_channels,
                        **kernel_args,
                    )
                else:
                    if mode == "comp":
                        A = torch.randn_like(_conj(w))
                    else:
                        A = torch.diag_embed(_conj(w)) \
                                - contract("... r p, ... r q -> ... p q", _conj(P), _conj(P).conj())
                    B = repeat(B, 't n -> t v n', v=mimo_channels)
                    C = repeat(C, 'c h n -> c h v n', v=mimo_channels)
                    self.kernel = SSKernelSlow(
                        A, _conj(B), _conj(C), log_dt, L=L,
                        lr=lr, mimo=mimo, mimo_channels=mimo_channels, comp=(mode == "comp"),
                    )
            elif mode == "diag":
                if not measure.startswith("diag"):
                    logger.warning("Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.")
                C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm)
                self.kernel = SSKernelDiag(
                    w, B, C, log_dt, L=L,
                    lr=lr,
                    **kernel_args,
                )
            elif mode == "slow":  # Mainly for testing
                A = torch.diag_embed(_conj(w)) \
                        - contract("... r p, ... r q -> ... p q", _conj(P), _conj(P).conj())
                self.kernel = SSKernelSlow(
                    A, _conj(B), _conj(C), log_dt, L=L,
                    lr=lr,
                )
            elif mode == "comp":
                # Init A to be shift
                # A = torch.randn(self.H, self.N)
                A = torch.zeros(self.H, self.N)
                # Initializing B to be e_1 or randn makes it not learn on induction task, but
                # if we init to all ones then it learns fine.
                # B = torch.zeros(self.H, self.N)
                # B[..., 0] = 1.0
                # B = torch.randn(self.H, self.N)
                B = torch.ones(self.H, self.N)
                C = torch.randn(1, self.H, self.N)
                self.kernel = SSKernelComp(A, B, C, L=L, lr=lr, **kernel_args)
            elif mode == "shift":
                # Initializing B to be e_1
                B = torch.zeros(self.H, self.N)
                B[..., 0] = 1.0
                # Match torch.Conv1d init
                C = torch.randn(self.H, self.channels, self.N)
                nn.init.kaiming_uniform_(C, a=math.sqrt(5))
                C = rearrange(C, 'h c n -> c h n')
                self.kernel = SSKernelShift(B, C, L=L, lr=lr, **kernel_args)
            elif mode == "diag_comp":
                companion_H = self.n_ssm // 2

                # companion
                A = torch.randn_like(_conj(w[:companion_H]))
                kernel1 = SSKernelSlow(
                    A, _conj(B[:companion_H]), _conj(C[:, :companion_H]), log_dt[:companion_H], L=L,
                    lr=lr,
                    comp=True,
                )

                if not measure.startswith("diag"):
                    logger.warning("Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.")
                C = C[:, companion_H:] * repeat(B[companion_H:], 't n -> (v t) n', v=1)
                kernel2 = SSKernelDiag(
                    w[companion_H:], B[companion_H:], C, log_dt[companion_H:], L=L,
                    lr=lr,
                    **kernel_args,
                )

                self.kernel = SSKernelFromTwoKernels(kernel1, kernel2)
            else: raise NotImplementedError(f"{mode=} is not valid")

    def forward(self, state=None, L=None, rate=None):
        return self.kernel(state=state, L=L, rate=rate)

    @torch.no_grad()
    def forward_state(self, u, state):
        """ Forward the state through a sequence, i.e. computes the state after passing chunk through SSM

        state: (B, H, N)
        u: (B, H, L)

        Returns: (B, H, N)
        """

        if hasattr(self.kernel, "forward_state"):
            return self.kernel.forward_state(u, state)

        dA, dB = self.kernel._setup_state() # Construct dA, dB matrices
        # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N)

        conj = state.size(-1) != dA.size(-1)
        if conj: state = _conj(state)

        v = contract('h n, b h l -> b h n l', dB, u.flip(-1)) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2)
        AL, v = power(u.size(-1), dA, v)
        next_state = contract("h m n, b h n -> b h m", AL, state)
        next_state = next_state + v

        if conj: next_state = next_state[..., : next_state.size(-1) // 2]
        return next_state

    def _setup_step(self, **kwargs):
        # This method is intended to be private so that setting up an S4 module with
        # ```
        # if hasattr(module, 'setup_step'): module.setup_step()
        # ```
        # will not trigger this method multiple times
        self.kernel._setup_step(**kwargs)

    def step(self, u, state, **kwargs):
        y, state = self.kernel.step(u, state, **kwargs)
        return y, state

    def default_state(self, *args, **kwargs):
        return self.kernel.default_state(*args, **kwargs)
