# Copyright (c) 2023, Tri Dao, Albert Gu.

import torch
import torch.nn.functional as F
from neuromamba.utils.torch import custom_bwd, custom_fwd

from einops import rearrange, repeat

try:
    from causal_conv1d import causal_conv1d_fn
    from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
except ImportError:
    causal_conv1d_fn = None
    causal_conv1d_fwd_function = None
    causal_conv1d_bwd_function = None
    causal_conv1d_update_function = None

from neuromamba.ops.triton.layer_norm import _layer_norm_fwd

import selective_scan_cuda


class HMFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, u, delta, mf, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                return_last_state=False):
        if u.stride(-1) != 1:
            u = u.contiguous()
        if delta.stride(-1) != 1:
            delta = delta.contiguous()
        if mf.stride(-1) != 1:
            mf = mf.contiguous()
        if D is not None:
            D = D.contiguous()
        if B.stride(-1) != 1:
            B = B.contiguous()
        if C.stride(-1) != 1:
            C = C.contiguous()
        if z is not None and z.stride(-1) != 1:
            z = z.contiguous()
        if B.dim() == 3:
            B = rearrange(B, "b dstate l -> b 1 dstate l")
            ctx.squeeze_B = True
        if C.dim() == 3:
            C = rearrange(C, "b dstate l -> b 1 dstate l")
            ctx.squeeze_C = True
        out, out_cathree, x, *rest = selective_scan_cuda.fwd(u, delta, mf, A, B, C, D, z, delta_bias, delta_softplus)
        ctx.delta_softplus = delta_softplus
        ctx.has_z = z is not None
        last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
        if not ctx.has_z:
            ctx.save_for_backward(u, delta, mf, A, B, C, D, delta_bias, x)
            return (out, out_cathree) if not return_last_state else (out, out_cathree, last_state)
        else:
            ctx.save_for_backward(u, delta, mf, A, B, C, D, z, delta_bias, x, out)
            out_z = rest[0]
            return (out_z, out_cathree) if not return_last_state else (out_z, out_cathree, last_state)

    @staticmethod
    def backward(ctx, dout, dout_cathree, *args):
        if not ctx.has_z:
            u, delta, mf, A, B, C, D, delta_bias, x = ctx.saved_tensors
            z = None
            out = None
        else:
            u, delta, mf, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        if dout_cathree.stride(-1) != 1:
            dout_cathree = dout_cathree.contiguous()
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        # Here we just pass in None and dz will be allocated in the C++ code.
        du, ddelta, dmf, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
            u, delta, mf, A, B, C, D, z, delta_bias, dout, dout_cathree, x, out, None, ctx.delta_softplus,
            False  # option to recompute out_z, not used here
        )
        dz = rest[0] if ctx.has_z else None
        dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
        dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
        return (du, ddelta, dmf, dA, dB, dC,
                dD if D is not None else None,
                dz,
                ddelta_bias if delta_bias is not None else None,
                None,
                None)


def rms_norm_forward(
    x,
    weight,
    bias,
    eps=1e-6,
    is_rms_norm=True,
):
    # x (b l) d
    if x.stride(-1) != 1:
        x = x.contiguous()
    weight = weight.contiguous()
    if bias is not None:
        bias = bias.contiguous()
    y = _layer_norm_fwd(
        x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
    )[0]
    # y (b l) d
    return y


def hm_fn(u, delta, mf, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return HMFn.apply(u, delta, mf, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)



class NeuroMambaInnerFn(torch.autograd.Function):

    @staticmethod
    @custom_fwd
    def forward(ctx, xz_gc, conv1d_weight, conv1d_bias, 
                conv1d_gc_weight, conv1d_gc_bias,
                x_proj_weight, delta_proj_weight,
                mf_proj_weight, mf_proj_bias,
                out_proj_weight, out_proj_bias, 
                out_cathree_proj_weight, out_cathree_proj_bias,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6):
        """
             xz_gc: (batch, dim, seqlen)
        """
        assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        assert checkpoint_lvl in [0, 1]
        L = xz_gc.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        d_inner = A.shape[0]
        d_mf = mf_proj_weight.shape[1]
        if torch.is_autocast_enabled():
            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            mf_proj_weight = mf_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            mf_proj_bias = (mf_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                             if mf_proj_bias is not None else None)
            out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                             if out_proj_bias is not None else None)
            out_cathree_proj_weight = out_cathree_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            out_cathree_proj_bias = (out_cathree_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                             if out_cathree_proj_bias is not None else None)
        if xz_gc.stride(-1) != 1:
            xz_gc = xz_gc.contiguous()
        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
        conv1d_gc_weight = rearrange(conv1d_gc_weight, "d 1 w -> d w")
        x = xz_gc[:, :d_inner,:]
        z = xz_gc[:, d_inner:d_inner + d_inner,:]
        gc = xz_gc[:, 2*d_inner:2*d_inner + d_mf, :]
        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
        conv1d_out = causal_conv1d_fwd_function(
            x, conv1d_weight, conv1d_bias, None, None, None, True
        )
        conv1d_gc_bias = conv1d_gc_bias.contiguous() if conv1d_gc_bias is not None else None
        conv1d_gc_out = causal_conv1d_fwd_function(
            gc, conv1d_gc_weight, conv1d_gc_bias, None, None, None, True
        )
        mf = rearrange(
            mf_proj_weight @ rearrange(conv1d_gc_out, "b d l -> d (b l)"),
            "d (b l) -> b d l",
            l=L,
        )
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
        ctx.is_variable_B = B is None
        ctx.is_variable_C = C is None
        ctx.B_proj_bias_is_None = B_proj_bias is None
        ctx.C_proj_bias_is_None = C_proj_bias is None
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if B.stride(-1) != 1:
                B = B.contiguous()
        if C is None:  # variable C
            C = x_dbl[:, -d_state:]  # (bl dstate)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if C.stride(-1) != 1:
                C = C.contiguous()
        if D is not None:
            D = D.contiguous()
            
        if b_rms_weight is not None:
            B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
            B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
            B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
        if c_rms_weight is not None:
            C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
            C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
            C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
        if dt_rms_weight is not None:
            delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
            delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
            delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
        
        out, out_cathree, scan_intermediates, out_z = selective_scan_cuda.fwd(
            conv1d_out, delta, mf, A, B, C, D, z, delta_bias, delta_softplus
        )
        ctx.delta_softplus = delta_softplus
        ctx.out_proj_bias_is_None = out_proj_bias is None
        ctx.mf_proj_bias_is_None = mf_proj_bias is None
        ctx.out_cathree_proj_bias_is_None = out_cathree_proj_bias is None
        ctx.checkpoint_lvl = checkpoint_lvl
        ctx.b_rms_weight = b_rms_weight
        ctx.c_rms_weight = c_rms_weight
        ctx.dt_rms_weight = dt_rms_weight
        ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
        if checkpoint_lvl >= 1:  # Will recompute conv1d_out, delta, conv1d_gc_out and mf in the backward pass
            conv1d_out, delta, conv1d_gc_out, mf = None, None, None, None
        ctx.save_for_backward(xz_gc, conv1d_weight, conv1d_bias, conv1d_gc_weight, conv1d_gc_bias,
                              x_dbl, x_proj_weight, delta_proj_weight, 
                              out_proj_weight, mf_proj_weight, out_cathree_proj_weight,
                              conv1d_out, delta, conv1d_gc_out, mf,
                              A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out, out_cathree)
        return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + F.linear(rearrange(out_cathree, "b d l -> b l d"), out_cathree_proj_weight, out_cathree_proj_bias)


    @staticmethod
    @custom_bwd
    def backward(ctx, dout_add):
        # dout: (batch, seqlen, dim)
        # dout_cathree: (batch, seqlen, dim)
        assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        (xz_gc, conv1d_weight, conv1d_bias, conv1d_gc_weight, conv1d_gc_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
         mf_proj_weight, out_cathree_proj_weight, conv1d_out, delta, conv1d_gc_out, mf, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out, out_cathree) = ctx.saved_tensors
        L = xz_gc.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        d_inner = A.shape[0]
        d_gc = mf_proj_weight.shape[1]
        x = xz_gc[:, :d_inner,:]
        z = xz_gc[:, d_inner:d_inner + d_inner,:]
        gc = xz_gc[:, 2*d_inner:2*d_inner + d_gc, :]

        if dout_add.stride(-1) != 1:
            dout_add = dout_add.contiguous()
        dout = dout_add
        dout_cathree = dout_add

        if ctx.checkpoint_lvl == 1:
            conv1d_out = causal_conv1d_fwd_function(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            )
            conv1d_gc_out = causal_conv1d_fwd_function(
                gc, conv1d_gc_weight, conv1d_gc_bias, None, None, None, True
            )
            mf = rearrange(
                mf_proj_weight @ rearrange(conv1d_gc_out, "b d l -> d (b l)"),
                "d (b l) -> b d l", l=L,
            )
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                              "d (b l) -> b d l", l = L)
            if dt_rms_weight is not None:
                delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
                delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
                delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
            if b_rms_weight is not None:
                # Recompute & RMSNorm B
                B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
                B = rms_norm_forward(
                    B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps
                )
                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            if c_rms_weight is not None:
                # Recompute & RMSNorm C
                C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
                C = rms_norm_forward(
                    C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps
                )
                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            

        dxz_gc = torch.empty_like(xz_gc)  # (batch, dim, seqlen)
        dx = dxz_gc[:, :d_inner,:]
        dz = dxz_gc[:, d_inner:d_inner + d_inner,:]
        dgc = dxz_gc[:, 2*d_inner:2*d_inner + d_gc, :]

        dout = rearrange(dout, "b l e -> e (b l)")
        dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
        dout_cathree = rearrange(dout_cathree, "b l e -> e (b l)")
        dout_cathree_y = rearrange(out_cathree_proj_weight.t() @ dout_cathree, "d (b l) -> b d l", l=L)

        dconv1d_out, ddelta, dmf, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
            conv1d_out, delta, mf, A, B, C, D, z, delta_bias, dout_y, dout_cathree_y, scan_intermediates, out, dz,
            ctx.delta_softplus,
            True  # option to recompute out_z
        )
        dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
        dout_cathree_proj_weight = torch.einsum("eB,dB->ed", dout_cathree, rearrange(out_cathree, "b d l -> d (b l)"))
        dout_cathree_proj_bias = dout_cathree.sum(dim=(0, 1)) if not ctx.out_cathree_proj_bias_is_None else None
        dD = dD if D is not None else None
        dx_dbl = torch.empty_like(x_dbl)
        dB_proj_bias = None
        if ctx.is_variable_B:
            if not A.is_complex():
                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
            dB = None
        dC_proj_bias = None
        if ctx.is_variable_C:
            if not A.is_complex():
                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
            dx_dbl[:, -d_state:] = dC  # (bl d)
            dC = None

        dmf = rearrange(dmf, "b d l -> d (b l)")
        dmf_proj_weight = torch.einsum("dB,Br->dr", dmf, rearrange(conv1d_gc_out, "b d l -> (b l) d"))
        dmf_proj_bias = dmf.sum(dim=(0, 1)) if not ctx.mf_proj_bias_is_None else None
        dconv1d_gc_out = torch.einsum("dB,de->eB", dmf, mf_proj_weight)
        dconv1d_gc_out = rearrange(dconv1d_gc_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
        dgc, dconv1d_gc_weight, dconv1d_gc_bias, *_ = causal_conv1d_bwd_function(
            gc, conv1d_gc_weight, conv1d_gc_bias, dconv1d_gc_out, None, None, None, dgc, False, True
        )
        dconv1d_gc_bias = dconv1d_gc_bias if conv1d_gc_bias is not None else None
        dconv1d_gc_weight = rearrange(dconv1d_gc_weight, "d w -> d 1 w")

        ddelta = rearrange(ddelta, "b d l -> d (b l)")
        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
        # backward of conv1d with the backward of chunk).
        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function(
            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
        )
        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")

        dxz_gc[:, :d_inner,:] = dx
        dxz_gc[:, d_inner:d_inner + d_inner,:] = dz
        dxz_gc[:, 2*d_inner:2*d_inner + d_gc, :] = dgc

        return (dxz_gc, dconv1d_weight, dconv1d_bias, dconv1d_gc_weight, dconv1d_gc_bias, dx_proj_weight, ddelta_proj_weight,
                dmf_proj_weight, dmf_proj_bias,
                dout_proj_weight, dout_proj_bias,
                dout_cathree_proj_weight, dout_cathree_proj_bias,
                dA, dB, dC, dD,
                ddelta_bias if delta_bias is not None else None,
                # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
                dB_proj_bias, dC_proj_bias, None, None, None, None, None, None)


def neuromamba_inner_fn(
    xz_gc, conv1d_weight, conv1d_bias, conv1d_gc_weight, conv1d_gc_bias, x_proj_weight, delta_proj_weight,
    mf_proj_weight, mf_proj_bias,
    out_proj_weight, out_proj_bias,
    out_cathree_proj_weight, out_cathree_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6
):
    return NeuroMambaInnerFn.apply(xz_gc, conv1d_weight, conv1d_bias, conv1d_gc_weight, conv1d_gc_bias, x_proj_weight, delta_proj_weight,
                              mf_proj_weight, mf_proj_bias,
                              out_proj_weight, out_proj_bias,
                              out_cathree_proj_weight, out_cathree_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps)

