import torch
import triton
import triton.language as tl
from time import perf_counter


@triton.jit
def logaddexp(x, y):
    """
    Numerically stable logaddexp: log(exp(x) + exp(y))
    Computes: max(x, y) + log(1 + exp(-abs(x - y)))
    """
    max_val = tl.maximum(x, y)
    min_val = tl.minimum(x, y)
    return max_val + tl.log(1.0 + tl.exp(min_val - max_val))


@triton.jit
def combine_fn(
    left_alpha,
    left_beta,
    left_gamma,
    left_delta,
    right_alpha,
    right_beta,
    right_gamma,
    right_delta,
):
    """Log-Space Matrix Multiplication: A_new = A_R * A_L"""
    new_alpha = logaddexp(right_alpha + left_alpha, right_beta + left_gamma)
    new_beta = logaddexp(right_alpha + left_beta, right_beta + left_delta)
    new_gamma = logaddexp(right_gamma + left_alpha, right_delta + left_gamma)
    new_delta = logaddexp(right_gamma + left_beta, right_delta + left_delta)
    return new_alpha, new_beta, new_gamma, new_delta


@triton.jit
def log_mobius_scan_forward_kernel(
    alpha_ptr,
    beta_ptr,
    gamma_ptr,
    delta_ptr,
    acc_alpha_ptr,
    acc_beta_ptr,
    acc_gamma_ptr,
    acc_delta_ptr,
    out_ptr,
    n_scans,
    SEQ_LEN,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * SEQ_LEN
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = tl.arange(0, BLOCK_SIZE) < SEQ_LEN

    # Load inputs
    alpha = tl.load(alpha_ptr + offsets, mask=mask, other=0.0)
    beta = tl.load(beta_ptr + offsets, mask=mask, other=-float("inf"))
    gamma = tl.load(gamma_ptr + offsets, mask=mask, other=-float("inf"))
    delta = tl.load(delta_ptr + offsets, mask=mask, other=0.0)

    # Scan
    res_alpha, res_beta, res_gamma, res_delta = tl.associative_scan(
        (alpha, beta, gamma, delta), axis=0, combine_fn=combine_fn
    )

    # Save cumulative state
    tl.store(acc_alpha_ptr + offsets, res_alpha, mask=mask)
    tl.store(acc_beta_ptr + offsets, res_beta, mask=mask)
    tl.store(acc_gamma_ptr + offsets, res_gamma, mask=mask)
    tl.store(acc_delta_ptr + offsets, res_delta, mask=mask)

    # Projection: Lambda = beta - delta
    # Assuming init is zeros
    # Lambda_out = res_beta - res_delta
    # Assuming init is ones
    Lambda_out = logaddexp(res_alpha, res_beta) - logaddexp(res_gamma, res_delta)
    tl.store(out_ptr + offsets, Lambda_out, mask=mask)
        

@triton.jit
def combine_linear_scan(
    l_w1_aa,
    l_w1_ac,
    l_w1_ca,
    l_w1_cc,
    l_w2_bb,
    l_w2_bd,
    l_w2_db,
    l_w2_dd,
    l_b_a,
    l_b_b,
    l_b_c,
    l_b_d,
    r_w1_aa,
    r_w1_ac,
    r_w1_ca,
    r_w1_cc,
    r_w2_bb,
    r_w2_bd,
    r_w2_db,
    r_w2_dd,
    r_b_a,
    r_b_b,
    r_b_c,
    r_b_d,
):
    """
    Linear Scan Combine for Backward Pass.
    W_new = W_R * W_L
    b_new = W_R * b_L + b_R
    """
    # 1. Multiply Matrices W_new = W_R * W_L
    new_w1_aa = r_w1_aa * l_w1_aa + r_w1_ac * l_w1_ca
    new_w1_ac = r_w1_aa * l_w1_ac + r_w1_ac * l_w1_cc
    new_w1_ca = r_w1_ca * l_w1_aa + r_w1_cc * l_w1_ca
    new_w1_cc = r_w1_ca * l_w1_ac + r_w1_cc * l_w1_cc

    new_w2_bb = r_w2_bb * l_w2_bb + r_w2_bd * l_w2_db
    new_w2_bd = r_w2_bb * l_w2_bd + r_w2_bd * l_w2_dd
    new_w2_db = r_w2_db * l_w2_bb + r_w2_dd * l_w2_db
    new_w2_dd = r_w2_db * l_w2_bd + r_w2_dd * l_w2_dd

    # 2. Update Bias b_new = W_R * b_L + b_R
    new_b_a = r_b_a + (r_w1_aa * l_b_a + r_w1_ac * l_b_c)
    new_b_c = r_b_c + (r_w1_ca * l_b_a + r_w1_cc * l_b_c)
    new_b_b = r_b_b + (r_w2_bb * l_b_b + r_w2_bd * l_b_d)
    new_b_d = r_b_d + (r_w2_db * l_b_b + r_w2_dd * l_b_d)

    return (
        new_w1_aa,
        new_w1_ac,
        new_w1_ca,
        new_w1_cc,
        new_w2_bb,
        new_w2_bd,
        new_w2_db,
        new_w2_dd,
        new_b_a,
        new_b_b,
        new_b_c,
        new_b_d,
    )


@triton.jit
def log_mobius_scan_backward_kernel(
    dout_ptr,
    dalpha_ptr,
    dbeta_ptr,
    dgamma_ptr,
    ddelta_ptr,
    alpha_ptr,
    beta_ptr,
    gamma_ptr,
    delta_ptr,
    acc_alpha_ptr,
    acc_beta_ptr,
    acc_gamma_ptr,
    acc_delta_ptr,
    n_scans,
    SEQ_LEN,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * SEQ_LEN

    # Reverse Scan Setup
    range_offs = tl.arange(0, BLOCK_SIZE)
    scan_mask = range_offs < SEQ_LEN
    rev_offsets = block_start + (SEQ_LEN - 1) - range_offs

    # Load Inputs M_t and Accumulators P_t (Current Step t)
    alpha = tl.load(alpha_ptr + rev_offsets, mask=scan_mask, other=0.0)
    beta = tl.load(beta_ptr + rev_offsets, mask=scan_mask, other=-float("inf"))
    gamma = tl.load(gamma_ptr + rev_offsets, mask=scan_mask, other=-float("inf"))
    delta = tl.load(delta_ptr + rev_offsets, mask=scan_mask, other=0.0)

    acc_a = tl.load(acc_alpha_ptr + rev_offsets, mask=scan_mask, other=0.0)
    acc_b = tl.load(acc_beta_ptr + rev_offsets, mask=scan_mask, other=-float("inf"))
    acc_c = tl.load(acc_gamma_ptr + rev_offsets, mask=scan_mask, other=-float("inf"))
    acc_d = tl.load(acc_delta_ptr + rev_offsets, mask=scan_mask, other=0.0)

    # Load Previous Accumulators P_{t-1}
    prev_rev_offsets = rev_offsets - 1
    has_prev = (SEQ_LEN - 1 - range_offs) > 0
    mask_prev = scan_mask & has_prev

    prev_a = tl.load(acc_alpha_ptr + prev_rev_offsets, mask=mask_prev, other=0.0)
    prev_b = tl.load(
        acc_beta_ptr + prev_rev_offsets, mask=mask_prev, other=-float("inf")
    )
    prev_c = tl.load(
        acc_gamma_ptr + prev_rev_offsets, mask=mask_prev, other=-float("inf")
    )
    prev_d = tl.load(acc_delta_ptr + prev_rev_offsets, mask=mask_prev, other=0.0)

    # --- Step 1: Prepare Inputs for Associative Scan (Transition Matrix J_{t+1}) ---
    j_off = rev_offsets + 1
    j_mask = scan_mask & (range_offs > 0)

    j_alpha = tl.load(alpha_ptr + j_off, mask=j_mask, other=0.0)
    j_beta = tl.load(beta_ptr + j_off, mask=j_mask, other=-float("inf"))
    j_gamma = tl.load(gamma_ptr + j_off, mask=j_mask, other=-float("inf"))
    j_delta = tl.load(delta_ptr + j_off, mask=j_mask, other=0.0)

    j_next_a = tl.load(acc_alpha_ptr + j_off, mask=j_mask, other=0.0)
    j_next_b = tl.load(acc_beta_ptr + j_off, mask=j_mask, other=-float("inf"))
    j_next_c = tl.load(acc_gamma_ptr + j_off, mask=j_mask, other=-float("inf"))
    j_next_d = tl.load(acc_delta_ptr + j_off, mask=j_mask, other=0.0)

    # Compute Softmax Weights for t+1
    # Block A/C
    jt_term_aa = j_alpha + acc_a  # alpha_{t+1} + A_t
    jt_term_ac = j_beta + acc_c  # beta_{t+1} + C_t

    w_next_aa = tl.exp(jt_term_aa - j_next_a)
    w_next_ac = tl.exp(jt_term_ac - j_next_a)

    jt_term_ca = j_gamma + acc_a
    jt_term_cc = j_delta + acc_c
    w_next_ca = tl.exp(jt_term_ca - j_next_c)
    w_next_cc = tl.exp(jt_term_cc - j_next_c)

    # Block B/D
    jt_term_bb = j_alpha + acc_b
    jt_term_bd = j_beta + acc_d
    w_next_bb = tl.exp(jt_term_bb - j_next_b)
    w_next_bd = tl.exp(jt_term_bd - j_next_b)

    jt_term_db = j_gamma + acc_b
    jt_term_dd = j_delta + acc_d
    w_next_db = tl.exp(jt_term_db - j_next_d)
    w_next_dd = tl.exp(jt_term_dd - j_next_d)

    # Mask out
    w_next_aa = tl.where(j_mask, w_next_aa, 0.0)
    w_next_ac = tl.where(j_mask, w_next_ac, 0.0)
    w_next_ca = tl.where(j_mask, w_next_ca, 0.0)
    w_next_cc = tl.where(j_mask, w_next_cc, 0.0)
    w_next_bb = tl.where(j_mask, w_next_bb, 0.0)
    w_next_bd = tl.where(j_mask, w_next_bd, 0.0)
    w_next_db = tl.where(j_mask, w_next_db, 0.0)
    w_next_dd = tl.where(j_mask, w_next_dd, 0.0)

    # Transpose Matrices for the Scan: J^T
    t_waa, t_wac, t_wca, t_wcc = w_next_aa, w_next_ca, w_next_ac, w_next_cc
    t_wbb, t_wbd, t_wdb, t_wdd = w_next_bb, w_next_db, w_next_bd, w_next_dd

    # Load Direct Gradients (Bias)
    dout = tl.load(dout_ptr + rev_offsets, mask=scan_mask, other=0.0)
    # Assuming init is zeros
    # grad_bias_a = tl.zeros_like(dout)
    # grad_bias_b = dout
    # grad_bias_c = tl.zeros_like(dout)
    # grad_bias_d = -dout
    # Assuming init is ones
    grad_bias_a = dout * tl.sigmoid(acc_a - acc_b)
    grad_bias_b = dout * tl.sigmoid(acc_b - acc_a)
    grad_bias_c = -dout * tl.sigmoid(acc_c - acc_d)
    grad_bias_d = -dout * tl.sigmoid(acc_d - acc_c)

    # --- Step 2: Perform Reverse Scan ---
    # !!! FIX: Correctly unpack the 12-element tuple.
    # The last 4 elements are the accumulated biases (gradients).
    _m1, _m2, _m3, _m4, _m5, _m6, _m7, _m8, total_da, total_db, total_dc, total_dd = (
        tl.associative_scan(
            (
                t_waa,
                t_wac,
                t_wca,
                t_wcc,
                t_wbb,
                t_wbd,
                t_wdb,
                t_wdd,
                grad_bias_a,
                grad_bias_b,
                grad_bias_c,
                grad_bias_d,
            ),
            axis=0,
            combine_fn=combine_linear_scan,
        )
    )

    # --- Step 3: Compute Gradients w.r.t Inputs ---
    # A/C Block Weights
    term_aa = alpha + prev_a
    term_ac = beta + prev_c

    w_curr_aa = tl.exp(term_aa - acc_a)
    w_curr_ac = tl.exp(term_ac - acc_a)

    term_ca = gamma + prev_a
    term_cc = delta + prev_c
    w_curr_ca = tl.exp(term_ca - acc_c)
    w_curr_cc = tl.exp(term_cc - acc_c)

    # B/D Block Weights
    term_bb = alpha + prev_b
    term_bd = beta + prev_d
    w_curr_bb = tl.exp(term_bb - acc_b)
    w_curr_bd = tl.exp(term_bd - acc_b)

    term_db = gamma + prev_b
    term_dd = delta + prev_d
    w_curr_db = tl.exp(term_db - acc_d)
    w_curr_dd = tl.exp(term_dd - acc_d)

    # Gradients w.r.t Inputs
    d_alpha = total_da * w_curr_aa + total_db * w_curr_bb
    d_beta = total_da * w_curr_ac + total_db * w_curr_bd
    d_gamma = total_dc * w_curr_ca + total_dd * w_curr_db
    d_delta = total_dc * w_curr_cc + total_dd * w_curr_dd

    # Store
    tl.store(dalpha_ptr + rev_offsets, d_alpha, mask=scan_mask)
    tl.store(dbeta_ptr + rev_offsets, d_beta, mask=scan_mask)
    tl.store(dgamma_ptr + rev_offsets, d_gamma, mask=scan_mask)
    tl.store(ddelta_ptr + rev_offsets, d_delta, mask=scan_mask)


class LogMobiusScanFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, alpha, beta, gamma, delta):
        B_dim, L, V, Q = alpha.shape
        n_scans = B_dim * V * Q

        alpha_in = alpha.permute(0, 2, 3, 1).contiguous().view(n_scans, L)
        beta_in = beta.permute(0, 2, 3, 1).contiguous().view(n_scans, L)
        gamma_in = gamma.permute(0, 2, 3, 1).contiguous().view(n_scans, L)
        delta_in = delta.permute(0, 2, 3, 1).contiguous().view(n_scans, L)

        out = torch.empty_like(alpha_in)
        acc_alpha = torch.empty_like(alpha_in)
        acc_beta = torch.empty_like(beta_in)
        acc_gamma = torch.empty_like(gamma_in)
        acc_delta = torch.empty_like(delta_in)

        BLOCK_SIZE = triton.next_power_of_2(L)
        BLOCK_SIZE = max(BLOCK_SIZE, 32)
        num_warps = 8 if BLOCK_SIZE >= 2048 else 4

        log_mobius_scan_forward_kernel[(n_scans,)](
            alpha_in,
            beta_in,
            gamma_in,
            delta_in,
            acc_alpha,
            acc_beta,
            acc_gamma,
            acc_delta,
            out,
            n_scans,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        ctx.save_for_backward(
            alpha_in,
            beta_in,
            gamma_in,
            delta_in,
            acc_alpha,
            acc_beta,
            acc_gamma,
            acc_delta,
        )
        ctx.dims = (n_scans, L, BLOCK_SIZE, num_warps)
        ctx.orig_shape = (B_dim, L, V, Q)

        return out.view(B_dim, V, Q, L).permute(0, 3, 1, 2).contiguous()

    @staticmethod
    def backward(ctx, dout):
        alpha, beta, gamma, delta, acc_alpha, acc_beta, acc_gamma, acc_delta = (
            ctx.saved_tensors
        )
        n_scans, L, BLOCK_SIZE, num_warps = ctx.dims
        B_dim, _, V, Q = ctx.orig_shape

        dout_in = dout.permute(0, 2, 3, 1).contiguous().view(n_scans, L)

        dalpha = torch.empty_like(alpha)
        dbeta = torch.empty_like(beta)
        dgamma = torch.empty_like(gamma)
        ddelta = torch.empty_like(delta)

        log_mobius_scan_backward_kernel[(n_scans,)](
            dout_in,
            dalpha,
            dbeta,
            dgamma,
            ddelta,
            alpha,
            beta,
            gamma,
            delta,
            acc_alpha,
            acc_beta,
            acc_gamma,
            acc_delta,
            n_scans,
            L,
            BLOCK_SIZE=BLOCK_SIZE,
            num_warps=num_warps,
        )

        def restore(t):
            return t.view(B_dim, V, Q, L).permute(0, 3, 1, 2).contiguous()

        return restore(dalpha), restore(dbeta), restore(dgamma), restore(ddelta)


def log_mobius_scan(alpha, beta, gamma, delta):
    return LogMobiusScanFn.apply(alpha, beta, gamma, delta)


def ref(alpha, beta, gamma, delta):
    """PyTorch reference implementation."""
    seq_len = alpha.shape[1]

    # Transpose for time-first iteration
    alpha = alpha.transpose(0, 1)
    beta = beta.transpose(0, 1)
    gamma = gamma.transpose(0, 1)
    delta = delta.transpose(0, 1)

    preds = []

    curr_alpha = torch.zeros_like(alpha[0])  # log(1)
    curr_beta = torch.full_like(beta[0], -float("inf"))  # log(0)
    curr_gamma = torch.full_like(gamma[0], -float("inf"))
    curr_delta = torch.zeros_like(delta[0])

    for t in range(len(alpha)):
        # New Matrix M_t
        a_t, b_t, c_t, d_t = alpha[t], beta[t], gamma[t], delta[t]

        # Update Accumulator: M_new = M_t @ M_prev
        # Log-Space MatMul

        next_alpha = torch.logaddexp(a_t + curr_alpha, b_t + curr_gamma)
        next_beta = torch.logaddexp(a_t + curr_beta, b_t + curr_delta)
        next_gamma = torch.logaddexp(c_t + curr_alpha, d_t + curr_gamma)
        next_delta = torch.logaddexp(c_t + curr_beta, d_t + curr_delta)

        curr_alpha, curr_beta, curr_gamma, curr_delta = (
            next_alpha,
            next_beta,
            next_gamma,
            next_delta,
        )

        # Project: h_t = curr_beta - curr_delta
        Lambda_out = torch.logaddexp(curr_alpha, curr_beta) - torch.logaddexp(
            curr_gamma, curr_delta
        )
        preds.append(Lambda_out)

    return torch.stack(preds, dim=1)


def benchmark_mobius_scan(
    name,
    func,
    alpha,
    beta,
    gamma,
    delta,
    warmup_iters=20,
    test_iters=100,
    backward=True,
):
    """
    Benchmarks the Mobius scan function for forward and forward+backward passes.
    """

    # --- Forward Benchmark ---
    # Warmup
    for _ in range(warmup_iters):
        with torch.no_grad():
            _ = func(alpha, beta, gamma, delta)
    torch.cuda.synchronize()

    # Timing
    start_time = perf_counter()
    for _ in range(test_iters):
        with torch.no_grad():
            _ = func(alpha, beta, gamma, delta)
    torch.cuda.synchronize()
    end_time = perf_counter()

    print(
        f"{name} forward time: {(end_time - start_time) * 1000 / test_iters:.6f} milliseconds per iteration"
    )

    # --- Backward Benchmark ---
    if backward:
        # Ensure inputs require grad
        for t in [alpha, beta, gamma, delta]:
            t.requires_grad_(True)
            t.grad = None

        # Warmup
        for _ in range(warmup_iters):
            out = func(alpha, beta, gamma, delta)
            loss = out.sum()
            loss.backward()
            # Zero grads manually since these are tensors, not model params
            for t in [alpha, beta, gamma, delta]:
                t.grad = None
        torch.cuda.synchronize()

        # Timing
        start_time = perf_counter()
        for _ in range(test_iters):
            out = func(alpha, beta, gamma, delta)
            loss = out.sum()
            loss.backward()
            # Zero grads
            for t in [alpha, beta, gamma, delta]:
                t.grad = None
        torch.cuda.synchronize()
        end_time = perf_counter()

        print(
            f"{name} forward+backward time: {(end_time - start_time) * 1000 / test_iters:.6f} milliseconds per iteration"
        )


if __name__ == "__main__":
    # Setup
    torch.manual_seed(0)
    # Using slightly smaller sequence length for reference comparison speed
    batch_size, seq_len, v_dim, qk_dim = 4, 128, 512, 16
    device = "cuda"
    test_backward = True

    # Inputs with Gradient
    # Standard normal is fine for log-space (values can be negative/positive)
    alpha = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    beta = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    gamma = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    delta = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )

    # ==========================================
    # 4. Correctness Check
    # ==========================================

    # 1. Triton Forward + Backward
    print("Running Log Mobius Scan Triton Test...")
    print("Running forward pass...")
    out_triton = log_mobius_scan(alpha, beta, gamma, delta)
    print("Forward pass completed.")

    if out_triton.isnan().any():
        print("Error: Triton output contains NaNs.")
    if out_triton.isinf().any():
        print("Error: Triton output contains Infs.")

    if test_backward:
        print("Running backward pass...")
        loss_triton = out_triton.sum()
        loss_triton.backward()
        print("Backward pass completed.")
        grad_alpha_triton = alpha.grad.clone()
        grad_beta_triton = beta.grad.clone()
        grad_gamma_triton = gamma.grad.clone()
        grad_delta_triton = delta.grad.clone()

    # Reset Grads
    alpha.grad = None
    beta.grad = None
    gamma.grad = None
    delta.grad = None

    # 2. PyTorch Reference Forward + Backward
    print("Running Log Mobius Scan Reference Test...")
    print("Running forward pass...")
    out_ref = ref(alpha, beta, gamma, delta)
    print("Forward pass completed.")

    if out_ref.isnan().any():
        print("Error: Reference output contains NaNs.")
    if out_ref.isinf().any():
        print("Error: Reference output contains Infs.")

    if test_backward:
        print("Running backward pass...")
        loss_ref = out_ref.sum()
        loss_ref.backward()
        print("Backward pass completed.")
        grad_alpha_ref = alpha.grad.clone()
        grad_beta_ref = beta.grad.clone()
        grad_gamma_ref = gamma.grad.clone()
        grad_delta_ref = delta.grad.clone()

    # 3. Compare
    print(f"Forward diff: {torch.max(torch.abs(out_triton - out_ref)).item()}")
    if test_backward:
        print(
            f"Grad alpha diff: {torch.max(torch.abs(grad_alpha_triton - grad_alpha_ref)).item()}"
        )
        print(
            f"Grad beta diff: {torch.max(torch.abs(grad_beta_triton - grad_beta_ref)).item()}"
        )
        print(
            f"Grad gamma diff: {torch.max(torch.abs(grad_gamma_triton - grad_gamma_ref)).item()}"
        )
        print(
            f"Grad delta diff: {torch.max(torch.abs(grad_delta_triton - grad_delta_ref)).item()}"
        )

    # Note: Log-space ops can be slightly less precise due to exp/log round trips.
    if torch.allclose(out_triton, out_ref, atol=1e-3):
        print("Test Passed: Triton forward matches PyTorch reference.")
    else:
        print("Test Failed: Triton forward does not match PyTorch reference.")

    if test_backward:
        if torch.allclose(grad_alpha_triton, grad_alpha_ref, atol=1e-3):
            print("Test Passed: Grad alpha matches.")
        else:
            print("Test Failed: Grad alpha does not match.")
        if torch.allclose(grad_beta_triton, grad_beta_ref, atol=1e-3):
            print("Test Passed: Grad beta matches.")
        else:
            print("Test Failed: Grad beta does not match.")
        if torch.allclose(grad_gamma_triton, grad_gamma_ref, atol=1e-3):
            print("Test Passed: Grad gamma matches.")
        else:
            print("Test Failed: Grad gamma does not match.")
        if torch.allclose(grad_delta_triton, grad_delta_ref, atol=1e-3):
            print("Test Passed: Grad delta matches.")
        else:
            print("Test Failed: Grad delta does not match.")

    print("-" * 50)

    # ==========================================
    # 5. Performance Benchmarks
    # ==========================================

    print("Starting Performance Benchmarks...")

    # Increase sequence length for meaningful benchmarking
    # We use new tensors to avoid accumulated history issues
    alpha_bench = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    beta_bench = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    gamma_bench = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )
    delta_bench = torch.randn(
        batch_size, seq_len, v_dim, qk_dim, device=device, requires_grad=True
    )

    # Benchmark Triton
    benchmark_mobius_scan(
        name="Triton Implementation",
        func=log_mobius_scan,
        alpha=alpha_bench,
        beta=beta_bench,
        gamma=gamma_bench,
        delta=delta_bench,
        warmup_iters=5,
        test_iters=20,
        backward=test_backward,
    )

    print("-" * 30)

    # Benchmark Reference
    # Reference is O(L) Python loop, so it will be slow. Reduced iters.
    benchmark_mobius_scan(
        name="PyTorch Reference",
        func=ref,
        alpha=alpha_bench,
        beta=beta_bench,
        gamma=gamma_bench,
        delta=delta_bench,
        warmup_iters=1,
        test_iters=3,
        backward=test_backward,
    )
