import torch
import math
import triton
import triton.language as tl

def forward_torch(
    q: torch.Tensor,  # (b, sq, d)
    k: torch.Tensor,  # (b, sk, d)
    v: torch.Tensor,  # (b, sk, d)
    inverse_bandwidth: float,
    ridge_lambda: float = 1,
    delta_eps: float = 1e-12,
):
    b, sq, d = q.shape
    sk = k.shape[1]
    row_offset = torch.arange(sq, device=q.device).view(-1, 1)
    col_offset = torch.arange(sk, device=k.device).view(1, -1)
    mask = row_offset >= col_offset
    logits = torch.einsum("bid,bjd->bij", q, k) * inverse_bandwidth
    logits = logits.masked_fill(~mask, float("-inf"))
    max_logits = logits.max(dim=-1, keepdim=True).values
    w = torch.exp(logits - max_logits)
    omega = w.sum(dim=-1, keepdim=True)
    z = k.unsqueeze(1) - q.unsqueeze(2)
    mu = torch.einsum("bij,bijd->bid", w, z)
    sigma = torch.einsum("bij,bijd,bije->bide", w, z, z) + ridge_lambda * torch.eye(d, device=k.device).view(1, 1, d, d)
    rho = torch.linalg.solve(sigma, mu)
    delta = omega - (mu * rho).sum(dim=-1, keepdim=True)
    delta = delta + delta_eps * torch.sign(delta)
    z_rho = torch.einsum("bijd,bid->bij", z, rho)
    s = (1 - z_rho) * w / delta
    output = s @ v
    return output, rho, max_logits, delta

@triton.jit
def _fwd_inner_kernel_pass_one(
    Q, K,
    MU, OMEGA, MAX_LOGITS,
    inverse_bandwidth,
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_mb, stride_mq, stride_md,
    stride_ob, stride_oq,
    stride_mlb, stride_mlq,
    B, SQ, SK, D,
    R_TILE_SIZE: tl.constexpr,
    C_TILE_SIZE: tl.constexpr,
    BLOCK_R: tl.constexpr,
    BLOCK_C: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    pid_b = tl.program_id(0)  # batch
    pid_r = tl.program_id(1)  # row tile
    row_offset = pid_r * R_TILE_SIZE

    row_idx = row_offset + tl.arange(0, BLOCK_R)
    row_mask = row_idx < SQ

    omega_acc = tl.zeros([BLOCK_R, 1], dtype=tl.float32)
    max_logits_acc = tl.full([BLOCK_R, 1], value=-float('inf'), dtype=tl.float32)
    mu_acc = tl.zeros([BLOCK_R, BLOCK_D], dtype=tl.float32)

    d_idx = tl.arange(0, BLOCK_D)
    q_ptrs = Q + pid_b * stride_qb + row_idx[:, None] * stride_qq + d_idx[None, :] * stride_qd
    q = tl.load(q_ptrs, mask=row_mask[:, None] & (d_idx[None, :] < D), other=0.0)

    num_col_tiles = tl.cdiv(SK, C_TILE_SIZE)
    for c_tile in range(num_col_tiles):
        col_offset = c_tile * C_TILE_SIZE
        col_idx = col_offset + tl.arange(0, BLOCK_C)
        col_mask = col_idx < SK
        causal_mask = row_idx[:, None] >= col_idx[None, :]
        mask = row_mask[:, None] & col_mask[None, :] & causal_mask

        k_ptrs = K + pid_b * stride_kb + col_idx[None, :] * stride_kk + d_idx[:, None] * stride_kd
        k = tl.load(k_ptrs, mask=col_mask[None, :] & (d_idx[:, None] < D), other=0.0)

        # logits = q @ k^T / bw
        logits = tl.dot(q, k, allow_tf32=True) * inverse_bandwidth

        logits_for_max = tl.where(mask, logits, -float('inf'))
        block_max = tl.max(logits_for_max, axis=1, keep_dims=True)

        new_max = tl.maximum(max_logits_acc, block_max)
        is_neginf = new_max == -float('inf')

        diff = tl.where(is_neginf, 0.0, max_logits_acc - new_max)
        alpha = tl.exp(diff)

        omega_acc = alpha * omega_acc
        mu_acc = alpha * mu_acc

        new_max_safe = tl.where(is_neginf, 0.0, new_max)
        logits_for_exp = tl.where(mask, logits, new_max_safe)
        w = tl.exp(logits_for_exp - new_max_safe)
        w = tl.where(mask, w, 0.0)

        omega_acc += tl.sum(w, axis=1, keep_dims=True)
        k_t = tl.trans(k)
        mu_acc += tl.dot(w, k_t, allow_tf32=True)

        max_logits_acc = new_max

    mu_ptrs = MU + pid_b * stride_mb + row_idx[:, None] * stride_mq + d_idx[None, :] * stride_md
    tl.store(mu_ptrs, mu_acc, mask=row_mask[:, None] & (d_idx[None, :] < D))

    omega_ptrs = OMEGA + pid_b * stride_ob + row_idx[:, None] * stride_oq
    tl.store(omega_ptrs, omega_acc, mask=row_mask[:, None])

    ml_ptrs = MAX_LOGITS + pid_b * stride_mlb + row_idx[:, None] * stride_mlq
    tl.store(ml_ptrs, max_logits_acc, mask=row_mask[:, None])

@triton.jit
def _cg_matvec_kernel(
    Q, K, P, SIGMA_P,
    OMEGA, MU, MAX_LOGITS,
    inverse_bandwidth, ridge_lambda,
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_pb, stride_pq, stride_pd,
    stride_spb, stride_spq, stride_spd,
    stride_ob, stride_oq,
    stride_mb, stride_mq, stride_md,
    stride_mlb, stride_mlq,
    B, SQ, SK, D,
    C_TILE_SIZE: tl.constexpr,
    BLOCK_C: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    pid = tl.program_id(0)
    pid_b = pid // SQ
    pid_q = pid % SQ

    d_idx = tl.arange(0, BLOCK_D)
    d_mask = d_idx < D

    q_ptrs = Q + pid_b * stride_qb + pid_q * stride_qq + d_idx * stride_qd
    p_ptrs = P + pid_b * stride_pb + pid_q * stride_pq + d_idx * stride_pd
    mu_ptrs = MU + pid_b * stride_mb + pid_q * stride_mq + d_idx * stride_md

    q = tl.load(q_ptrs, mask=d_mask, other=0.0)
    p = tl.load(p_ptrs, mask=d_mask, other=0.0)
    mu = tl.load(mu_ptrs, mask=d_mask, other=0.0)

    omega = tl.load(OMEGA + pid_b * stride_ob + pid_q * stride_oq)
    max_logit = tl.load(MAX_LOGITS + pid_b * stride_mlb + pid_q * stride_mlq)

    sigma_p = tl.zeros([BLOCK_D], dtype=tl.float32)

    num_col_tiles = tl.cdiv(SK, C_TILE_SIZE)
    for c_tile in range(num_col_tiles):
        col_offset = c_tile * C_TILE_SIZE
        col_idx = col_offset + tl.arange(0, BLOCK_C)
        col_mask = (col_idx < SK) & (col_idx <= pid_q)  # causal

        k_ptrs = K + pid_b * stride_kb + col_idx[None, :] * stride_kk + d_idx[:, None] * stride_kd
        k = tl.load(k_ptrs, mask=d_mask[:, None] & col_mask[None, :], other=0.0)

        # qk / bw
        qk = tl.sum(q[:, None] * k, axis=0) * inverse_bandwidth

        qk_for_exp = tl.where(col_mask, qk, max_logit)
        w = tl.exp(qk_for_exp - max_logit)
        w = tl.where(col_mask, w, 0.0)

        pk = tl.sum(p[:, None] * k, axis=0)
        wpk = w * pk
        sigma_p += tl.sum(k * wpk[None, :], axis=1)

    pq = tl.sum(p * q)
    pmu = tl.sum(p * mu)
    sigma_p = sigma_p - pq * mu - pmu * q + omega * pq * q
    sigma_p = sigma_p + ridge_lambda * p

    sp_ptrs = SIGMA_P + pid_b * stride_spb + pid_q * stride_spq + d_idx * stride_spd
    tl.store(sp_ptrs, sigma_p, mask=d_mask)

def cg_solver(q, k, r, omega, mu, max_logits, ridge_lambda, inverse_bandwidth,
                      max_iters=256, atol=1e-15, rtol=1e-15):
    b, sq, d = q.shape
    sk = k.shape[1]

    p = r.clone()
    x = torch.zeros_like(q)
    rr0 = (r * r).sum(dim=-1, keepdim=True)

    for _ in range(max_iters):
        sigma_p = torch.zeros_like(p)

        C_TILE_SIZE = min(64, int(triton.next_power_of_2(sk)))
        BLOCK_C = C_TILE_SIZE
        BLOCK_D = min(int(triton.next_power_of_2(d)), 128)

        grid = (b * sq,)

        _cg_matvec_kernel[grid](
            q, k, p, sigma_p,
            omega, mu, max_logits,
            inverse_bandwidth, ridge_lambda,
            q.stride(0), q.stride(1), q.stride(2),
            k.stride(0), k.stride(1), k.stride(2),
            p.stride(0), p.stride(1), p.stride(2),
            sigma_p.stride(0), sigma_p.stride(1), sigma_p.stride(2),
            omega.stride(0), omega.stride(1),
            mu.stride(0), mu.stride(1), mu.stride(2),
            max_logits.stride(0), max_logits.stride(1),
            b, sq, sk, d,
            C_TILE_SIZE=C_TILE_SIZE,
            BLOCK_C=BLOCK_C,
            BLOCK_D=BLOCK_D,
        )

        rr = (r * r).sum(dim=-1, keepdim=True)
        psp = (p * sigma_p).sum(dim=-1, keepdim=True)
        alpha = rr / (psp + 1e-10)

        x = x + alpha * p
        r_new = r - alpha * sigma_p
        rr_new = (r_new * r_new).sum(dim=-1, keepdim=True)

        converged = rr_new <= torch.maximum(rr0 * rtol, torch.tensor(atol, device=rr0.device))
        if converged.all():
            break

        beta = rr_new / (rr + 1e-10)
        p = r_new + beta * p
        r = r_new

    return x

@triton.jit
def _fwd_inner_kernel_pass_two(
    Q, K, V, RHO,
    O, DELTA,
    OMEGA, MU, MAX_LOGITS,
    inverse_bandwidth, delta_eps,
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_vb, stride_vk, stride_vd,
    stride_rb, stride_rq, stride_rd,
    stride_ob, stride_oq, stride_od,
    stride_db, stride_dq,
    stride_omega_b, stride_omega_q,
    stride_mb, stride_mq, stride_md,
    stride_mlb, stride_mlq,
    B, SQ, SK, D,
    R_TILE_SIZE: tl.constexpr,
    C_TILE_SIZE: tl.constexpr,
    BLOCK_R: tl.constexpr,
    BLOCK_C: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    pid_b = tl.program_id(0)
    pid_r = tl.program_id(1)
    row_offset = pid_r * R_TILE_SIZE

    row_idx = row_offset + tl.arange(0, BLOCK_R)
    row_mask = row_idx < SQ

    d_idx = tl.arange(0, BLOCK_D)
    d_mask = d_idx < D

    q_ptrs = Q + pid_b * stride_qb + row_idx[:, None] * stride_qq + d_idx[None, :] * stride_qd
    q = tl.load(q_ptrs, mask=row_mask[:, None] & d_mask[None, :], other=0.0)

    rho_ptrs = RHO + pid_b * stride_rb + row_idx[:, None] * stride_rq + d_idx[None, :] * stride_rd
    rho = tl.load(rho_ptrs, mask=row_mask[:, None] & d_mask[None, :], other=0.0)

    mu_ptrs = MU + pid_b * stride_mb + row_idx[:, None] * stride_mq + d_idx[None, :] * stride_md
    mu = tl.load(mu_ptrs, mask=row_mask[:, None] & d_mask[None, :], other=0.0)

    omega_ptrs = OMEGA + pid_b * stride_omega_b + row_idx[:, None] * stride_omega_q
    omega = tl.load(omega_ptrs, mask=row_mask[:, None], other=0.0)

    ml_ptrs = MAX_LOGITS + pid_b * stride_mlb + row_idx[:, None] * stride_mlq
    max_logits = tl.load(ml_ptrs, mask=row_mask[:, None], other=-float('inf'))

    mu_adj = mu - omega * q
    mu_rho = tl.sum(mu_adj * rho, axis=1, keep_dims=True)
    delta = omega - mu_rho
    delta = delta + delta_eps * tl.where(delta >= 0, 1.0, -1.0)

    o_acc = tl.zeros([BLOCK_R, BLOCK_D], dtype=tl.float32)

    num_col_tiles = tl.cdiv(SK, C_TILE_SIZE)
    for c_tile in range(num_col_tiles):
        col_offset = c_tile * C_TILE_SIZE
        col_idx = col_offset + tl.arange(0, BLOCK_C)
        col_mask = col_idx < SK

        causal_mask = row_idx[:, None] >= col_idx[None, :]
        mask = row_mask[:, None] & col_mask[None, :] & causal_mask

        k_ptrs = K + pid_b * stride_kb + col_idx[None, :] * stride_kk + d_idx[:, None] * stride_kd
        k = tl.load(k_ptrs, mask=col_mask[None, :] & d_mask[:, None], other=0.0)

        v_ptrs = V + pid_b * stride_vb + col_idx[None, :] * stride_vk + d_idx[:, None] * stride_vd
        v = tl.load(v_ptrs, mask=col_mask[None, :] & d_mask[:, None], other=0.0)

        logits = tl.dot(q, k, allow_tf32=True) * inverse_bandwidth

        is_neginf = (max_logits == -float('inf'))
        max_logits_safe = tl.where(is_neginf, 0.0, max_logits)
        logits_for_exp = tl.where(mask, logits, max_logits_safe)
        w = tl.exp(logits_for_exp - max_logits_safe)
        w = tl.where(mask, w, 0.0)

        rho_k = tl.dot(rho, k, allow_tf32=True)
        rho_q = tl.sum(rho * q, axis=1, keep_dims=True)

        s = (1.0 - rho_k + rho_q) * w / delta
        s = tl.where(mask, s, 0.0)

        v_t = tl.trans(v)
        o_acc += tl.dot(s, v_t, allow_tf32=True)

    o_ptrs = O + pid_b * stride_ob + row_idx[:, None] * stride_oq + d_idx[None, :] * stride_od
    tl.store(o_ptrs, o_acc, mask=row_mask[:, None] & d_mask[None, :])

    delta_ptrs = DELTA + pid_b * stride_db + row_idx[:, None] * stride_dq
    tl.store(delta_ptrs, delta, mask=row_mask[:, None])

def forward_tile_triton(
    q: torch.Tensor,  # (b, sq, d)
    k: torch.Tensor,  # (b, sk, d)
    v: torch.Tensor,  # (b, sk, d)
    inverse_bandwidth: float,
    r_tile_size: int = 64,
    c_tile_size: int = 64,
    ridge_lambda: float = 1,
    delta_eps: float = 1e-12,
    cg_max_iters: int = 20,
    cg_atol: float = 1e-15,
    cg_rtol: float = 1e-15,
):
    b, sq, d = q.shape
    sk = k.shape[1]

    o = torch.zeros(b, sq, d, device=q.device, dtype=q.dtype)
    rho = torch.zeros(b, sq, d, device=q.device, dtype=q.dtype)
    max_logits = torch.zeros(b, sq, 1, device=q.device, dtype=q.dtype)
    delta = torch.zeros(b, sq, 1, device=q.device, dtype=q.dtype)
    mu = torch.zeros(b, sq, d, device=q.device, dtype=q.dtype)
    omega = torch.zeros(b, sq, 1, device=q.device, dtype=q.dtype)

    ROW_BLOCK = min(r_tile_size, sq)
    br = math.ceil(sq / ROW_BLOCK)

    C_TILE = min(c_tile_size, 64)
    BLOCK_R = ROW_BLOCK
    BLOCK_C = C_TILE
    BLOCK_D = min(int(triton.next_power_of_2(d)), 128)

    # Phase 1
    grid = (b, br)
    _fwd_inner_kernel_pass_one[grid](
        q, k,
        mu, omega, max_logits,
        inverse_bandwidth,
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        mu.stride(0), mu.stride(1), mu.stride(2),
        omega.stride(0), omega.stride(1),
        max_logits.stride(0), max_logits.stride(1),
        b, sq, sk, d,
        R_TILE_SIZE=ROW_BLOCK,
        C_TILE_SIZE=C_TILE,
        BLOCK_R=BLOCK_R,
        BLOCK_C=BLOCK_C,
        BLOCK_D=BLOCK_D,
    )

    residual = mu - omega * q
    rho = cg_solver(
        q, k, residual, omega, mu, max_logits,
        ridge_lambda, inverse_bandwidth,
        max_iters=cg_max_iters, atol=cg_atol, rtol=cg_rtol
    )

    _fwd_inner_kernel_pass_two[grid](
        q, k, v, rho,
        o, delta,
        omega, mu, max_logits,
        inverse_bandwidth, delta_eps,
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        rho.stride(0), rho.stride(1), rho.stride(2),
        o.stride(0), o.stride(1), o.stride(2),
        delta.stride(0), delta.stride(1),
        omega.stride(0), omega.stride(1),
        mu.stride(0), mu.stride(1), mu.stride(2),
        max_logits.stride(0), max_logits.stride(1),
        b, sq, sk, d,
        R_TILE_SIZE=ROW_BLOCK,
        C_TILE_SIZE=C_TILE,
        BLOCK_R=BLOCK_R,
        BLOCK_C=BLOCK_C,
        BLOCK_D=BLOCK_D,
    )

    return o, rho, max_logits, delta

if __name__ == "__main__":
    device = torch.device("cuda")
    b = 2
    sq = 256
    d = 64
    q = torch.randn(b, sq, d, device=device, dtype=torch.float32)
    k = torch.randn(b, sq, d, device=device, dtype=torch.float32)
    v = torch.randn(b, sq, d, device=device, dtype=torch.float32)

    eps = 1e-12
    q = torch.nn.functional.normalize(q, p=2, dim=-1, eps=eps)
    k = torch.nn.functional.normalize(k, p=2, dim=-1, eps=eps)
    v = torch.nn.functional.normalize(v, p=2, dim=-1, eps=eps)
    inverse_bandwidth = 1.0 / math.sqrt(d)

    o_ref, rho_ref, max_logits_ref, delta_ref = forward_torch(
        q, k, v, inverse_bandwidth
    )
    o_triton, rho_triton, max_logits_triton, delta_triton = forward_tile_triton(
        q, k, v, inverse_bandwidth
    )
    print("Output diff:", torch.max(torch.abs(o_ref - o_triton)).item())
    print("Rho diff:", torch.max(torch.abs(rho_ref - rho_triton)).item())
    print("Max logits diff:", torch.max(torch.abs(max_logits_ref - max_logits_triton)).item())
    print("Delta diff:", torch.max(torch.abs(delta_ref - delta_triton)).item())