import torch
import triton
import triton.language as tl


@triton.jit
def combine_fn(
    left_alpha,
    left_beta,
    left_gamma,
    left_delta,
    right_alpha,
    right_beta,
    right_gamma,
    right_delta,
):
    """
    Lambda_t = (alpha_t * Lambda_{t-1} + beta_t) / (gamma_t * Lambda_{t-1} + delta_t)
    """
    new_alpha = right_alpha * left_alpha + right_beta * left_gamma
    new_beta = right_alpha * left_beta + right_beta * left_delta
    new_gamma = right_gamma * left_alpha + right_delta * left_gamma
    new_delta = right_gamma * left_beta + right_delta * left_delta

    return (
        new_alpha,
        new_beta,
        new_gamma,
        new_delta,
    )


@triton.jit
def combine_fn(
    left_alpha,
    left_beta,
    left_gamma,
    left_delta,
    left_G,
    left_kpv,
    right_alpha,
    right_beta,
    right_gamma,
    right_delta,
    right_G,
    right_kpv,
):
    """
    Lambda_t = (alpha_t * Lambda_{t-1} + beta_t) / (gamma_t * Lambda_{t-1} + delta_t)
    H_t = G_t * H_{t-1} + kpv_t
    """
    # Lambda
    new_alpha = right_alpha * left_alpha + right_beta * left_gamma
    new_beta = right_alpha * left_beta + right_beta * left_delta
    new_gamma = right_gamma * left_alpha + right_delta * left_gamma
    new_delta = right_gamma * left_beta + right_delta * left_delta

    # H
    new_G = right_G * left_G
    new_kpv = right_G * left_kpv + right_kpv
    return (
        new_alpha,
        new_beta,
        new_gamma,
        new_delta,
        new_G,
        new_kpv,
    )


@triton.jit
def kla_scan_forward_kernel(
    alpha_ptr,
    beta_ptr,
    gamma_ptr,
    delta_ptr,
    G_ptr,
    kpv_ptr,
    out_lambda_ptr,
    out_H_ptr,
    BATCH_SIZE,  # Total independent sequences (B * V * Q)
    SEQ_LEN,  # The scan length (L)
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    # 1. Setup Offsets
    # The inputs are flattened to (N_scans, L).
    # We process one sequence per program instance.
    block_start = pid * SEQ_LEN
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Mask limits loads to the specific sequence length (if L < BLOCK_SIZE)
    mask = tl.arange(0, BLOCK_SIZE) < SEQ_LEN

    # 2. Load Inputs
    # Note: 'other' values are identity elements or zeros depending on usage.
    # For matrix multiplication scan, identity is usually Identity Matrix,
    # but since we load data to *be scanned*, we load the raw values.
    # Padding with 0/1 helps the scan ops remain neutral outside valid range.

    alpha = tl.load(alpha_ptr + offsets, mask=mask, other=1.0)  # Diagonal 1
    beta = tl.load(beta_ptr + offsets, mask=mask, other=0.0)
    gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
    delta = tl.load(delta_ptr + offsets, mask=mask, other=1.0)  # Diagonal 1

    G = tl.load(G_ptr + offsets, mask=mask, other=1.0)  # Mult identity
    kpv = tl.load(kpv_ptr + offsets, mask=mask, other=0.0)  # Add identity

    # 3. Associative Scan
    res_alpha, res_beta, res_gamma, res_delta, res_G, res_kpv = tl.associative_scan(
        (alpha, beta, gamma, delta, G, kpv), axis=0, combine_fn=combine_fn
    )

    # 4. Compute Final Outputs
    # Initial state assumption: Lambda_{-1} = 0, H_{-1} = 0.
    # Representation: Vector [-1] = [0, 1]^T.
    # M_cum @ [0, 1]^T = [beta_cum, delta_cum]^T.
    # Lambda = beta_cum / delta_cum.

    Lambda_out = res_beta / res_delta
    H_out = res_kpv

    # 5. Store Results
    tl.store(out_lambda_ptr + offsets, Lambda_out, mask=mask)
    tl.store(out_H_ptr + offsets, H_out, mask=mask)


@triton.jit
def kla_scan_backward_kernel(
    dout_ptr,
):
    raise NotImplementedError


class KLAScanFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, alpha, beta, gamma, delta, G, kpv):
        # Input shape: (B, L, V, Q)
        B, L, V, Q = alpha.shape

        # 1. Prepare inputs for Triton
        # Transpose to (B, V, Q, L) so L is contiguous in memory.
        # This allows simple pointer arithmetic in the kernel.
        alpha_in = alpha.permute(0, 2, 3, 1).contiguous()
        beta_in = beta.permute(0, 2, 3, 1).contiguous()
        gamma_in = gamma.permute(0, 2, 3, 1).contiguous()
        delta_in = delta.permute(0, 2, 3, 1).contiguous()
        G_in = G.permute(0, 2, 3, 1).contiguous()
        kpv_in = kpv.permute(0, 2, 3, 1).contiguous()

        # 2. Allocate Outputs
        out_lambda = torch.empty_like(alpha_in)
        out_H = torch.empty_like(G_in)

        # 3. Kernel Launch Config
        n_scans = B * V * Q
        BLOCK_SIZE = triton.next_power_of_2(L)
        num_warps = 4 if BLOCK_SIZE <= 1024 else 8

        kla_scan_forward_kernel[(n_scans,)](
            alpha_in,
            beta_in,
            gamma_in,
            delta_in,
            G_in,
            kpv_in,
            out_lambda,
            out_H,
            n_scans,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        # 4. Restore Shapes (B, V, Q, L) -> (B, L, V, Q)
        Lambda_out = out_lambda.permute(0, 3, 1, 2).contiguous()
        H_out = out_H.permute(0, 3, 1, 2).contiguous()

        # Save tensors for backward pass
        ctx.save_for_backward(alpha, beta, gamma, delta, G, kpv)

        return Lambda_out, H_out

    @staticmethod
    def backward(ctx, dLambda, dH):
        # Retrieve saved inputs
        alpha, beta, gamma, delta, G, kpv = ctx.saved_tensors

        raise NotImplementedError


def kla_scan(
    alpha: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    beta: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    gamma: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    delta: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    G: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    kpv: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
):
    """
    Computes the hidden state of the Precision SSM and then Information SSM.
    Precision SSM:
        Lambda_t = (alpha_t * Lambda_{t-1} + beta_t) / (gamma_t * Lambda_{t-1} + delta_t)
    Information SSM:
        H_t = G_t * H_{t-1} + kpv_t
    """
    return KLAScanFn.apply(alpha, beta, gamma, delta, G, kpv)


def ref(
    alpha: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    beta: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    gamma: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    delta: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    G: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
    kpv: torch.Tensor,  # batch_size, seq_len, v_dim, qk_dim
):
    """
    Computes the hidden state of the Precision SSM and then Information SSM.
    Precision SSM:
        Lambda_t = (alpha_t * Lambda_{t-1} + beta_t) / (gamma_t * Lambda_{t-1} + delta_t)
    Information SSM:
        H_t = G_t * H_{t-1} + kpv_t
    """
    batch_size, seq_len, v_dim, qk_dim = alpha.shape

    alpha = alpha.transpose(0, 1).contiguous()
    beta = beta.transpose(0, 1).contiguous()
    gamma = gamma.transpose(0, 1).contiguous()
    delta = delta.transpose(0, 1).contiguous()
    G = G.transpose(0, 1).contiguous()
    kpv = kpv.transpose(0, 1).contiguous()

    # Precision SSM
    Lambda = torch.zeros((seq_len, batch_size, v_dim, qk_dim), device=alpha.device)
    for t in range(seq_len):
        if t == 0:
            Lambda[t] = beta[t] / delta[t]
        else:
            Lambda[t] = (alpha[t] * Lambda[t - 1] + beta[t]) / (
                gamma[t] * Lambda[t - 1] + delta[t]
            )
    Lambda = Lambda.transpose(0, 1).contiguous()

    # Information SSM
    H = torch.zeros((seq_len, batch_size, v_dim, qk_dim), device=alpha.device)
    for t in range(seq_len):
        if t == 0:
            H[t] = kpv[t]
        else:
            H[t] = G[t] * H[t - 1] + kpv[t]
    H = H.transpose(0, 1).contiguous()

    return Lambda, H


if __name__ == "__main__":
    # Setup
    torch.manual_seed(0)
    batch_size, seq_len, v_dim, qk_dim = 2, 128, 64, 16
    device = "cuda"
    test_backward = False

    # Inputs with Gradient
    alpha = torch.rand(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    beta = torch.rand(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    gamma = torch.log(
        0.1
        * torch.rand(
            batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
        )
    )
    gamma.retain_grad()
    delta = torch.rand(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    G = torch.rand(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    kpv = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )

    # 1. Triton Forward + Backward

    out_triton1, out_triton2 = kla_scan(alpha, beta, gamma, delta, G, kpv)
    out_triton = out_triton1 + out_triton2
    if test_backward:
        loss_triton = out_triton.sum()
        loss_triton.backward()
        grad_alpha_triton = alpha.grad.clone()
        grad_beta_triton = beta.grad.clone()
        grad_gamma_triton = gamma.grad.clone()
        grad_delta_triton = delta.grad.clone()
        grad_G_triton = G.grad.clone()
        grad_kpv_triton = kpv.grad.clone()

    # Reset Grads
    alpha.grad = None
    beta.grad = None
    gamma.grad = None
    delta.grad = None
    G.grad = None
    kpv.grad = None

    # 2. PyTorch Reference Forward + Backward
    out_ref1, out_ref2 = ref(alpha, beta, gamma, delta, G, kpv)
    out_ref = out_ref1 + out_ref2
    if test_backward:
        loss_ref = out_ref.sum()
        loss_ref.backward()
        grad_alpha_ref = alpha.grad.clone()
        grad_beta_ref = beta.grad.clone()
        grad_gamma_ref = gamma.grad.clone()
        grad_delta_ref = delta.grad.clone()
        grad_G_ref = G.grad.clone()
        grad_kpv_ref = kpv.grad.clone()

    # 3. Compare
    print(f"Forward diff: {torch.max(torch.abs(out_triton - out_ref)).item()}")
    if test_backward:
        print(
            f"Grad alpha diff: {torch.max(torch.abs(grad_alpha_triton - grad_alpha_ref)).item()}"
        )
        print(
            f"Grad beta diff: {torch.max(torch.abs(grad_beta_triton - grad_beta_ref)).item()}"
        )
        print(
            f"Grad gamma diff: {torch.max(torch.abs(grad_gamma_triton - grad_gamma_ref)).item()}"
        )
        print(
            f".Grad delta diff: {torch.max(torch.abs(grad_delta_triton - grad_delta_ref)).item()}"
        )
        print(f"Grad G diff: {torch.max(torch.abs(grad_G_triton - grad_G_ref)).item()}")
        print(
            f"Grad kpv diff: {torch.max(torch.abs(grad_kpv_triton - grad_kpv_ref)).item()}"
        )

    assert torch.allclose(out_triton, out_ref, atol=1e-4)
    if test_backward:
        assert torch.allclose(grad_alpha_triton, grad_alpha_ref, atol=1e-4)
        assert torch.allclose(grad_beta_triton, grad_beta_ref, atol=1e-4)
        assert torch.allclose(grad_gamma_triton, grad_gamma_ref, atol=1e-4)
        assert torch.allclose(grad_delta_triton, grad_delta_ref, atol=1e-4)
        assert torch.allclose(grad_G_triton, grad_G_ref, atol=1e-4)
        assert torch.allclose(grad_kpv_triton, grad_kpv_ref, atol=1e-4)

    print("Test Passed: Triton backward matches PyTorch autograd.")
