import torch
import triton
import triton.language as tl
from prng import rand_4bit_packed, noise_4bit_unpack_fp
from triton.language.extra import libdevice


def get_configs():
    configs = []
    for num_stages in [2, 3, 4, 5]:
        for num_warps in [2]:
            for m in [1]:
                for n in [1]:
                    configs.append(
                        triton.Config(
                            {
                                "BM": m,
                                "BN": n,
                            },
                            num_stages=num_stages,
                            num_warps=num_warps,
                        )
                    )
    return configs


@triton.autotune(
    configs=get_configs(),
    key=["M", "N"],
)
@triton.jit
def gelu_fwd_kernel(
    # pointer i/o
    g_x,
    g_bit,
    g_alpha,
    g_noise,
    g_out,
    # pointer param
    M,
    N,
    # scalar param
    block_size: tl.constexpr,
    # hyperparam
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    pid_m = pid // tl.cdiv(N, BN * block_size)
    pid_n = pid % tl.cdiv(N, BN * block_size)

    # load x
    stride_m = N
    stride_n = 1
    offs_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_n = pid_n * BN * block_size + tl.arange(0, BN * block_size)
    x_ptrs = g_x + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
    mask_m = offs_m[:, None] < M
    mask_n = offs_n[None, :] < N
    x = tl.load(x_ptrs, mask=mask_m & mask_n, other=0.0)

    # 1 / sqrt(2) in fp32
    inv_sqrt2: tl.constexpr = 0.707106769084930419922
    x32 = tl.cast(x, dtype=tl.float32)
    gelu_x = (
        x * 0.5 * (1 + tl.cast(libdevice.erf(x32 * inv_sqrt2), dtype=x.dtype))
    )

    # compute alpha
    alpha = tl.max(
        tl.max(
            tl.abs(
                tl.reshape(
                    gelu_x, (BM, block_size, BN, block_size), can_reorder=False
                )
            ),
            axis=1,
            keep_dims=False,
        ),
        axis=-1,
        keep_dims=False,
    )

    # store alpha required for bwd
    stride_blk_m = tl.cdiv(N, block_size)
    stride_blk_n = 1
    offs_blk_m = pid_m * BM + tl.arange(0, BM)
    offs_blk_n = pid_n * BN + tl.arange(0, BN)

    alpha_ptrs = g_alpha + (
        offs_blk_m[:, None] * stride_blk_m + offs_blk_n[None, :] * stride_blk_n
    )
    mask_blk_m = offs_blk_m[:, None] < tl.cdiv(M, block_size)
    mask_blk_n = offs_blk_n[None, :] < tl.cdiv(N, block_size)
    tl.store(alpha_ptrs, alpha, mask_blk_m & mask_blk_n)

    # load bit
    bit_ptrs = None
    offs_bit_m = tl.arange(0, 1)
    mask_bit_m = offs_bit_m[:, None] < 1
    bit_ptrs = g_bit + (
        offs_bit_m[:, None] * stride_blk_m + offs_blk_n[None, :] * stride_blk_n
    )
    bit = tl.load(bit_ptrs, mask=mask_bit_m & mask_blk_n, other=0.0)

    # broadcast (1, BN) --> (BM, BN)
    bit = tl.broadcast_to(bit, (BM, BN))

    # get scale
    scale = alpha * tl.exp2(1 - bit)
    scale_br = tl.reshape(
        tl.broadcast_to(
            tl.reshape(scale, (BM, 1, BN, 1), can_reorder=False),
            (BM, block_size, BN, block_size),
        ),
        (BM * block_size, BN * block_size),
        can_reorder=False,
    )

    # load noise
    stride_noise_m = tl.cdiv(N, 8)
    stride_noise_n = 1
    offs_noise_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_noise_n = pid_n * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )
    noise_ptrs = g_noise + (
        offs_noise_m[:, None] * stride_noise_m
        + offs_noise_n[None, :] * stride_noise_n
    )
    mask_noise_m = offs_noise_m[:, None] < M
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)
    noise = tl.load(noise_ptrs, mask=mask_noise_m & mask_noise_n, other=0)
    unpacked = noise_4bit_unpack_fp(
        noise, BM * block_size, BN * block_size // 8
    )

    # add blockwise noise
    out = gelu_x + unpacked * scale_br

    # store gelu(x) + noise
    out_ptrs = g_out + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
    mask_m = offs_m[:, None] < M
    mask_n = offs_n[None, :] < N
    tl.store(out_ptrs, out, mask=mask_m & mask_n)
    return


def gelu_fwd(x, out_b, block_size, seed):
    """
    always activation
    - x : [-, N] --(flatten)-> [flat_M, N]
    - out_b: [cdiv(N, block_size)]
    returns
    - out : same as x
    - out_alpha : [cdiv(flat_M, block_size), cdiv(N, block_size)]
    """
    orig_shape = x.shape
    flat_x = x.flatten(0, -2)
    M, N = flat_x.shape

    # generate `M * cdiv(N, 8)` u32
    noise = rand_4bit_packed(M * triton.cdiv(N, 8), seed)

    grid = lambda meta: (
        triton.cdiv(M, meta["BM"] * block_size)
        * triton.cdiv(N, meta["BN"] * block_size),
    )
    out = torch.empty_like(flat_x)
    out_alpha = torch.empty(
        (triton.cdiv(M, block_size), triton.cdiv(N, block_size)),
        dtype=x.dtype,
        device=x.device,
    )
    gelu_fwd_kernel[grid](x, out_b, out_alpha, noise, out, M, N, block_size)
    return out.view(orig_shape), out_alpha


if __name__ == "__main__":
    print("verify ...")
    # TODO verify
    print("benchmark ...")
    # TODO benchmark
