import torch
import triton
import triton.language as tl
from prng import noise_4bit_unpack_fp, rand_4bit_packed


def get_configs(force_m1=False):
    configs = []
    m_candidates = [1, 2, 4]
    if force_m1:
        m_candidates = [1]
    for num_stages in [2, 3, 4, 5]:
        for num_warps in [2]:
            for m in m_candidates:
                for n in [1, 2, 4]:
                    configs.append(
                        triton.Config(
                            {
                                "BM": m,
                                "BN": n,
                            },
                            num_stages=num_stages,
                            num_warps=num_warps,
                        )
                    )
    return configs


@triton.autotune(
    configs=get_configs(force_m1=True),
    key=["M", "N", "block_size"],
)
@triton.jit
def nadd_fwd_kernel(
    # pointers
    g_x,
    g_x_b,
    g_noise,
    g_out,
    g_x_a,
    # pointer info
    M: tl.constexpr,
    N: tl.constexpr,
    block_size: tl.constexpr,
    # scalar param
    repeat_b: tl.constexpr,
    is_uniform: tl.constexpr,
    # hyperparam
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    """
    out = x + rn * max(|x|, blockwise) * 2**(1-bit)

    x in (M, N)
    rn in (M, N)
    alpha in (cdiv(M, block_size), cdiv(N, block_size))
    bit in (cdiv(M, block_size), cdiv(N, block_size))
    """
    pid = tl.program_id(axis=0)
    pid_m = pid // tl.cdiv(N, BN * block_size)
    pid_n = pid % tl.cdiv(N, BN * block_size)

    stride_m = N
    stride_n = 1
    offs_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_n = pid_n * BN * block_size + tl.arange(0, BN * block_size)

    g_x_ptrs = g_x + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
    mask_m = offs_m[:, None] < M
    mask_n = offs_n[None, :] < N
    x = tl.load(g_x_ptrs, mask=mask_m & mask_n, other=0.0)

    offs_blk_m = pid_m * BM + tl.arange(0, BM)
    if repeat_b:
        offs_bit_m = tl.arange(0, BM)
        mask_bit_m = offs_bit_m[:, None] < 1
    else:
        offs_bit_m = pid_m * BM + tl.arange(0, BM)
        mask_bit_m = offs_bit_m[:, None] < tl.cdiv(M, block_size)
    stride_b_m = tl.cdiv(N, block_size)
    stride_b_n = 1
    offs_b_n = pid_n * BN + tl.arange(0, BN)
    g_x_b_ptrs = g_x_b + (
        offs_bit_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
    )
    mask_b_n = offs_b_n[None, :] < tl.cdiv(N, block_size)
    x_b = tl.load(g_x_b_ptrs, mask=mask_bit_m & mask_b_n, other=0.0)

    if repeat_b:
        # broadcast x_b (1, BN) to (BM, BN)
        x_b = tl.broadcast_to(tl.max(x_b, axis=0, keep_dims=True), (BM, BN))

    # calculate alpha
    alpha = tl.max(
        tl.max(
            tl.abs(
                tl.reshape(
                    x, (BM, block_size, BN, block_size), can_reorder=False
                )
            ),
            axis=1,
        ),
        axis=-1,
    )

    # NOTE approximate alpha / (2**(b-1) - 1)
    scale = alpha * tl.exp2(1 - tl.cast(x_b, dtype=tl.float32))

    scale_br = tl.reshape(
        tl.broadcast_to(
            tl.reshape(scale, (BM, 1, BN, 1), can_reorder=False),
            (BM, block_size, BN, block_size),
        ),
        (BM * block_size, BN * block_size),
        can_reorder=False,
    )

    # load packed noise from GMEM
    # (BM * block_size, BN * block_size // 8) as part of (M, cdiv(N, 8))
    stride_noise_m = tl.cdiv(N, 8)
    stride_noise_n = 1
    offs_noise_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_noise_n = pid_n * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )
    mask_noise_m = offs_noise_m[:, None] < M
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)
    noise_ptrs = g_noise + (
        offs_noise_m[:, None] * stride_noise_m
        + offs_noise_n[None, :] * stride_noise_n
    )
    packed = tl.load(noise_ptrs, mask=mask_noise_m & mask_noise_n, other=0)

    # unpack as (BM * block_size, BN * block_size)
    noise = noise_4bit_unpack_fp(
        packed, BM * block_size, BN * block_size // 8, is_uniform
    )

    # inject noise
    x = x + noise * scale_br

    # write in other storage
    g_out_ptrs = g_out + (
        offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
    )
    tl.store(
        g_out_ptrs, tl.cast(x, g_out.dtype.element_ty), mask=mask_m & mask_n
    )

    # store `x_a`
    x_a_ptrs = g_x_a + (
        offs_blk_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
    )
    mask_blk_m = offs_blk_m[:, None] < tl.cdiv(M, block_size)
    tl.store(x_a_ptrs, alpha, mask=mask_blk_m & mask_b_n)

    return


def nadd_fwd(x, x_b, block_size, seed, is_uniform):
    """
    input as activation or weight
    - x : [-, N] or [M, N] --(flatten)-> [flat_M, N]
    - x_b : [cdiv(N, block_size)] or [cdiv(M, block_size), cdiv(N, block_size)]
    returns
    - noisy_x : same shape as x
    - x_alpha : [cdiv(flat_M, block_size), cdiv(N, block_size)]
    """
    # generalize on input shape : flatten to be 2D
    orig_shape = x.shape
    if len(x.shape) == 1:
        x = x.view((1, -1))
    flat_x = x.flatten(0, -2)

    M, N = flat_x.shape
    grid = lambda meta: (
        triton.cdiv(M, meta["BM"] * block_size)
        * triton.cdiv(N, meta["BN"] * block_size),
    )

    noisy_x = torch.empty_like(flat_x)
    x_alpha = torch.empty(
        (triton.cdiv(M, block_size), triton.cdiv(N, block_size)),
        dtype=x.dtype,
        device=x.device,
    )
    repeat_b = (
        len(x_alpha.shape) > len(x_b.shape) or x_alpha.numel() > x_b.numel()
    )

    # generate packed noise (M * cdiv(N, 8),)
    noise = rand_4bit_packed(M * triton.cdiv(N, 8), seed, is_uniform=is_uniform)

    nadd_fwd_kernel[grid](
        flat_x,
        x_b,
        noise.view((M, triton.cdiv(N, 8))),
        noisy_x,
        x_alpha,
        M,
        N,
        block_size,
        repeat_b,
        is_uniform,
    )
    del noise
    return noisy_x.view(orig_shape), x_alpha


if __name__ == "__main__":
    print("verify ...")
    # TODO verify
    print("benchmark ...")
    # TODO benchmark
