# Copyright (c) 2023, Tri Dao, Albert Gu.

import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd

from einops import rearrange, repeat

try:
    from causal_conv1d import causal_conv1d_fn
    import causal_conv1d_cuda
except ImportError:
    causal_conv1d_fn = None
    causal_conv1d_cuda = None

import selective_scan_cuda


class SelectiveScanFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                return_last_state=False):
        if u.stride(-1) != 1:
            u = u.contiguous()
        if delta.stride(-1) != 1:
            delta = delta.contiguous()
        if D is not None:
            D = D.contiguous()
        if B.stride(-1) != 1:
            B = B.contiguous()
        if C.stride(-1) != 1:
            C = C.contiguous()
        if z is not None and z.stride(-1) != 1:
            z = z.contiguous()
        if B.dim() == 3:
            B = rearrange(B, "b dstate l -> b 1 dstate l")
            ctx.squeeze_B = True
        if C.dim() == 3:
            C = rearrange(C, "b dstate l -> b 1 dstate l")
            ctx.squeeze_C = True
        out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
        ctx.delta_softplus = delta_softplus
        ctx.has_z = z is not None
        last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
        if not ctx.has_z:
            ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
            return out if not return_last_state else (out, last_state)
        else:
            ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
            out_z = rest[0]
            return out_z if not return_last_state else (out_z, last_state)

    @staticmethod
    def backward(ctx, dout, *args):
        if not ctx.has_z:
            u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
            z = None
            out = None
        else:
            u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        # Here we just pass in None and dz will be allocated in the C++ code.
        du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
            u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
            False  # option to recompute out_z, not used here
        )
        dz = rest[0] if ctx.has_z else None
        dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
        dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
        return (du, ddelta, dA, dB, dC,
                dD if D is not None else None,
                dz,
                ddelta_bias if delta_bias is not None else None,
                None,
                None)


def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)


def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                      return_last_state=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    if A.is_complex():
        if is_variable_B:
            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
        if is_variable_C:
            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
    else:
        B = B.float()
        C = C.float()
    x = A.new_zeros((batch, dim, dstate))
    ys = []
    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
    if not is_variable_B:
        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
    else:
        if B.dim() == 3:
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    last_state = None
    for i in range(u.shape[2]):
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
        if not is_variable_C:
            y = torch.einsum('bdn,dn->bd', x, C)
        else:
            if C.dim() == 3:
                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
            else:
                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
        if i == u.shape[2] - 1:
            last_state = x
        if y.is_complex():
            y = y.real * 2
        ys.append(y)
    y = torch.stack(ys, dim=2) # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
    return out if not return_last_state else (out, last_state)


class MambaInnerFn(torch.autograd.Function):

    @staticmethod
    @custom_fwd
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                out_proj_weight, out_proj_bias,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, compute_attn_matrix=False, old_attention=False, compute_attn_vector=False, ablate_attn_mat=False, ablate_attn_vec=False):
        """
             xz: (batch, dim, seqlen)
        """
        Test = False
        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        assert checkpoint_lvl in [0, 1]
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        if torch.is_autocast_enabled():
            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                             if out_proj_bias is not None else None)
        if xz.stride(-1) != 1:
            xz = xz.contiguous()
        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
        x, z = xz.chunk(2, dim=1)
        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
        if Test: conv1d_bias = conv1d_bias*0
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
            x, conv1d_weight, conv1d_bias, None, True
        )
        if Test or compute_attn_matrix or compute_attn_vector or ablate_attn_mat or ablate_attn_vec : conv1d_out_no_act = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None,  False)
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
        ctx.is_variable_B = B is None
        ctx.is_variable_C = C is None
        ctx.B_proj_bias_is_None = B_proj_bias is None
        ctx.C_proj_bias_is_None = C_proj_bias is None
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if B.stride(-1) != 1:
                B = B.contiguous()
        if C is None:  # variable C
            C = x_dbl[:, -d_state:]  # (bl dstate)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if C.stride(-1) != 1:
                C = C.contiguous()
        if D is not None:
            D = D.contiguous()

        if Test: D = D*0
        out, scan_intermediates, out_z = selective_scan_cuda.fwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
        )

        H = conv1d_weight.shape[0]
        CLSposition = L - 1 
        if Test or compute_attn_matrix or ablate_attn_mat:
            if ablate_attn_mat: ablate_dict = {}
            attn_mat = compute_attn_matrix_fn(delta.to(torch.float32), delta_bias.to(torch.float32), A.to(torch.float32), B.to(torch.float32), C.to(torch.float32), L, x.shape, dtype=torch.float32)
            if ablate_attn_mat: ablate_dict["S6Attention"] = attn_mat
            M = conv2Mat(conv1d_weight,x.shape[-1]).to(x.device).transpose(-1,-2)
            if ablate_attn_mat: ablate_dict["convMat"] = M
            gate = torch.diag_embed(F.silu(z)).to(torch.float32)
            if ablate_attn_mat: ablate_dict["gate"] = gate
            Z = torch.diag_embed(torch.sigmoid(conv1d_out_no_act)).to(torch.float32)
            if ablate_attn_mat: ablate_dict["act"] = Z
            attn_mat_fin = gate @ attn_mat @ Z @M.to(attn_mat.device)
            if ablate_attn_mat: 
                ablate_dict["finalMat"] = attn_mat_fin.squeeze(-2)
                ablate_dict["withoutConv"] = (gate @ attn_mat @ Z).squeeze(-2)
                ablate_dict["withoutGate"] = (attn_mat @ Z @M.to(attn_mat.device)).squeeze(-2)
                ablate_dict["withoutAct"] = (gate @ attn_mat @ M.to(attn_mat.device)).squeeze(-2)
            if old_attention: attn_mat_fin = attn_mat
            if Test:
                semi_ssm_all = torch.einsum('bcij,bcj->bci', attn_mat_fin.to(torch.float32), x.to(torch.float32))
                semi_ssm = torch.einsum('bcij,bcj->bci', attn_mat.to(torch.float32), conv1d_out.to(torch.float32))
                semi_ssm_c = semi_ssm + (conv1d_out * rearrange(D, "d -> d 1"))
                semi_ssm_c = semi_ssm + (x * rearrange(D, "d -> d 1"))
                semi_ssm_z = semi_ssm_c *F.silu(z)
                print("test1 err", ((out-semi_ssm_c).abs()).max())
                print('test2 err', ((semi_ssm_z-semi_ssm_all).abs()).max())
                print('test3 err', ((semi_ssm_z-out_z).abs()).max())
                print('test4 err', ((semi_ssm_all-out_z).abs()).max())
                #assert (((semi_ssm_z-out_z).abs()).max() < 0.1)
                

        # Attention Vector
        if Test or compute_attn_vector or ablate_attn_vec:
            ablate_conv, ablate_act, ablate_gate = False,False,False
            if ablate_attn_vec: ablate_dict = {}
            attn_mat = compute_attn_vector_fn(delta.to(torch.float32), delta_bias.to(torch.float32), A.to(torch.float32), B.to(torch.float32), C.to(torch.float32), L, x.shape, dtype=torch.float32, pos=CLSposition,numlines=1) # B H L 4
            if ablate_attn_vec: ablate_dict["S6Attention"] = attn_mat.squeeze(-2)
            #M = conv2Matsub(conv1d_weight,x.shape[-1]).to(x.device).flip(dims=[-1])#.transpose(-1,-2) # R
            gate = (F.silu(z[:,:,CLSposition].to(torch.float32)))
            if ablate_attn_vec: ablate_dict["gate"] = gate
            #Z = torch.diag_embed(torch.sigmoid(conv1d_out_no_act[:,:,(CLSposition-2):(CLSposition+2)]))
            Z = torch.sigmoid(conv1d_out_no_act.to(torch.float32))        
            if ablate_attn_vec: ablate_dict["act"] = Z
            #Z = torch.sigmoid(x[:,:,CLSposition:(CLSposition+1)])   
            if ablate_act:
                temp = attn_mat.squeeze(-2)
            else:
                temp = attn_mat.squeeze(-2) * Z
            #M = conv2Mat(conv1d_weight,x.shape[-1]).to(x.device).transpose(-1,-2)
            #conv1 =(temp.unsqueeze(2)@M.unsqueeze(0)).squeeze(-2)
            if ablate_conv:
                conv2 = temp.squeeze(-2)
            else:
                conv2 = F.conv1d(temp.squeeze(-2), conv1d_weight.flip(dims=[-1]).unsqueeze(1).to(torch.float32), groups=H,padding=4, stride=1)[..., 4:L+4]
                conv2AblalteZ = F.conv1d(attn_mat.squeeze(-2).squeeze(-2), conv1d_weight.flip(dims=[-1]).unsqueeze(1).to(torch.float32), groups=H,padding=4, stride=1)[..., 4:L+4]
            if not ablate_gate:
                attn_vec_fin = gate.unsqueeze(-1) * conv2
            else:
                attn_vec_fin = conv2

            if ablate_attn_vec: 
                ablate_dict["finalMat"] = attn_vec_fin.squeeze(-2)
                ablate_dict["withoutConv"] = (gate.unsqueeze(-1) * temp.squeeze(-2))
                ablate_dict["withoutGate"] = conv2.squeeze(-2)
                ablate_dict["withoutAct"] = gate.unsqueeze(-1) * conv2AblalteZ.squeeze(-2)
            if old_attention: attn_vec_fin = attn_mat.squeeze(-2)
            cls_val = (attn_vec_fin.to(torch.float32) * x.to(torch.float32)).sum(-1)

            if Test:
                cls_val = (attn_vec_fin.to(torch.float32) * x.to(torch.float32)).sum(-1) 
                print('test0 err', ((attn_mat_fin[:,:,CLSposition,:]-attn_vec_fin).abs()).max())
                print('test1 err', ((out_z[:,:,CLSposition]-cls_val).abs()).max())
                print('test2 err', ((semi_ssm_all[:,:,CLSposition]-cls_val).abs()).max())
                raise ValueError("test done")


        ctx.delta_softplus = delta_softplus
        ctx.out_proj_bias_is_None = out_proj_bias is None
        ctx.checkpoint_lvl = checkpoint_lvl
        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
            conv1d_out, delta = None, None
        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                              delta_proj_weight, out_proj_weight, conv1d_out, delta,
                              A, B, C, D, delta_bias, scan_intermediates, out)
        
        out = F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

        if ablate_attn_mat:
            return out, ablate_dict
        if ablate_attn_vec:
            return out, ablate_dict
        elif compute_attn_matrix:
            return out, attn_mat_fin
        elif compute_attn_vector:
            return out, attn_vec_fin 
        else:
            return out
        
        return 
    @staticmethod
    @custom_bwd
    def backward(ctx, dout):
        # dout: (batch, seqlen, dim)
        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        if ctx.checkpoint_lvl == 1:
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            )
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                              "d (b l) -> b d l", l = L)
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
        dx, dz = dxz.chunk(2, dim=1)
        dout = rearrange(dout, "b l e -> e (b l)")
        dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
            ctx.delta_softplus,
            True  # option to recompute out_z
        )
        dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
        dD = dD if D is not None else None
        dx_dbl = torch.empty_like(x_dbl)
        dB_proj_bias = None
        if ctx.is_variable_B:
            if not A.is_complex():
                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
            dB = None
        dC_proj_bias = None
        if ctx.is_variable_C:
            if not A.is_complex():
                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
            dx_dbl[:, -d_state:] = dC  # (bl d)
            dC = None
        ddelta = rearrange(ddelta, "b d l -> d (b l)")
        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
        # backward of conv1d with the backward of chunk).
        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
        )
        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                dout_proj_weight, dout_proj_bias,
                dA, dB, dC, dD,
                ddelta_bias if delta_bias is not None else None,
                dB_proj_bias, dC_proj_bias, None)


def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True,compute_attn_matrix=False,old_attention=False,compute_attn_vector=False,ablate_attn_mat=False,ablate_attn_vec=False
):  
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, 1,compute_attn_matrix, old_attention, compute_attn_vector, ablate_attn_mat, ablate_attn_vec)


def mamba_inner_ref(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
    L = xz.shape[-1]
    delta_rank = delta_proj_weight.shape[1]
    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
    x, z = xz.chunk(2, dim=1)
    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
    # We're being very careful here about the layout, to avoid extra transposes.
    # We want delta to have d as the slowest moving dimension
    # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
    x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
    delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
    delta = rearrange(delta, "d (b l) -> b d l", l=L)
    if B is None:  # variable B
        B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
        if B_proj_bias is not None:
            B = B + B_proj_bias.to(dtype=B.dtype)
        if not A.is_complex():
            B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
        else:
            B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
    if C is None:  # variable B
        C = x_dbl[:, -d_state:]  # (bl d)
        if C_proj_bias is not None:
            C = C + C_proj_bias.to(dtype=C.dtype)
        if not A.is_complex():
            C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
        else:
            C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
    y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
    return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)

def compute_attn_matrix_fn(delta, delta_bias, A, B, C, L, x_shape, dtype=torch.float32):
    dt = F.softplus(delta + delta_bias.unsqueeze(0).unsqueeze(-1))
    dA = torch.exp(torch.einsum("bdl,dn->bldn", dt, A))
    dB = torch.einsum("bdl,bnl->bldn", dt, B.squeeze(1))
    AttnMatrixOverCLS = torch.zeros((x_shape[0], x_shape[1], x_shape[2], x_shape[2]),requires_grad=True).to(dtype).to(dA.device) #BHLL: L vectors per batch and channel
    #cumulative_products = torch.cumprod(dA[:,1:,:,:], dim=1)
    for r in range(L):
        for c in range(r+1):
            curr_C = C[:,:,:,r]
            currA = torch.ones((dA.shape[0],dA.shape[2],dA.shape[3]),requires_grad=True, dtype = dtype).to(dA.device)
            if c < r:
                for i in range(r-c):
                    currA = currA*dA[:,r-i,:,:]
            currB = dB[:,c,:,:]
            AttnMatrixOverCLS[:,:,r,c] = torch.sum(curr_C*currA*currB, axis=-1)
    return AttnMatrixOverCLS

def compute_attn_vector_fn(delta, delta_bias, A, B, C, L, x_shape, dtype=torch.float32, pos=-1,numlines=4):
    dt = F.softplus(delta + delta_bias.unsqueeze(0).unsqueeze(-1))
    dA = torch.exp(torch.einsum("bdl,dn->bldn", dt, A))
    dB = torch.einsum("bdl,bnl->bldn", dt, B.squeeze(1))
    #AttnMatrixOverCLS = torch.zeros((x_shape[0], x_shape[1], x_shape[2], x_shape[2]),requires_grad=True).to(dtype).to(dA.device) #BHLL: L vectors per batch and channel
    AttnMatrixOverCLS = torch.zeros((x_shape[0], x_shape[1], numlines, x_shape[2]),requires_grad=True).to(dtype).to(dA.device) #BHLL: L vectors per batch and channel
    #cumulative_products = torch.cumprod(dA[:,1:,:,:], dim=1)
    #for r in range(L):
    for rr in range(numlines):
        r = rr + pos
        for c in range(r+1):
            curr_C = C[:,:,:,r]
            currA = torch.ones((dA.shape[0],dA.shape[2],dA.shape[3]),requires_grad=True, dtype = dtype).to(dA.device)
            if c < r:
                for i in range(r-c):
                    currA = currA*dA[:,r-i,:,:]
            currB = dB[:,c,:,:]
            AttnMatrixOverCLS[:,:,rr,c] = torch.sum(curr_C*currA*currB, axis=-1)
    return AttnMatrixOverCLS

def conv2Mat(conv1d_weight, L):
    H, D = conv1d_weight.shape
    conv1d_weight = torch.flip(conv1d_weight,[-1])
    # Initialize the matrix with zeros
    M = torch.zeros(H, L, L).to(conv1d_weight.device)
    
    # Fill the matrix with the kernel weights
    for h in range(H):
        for i in range(L):
            # Set the diagonal and the next (D-1) positions, respecting the input length L
            M[h, i, i:i+D] = conv1d_weight[h, :max(0, D - (i + D - L))]
    return M