import triton
import triton.language as tl
import random
import numpy as np
import torch
import math, time



def quant_and_pack_kcache(k: torch.FloatTensor, group_size: int, bits: int):
    assert len(k.shape) == 4
    shape = k.shape
    B, nh, T, D = shape
    # ================== Get Scale & Zeros ===============
    assert T % group_size == 0
    num_groups = T // group_size
    new_shape = (B, nh, num_groups, group_size, D)
    # Quantize
    max_int = 2**bits - 1
    data = k.view(new_shape)
    mn = torch.min(data, dim=-2, keepdim=True)[0]
    mx = torch.max(data, dim=-2, keepdim=True)[0]
    scale = (mx - mn) / max_int
    data = data - mn
    data.div_(scale)
    data = data.clamp_(0, max_int).round_().to(torch.int32)
    data = data.view(shape)
    code = pack_tensor(data, bits, pack_dim=2)
    return code, scale, mn

def quant_and_pack_vcache(v: torch.FloatTensor, group_size: int, bits: int):
    shape = v.shape
    assert len(shape) == 4
    assert v.shape[-1] % group_size == 0
    num_groups = shape[-1] // group_size
    new_shape = shape[:-1] + (num_groups, group_size)
    # Quantize
    max_int = 2**bits - 1
    data = v.view(new_shape)
    mn = torch.min(data, dim=-1, keepdim=True)[0]
    mx = torch.max(data, dim=-1, keepdim=True)[0]
    scale = (mx - mn) / max_int
    data = data - mn
    data.div_(scale)
    data = data.clamp_(0, max_int).round_().to(torch.int32)
    data = data.view(shape)
    # Pack
    code = pack_tensor(data, bits, pack_dim=3)
    return code, scale, mn


def unpack_and_dequant_kcache(
    k_code: torch.FloatTensor,
    scale: torch.FloatTensor,
    mn: torch.FloatTensor,
    group_size: int,
    bits: int,
):
    pack_dim = 2
    assert bits in [2, 4, 8]
    assert len(k_code.shape) == 4
    data = unpack_tensor(k_code, bits, pack_dim=pack_dim)
    shape = data.shape
    num_groups = shape[pack_dim] // group_size
    data = data.view(
        shape[:pack_dim]
        + (
            num_groups,
            group_size,
        )
        + shape[pack_dim + 1 :]
    )
    data = data.to(torch.float16)
    data = data * scale + mn
    return data.view(shape)


def unpack_and_dequant_vcache(
    v_code: torch.FloatTensor,
    scale: torch.FloatTensor,
    mn: torch.FloatTensor,
    group_size: int,
    bits: int,
):
    assert bits in [2, 4, 8]
    assert len(v_code.shape) == 4
    data = unpack_tensor(v_code, bits, pack_dim=3)
    shape = data.shape
    num_groups = shape[-1] // group_size
    data = data.view(
        shape[:-1]
        + (
            num_groups,
            group_size,
        )
    )
    data = data.to(torch.float16)
    data = data * scale + mn
    return data.view(shape)


def pack_tensor(data, bits, pack_dim):
    # Pack
    shape = data.shape
    feat_per_int = 32 // bits
    assert bits in [2, 4, 8], "Only 2, 4, 8 bits are supported"
    assert (
        shape[pack_dim] % feat_per_int == 0
    ), "Dimension length must be divisible by number of features per int"
    # BS, nh, T, nd // 16 # 16 is for 2bit
    code = torch.zeros(
        shape[:pack_dim] + (shape[pack_dim] // feat_per_int,) + shape[pack_dim + 1 :],
        dtype=torch.int32,
        device=data.device,
    )
    i = 0
    row = 0
    unpacked_indices = [slice(None)] * len(data.shape)
    packed_indices = [slice(None)] * len(data.shape)
    while row < code.shape[pack_dim]:
        packed_indices[pack_dim] = row
        for j in range(i, i + (32 // bits)):
            unpacked_indices[pack_dim] = j
            code[packed_indices] |= data[unpacked_indices] << (bits * (j - i))
        i += 32 // bits
        row += 1
    return code


def unpack_tensor(v_code: torch.FloatTensor, bits: int, pack_dim: int):
    assert bits in [2, 4, 8]
    shape = v_code.shape
    feat_per_int = 32 // bits
    new_shape = (
        shape[:pack_dim] + (shape[pack_dim] * feat_per_int,) + shape[pack_dim + 1 :]
    )
    unpacked_v_code = torch.zeros(new_shape, dtype=torch.int8, device=v_code.device)
    i = torch.arange(new_shape[pack_dim], device=v_code.device) // feat_per_int
    j = torch.arange(new_shape[pack_dim], device=v_code.device) % feat_per_int
    num = 0xFF >> (8 - bits)
    packed_indices = [slice(None)] * len(new_shape)
    packed_indices[pack_dim] = i
    if pack_dim == 2:
        unpacked_v_code = (
            (v_code[packed_indices] >> (j * bits)[None, None, :, None]).to(torch.int16)
        ) & num
    elif pack_dim == 3:
        unpacked_v_code = ((v_code[packed_indices] >> (j * bits)).to(torch.int16)) & num
    else:
        raise NotImplementedError
    return unpacked_v_code

def fake_quant_channel_wise(k: torch.FloatTensor, group_size: int, bits: int):
    assert len(k.shape) == 4
    shape = k.shape
    B, nh, T, D = shape
    # ================== Get Scale & Zeros ===============
    assert T % group_size == 0
    num_groups = T // group_size
    new_shape = (B, nh, num_groups, group_size, D)
    # Quantize
    max_int = 2**bits - 1
    data = k.view(new_shape)
    mn = torch.min(data, dim=-2, keepdim=True)[0]
    mx = torch.max(data, dim=-2, keepdim=True)[0]
    scale = (mx - mn) / max_int
    scale = torch.clamp(scale, min=1e-8)
    data = data - mn
    data.div_(scale)
    data = data.clamp_(0, max_int).round_()
    data = data * scale + mn
    data = data.view(shape)
    

    return data

def fake_quant_token_wise(v: torch.FloatTensor, group_size: int, bits: int):
    shape = v.shape
    # assert len(shape) == 4
    assert v.shape[-1] % group_size == 0
    num_groups = shape[-1] // group_size
    new_shape = shape[:-1] + (num_groups, group_size)
    # Quantize
    max_int = 2**bits - 1
    data = v.view(new_shape)
    mn = torch.min(data, dim=-1, keepdim=True)[0]
    mx = torch.max(data, dim=-1, keepdim=True)[0]
    scale = (mx - mn) / max_int
    scale = torch.clamp(scale, min=1e-8)
    data = data - mn
    data.div_(scale)
    data = data.clamp_(0, max_int).round_()
    data = data * scale + mn
    data = data.view(shape)
    

    return data

def quant_and_pack_ref(v: torch.FloatTensor, group_size: int, bits: int):
    shape = v.shape
    assert len(shape) == 4
    assert v.shape[-1] % group_size == 0
    num_groups = shape[-1] // group_size
    new_shape = shape[:-1] + (num_groups, group_size)
    # Quantize
    max_int = 2**bits - 1
    data = v.view(new_shape)
    mn = torch.min(data, dim=-1, keepdim=True)[0]
    mx = torch.max(data, dim=-1, keepdim=True)[0]
    scale = (mx - mn) / max_int
    data = data - mn
    data.div_(scale)
    data = data.clamp_(0, max_int).round_().to(torch.int32)
    data = data.view(shape)
    # Pack
    code = pack_tensor(data, bits, pack_dim=3)
    return code, scale[..., 0], mn[..., 0]

@triton.jit
def _pack_along_last_dim(
    bits: tl.constexpr,
    intensor_ptr,              # expects int32 buffer [N, num_feats]
    code_ptr,                  # int32 buffer [N, num_feats // feat_per_int]
    N: tl.int32,
    num_feats: tl.int32,       # <-- runtime
    feat_per_int: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    num_int_per_y_dim = num_feats // feat_per_int
    bid = tl.program_id(0)
    yid = tl.program_id(1)

    offs_N = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    block_start = intensor_ptr + offs_N * num_feats + yid * feat_per_int
    packed = tl.zeros((BLOCK_SIZE_N,), dtype=tl.int32)

    for i in range(feat_per_int):
        ptr = block_start + i
        element = tl.load(ptr, mask=offs_N < N, other=0)      # int other
        # assure int32 math (paranoia): element = element.to(tl.int32)
        element = element << (i * bits)
        packed = packed | element

    tl.store(code_ptr + offs_N * num_int_per_y_dim + yid, packed, mask=offs_N < N)


@triton.jit
def _minmax_along_last_dim(
    x_ptr,
    mn_ptr, mx_ptr,
    total_elements: tl.int32,  # <-- runtime
    N: tl.int32,               # rows = B*nh*D
    num_groups: tl.int32,
    group_size: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    bid = tl.program_id(0)
    offsets_b = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offsets = offsets_b[:, None] * group_size + tl.arange(0, group_size)[None, :]
    mask = offsets < total_elements
    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    mx_val = tl.max(x, axis=1)
    mn_val = tl.min(x, axis=1)
    tl.store(mn_ptr + offsets_b, mn_val, mask=offsets_b < N * num_groups)
    tl.store(mx_ptr + offsets_b, mx_val, mask=offsets_b < N * num_groups)


def triton_quantize_and_pack_along_last_dim(data: torch.Tensor, group_size: int, bit: int):
    assert len(data.shape) == 4
    shape = data.shape
    B, nh, D, T = shape
    # ================== Get Scale & Zeros ===============
    assert T % group_size == 0
    num_groups = T // group_size
    new_shape = (B * nh * D, num_groups, group_size)
    scale_mn_shape = B, nh, D, num_groups
    # Quantize
    data = data.reshape(new_shape)
    mx = torch.empty((B * nh * D, num_groups), device=data.device, dtype=data.dtype)
    mn = torch.empty((B * nh * D, num_groups), device=data.device, dtype=data.dtype)
    BLOCK_SIZE_N = 128
    grid = lambda meta: (triton.cdiv(data.shape[0]*data.shape[1], BLOCK_SIZE_N),)
    with torch.cuda.device(data.device):
        _minmax_along_last_dim[grid](data, mn, mx,
                                data.numel(), data.shape[0], num_groups, group_size,
                                BLOCK_SIZE_N=BLOCK_SIZE_N, num_warps=8) 
    # mn = torch.min(data, dim=-1, keepdim=True)[0].squeeze(-1)
    # mx = torch.max(data, dim=-1, keepdim=True)[0].squeeze(-1)
    scale = (mx - mn) / (2 ** bit - 1)
    data_q = data - mn.unsqueeze(-1)
    data_q.div_(scale.unsqueeze(-1))
    data_q = data_q.clamp_(0, 2 ** bit - 1).round_().to(torch.int32)
    
    # get quant error and dequant tensor
    data_dq = (data_q.to(data.dtype) * scale.unsqueeze(-1) + mn.unsqueeze(-1))
    error = data - data_dq
    error = error.view(shape)
    data_dq = data_dq.view(shape)

    data_q = data_q.view(-1, T)
    feat_per_int = 32 // bit
    packshape = (np.prod(shape[:-1]), shape[-1] // feat_per_int,)
    code = torch.zeros(*packshape, device=data_q.device, dtype=torch.int32)
    grid = lambda meta: (triton.cdiv(data_q.shape[0], BLOCK_SIZE_N), data_q.shape[1] // feat_per_int,)
    with torch.cuda.device(data_q.device):
        _pack_along_last_dim[grid](bit, data_q, code, data.shape[0], 
                                data_q.shape[1], feat_per_int, 
                                BLOCK_SIZE_N=BLOCK_SIZE_N, 
                                num_warps=8)
    return code.view(B, nh, D, -1), scale.reshape(scale_mn_shape), mn.reshape(scale_mn_shape), error, data_dq



# ------------------------------------------------------------
# Main test runner
# ------------------------------------------------------------

@torch.no_grad()
def run_one_case_triton_qpack(
    triton_quantize_and_pack_along_last_dim,
    *,
    B=2, nh=4, D=8, T=256,
    group_size=32,
    bit=2,
    dtype=torch.float32,
    device="cuda",
    seed=0,
    do_bench=True,
):
    """
    - Ensures T % group_size == 0 and T % (32//bit) == 0 (so packing has no tail).
    - Validates:
        1) unpack(qpack) == CPU requant(x | mn,scale)  [bit-exact]
        2) mn/max consistency vs implied max
        3) |x - dequant(q)| <= 0.5*scale (per-group) up to tiny eps
    """
    torch.manual_seed(seed)
    assert torch.cuda.is_available()
    feat_per_int = 32 // bit
    assert T % group_size == 0, "T must be divisible by group_size"
    assert T % feat_per_int == 0, "T must be divisible by 32//bit (packing)"
    x = torch.randn(B, nh, D, T, device=device, dtype=dtype)

    # Add a bit of structure to stress min/max and packing:
    #   - scale amplitude per (B,nh,D)
    amp = torch.randn(B, nh, D, 1, device=device, dtype=dtype).abs() * 3.0 + 0.5
    bias = torch.randn(B, nh, D, 1, device=device, dtype=dtype) * 0.1
    x = x * amp + bias

    # Quantize + pack with your Triton function
    torch.cuda.synchronize()
    t0 = time.time()
    code, scale, mn, _, __ = triton_quantize_and_pack_along_last_dim(x.clone(), group_size, bit)
    torch.cuda.synchronize()
    t_kernel = (time.time() - t0) * 1000

    # Shapes & dtypes
    assert code.dtype == torch.int32
    assert scale.shape == (B, nh, D, T // group_size)
    assert mn.shape == scale.shape

    # CPU unpack
    # qints = cpu_unpack_ints(code, bit)  # [B,nh,D,T]
    # CPU requant using returned params (must match exactly)
    # qints_ref = cpu_requantize_from_params(x, scale, mn, bit, group_size)
    code_ref, scale_ref, mn_ref = quant_and_pack_ref(x.clone(), group_size, bit)
    if not torch.equal(code, code_ref):
        mismatch = (code != code_ref).sum().item()
        raise AssertionError(f"Packed ints mismatch re-quantized ints: {mismatch} elements differ")
    # Allow tiny tolerance due to float ops (use dtype-aware eps)
    tol = 1e-6 if dtype == torch.float32 else 1e-3
    if not torch.allclose(mn, mn_ref, rtol=0, atol=tol):
        err = (mn - mn_ref).abs().max().item()
        raise AssertionError(f"mn mismatch vs CPU: max |Δ|={err:.3e}")
    if not torch.allclose(scale, scale_ref, rtol=0, atol=tol):
        err = (scale - scale_ref).abs().max().item()
        raise AssertionError(f"mx mismatch vs CPU (via implied): max |Δ|={err:.3e}")

    print(f"[OK] B={B} nh={nh} D={D} T={T} bit={bit} G={group_size} dtype={dtype}  "
          f"pack-kernel={t_kernel:.2f} ms")

    # Optional quick bench (repeat kernel)
    if do_bench:
        iters = 50
        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(iters):
            triton_quantize_and_pack_along_last_dim(x, group_size, bit)
        torch.cuda.synchronize()
        ms = (time.time() - t0) * 1000 / iters
        print(f"     avg pack time over {iters} iters: {ms:.3f} ms\n")

# ------------------------------------------------------------
# Example battery
# ------------------------------------------------------------
if __name__ == "__main__":
    # import your kernels & function before running this
    # from your_module import triton_quantize_and_pack_along_last_dim

    # Small sanity
    run_one_case_triton_qpack(triton_quantize_and_pack_along_last_dim,
        B=1, nh=1, D=2, T=512, group_size=128, bit=2, dtype=torch.float32, seed=0, do_bench=False)

    # Typical KV-like sizes (ensure T % (32//bit) == 0):
    run_one_case_triton_qpack(triton_quantize_and_pack_along_last_dim,
        B=512, nh=32, D=64, T=1024, group_size=128, bit=2, dtype=torch.bfloat16, seed=1)

    # Bit variants
    run_one_case_triton_qpack(triton_quantize_and_pack_along_last_dim,
        B=2, nh=8, D=32, T=1024, group_size=32, bit=2, dtype=torch.float32, seed=2)

    run_one_case_triton_qpack(triton_quantize_and_pack_along_last_dim,
        B=1, nh=32, D=128, T=8192, group_size=64, bit=2, dtype=torch.float32, seed=3)

    # bf16 path (if your _minmax kernel supports bf16 math)
    run_one_case_triton_qpack(triton_quantize_and_pack_along_last_dim,
        B=1, nh=8, D=64, T=1024, group_size=32, bit=2, dtype=torch.bfloat16, seed=4, do_bench=False)
