import torch
import triton
import triton.language as tl
from prng import rand_4bit_packed, noise_4bit_unpack_fp


def get_configs(force_m1=False, force_n1=False):
    configs = []
    m_candidates = [1, 2, 4]
    n_candidates = [1, 2, 4]
    if force_m1:
        m_candidates = [1]
    if force_n1:
        n_candidates = [1]
    for num_stages in [2, 3, 4, 5]:
        for num_warps in [2]:
            for m in m_candidates:
                for n in n_candidates:
                    configs.append(
                        triton.Config(
                            {
                                "BM": m,
                                "BN": n,
                            },
                            num_stages=num_stages,
                            num_warps=num_warps,
                        )
                    )
    return configs


@triton.autotune(
    configs=get_configs(force_m1=True),
    key=["M", "N"],
)
@triton.jit
def nadd_bwd_bit_kernel(
    # pointers
    g_grad_out,
    g_alpha,
    g_bit,
    g_noise,
    g_grad_bit,
    # pointer info
    M: tl.constexpr,
    N: tl.constexpr,
    block_size: tl.constexpr,
    # scalar param
    is_uniform: tl.constexpr,
    # hyperparam
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    """
    grad_bit = -ln2 * 2**(1-bit) * alpha * sum(grad_out * rn, blockwise)

    grad_out in (M, N) : (BM * block_size, BN * block_size)
    rn in (M, N) : (BM * block_size, BN * block_size)
    alpha in (cdiv(M, block_size), cdiv(N, block_size)) : (BM, BN)
    bit in (cdiv(M, block_size), cdiv(N, block_size)) : (BM, BN)

    grad_bit in (cdiv(M, block_size), cdiv(N, block_size))
    """
    pid = tl.program_id(axis=0)
    pid_m = pid // tl.cdiv(N, BN * block_size)
    pid_n = pid % tl.cdiv(N, BN * block_size)

    stride_m = N
    stride_n = 1
    offs_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_n = pid_n * BN * block_size + tl.arange(0, BN * block_size)

    grad_out_ptrs = g_grad_out + (
        offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
    )
    mask_m = offs_m[:, None] < M
    mask_n = offs_n[None, :] < N
    grad_out = tl.load(grad_out_ptrs, mask=mask_m & mask_n, other=0.0)

    # noise in GMEM is packed in shape : (M, cdiv(N, 8))
    # load (BM * block_size, cdiv(BN * block_size, 8))
    offs_noise_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_noise_n = pid_n * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )
    stride_noise_m = tl.cdiv(N, 8)
    stride_noise_n = 1
    mask_noise_m = offs_noise_m[:, None] < M
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)
    noise_ptrs = g_noise + (
        offs_noise_m[:, None] * stride_noise_m
        + offs_noise_n[None, :] * stride_noise_n
    )
    packed = tl.load(noise_ptrs, mask=mask_noise_m & mask_noise_n, other=0)

    # unpack as (BM * block_size, BN * block_size)
    noise = noise_4bit_unpack_fp(
        packed, BM * block_size, BN * block_size // 8, is_uniform
    )

    blockwise = tl.sum(
        tl.sum(
            tl.reshape(
                grad_out * noise,
                (BM, block_size, BN, block_size),
                can_reorder=False,
            ),
            axis=1,
        ),
        axis=-1,
    )

    stride_b_m = tl.cdiv(N, block_size)
    stride_b_n = 1
    offs_b_m = pid_m * BM + tl.arange(0, BM)
    offs_b_n = pid_n * BN + tl.arange(0, BN)
    alpha_ptrs = g_alpha + (
        offs_b_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
    )
    mask_b_m = offs_b_m[:, None] < tl.cdiv(M, block_size)
    mask_b_n = offs_b_n[None, :] < tl.cdiv(N, block_size)
    alpha = tl.load(alpha_ptrs, mask=mask_b_m & mask_b_n, other=0.0)

    bit_ptrs = g_bit + (
        offs_b_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
    )
    bit = tl.load(bit_ptrs, mask=mask_b_m & mask_b_n, other=0.0)

    neg_ln2: tl.constexpr = -0.693147182464599609375  # -ln(2)
    grad_bit = (
        blockwise
        * alpha
        * tl.exp2(1 - tl.cast(bit, dtype=tl.float32))
        * neg_ln2
    )  # (BM, BN)

    # store result
    grad_bit_ptrs = g_grad_bit + (
        offs_b_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
    )
    tl.store(grad_bit_ptrs, grad_bit, mask=mask_b_m & mask_b_n)
    return


@triton.autotune(
    configs=get_configs(force_n1=True),
    key=["M", "N"],
)
@triton.jit
def nadd_bwd_bit_reduce_kernel(
    # pointers
    g_grad_out,
    g_alpha,
    g_bit,
    g_noise,
    g_grad_bit,
    # pointer info
    M: tl.constexpr,
    N: tl.constexpr,
    block_size: tl.constexpr,
    # scalar param
    is_uniform: tl.constexpr,
    # hyperparam
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    """
    grad_bit = -ln2 * 2**(1-bit) * sum(
        alpha * sum(grad_out * rn, blockwise),
        axis=0,
    )

    grad_out in (M, N) : (BM * block_size, BN * block_size)
    rn in (M, N) : (BM * block_size, BN * block_size)
    alpha in (cdiv(M, block_size), cdiv(N, block_size)) : (BM, BN)
    bit in (cdiv(N, block_size),) : (BN,)

    accum grad_bit in (BM, BN) --(sum)-> (BN,)
    grad_bit in (cdiv(N, block_size),) : (BN,)
    """
    pid = tl.program_id(axis=0)

    stride_m = N
    stride_n = 1
    stride_noise_m = tl.cdiv(N, 8)
    stride_noise_n = 1
    stride_blk_m = tl.cdiv(N, block_size)
    stride_blk_n = 1

    offs_n = pid * BN * block_size + tl.arange(0, BN * block_size)
    offs_noise_n = pid * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )
    offs_blk_n = pid * BN + tl.arange(0, BN)

    mask_n = offs_n[None, :] < N
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)
    mask_blk_n = offs_blk_n[None, :] < tl.cdiv(N, block_size)

    # # load bit (BN,) --(broadcast)-> (BM, BN)
    # bit_ptrs = g_bit + (offs_blk_n)
    # mask_bit_n = offs_blk_n < tl.cdiv(N, block_size)
    # bit = tl.load(bit_ptrs, mask=mask_bit_n, other=0.0)
    # bit_re = tl.broadcast_to(
    #     tl.reshape(bit, (1, BN), can_reorder=False), (BM, BN)
    # )

    accum = tl.zeros((BN,), dtype=tl.float32)
    for i in tl.range(0, tl.cdiv(M, BM * block_size)):
        # load grad_out
        offs_m = i * BM * block_size + tl.arange(0, BM * block_size)

        grad_out_ptrs = g_grad_out + (
            offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
        )
        mask_m = offs_m[:, None] < M
        grad_out = tl.load(grad_out_ptrs, mask=mask_m & mask_n, other=0.0)

        # load noise
        noise_ptrs = g_noise + (
            offs_m[:, None] * stride_noise_m
            + offs_noise_n[None, :] * stride_noise_n
        )
        packed = tl.load(noise_ptrs, mask=mask_m & mask_noise_n, other=0)
        noise = noise_4bit_unpack_fp(
            packed, BM * block_size, BN * block_size // 8, is_uniform
        )

        blockwise = tl.sum(
            tl.sum(
                tl.reshape(
                    grad_out * noise,
                    (BM, block_size, BN, block_size),
                    can_reorder=False,
                ),
                axis=1,
            ),
            axis=-1,
        )

        # load alpha
        offs_blk_m = i * BM + tl.arange(0, BM)
        alpha_ptrs = g_alpha + (
            offs_blk_m[:, None] * stride_blk_m
            + offs_blk_n[None, :] * stride_blk_n
        )
        mask_blk_m = offs_blk_m[:, None] < tl.cdiv(M, block_size)
        alpha = tl.load(alpha_ptrs, mask=mask_blk_m & mask_blk_n, other=0.0)

        # reduce-sum-accumulate, [BM, BN] --(sum)-> [BN,]
        # accum += tl.sum(blockwise * alpha, axis=0)
        accum += tl.sum(blockwise * alpha, axis=0)

    # load bit (BN,)
    bit_ptrs = g_bit + (offs_blk_n)
    mask_bit_n = offs_blk_n < tl.cdiv(N, block_size)
    bit = tl.load(bit_ptrs, mask=mask_bit_n, other=0.0)

    # calc bit
    neg_ln2: tl.constexpr = -0.693147182464599609375  # -ln(2)
    grad_bit = accum * tl.exp2(1 - bit) * neg_ln2

    # neg_ln2: tl.constexpr = -0.693147182464599609375  # -ln(2)
    # grad_bit = tl.sum(accum, axis=0) * neg_ln2

    # store
    grad_bit_ptrs = g_grad_bit + (offs_blk_n)
    tl.store(grad_bit_ptrs, grad_bit, mask=mask_bit_n)
    return


def nadd_bwd_bit(grad_out, alpha, bit, seed, block_size, is_uniform):
    """
    input as activation or weight
    - grad_out : [-, N] or [M, N] --(flatten)-> [M, N]
    - alpha : [cdiv(M, block_size), cdiv(N, block_size)]
    - bit : [cdiv(N, block_size)] or [cdiv(M, block_size), cdiv(N, block_size)]
    return
    - grad_bit : same as bit
    """
    # generalize on input shape : flatten to be 2D
    # orig_shape = grad_out.shape
    if len(grad_out.shape) == 1:
        grad_out = grad_out.view((1, -1))
    flat_grad_out = grad_out.flatten(0, -2)

    repeat_b = len(alpha.shape) > len(bit.shape) or alpha.numel() > bit.numel()

    M, N = flat_grad_out.shape

    # generate packed noise
    noise = rand_4bit_packed(M * triton.cdiv(N, 8), seed, is_uniform=is_uniform)

    grad_bit = torch.empty_like(bit)
    if repeat_b:
        grid = lambda meta: (triton.cdiv(N, meta["BN"] * block_size),)
        nadd_bwd_bit_reduce_kernel[grid](
            grad_out,
            alpha,
            bit,
            noise.view((M, triton.cdiv(N, 8))),
            grad_bit,
            M,
            N,
            block_size,
            is_uniform,
        )
    else:
        grid = lambda meta: (
            triton.cdiv(M, meta["BM"] * block_size)
            * triton.cdiv(N, meta["BN"] * block_size),
        )
        nadd_bwd_bit_kernel[grid](
            grad_out,
            alpha,
            bit,
            noise.view((M, triton.cdiv(N, 8))),
            grad_bit,
            M,
            N,
            block_size,
            is_uniform,
        )

    del noise
    return grad_bit


if __name__ == "__main__":
    print("verify ...")
    # TODO verify
    print("benchmark ...")
    # TODO benchmark
