import torch
import triton
import triton.language as tl

# -------------------------------------------------------------------------
# 1. TRITON KERNELS (Optimized for Contiguous L)
# -------------------------------------------------------------------------


@triton.jit
def combine_fn(left_dA, left_dBx, right_dA, right_dBx):
    # This remains the same
    new_dA = right_dA * left_dA
    new_dBx = right_dA * left_dBx + right_dBx
    return new_dA, new_dBx


@triton.jit
def mamba_scan_forward_kernel_contiguous(
    dA_ptr,
    dBx_ptr,
    out_ptr,
    BATCH_SIZE,  # Total number of independent scans (B * S * V)
    SEQ_LEN,  # The scan length
    BLOCK_SIZE: tl.constexpr,
):
    # We treat the input as a 2D matrix: [BATCH_SIZE, SEQ_LEN]
    # where BATCH_SIZE represents all independent sequences flattened (B*S*V)
    pid = tl.program_id(0)

    # Calculate pointers to the specific row for this PID
    # Since we transposed L to be the last dim, stride is just SEQ_LEN for rows
    # and 1 for columns.
    row_offset = pid * SEQ_LEN

    # Pointers
    p_dA = dA_ptr + row_offset
    p_dBx = dBx_ptr + row_offset
    p_out = out_ptr + row_offset

    # Offsets for the scan dimension
    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < SEQ_LEN

    # Load contiguous data (Stride = 1 is implicit here)
    dA_val = tl.load(p_dA + offs, mask=mask, other=1.0)
    dBx_val = tl.load(p_dBx + offs, mask=mask, other=0.0)

    # Scan
    final_a, final_b = tl.associative_scan((dA_val, dBx_val), 0, combine_fn)

    # Store contiguous data
    tl.store(p_out + offs, final_b, mask=mask)


@triton.jit
def mamba_scan_backward_kernel_contiguous(
    dout_ptr,
    dA_ptr,
    h_ptr,  # Inputs
    ddA_ptr,
    ddB_ptr,  # Outputs
    BATCH_SIZE,
    SEQ_LEN,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    row_offset = pid * SEQ_LEN

    # Pointers
    p_dout = dout_ptr + row_offset
    p_dA = dA_ptr + row_offset
    p_h = h_ptr + row_offset
    p_ddA = ddA_ptr + row_offset
    p_ddB = ddB_ptr + row_offset

    offs = tl.arange(0, BLOCK_SIZE)

    # --- Reverse Scan Logic ---
    # We load logically reversed: Index 0 becomes L-1, Index 1 becomes L-2
    rev_offs = (SEQ_LEN - 1) - offs
    mask = (rev_offs >= 0) & (rev_offs < SEQ_LEN)

    # 1. Load dout (the accumulation target)
    dout_val = tl.load(p_dout + rev_offs, mask=mask, other=0.0)

    # 2. Load dA. For reverse recurrence: dB_t = dB_{t+1} * A_{t+1} + dout_t
    # We need A corresponding to the *previous* logical step in the reverse sequence.
    # In memory terms: if we are at `rev_offs`, we need `rev_offs + 1`.
    rev_offs_plus_1 = rev_offs + 1
    mask_shifted = (rev_offs_plus_1 < SEQ_LEN) & (rev_offs_plus_1 >= 0)

    dA_val_shifted = tl.load(p_dA + rev_offs_plus_1, mask=mask_shifted, other=0.0)

    # 3. Associative Scan (Reverse)
    _, ddB_rev_val = tl.associative_scan((dA_val_shifted, dout_val), 0, combine_fn)

    # 4. Compute ddA: ddA_t = ddB_t * h_{t-1}
    # We need h at `rev_offs - 1`
    rev_offs_minus_1 = rev_offs - 1
    mask_h = (rev_offs_minus_1 >= 0) & (rev_offs_minus_1 < SEQ_LEN)
    h_val_shifted = tl.load(p_h + rev_offs_minus_1, mask=mask_h, other=0.0)

    ddA_rev_val = ddB_rev_val * h_val_shifted

    # 5. Store (using rev_offs to put them back in natural order)
    tl.store(p_ddB + rev_offs, ddB_rev_val, mask=mask)
    tl.store(p_ddA + rev_offs, ddA_rev_val, mask=mask)


# -------------------------------------------------------------------------
# 2. AUTOGRAD FUNCTION (Handles Transpose)
# -------------------------------------------------------------------------


class MambaScanFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, dA, dBx):
        # Input shape: (B, L, S, V)
        # We need L to be contiguous.
        # Permute to (B, S, V, L) -> Flatten to (B*S*V, L)

        ctx.shape_orig = dA.shape
        B, L, S, V = dA.shape

        # 1. Transpose so L is last (contiguous dimension)
        # shape: (B, S, V, L)
        dA_in = dA.permute(0, 2, 3, 1).contiguous()
        dBx_in = dBx.permute(0, 2, 3, 1).contiguous()

        out_temp = torch.empty_like(dA_in)

        # 2. Launch Kernel
        # We flatten B, S, V into one dimension for the grid
        n_scans = B * S * V
        BLOCK_SIZE = triton.next_power_of_2(L)

        # Tuning num_warps is helpful for occupancy
        num_warps = 4 if BLOCK_SIZE <= 1024 else 8

        mamba_scan_forward_kernel_contiguous[(n_scans,)](
            dA_in,
            dBx_in,
            out_temp,
            n_scans,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        # 3. Restore layout for output: (B, S, V, L) -> (B, L, S, V)
        out = out_temp.permute(0, 3, 1, 2)

        ctx.save_for_backward(dA_in, out_temp)
        return out

    @staticmethod
    def backward(ctx, dout):
        # dA_in and out_temp are already (B, S, V, L) and contiguous
        dA_in, out_temp = ctx.saved_tensors

        # dout comes in as (B, L, S, V), we need to permute it to match storage
        dout_in = dout.permute(0, 2, 3, 1).contiguous()

        B, S, V, L = dA_in.shape
        n_scans = B * S * V
        BLOCK_SIZE = triton.next_power_of_2(L)
        num_warps = 4 if BLOCK_SIZE <= 1024 else 8

        ddA_temp = torch.empty_like(dA_in)
        ddB_temp = torch.empty_like(dA_in)

        mamba_scan_backward_kernel_contiguous[(n_scans,)](
            dout_in,
            dA_in,
            out_temp,
            ddA_temp,
            ddB_temp,
            n_scans,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        # Permute gradients back to (B, L, S, V)
        ddA = ddA_temp.permute(0, 3, 1, 2)
        ddB = ddB_temp.permute(0, 3, 1, 2)

        return ddA, ddB


def mamba_scan(dA, dBx):
    return MambaScanFn.apply(dA, dBx)


# --- Verification ---

if __name__ == "__main__":
    # Setup
    torch.manual_seed(0)
    batch_size, seq_len, qk_dim, v_dim = 2, 64, 4, 8
    device = "cuda"

    # Inputs with Gradient
    dA = torch.randn(
        batch_size, seq_len, qk_dim, v_dim, device=device, requires_grad=True
    )
    dBx = torch.randn(
        batch_size, seq_len, qk_dim, v_dim, device=device, requires_grad=True
    )

    # 1. Triton Forward + Backward
    out_triton = mamba_scan(dA, dBx)
    loss_triton = out_triton.sum()
    loss_triton.backward()
    grad_dA_triton = dA.grad.clone()
    grad_dBx_triton = dBx.grad.clone()

    # Reset Grads
    dA.grad = None
    dBx.grad = None

    # 2. PyTorch Reference Forward + Backward
    # Sequential scan for verification
    out_ref_list = []
    h = torch.zeros(batch_size, qk_dim, v_dim, device=device)

    for t in range(seq_len):
        h = dA[:, t] * h + dBx[:, t]
        out_ref_list.append(h)

    out_ref = torch.stack(out_ref_list, dim=1)

    loss_ref = out_ref.sum()
    loss_ref.backward()
    grad_dA_ref = dA.grad.clone()
    grad_dBx_ref = dBx.grad.clone()

    # 3. Compare
    print(f"Forward diff: {torch.max(torch.abs(out_triton - out_ref)).item()}")
    print(f"Grad dA diff: {torch.max(torch.abs(grad_dA_triton - grad_dA_ref)).item()}")
    print(
        f"Grad dB diff: {torch.max(torch.abs(grad_dBx_triton - grad_dBx_ref)).item()}"
    )

    assert torch.allclose(out_triton, out_ref, atol=1e-4)
    assert torch.allclose(grad_dA_triton, grad_dA_ref, atol=1e-4)
    assert torch.allclose(grad_dBx_triton, grad_dBx_ref, atol=1e-4)

    print("Test Passed: Triton backward matches PyTorch autograd.")
