import math
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice


@triton.jit
def get_causal_mask(
    row_offset,
    col_offset,
    ROW_TILE_SIZE: tl.constexpr,
    COL_TILE_SIZE: tl.constexpr,
):
    row_indices = row_offset + tl.arange(0, ROW_TILE_SIZE)
    col_indices = col_offset + tl.arange(0, COL_TILE_SIZE)
    mask = row_indices[:, None] >= col_indices[None, :]
    return mask


@triton.jit
def relmm(
    X,
    Qr,
    Kc,
):
    XK = tl.dot(X.to(tl.bfloat16), tl.trans(Kc), out_dtype=tl.float32)
    XQ = tl.sum((X.to(tl.float32) * Qr.to(tl.float32)).to(tl.float32), axis=1, keep_dims=True)
    return XK - XQ


@triton.jit
def update_active_mask(
    norm0,
    norm,
    cg_atol,
    cg_rtol,
):
    tol_ref = tl.maximum(cg_atol, norm0)
    converged = norm <= (tol_ref * cg_rtol)
    invalid = libdevice.isnan(norm) | libdevice.isinf(norm) # type: ignore
    finished = converged | invalid
    return ~finished

