import torch
import triton
from triton import language as tl

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_float32_matmul_precision('highest')


@triton.jit
def tl_pow(x, a):
    return (x.abs().log() * a).exp()  # TODO: triton does not have x.pow(a) or x ** a?


@triton.jit
def tl_round(x):
    return (x + .5).floor()  # TODO: triton does not have round()? We might want to change to round to even number here.


@triton.jit
def tl_round_fp(x, dtype):
    return x if dtype is None else x.cast(dtype, fp_downcast_rounding='rtne').cast(x.dtype)


@triton.jit
def tl_quantize(x, scale, qzero, maxq=None):
    x = tl_round(x / scale + qzero)
    return x if maxq is None else tl.clamp(x, 0., maxq)


@triton.jit
def tl_dequantize(qx, scale, qzero, dtype):
    return tl_round_fp((qx - qzero) * scale, dtype)


@triton.jit
def tl_dequantize_quantized(x, scale, qzero, maxq, dtype):
    return tl_dequantize(tl_quantize(x, scale, qzero, maxq), scale, qzero, dtype)


def round_fp(x: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
    return x if dtype is None else x.to(dtype=dtype).to(x.dtype)


def quantize(x: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor = None) -> torch.Tensor:
    x: torch.Tensor = (x / scale + qzero).round()
    return x if maxq is None else x.clamp(torch.zeros_like(maxq), maxq)


def dequantize(qx: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
    return round_fp((qx - qzero) * scale, dtype)


def dequantize_quantized(x: torch.Tensor, scale: torch.Tensor, qzero: torch.Tensor, maxq: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
    return dequantize(quantize(x, scale, qzero, maxq), scale, qzero, dtype)


def find_quantization_meta(
        x: torch.Tensor,
        bit_width: int | str,
        symmetric: bool = False,
        dtype: torch.dtype = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Find quantization metadata over dim=-1
    x: (..., C), weight
    bit_width: int
    symmetric: bool, whether to set qzero to the middle
    dtype: torch.dtype, target scale dtype, fp16 or bf16
    """
    x_dtype, device = x.dtype, x.device
    epsilon: float = 1e-12
    if bit_width == 'ternary':
        maxq: torch.Tensor = torch.tensor(2, dtype=x_dtype, device=device)  # ()
        scale: torch.Tensor = round_fp(x.abs().amax(dim=-1) + epsilon, dtype)  # (...)
        qzero: torch.Tensor = torch.ones_like(scale)  # ()
    else:
        maxq: torch.Tensor = torch.tensor(2 ** bit_width - 1, dtype=x_dtype, device=device)  # ()
        if symmetric:
            # scale: torch.Tensor = x.abs().amax(dim=-1) * (2. / maxq) + epsilon  # (...)
            scale_pos: torch.Tensor = x.amax(dim=-1) * (2. / (maxq - 1.))  # (...)
            scale_neg: torch.Tensor = -x.amin(dim=-1) * (2. / (maxq + 1.))  # (...)
            scale: torch.Tensor = round_fp(torch.maximum(scale_pos, scale_neg) + epsilon, dtype)  # (...)
            qzero: torch.Tensor = torch.full_like(scale, ((maxq + 1.) * .5).item())  # (...)
        else:
            x_max: torch.Tensor = x.amax(dim=-1).relu()  # (...)
            x_min: torch.Tensor = -(-x.amin(dim=-1)).relu()  # (...)
            scale: torch.Tensor = round_fp((x_max - x_min) / maxq + epsilon, dtype)  # (...)
            qzero: torch.Tensor = (-x_min / scale).round()  # (...)
    return scale, qzero, maxq


@triton.jit
def mse_scale_triton_kernel(
        x_ptr,
        p_ptr,
        scale_ptr,
        qzero_ptr,
        maxq_ptr,
        dtype_ptr,
        norm: float,
        p_size: int,
        group_size: int,
        batch_size: int,
        BLOCK_SIZE_P: tl.constexpr,
        BLOCK_SIZE_G: tl.constexpr,
        BLOCK_SIZE_B: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    b_offsets = pid * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)  # (R)
    b_mask = b_offsets < batch_size  # (R)
    x_offsets = b_offsets[:, None] * group_size + tl.arange(0, BLOCK_SIZE_G)  # (R, C)
    x_mask = b_mask[:, None] & (tl.arange(0, BLOCK_SIZE_G) < group_size)  # (R, C)
    p_offsets = tl.arange(0, BLOCK_SIZE_P)  # (P)
    p_mask = p_offsets < p_size  # (P)
    scale_ptrs = scale_ptr + b_offsets  # (R)

    x = tl.load(x_ptr + x_offsets, mask=x_mask)[:, None, :]  # (R, 1, C)
    p = tl.load(p_ptr + p_offsets, mask=p_mask)  # (P)
    scale = tl.load(scale_ptrs, mask=b_mask)  # (R)
    qzero = tl.load(qzero_ptr + b_offsets, mask=b_mask)[:, None, None]  # (R, 1, 1)
    maxq = tl.load(maxq_ptr)  # ()
    dtype = None if dtype_ptr is None else tl.load(dtype_ptr).dtype

    scale_p = tl_round_fp(scale[:, None] * p, dtype)[:, :, None]  # (R, P, 1)
    q = tl_dequantize_quantized(x, scale_p, qzero, maxq, dtype)  # (R, P, C)
    best_idx = tl.argmin(tl.sum(tl_pow(q - x, norm), axis=-1), axis=-1, tie_break_left=False)  # (R)

    scale = tl_round_fp(scale * tl.load(p_ptr + best_idx), dtype)  # (R)  # TODO: replace with tl.gather()
    tl.store(scale_ptrs, scale, mask=b_mask)  # (R)


def mse_scale(
        x: torch.Tensor,
        p: torch.Tensor,
        scale: torch.Tensor,
        qzero: torch.Tensor,
        maxq: torch.Tensor,
        dtype: torch.dtype = None,
        norm: float = 2.,
        debug_mode: bool = False,
) -> torch.Tensor:
    """
    Find the optimal scale for quantization with respect to the MSE loss
    x: (..., C), weight
    p: (P), shrinkage factors
    scale: (...), initial scale, modified in-place and returned
    qzero: (...), zero points
    maxq: ()
    dtype: torch.dtype, target scale dtype, fp16 or bf16
    norm: float, norm for the loss
    debug_mode: bool, whether to use the baseline implementation without Triton
    """
    if debug_mode:
        return mse_scale_baseline(x, p, scale, qzero, maxq, dtype, norm)

    assert x.is_contiguous() and p.is_contiguous() and scale.is_contiguous() and qzero.is_contiguous() and maxq.is_contiguous()
    batch_size: int = torch.tensor(x.shape[:-1]).prod().item()
    previous_device: torch.device = torch.device(f'cuda:{torch.cuda.current_device()}')
    torch.cuda.set_device(x.device)
    grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE_B']), )
    mse_scale_triton_kernel[grid](
        x,
        p,
        scale,
        qzero,
        maxq,
        torch.empty(0, dtype=dtype) if dtype is not None else None,
        norm,
        p.size(-1),
        x.size(-1),
        batch_size,
        BLOCK_SIZE_P=torch.tensor(p.size(-1)).log2().ceil().exp2().int().item(),
        BLOCK_SIZE_G=torch.tensor(x.size(-1)).log2().ceil().exp2().int().item(),
        BLOCK_SIZE_B=1,
    )
    torch.cuda.set_device(previous_device)
    return scale


def mse_scale_baseline(
        x: torch.Tensor,
        p: torch.Tensor,
        scale: torch.Tensor,
        qzero: torch.Tensor,
        maxq: torch.Tensor,
        dtype: torch.dtype = None,
        norm: float = 2.,
) -> torch.Tensor:
    scale_p: torch.Tensor = round_fp(scale[..., None] * p, dtype)  # (..., P)
    y: torch.Tensor = dequantize_quantized(x[..., None, :], scale_p[..., None], qzero[..., None, None], maxq, dtype)  # (..., P, C)
    best_idx: torch.Tensor = (y - x[..., None, :]).abs().pow(norm).sum(dim=-1).argmin(dim=-1, keepdim=True)  # (..., 1)
    scale_: torch.Tensor = scale_p.take_along_dim(best_idx, dim=-1)[..., 0]  # (...)
    scale.copy_(scale_)
    return scale


def _get_random_inputs(batch_size: int = 512, group_size: int = 128, seed: int = 0) -> tuple:
    torch.manual_seed(seed)
    device: torch.device = torch.device('cuda')
    weight = torch.randn(batch_size, group_size, dtype=torch.float32, device=device)
    dtype: torch.dtype = torch.bfloat16
    scale, qzero, maxq = find_quantization_meta(x=weight, bit_width=4, symmetric=True, dtype=dtype)
    max_shrink = .8
    n_grid = 100
    p = 1. - torch.linspace(0., max_shrink, n_grid, dtype=torch.float32, device=device)  # (P)
    norm = 2.4
    return weight, p, scale, qzero, maxq, dtype, norm


def _unit_test() -> None:
    weight, p, scale_0, qzero, maxq, dtype, norm = _get_random_inputs()
    err_0 = dequantize_quantized(weight, scale_0[..., None], qzero[..., None], maxq, dtype) - weight

    scale_1 = mse_scale_baseline(weight, p, scale_0.clone(), qzero, maxq, dtype, norm)
    err_1 = dequantize_quantized(weight, scale_1[..., None], qzero[..., None], maxq, dtype) - weight

    scale_2 = mse_scale(weight, p, scale_0.clone(), qzero, maxq, dtype, norm)
    err_2 = dequantize_quantized(weight, scale_2[..., None], qzero[..., None], maxq, dtype) - weight
    diff = scale_2 - scale_1

    print(err_0.abs().pow(norm).mean().item(), err_1.abs().pow(norm).mean().item(), err_2.abs().pow(norm).mean().item())
    print(diff.abs().max().item())
    print()


def _benchmark() -> None:
    quantiles = [.5, .2, .8]

    ms, min_ms, max_ms = triton.testing.do_bench(lambda: mse_scale_baseline(*_get_random_inputs()), quantiles=quantiles)
    print(ms, min_ms, max_ms)

    ms, min_ms, max_ms = triton.testing.do_bench(lambda: mse_scale(*_get_random_inputs()), quantiles=quantiles)
    print(ms, min_ms, max_ms)


if __name__ == '__main__':
    _unit_test()
    _benchmark()
