import torch
from typing import Tuple


def relmm(
    X: torch.Tensor,
    Q: torch.Tensor,
    K: torch.Tensor,
) -> torch.Tensor:
    XK = X @ K.transpose(-1, -2)
    XQ = (X * Q).sum(dim=-1, keepdim=True)
    return XK - XQ


def lla_naive(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    qk_scale: float,
    ridge_lambda: float | torch.Tensor,
    delta_eps: float,
    lla_block_size: int, # lla_block_size=head_dim recovers full LLA
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    sq, sk = q.shape[-2], k.shape[-2]
    batch_size, _, dim = q.shape
    assert dim % lla_block_size == 0 and lla_block_size <= dim, f"Dimension must be divisible by block size; Got dim={dim} and block size={lla_block_size}"
    assert q.dtype == k.dtype == v.dtype, f"Dtypes must match; Got q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}"
    assert q.dtype in [torch.float32, torch.float64], f"Only float32 and float64 are supported; Got q.dtype={q.dtype}"
    row_offset = torch.arange(sq, device=q.device).view(-1, 1)
    col_offset = torch.arange(sk, device=k.device).view(1, -1)
    attention_mask = row_offset >= col_offset
    qk = (q @ k.transpose(-1, -2)) * qk_scale
    qk = qk.masked_fill(~attention_mask, float("-inf"))
    m = qk.max(dim=-1, keepdim=True).values
    weight = torch.exp(qk - m)
    omega = weight.sum(dim=-1, keepdim=True)
    tilde_mu = torch.einsum("bij,bjd->bid", weight, k)
    mu = tilde_mu - omega * q
    q = q.contiguous().view(batch_size, sq, dim//lla_block_size, lla_block_size)
    k = k.contiguous().view(batch_size, sk, dim//lla_block_size, lla_block_size)
    mu = mu.contiguous().view(batch_size, sq, dim//lla_block_size, lla_block_size)
    tilde_mu = tilde_mu.contiguous().view(batch_size, sq, dim//lla_block_size, lla_block_size)
    tilde_sigma = torch.einsum("bij,bjtd,bjte->bitde", weight, k, k)
    A = torch.einsum("bitd,bite->bitde", tilde_mu, q)
    B = torch.einsum("bitd,bite->bitde", q, tilde_mu)
    C = omega.unsqueeze(-1).unsqueeze(-1) * torch.einsum("bitd,bite->bitde", q, q)
    sigma = tilde_sigma - A - B + C
    eye = torch.eye(lla_block_size, device=q.device, dtype=q.dtype).view(1, 1, lla_block_size, lla_block_size)

    if isinstance(ridge_lambda, torch.Tensor):
        if ridge_lambda.ndim == 1:
            ridge_lambda = ridge_lambda.expand(batch_size, sq).unsqueeze(-1)

    sigma = sigma + eye * (ridge_lambda * omega).unsqueeze(-1).unsqueeze(-1)
    rho = torch.linalg.solve(sigma, mu)
    rho = rho.contiguous().view(batch_size, sq, dim)
    q = q.contiguous().view(batch_size, sq, dim)
    k = k.contiguous().view(batch_size, sk, dim)
    mu = mu.contiguous().view(batch_size, sq, dim)
    denom = omega - (mu * rho).sum(dim=-1, keepdim=True)
    denom = denom + delta_eps * torch.sign(denom)
    numer = 1 - torch.einsum("bid,bjd->bij", rho, k) + (rho * q).sum(dim=-1, keepdim=True)
    S = weight * numer / denom
    out = torch.einsum("bij,bjd->bid", S, v)
    return out, rho, denom, m
