import torch
import triton
import triton.language as tl
from time import perf_counter


@triton.jit
def combine_fn(
    left_alpha,
    left_beta,
    left_gamma,
    left_delta,
    right_alpha,
    right_beta,
    right_gamma,
    right_delta,
):
    """
    Group composition for the fractional linear recurrence (Mobius transformation).
    Lambda_t = (alpha_t * Lambda_{t-1} + beta_t) / (gamma_t * Lambda_{t-1} + delta_t)
    """
    new_alpha = right_alpha * left_alpha + right_beta * left_gamma
    new_beta = right_alpha * left_beta + right_beta * left_delta
    new_gamma = right_gamma * left_alpha + right_delta * left_gamma
    new_delta = right_gamma * left_beta + right_delta * left_delta
    return new_alpha, new_beta, new_gamma, new_delta


@triton.jit
def mobius_scan_forward_kernel(
    alpha_ptr,
    beta_ptr,
    gamma_ptr,
    delta_ptr,
    out_ptr,
    n_scans,  # Total independent sequences (B * V * Q)
    SEQ_LEN,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    # Flattened offset: (n_scans, SEQ_LEN)
    block_start = pid * SEQ_LEN
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = tl.arange(0, BLOCK_SIZE) < SEQ_LEN

    # Default identity for Mobius: alpha=1, delta=1, beta=0, gamma=0
    alpha = tl.load(alpha_ptr + offsets, mask=mask, other=1.0)
    beta = tl.load(beta_ptr + offsets, mask=mask, other=0.0)
    gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
    delta = tl.load(delta_ptr + offsets, mask=mask, other=1.0)

    # Parallel associative scan
    res_alpha, res_beta, res_gamma, res_delta = tl.associative_scan(
        (alpha, beta, gamma, delta), axis=0, combine_fn=combine_fn
    )

    # Final projection: Lambda_t = (alpha*h_0 + beta)/(gamma*h_0 + delta) where h_0 = 0
    Lambda_out = res_beta / res_delta

    tl.store(out_ptr + offsets, Lambda_out, mask=mask)


@triton.jit
def combine_linear_scan(a1, b1, a2, b2):
    """
    Composition for linear recurrence: y_t = a_t * y_{t+1} + b_t
    (a_L, b_L) o (a_R, b_R) -> (a_L * a_R, a_L * b_R + b_L)
    """
    return a1 * a2, a2 * b1 + b2


@triton.jit
def mobius_scan_backward_kernel(
    dout_ptr,
    alpha_ptr,
    beta_ptr,
    gamma_ptr,
    delta_ptr,
    out_ptr,
    dalpha_ptr,
    dbeta_ptr,
    dgamma_ptr,
    ddelta_ptr,
    SEQ_LEN,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Backward pass via reverse linear scan.
    Adjoint recurrence: g_{t-1} = dout_{t-1} + g_t * (d Lambda_t / d Lambda_{t-1})
    """
    pid = tl.program_id(0)
    block_start = pid * SEQ_LEN

    # Reverse offsets for backward pass: t = L-1 ... 0
    rng = tl.arange(0, BLOCK_SIZE)
    offsets_t = block_start + (SEQ_LEN - 1) - rng
    mask_t = rng < SEQ_LEN

    # --- 1. Compute Recurrence Weights (W_{t+1}) ---
    # W_{t+1} = d(Lambda_{t+1}) / d(Lambda_t)
    # This requires params at t+1 and state at t.

    offsets_tp1 = offsets_t + 1
    # Check bounds: t+1 must be < L
    mask_tp1 = mask_t & (offsets_tp1 < (block_start + SEQ_LEN))

    alpha_tp1 = tl.load(alpha_ptr + offsets_tp1, mask=mask_tp1, other=0.0)
    beta_tp1 = tl.load(beta_ptr + offsets_tp1, mask=mask_tp1, other=0.0)
    gamma_tp1 = tl.load(gamma_ptr + offsets_tp1, mask=mask_tp1, other=0.0)
    delta_tp1 = tl.load(delta_ptr + offsets_tp1, mask=mask_tp1, other=1.0)

    Lambda_t = tl.load(out_ptr + offsets_t, mask=mask_t, other=0.0)

    # Gradient of fractional linear step w.r.t previous hidden state
    denom_tp1 = gamma_tp1 * Lambda_t + delta_tp1
    det_tp1 = alpha_tp1 * delta_tp1 - beta_tp1 * gamma_tp1
    W_tp1 = det_tp1 / (denom_tp1 * denom_tp1)

    # --- 2. Accumulate Gradients via Scan ---
    dout_t = tl.load(dout_ptr + offsets_t, mask=mask_t, other=0.0)

    # Scan accumulates the chain rule gradients flowing backwards
    # total_grad is effectively dLoss / dLambda_t
    _, D_Lambda_t = tl.associative_scan(
        (W_tp1, dout_t), axis=0, combine_fn=combine_linear_scan
    )

    # --- 3. Compute Partials w.r.t Params ---
    alpha_t = tl.load(alpha_ptr + offsets_t, mask=mask_t, other=0.0)
    # beta_t, gamma_t, delta_t not explicitly needed for partials below except in denom

    # Need Lambda_{t-1} for partials at t
    offsets_tm1 = offsets_t - 1
    mask_tm1 = mask_t & (offsets_tm1 >= block_start)
    Lambda_tm1 = tl.load(out_ptr + offsets_tm1, mask=mask_tm1, other=0.0)

    # Common terms for derivatives
    # Lambda_t = (alpha * L_{t-1} + beta) / denom
    # We recompute denom_t using current params
    gamma_t = tl.load(gamma_ptr + offsets_t, mask=mask_t, other=0.0)
    delta_t = tl.load(delta_ptr + offsets_t, mask=mask_t, other=1.0)

    denom_t = gamma_t * Lambda_tm1 + delta_t
    inv_denom_t = 1.0 / denom_t

    # Partials: dLambda_t / d(param)
    dl_dalpha = Lambda_tm1 * inv_denom_t
    dl_dbeta = inv_denom_t
    dl_dgamma = -1.0 * Lambda_t * Lambda_tm1 * inv_denom_t
    dl_ddelta = -1.0 * Lambda_t * inv_denom_t

    # Chain rule to parameter gradients
    tl.store(dalpha_ptr + offsets_t, D_Lambda_t * dl_dalpha, mask=mask_t)
    tl.store(dbeta_ptr + offsets_t, D_Lambda_t * dl_dbeta, mask=mask_t)
    tl.store(dgamma_ptr + offsets_t, D_Lambda_t * dl_dgamma, mask=mask_t)
    tl.store(ddelta_ptr + offsets_t, D_Lambda_t * dl_ddelta, mask=mask_t)


class MobiusScanFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, alpha, beta, gamma, delta):
        # (B, L, V, Q) -> (B, V, Q, L) flattened layout for scanning
        B, L, V, Q = alpha.shape

        # Transpose to (B, V, Q, L) for contiguous memory access in scan dim
        alpha_in = alpha.permute(0, 2, 3, 1).contiguous()
        beta_in = beta.permute(0, 2, 3, 1).contiguous()
        gamma_in = gamma.permute(0, 2, 3, 1).contiguous()
        delta_in = delta.permute(0, 2, 3, 1).contiguous()

        out = torch.empty_like(alpha_in)

        n_scans = B * V * Q
        BLOCK_SIZE = max(triton.next_power_of_2(L), 32)
        num_warps = 8 if BLOCK_SIZE >= 2048 else 4

        mobius_scan_forward_kernel[(n_scans,)](
            alpha_in,
            beta_in,
            gamma_in,
            delta_in,
            out,
            n_scans,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        out_reshaped = out.permute(0, 3, 1, 2).contiguous()

        ctx.save_for_backward(alpha_in, beta_in, gamma_in, delta_in, out)
        ctx.dims = (n_scans, L, BLOCK_SIZE, num_warps)

        return out_reshaped

    @staticmethod
    def backward(ctx, dout):
        alpha, beta, gamma, delta, out = ctx.saved_tensors
        n_scans, L, BLOCK_SIZE, num_warps = ctx.dims

        dout_in = dout.permute(0, 2, 3, 1).contiguous()

        dalpha = torch.empty_like(alpha)
        dbeta = torch.empty_like(beta)
        dgamma = torch.empty_like(gamma)
        ddelta = torch.empty_like(delta)

        mobius_scan_backward_kernel[(n_scans,)](
            dout_in,
            alpha,
            beta,
            gamma,
            delta,
            out,
            dalpha,
            dbeta,
            dgamma,
            ddelta,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        return (
            dalpha.permute(0, 3, 1, 2).contiguous(),
            dbeta.permute(0, 3, 1, 2).contiguous(),
            dgamma.permute(0, 3, 1, 2).contiguous(),
            ddelta.permute(0, 3, 1, 2).contiguous(),
        )


def mobius_scan(alpha, beta, gamma, delta):
    return MobiusScanFn.apply(alpha, beta, gamma, delta)


def ref(alpha, beta, gamma, delta):
    """PyTorch reference implementation."""
    seq_len = alpha.shape[1]

    # Transpose for time-first iteration
    alpha = alpha.transpose(0, 1)
    beta = beta.transpose(0, 1)
    gamma = gamma.transpose(0, 1)
    delta = delta.transpose(0, 1)

    preds = []
    Lambda_t = torch.zeros_like(
        beta[0]
    )  # Implicit h_{-1} = 0 condition if needed, but logic below handles t=0

    for t in range(seq_len):
        Lambda_t = (alpha[t] * Lambda_t + beta[t]) / (gamma[t] * Lambda_t + delta[t])
        preds.append(Lambda_t)

    return torch.stack(preds, dim=1)


def benchmark_mobius_scan(
    name,
    func,
    alpha,
    beta,
    gamma,
    delta,
    warmup_iters=20,
    test_iters=100,
    backward=True,
):
    """
    Benchmarks the Mobius scan function for forward and forward+backward passes.
    """

    # --- Forward Benchmark ---
    # Warmup
    for _ in range(warmup_iters):
        with torch.no_grad():
            _ = func(alpha, beta, gamma, delta)
    torch.cuda.synchronize()

    # Timing
    start_time = perf_counter()
    for _ in range(test_iters):
        with torch.no_grad():
            _ = func(alpha, beta, gamma, delta)
    torch.cuda.synchronize()
    end_time = perf_counter()

    print(
        f"{name} forward time: {(end_time - start_time) * 1000 / test_iters:.6f} milliseconds per iteration"
    )

    # --- Backward Benchmark ---
    if backward:
        # Ensure inputs require grad
        for t in [alpha, beta, gamma, delta]:
            t.requires_grad_(True)
            t.grad = None

        # Warmup
        for _ in range(warmup_iters):
            out = func(alpha, beta, gamma, delta)
            loss = out.sum()
            loss.backward()
            # Zero grads manually since these are tensors, not model params
            for t in [alpha, beta, gamma, delta]:
                t.grad = None
        torch.cuda.synchronize()

        # Timing
        start_time = perf_counter()
        for _ in range(test_iters):
            out = func(alpha, beta, gamma, delta)
            loss = out.sum()
            loss.backward()
            # Zero grads
            for t in [alpha, beta, gamma, delta]:
                t.grad = None
        torch.cuda.synchronize()
        end_time = perf_counter()

        print(
            f"{name} forward+backward time: {(end_time - start_time) * 1000 / test_iters:.6f} milliseconds per iteration"
        )


if __name__ == "__main__":
    # Setup
    torch.manual_seed(0)
    batch_size, seq_len, v_dim, qk_dim = 4, 128, 512, 16
    device = "cuda"
    test_backward = True

    # Inputs with Gradient
    # Note: We detach/clone for benchmarks to ensure clean state if needed,
    # but reusing the same tensors is fine if we zero grads.
    # alpha = (
    #     torch.rand(
    #         batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    #     )
    #     * 0.01
    #     + 0.5
    # )
    # alpha.retain_grad()
    # beta = (
    #     torch.rand(
    #         batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    #     )
    #     * 0.01
    #     + 0.5
    # )
    # beta.retain_grad()
    # gamma = (
    #     torch.rand(
    #         batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    #     )
    #     * 0.01
    #     + 0.5
    # )
    # gamma.retain_grad()
    # delta = (
    #     torch.rand(
    #         batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    #     )
    #     * 0.01
    #     + 0.5
    # )
    # delta.retain_grad()
    alpha = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    alpha.retain_grad()
    beta = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    beta.retain_grad()
    gamma = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    gamma.retain_grad()
    delta = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    delta.retain_grad()

    # ==========================================
    # 1. Correctness Check (Existing Code)
    # ==========================================

    # 1. Triton Forward + Backward
    print("Running mobius Scan Test...")
    print("Running forward pass...")
    out_triton = mobius_scan(alpha, beta, gamma, delta)
    print("Forward pass completed.")

    if out_triton.isnan().any():
        print("Error: Triton output contains NaNs.")
    if out_triton.isinf().any():
        print("Error: Triton output contains Infs.")

    if test_backward:
        print("Running backward pass...")
        loss_triton = out_triton.sum()
        loss_triton.backward()
        print("Backward pass completed.")
        grad_alpha_triton = alpha.grad.detach().clone()
        grad_beta_triton = beta.grad.detach().clone()
        grad_gamma_triton = gamma.grad.detach().clone()
        grad_delta_triton = delta.grad.detach().clone()

    # Reset Grads
    alpha.grad = None
    beta.grad = None
    gamma.grad = None
    delta.grad = None

    # 2. PyTorch Reference Forward + Backward
    print("Running mobius Scan Reference Test...")

    alpha_ref_in = alpha.detach().clone().requires_grad_(True)
    beta_ref_in = beta.detach().clone().requires_grad_(True)
    gamma_ref_in = gamma.detach().clone().requires_grad_(True)
    delta_ref_in = delta.detach().clone().requires_grad_(True)
    print("Running forward pass...")
    out_ref = ref(alpha_ref_in, beta_ref_in, gamma_ref_in, delta_ref_in)
    print("Forward pass completed.")

    if out_ref.isnan().any():
        print("Error: Reference output contains NaNs.")
    if out_ref.isinf().any():
        print("Error: Reference output contains Infs.")

    if test_backward:
        print("Running backward pass...")
        loss_ref = out_ref.sum()
        loss_ref.backward()
        print("Backward pass completed.")
        grad_alpha_ref = alpha_ref_in.grad.detach().clone()
        grad_beta_ref = beta_ref_in.grad.detach().clone()
        grad_gamma_ref = gamma_ref_in.grad.detach().clone()
        grad_delta_ref = delta_ref_in.grad.detach().clone()

    # 3. Compare
    print(f"Forward diff: {torch.max(torch.abs(out_triton - out_ref)).item()}")
    if test_backward:
        print(
            f"Grad alpha diff: {torch.max(torch.abs(grad_alpha_triton - grad_alpha_ref)).item()}"
        )
        print(
            f"Grad beta diff: {torch.max(torch.abs(grad_beta_triton - grad_beta_ref)).item()}"
        )
        print(
            f"Grad gamma diff: {torch.max(torch.abs(grad_gamma_triton - grad_gamma_ref)).item()}"
        )
        print(
            f"Grad delta diff: {torch.max(torch.abs(grad_delta_triton - grad_delta_ref)).item()}"
        )

    if torch.allclose(out_triton, out_ref, atol=1e-4):
        print("Test Passed: Triton forward matches PyTorch reference.")
    else:
        print("Test Failed: Triton forward does not match PyTorch reference.")

    if test_backward:
        if torch.allclose(grad_alpha_triton, grad_alpha_ref, atol=1e-4):
            print("Test Passed: Grad alpha matches.")
        else:
            print("Test Failed: Grad alpha does not match.")
        if torch.allclose(grad_beta_triton, grad_beta_ref, atol=1e-4):
            print("Test Passed: Grad beta matches.")
        else:
            print("Test Failed: Grad beta does not match.")
        if torch.allclose(grad_gamma_triton, grad_gamma_ref, atol=1e-4):
            print("Test Passed: Grad gamma matches.")
        else:
            print("Test Failed: Grad gamma does not match.")
        if torch.allclose(grad_delta_triton, grad_delta_ref, atol=1e-4):
            print("Test Passed: Grad delta matches.")
        else:
            print("Test Failed: Grad delta does not match.")

    print("-" * 50)

    # ==========================================
    # 2. Performance Benchmarks
    # ==========================================

    print("Starting Performance Benchmarks...")

    # We create fresh inputs for benchmarking to avoid accumulated grad history affecting memory
    alpha_bench = alpha.detach().clone().requires_grad_(True)
    beta_bench = beta.detach().clone().requires_grad_(True)
    gamma_bench = gamma.detach().clone().requires_grad_(True)
    delta_bench = delta.detach().clone().requires_grad_(True)

    # Benchmark Triton
    benchmark_mobius_scan(
        name="Triton Implementation",
        func=mobius_scan,
        alpha=alpha_bench,
        beta=beta_bench,
        gamma=gamma_bench,
        delta=delta_bench,
        warmup_iters=5,
        test_iters=20,
        backward=test_backward,
    )

    print("-" * 30)

    # Benchmark Reference
    # (Optional: reduce iterations for Reference if it is significantly slower)
    benchmark_mobius_scan(
        name="PyTorch Reference",
        func=ref,
        alpha=alpha_bench,
        beta=beta_bench,
        gamma=gamma_bench,
        delta=delta_bench,
        warmup_iters=1,
        test_iters=3,
        backward=test_backward,
    )
