import torch
import triton
import triton.language as tl

# -------------------------------------------------------------------------
# 1. FUSED KERNELS
# -------------------------------------------------------------------------
@triton.jit
def combine_fn(left_dA, left_dBx, right_dA, right_dBx):
    # Standard linear recurrence combination
    new_dA = right_dA * left_dA
    new_dBx = right_dA * left_dBx + right_dBx
    return new_dA, new_dBx

@triton.jit
def fused_forward_kernel_mamba(
    # Inputs
    x_ptr, dt_ptr, k_ptr, q_ptr, A_ptr, gate_ptr, D_ptr,
    # Output
    out_ptr,
    # Dimensions
    BATCH, SEQ_LEN, V_DIM, QK_DIM,
    # Strides
    stride_x_b, stride_x_v, stride_x_l,
    stride_dt_b, stride_dt_v, stride_dt_l,
    stride_k_b, stride_k_qk, stride_k_l,
    stride_q_b, stride_q_qk, stride_q_l,
    stride_A_v, stride_A_qk,
    stride_gate_b, stride_gate_v, stride_gate_l,
    stride_D_v,
    stride_out_b, stride_out_v, stride_out_l,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    i_v = pid % V_DIM
    i_b = pid // V_DIM

    # Offsets
    off_x = i_b * stride_x_b + i_v * stride_x_v
    off_dt = i_b * stride_dt_b + i_v * stride_dt_v
    off_gate = i_b * stride_gate_b + i_v * stride_gate_v
    off_out = i_b * stride_out_b + i_v * stride_out_v
    
    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < SEQ_LEN

    x_val = tl.load(x_ptr + off_x + offs, mask=mask, other=0.0)
    dt_val = tl.load(dt_ptr + off_dt + offs, mask=mask, other=0.0)

    # Accumulator for the scan result (h @ q)
    acc_y = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

    # Iterate over QK_DIM
    for i_qk in range(QK_DIM):
        off_k = i_b * stride_k_b + i_qk * stride_k_qk
        off_q = i_b * stride_q_b + i_qk * stride_q_qk
        off_A = i_v * stride_A_v + i_qk * stride_A_qk

        k_val = tl.load(k_ptr + off_k + offs, mask=mask, other=0.0)
        q_val = tl.load(q_ptr + off_q + offs, mask=mask, other=0.0)
        A_val = tl.load(A_ptr + off_A)

        dA_val = tl.exp(dt_val * A_val)
        dBx_val = dt_val * k_val * x_val

        _, h_val = tl.associative_scan((dA_val, dBx_val), 0, combine_fn)
        acc_y += h_val * q_val

    # --- Fused Skip Connection + Gating ---
    # 1. Skip: y = scan_out + x * D
    D_val = tl.load(D_ptr + i_v * stride_D_v)
    y_skip = acc_y + x_val * D_val

    # 2. Gate: out = y_skip * SiLU(gate)
    gate_val = tl.load(gate_ptr + off_gate + offs, mask=mask, other=0.0)
    gate_sig = tl.sigmoid(gate_val)
    gate_silu = gate_val * gate_sig
    
    out_val = y_skip * gate_silu

    tl.store(out_ptr + off_out + offs, out_val, mask=mask)

@triton.jit
def fused_backward_kernel_mamba(
    # Gradients
    dout_ptr, dx_ptr, ddt_ptr, dk_ptr, dq_ptr, dA_ptr, dgate_ptr, dD_ptr,
    # Inputs
    x_ptr, dt_ptr, k_ptr, q_ptr, A_ptr, gate_ptr, D_ptr,
    # Scratchpad
    scratch_h_ptr,
    # Dimensions
    BATCH, SEQ_LEN, V_DIM, QK_DIM,
    # Strides (Inputs)
    stride_x_b, stride_x_v, stride_x_l,
    stride_dt_b, stride_dt_v, stride_dt_l,
    stride_k_b, stride_k_qk, stride_k_l,
    stride_q_b, stride_q_qk, stride_q_l,
    stride_A_v, stride_A_qk,
    stride_gate_b, stride_gate_v, stride_gate_l,
    stride_D_v,
    # Strides (Grads)
    stride_dx_b, stride_dx_v, stride_dx_l,
    stride_ddt_b, stride_ddt_v, stride_ddt_l,
    stride_dk_b, stride_dk_qk, stride_dk_l,
    stride_dq_b, stride_dq_qk, stride_dq_l,
    stride_dA_v, stride_dA_qk,
    stride_dgate_b, stride_dgate_v, stride_dgate_l,
    # Scratch Stride
    stride_h_b, stride_h_v, stride_h_qk, stride_h_l,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    i_v = pid % V_DIM
    i_b = pid // V_DIM

    # Base Offsets
    off_x = i_b * stride_x_b + i_v * stride_x_v
    off_dt = i_b * stride_dt_b + i_v * stride_dt_v
    off_gate = i_b * stride_gate_b + i_v * stride_gate_v
    off_dout = i_b * stride_dx_b + i_v * stride_dx_v 

    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < SEQ_LEN

    # Load shared inputs
    x_val = tl.load(x_ptr + off_x + offs, mask=mask, other=0.0)
    dt_val = tl.load(dt_ptr + off_dt + offs, mask=mask, other=0.0)
    gate_val = tl.load(gate_ptr + off_gate + offs, mask=mask, other=0.0)
    dout_val = tl.load(dout_ptr + off_dout + offs, mask=mask, other=0.0)
    D_val = tl.load(D_ptr + i_v * stride_D_v)

    # -----------------------------------------------------------
    # 1. Forward Pass Recomputation
    # -----------------------------------------------------------
    acc_y = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

    for i_qk in range(QK_DIM):
        off_k = i_b * stride_k_b + i_qk * stride_k_qk
        off_q = i_b * stride_q_b + i_qk * stride_q_qk
        off_A = i_v * stride_A_v + i_qk * stride_A_qk

        k_val = tl.load(k_ptr + off_k + offs, mask=mask, other=0.0)
        q_val = tl.load(q_ptr + off_q + offs, mask=mask, other=0.0)
        A_val = tl.load(A_ptr + off_A)

        dA_val = tl.exp(dt_val * A_val)
        dBx_val = dt_val * k_val * x_val
        
        _, h_val = tl.associative_scan((dA_val, dBx_val), 0, combine_fn)
        acc_y += h_val * q_val

        # Store h to scratch (B, V, QK, L)
        off_h = i_b * stride_h_b + i_v * stride_h_v + i_qk * stride_h_qk + offs * stride_h_l
        tl.store(scratch_h_ptr + off_h, h_val, mask=mask)

    # -----------------------------------------------------------
    # 2. Compute Gradients through Gate (Forward Order)
    # -----------------------------------------------------------
    gate_sig = tl.sigmoid(gate_val)
    # dSiLU
    d_silu_gate = gate_sig * (1.0 + gate_val * (1.0 - gate_sig))
    
    y_skip = acc_y + x_val * D_val
    d_gate = dout_val * y_skip * d_silu_gate
    
    off_dgate = i_b * stride_dgate_b + i_v * stride_dgate_v
    tl.store(dgate_ptr + off_dgate + offs * stride_dgate_l, d_gate, mask=mask)

    # Calculate d_scan_out (dL/d_y_scan)
    gate_silu = gate_val * gate_sig
    d_scan_out = dout_val * gate_silu

    # dD: Sum over L (correctly parallelized over V)
    d_D_local = d_scan_out * x_val
    d_D_sum = tl.sum(d_D_local, axis=0)
    tl.atomic_add(dD_ptr + i_v, d_D_sum)

    # Note: We do NOT calculate d_x_skip here for the accumulator because
    # this is forward order, and the accumulator needs reverse order.

    # -----------------------------------------------------------
    # 3. Backward Pass (Reverse Scan for SSM)
    # -----------------------------------------------------------
    
    # Prepare for reverse scan
    rev_offs = (SEQ_LEN - 1) - offs
    rev_mask = (rev_offs >= 0) & (rev_offs < SEQ_LEN)
    
    # Reload/Recompute inputs in REVERSE
    x_rev = tl.load(x_ptr + off_x + rev_offs, mask=rev_mask, other=0.0)
    dt_rev = tl.load(dt_ptr + off_dt + rev_offs, mask=rev_mask, other=0.0)
    
    # Recompute d_scan_out in REVERSE
    gate_rev = tl.load(gate_ptr + off_gate + rev_offs, mask=rev_mask, other=0.0)
    d_scan_out_rev = tl.load(dout_ptr + off_dout + rev_offs, mask=rev_mask, other=0.0) * \
                     (gate_rev * tl.sigmoid(gate_rev))

    # [FIX 1] Initialize acc_dx with d_x_skip in REVERSE order
    d_x_skip_rev = d_scan_out_rev * D_val
    acc_dx = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + d_x_skip_rev
    acc_ddt = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

    # Pre-fetch dt[t+1]
    rev_offs_plus1 = rev_offs + 1
    mask_plus1 = rev_offs_plus1 < SEQ_LEN
    dt_rev_plus = tl.load(dt_ptr + off_dt + rev_offs_plus1, mask=mask_plus1, other=0.0)

    # Offset for previous h
    rev_offs_minus1 = rev_offs - 1
    mask_minus1 = (rev_offs_minus1 >= 0) & (rev_offs_minus1 < SEQ_LEN)

    tl.debug_barrier()

    for i_qk in range(QK_DIM):
        off_k = i_b * stride_k_b + i_qk * stride_k_qk
        off_q = i_b * stride_q_b + i_qk * stride_q_qk
        off_A = i_v * stride_A_v + i_qk * stride_A_qk

        q_rev = tl.load(q_ptr + off_q + rev_offs, mask=rev_mask, other=0.0)
        k_rev = tl.load(k_ptr + off_k + rev_offs, mask=rev_mask, other=0.0)
        A_val = tl.load(A_ptr + off_A)

        # 1. d_h_upstream
        d_h_rev = d_scan_out_rev * q_rev
        
        # 2. Reverse Scan
        dA_rev_shifted = tl.exp(dt_rev_plus * A_val)
        _, ddB_rev = tl.associative_scan((dA_rev_shifted, d_h_rev), 0, combine_fn)

        # 3. Read h[t-1]
        off_h_prev = i_b * stride_h_b + i_v * stride_h_v + i_qk * stride_h_qk + rev_offs_minus1 * stride_h_l
        h_prev_rev = tl.load(scratch_h_ptr + off_h_prev, mask=mask_minus1, other=0.0)

        # 4. Compute Gradients
        acc_dx += ddB_rev * dt_rev * k_rev

        # d_k [FIX 2: Use atomic_add]
        d_k_local = ddB_rev * dt_rev * x_rev
        off_dk = i_b * stride_dk_b + i_qk * stride_dk_qk + rev_offs * stride_dk_l
        tl.atomic_add(dk_ptr + off_dk, d_k_local, mask=rev_mask)

        # d_q [FIX 2: Use atomic_add]
        off_h_curr = i_b * stride_h_b + i_v * stride_h_v + i_qk * stride_h_qk + rev_offs * stride_h_l
        h_curr_rev = tl.load(scratch_h_ptr + off_h_curr, mask=rev_mask, other=0.0)
        d_q_local = d_scan_out_rev * h_curr_rev
        off_dq = i_b * stride_dq_b + i_qk * stride_dq_qk + rev_offs * stride_dq_l
        tl.atomic_add(dq_ptr + off_dq, d_q_local, mask=rev_mask)

        # d_dt
        d_dt_1 = ddB_rev * k_rev * x_rev 
        dA_rev = tl.exp(dt_rev * A_val)
        grad_dA = ddB_rev * h_prev_rev 
        d_dt_2 = grad_dA * A_val * dA_rev
        acc_ddt += (d_dt_1 + d_dt_2)

        # d_A
        d_A_el = grad_dA * dt_rev * dA_rev
        d_A_sum = tl.sum(d_A_el, axis=0)
        tl.atomic_add(dA_ptr + off_A, d_A_sum)

    # Write d_x, d_dt (unique to V, no atomic needed)
    off_dx = i_b * stride_dx_b + i_v * stride_dx_v + rev_offs * stride_dx_l
    tl.store(dx_ptr + off_dx, acc_dx, mask=rev_mask)

    off_ddt = i_b * stride_ddt_b + i_v * stride_ddt_v + rev_offs * stride_ddt_l
    tl.store(ddt_ptr + off_ddt, acc_ddt, mask=rev_mask)


# -------------------------------------------------------------------------
# 2. PYTHON AUTOGRAD
# -------------------------------------------------------------------------

class FusedMambaScanSkipGateFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, dt, A, k, q, D, gate):
        # Input shapes:
        # x, gate, dt, k, q: (B, L, D) or (B, L, QK) -- assumes channels last from Linear
        # But Triton expects (B, V, L) for contiguous access usually.
        # The Mamba Module below provides (B, L, V). 
        # We permute to (B, V, L) for the kernel.

        B, L, V = x.shape
        _, _, QK = k.shape

        x_in = x.permute(0, 2, 1).contiguous()
        gate_in = gate.permute(0, 2, 1).contiguous()
        dt_in = dt.permute(0, 2, 1).contiguous()
        k_in = k.permute(0, 2, 1).contiguous()
        q_in = q.permute(0, 2, 1).contiguous()
        D_in = D.contiguous() # (V)

        out = torch.empty(B, V, L, device=x.device, dtype=x.dtype)

        BLOCK_SIZE = triton.next_power_of_2(L)
        # Grid: One program per Batch per V channel
        grid = (B * V,)

        fused_forward_kernel_mamba[grid](
            x_in, dt_in, k_in, q_in, A, gate_in, D_in,
            out,
            B, L, V, QK,
            *x_in.stride(),
            *dt_in.stride(),
            *k_in.stride(),
            *q_in.stride(),
            *A.stride(),
            *gate_in.stride(),
            *D_in.stride(),
            *out.stride(),
            BLOCK_SIZE=BLOCK_SIZE,
        )

        ctx.save_for_backward(x_in, dt_in, k_in, q_in, A, gate_in, D_in)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        return out.permute(0, 2, 1)

    @staticmethod
    def backward(ctx, dout):
        x_in, dt_in, k_in, q_in, A, gate_in, D_in = ctx.saved_tensors
        BLOCK_SIZE = ctx.BLOCK_SIZE
        B, V, L = x_in.shape
        _, QK, _ = k_in.shape

        dout_in = dout.permute(0, 2, 1).contiguous()

        # Grads
        d_x = torch.zeros_like(x_in)
        d_dt = torch.zeros_like(dt_in)
        d_k = torch.zeros_like(k_in)
        d_q = torch.zeros_like(q_in)
        d_gate = torch.zeros_like(gate_in)
        d_A = torch.zeros_like(A)
        d_D = torch.zeros_like(D_in)

        # Scratchpad for H (B, V, QK, L)
        scratch_h = torch.empty(B, V, QK, L, device=x_in.device, dtype=torch.float32)

        grid = (B * V,)

        fused_backward_kernel_mamba[grid](
            dout_in, d_x, d_dt, d_k, d_q, d_A, d_gate, d_D,
            x_in, dt_in, k_in, q_in, A, gate_in, D_in,
            scratch_h,
            B, L, V, QK,
            *x_in.stride(),
            *dt_in.stride(),
            *k_in.stride(),
            *q_in.stride(),
            *A.stride(),
            *gate_in.stride(),
            *D_in.stride(),
            *d_x.stride(),
            *d_dt.stride(),
            *d_k.stride(),
            *d_q.stride(),
            *d_A.stride(),
            *d_gate.stride(),
            *scratch_h.stride(),
            BLOCK_SIZE=BLOCK_SIZE,
        )

        return (
            d_x.permute(0, 2, 1),
            d_dt.permute(0, 2, 1),
            d_A,
            d_k.permute(0, 2, 1),
            d_q.permute(0, 2, 1),
            d_D,
            d_gate.permute(0, 2, 1)
        )


def fused_mamba_scan(x, dt, A, k, q, D, gate):
    return FusedMambaScanSkipGateFn.apply(x, dt, A, k, q, D, gate)


# -------------------------------------------------------------------------
# 3. TESTING THE FUSED KERNEL

if __name__ == "__main__":
    import torch

    # -------------------------------------------------------------------------
    # REFERENCE IMPLEMENTATION (Pure PyTorch)
    # -------------------------------------------------------------------------
    def reference_mamba_scan(x, dt, A, k, q, D, z):
        """
        PyTorch reference for Mamba scan fused with Q contraction.

        Recurrence:
            h_t = dA_t * h_{t-1} + dBx_t
        Output:
            y_t = sum(h_t * q_t, dim=-1)

        where:
            dA = exp(dt * A)
            dBx = dt * k * x
        """
        # 1. Expand dimensions for broadcasting
        # x: (B, L, V) -> (B, L, V, 1)
        x_ex = x.unsqueeze(-1)
        # dt: (B, L, V) -> (B, L, V, 1)
        dt_ex = dt.unsqueeze(-1)
        # A: (V, QK) -> (1, 1, V, QK)
        A_ex = A.unsqueeze(0).unsqueeze(0)
        # k: (B, L, QK) -> (B, L, 1, QK)
        k_ex = k.unsqueeze(2)
        # q: (B, L, QK) -> (B, L, 1, QK)
        q_ex = q.unsqueeze(2)

        # 2. Compute recurrence terms
        # dA: (B, L, V, QK)
        dA = torch.exp(dt_ex * A_ex)
        # dBx: (B, L, V, QK)
        dBx = dt_ex * k_ex * x_ex

        # 3. Sequential Scan
        B, L, V, QK = dA.shape

        # Output container y: (B, L, V)
        y = torch.zeros(B, L, V, device=x.device, dtype=x.dtype)

        # Current hidden state: (B, V, QK)
        current_h = torch.zeros(B, V, QK, device=x.device, dtype=x.dtype)

        for t in range(L):
            # Recurrence: h_t = dA_t * h_{t-1} + dBx_t
            current_h = dA[:, t] * current_h + dBx[:, t]

            # Contraction: y_t = sum(h_t * q_t, dim=-1)
            # q_ex[:, t] is (B, 1, QK) broadcasting to (B, V, QK)
            # We sum over the QK dimension (last dim)
            y_t = (current_h * q_ex[:, t]).sum(dim=-1)

            y[:, t] = y_t

        # Add skip connection
        y = y + x * D
        y = y * torch.nn.functional.silu(z)

        return y

    # -------------------------------------------------------------------------
    # TEST SETUP
    # -------------------------------------------------------------------------
    torch.manual_seed(42)

    if not torch.cuda.is_available():
        print("CUDA is not available. Triton requires a GPU.")
    else:
        # Dimensions
        BATCH = 8
        SEQ_LEN = 128
        V_DIM = 64
        QK_DIM = 16

        print(f"Testing with B={BATCH}, L={SEQ_LEN}, V={V_DIM}, QK={QK_DIM}")

        # Initialize Inputs
        device = "cuda"
        dtype = torch.float32

        # Create tensors with requires_grad=True
        x = torch.randn(
            BATCH, SEQ_LEN, V_DIM, device=device, dtype=dtype, requires_grad=True
        )
        dt_raw = torch.randn(
            BATCH, SEQ_LEN, V_DIM, device=device, dtype=dtype, requires_grad=True
        )
        dt = torch.nn.functional.softplus(dt_raw)
        dt.retain_grad()
        A_raw = torch.rand(
            V_DIM, QK_DIM, device=device, dtype=dtype, requires_grad=True
        )
        A = -A_raw
        A.retain_grad()
        k = torch.randn(
            BATCH, SEQ_LEN, QK_DIM, device=device, dtype=dtype, requires_grad=True
        )
        q = torch.randn(
            BATCH, SEQ_LEN, QK_DIM, device=device, dtype=dtype, requires_grad=True
        )
        D_raw = torch.rand(V_DIM, device=device, dtype=dtype, requires_grad=True)
        D = 1 + 0.2 * D_raw
        D.retain_grad()
        z = torch.randn(
            BATCH, SEQ_LEN, V_DIM, device=device, dtype=dtype, requires_grad=True
        )

        # Clone for reference (to ensure distinct grad accumulation)
        x_ref = x.clone().detach().requires_grad_(True)
        dt_ref = dt.clone().detach().requires_grad_(True)
        A_ref = A.clone().detach().requires_grad_(True)
        k_ref = k.clone().detach().requires_grad_(True)
        q_ref = q.clone().detach().requires_grad_(True)
        D_ref = D.clone().detach().requires_grad_(True)
        z_ref = z.clone().detach().requires_grad_(True)

        # -------------------------------------------------------------------------
        # FORWARD PASS CHECK
        # -------------------------------------------------------------------------
        print("-" * 40)
        print("Checking Forward Pass...")

        # Run Reference
        out_ref = reference_mamba_scan(x_ref, dt_ref, A_ref, k_ref, q_ref, D_ref, z_ref)

        # Run Triton
        out_triton = fused_mamba_scan(x, dt, A, k, q, D, z)

        # Compare
        # Note: Parallel scan (associative) vs Sequential scan has slightly different
        # floating point accumulation, so we use a reasonable tolerance.
        if torch.allclose(out_ref, out_triton, atol=1e-4, rtol=1e-4):
            print("✅ Forward Pass Matches!")
        else:
            print("❌ Forward Pass Mismatch!")
            diff = (out_ref - out_triton).abs().max()
            print(f"   Max Diff: {diff.item()}")

        # -------------------------------------------------------------------------
        # BACKWARD PASS CHECK
        # -------------------------------------------------------------------------
        print("-" * 40)
        print("Checking Backward Pass...")

        # Create a dummy loss gradient
        dout = torch.randn_like(out_ref)

        # Backward Reference
        out_ref.backward(dout)

        # Backward Triton
        out_triton.backward(dout)

        # Helper to compare gradients
        def check_grad(name, ref_grad, tri_grad):
            if torch.allclose(ref_grad, tri_grad, atol=1e-3, rtol=1e-3):
                print(f"✅ Gradient {name} Matches!")
            else:
                print(f"❌ Gradient {name} Mismatch!")
                diff = (ref_grad - tri_grad).abs().max()
                print(f"   Max Diff: {diff.item()}")

        check_grad("x", x_ref.grad, x.grad)
        check_grad("dt", dt_ref.grad, dt.grad)
        check_grad("A", A_ref.grad, A.grad)
        check_grad("k", k_ref.grad, k.grad)
        check_grad("q", q_ref.grad, q.grad)
        check_grad("D", D_ref.grad, D.grad)
        check_grad("z", z_ref.grad, z.grad)

        print("-" * 40)
        print("Test Complete.")
