import torch
import triton
import triton.language as tl
from prng import rand_4bit_packed, noise_4bit_unpack_fp
from triton.language.extra import libdevice


def get_configs(force_n1=False):
    configs = []
    n_candidates = [1, 2, 4]
    if force_n1:
        n_candidates = [1]
    for num_stages in [2, 3, 4, 5]:
        for num_warps in [2]:
            for m in [1, 2, 4]:
                for n in n_candidates:
                    configs.append(
                        triton.Config(
                            {
                                "BM": m,
                                "BN": n,
                            },
                            num_stages=num_stages,
                            num_warps=num_warps,
                        )
                    )
    return configs


def get_linear_configs():
    configs = []
    for num_stages in [2, 3, 4, 5]:
        for num_warps in [2]:
            for n in [128, 256, 512, 1024]:
                configs.append(
                    triton.Config(
                        {
                            "BN": n,
                        },
                        num_stages=num_stages,
                        num_warps=num_warps,
                    )
                )
    return configs


@triton.autotune(
    configs=get_configs(force_n1=True),
    key=["M", "N"],
)
@triton.jit
def gelu_bwd_bit_kernel(
    # pointer i/o
    g_grad_out,
    g_out_alpha,
    g_out_bit,
    g_noise,
    g_grad_bit,
    # pointer param
    M,
    N,
    # scalar param
    block_size: tl.constexpr,
    # hyperparam
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    """
    grad_bit = -ln2 * 2**(1-bit) * sum(
        alpha * sum(grad_out * rn, blockwise),
        axis=0,
    )
    """
    pid = tl.program_id(axis=0)

    stride_m = N
    stride_n = 1
    stride_noise_m = tl.cdiv(N, 8)
    stride_noise_n = 1
    stride_blk_m = tl.cdiv(N, block_size)
    stride_blk_n = 1

    offs_n = pid * BN * block_size + tl.arange(0, BN * block_size)
    offs_noise_n = pid * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )
    offs_blk_n = pid * BN + tl.arange(0, BN)

    mask_n = offs_n[None, :] < N
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)
    mask_blk_n = offs_blk_n[None, :] < tl.cdiv(N, block_size)

    accum = tl.zeros((BN,), dtype=tl.float32)
    for i in tl.range(0, tl.cdiv(M, BM * block_size)):
        # load grad_out
        offs_m = i * BM * block_size + tl.arange(0, BM * block_size)
        grad_out_ptrs = g_grad_out + (
            offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
        )
        mask_m = offs_m[:, None] < M
        grad_out = tl.load(grad_out_ptrs, mask=mask_m & mask_n, other=0.0)

        # load noise
        offs_noise_m = i * BM * block_size + tl.arange(0, BM * block_size)
        noise_ptrs = g_noise + (
            offs_noise_m[:, None] * stride_noise_m
            + offs_noise_n[None, :] * stride_noise_n
        )
        mask_noise_m = offs_noise_m[:, None] < M
        noise = tl.load(noise_ptrs, mask=mask_noise_m & mask_noise_n, other=0)
        unpacked = noise_4bit_unpack_fp(
            noise, BM * block_size, BN * block_size // 8
        )

        # blockwise sum
        blockwise = tl.sum(
            tl.sum(
                tl.reshape(
                    grad_out * unpacked,
                    (BM, block_size, BN, block_size),
                    can_reorder=False,
                ),
                axis=1,
            ),
            axis=-1,
        )

        # load alpha
        offs_blk_m = i * BM + tl.arange(0, BM)
        alpha_ptrs = g_out_alpha + (
            offs_blk_m[:, None] * stride_blk_m
            + offs_blk_n[None, :] * stride_blk_n
        )

        mask_blk_m = offs_blk_m[:, None] < tl.cdiv(M, block_size)
        alpha = tl.load(alpha_ptrs, mask=mask_blk_m & mask_blk_n, other=0.0)

        # row-wise sum [BM, BN] --> [BN]
        accum += tl.sum(blockwise * alpha, axis=0)

    # load bit [BN]
    bit_ptrs = g_out_bit + (offs_blk_n)
    mask_bit_n = offs_blk_n < tl.cdiv(N, block_size)
    bit = tl.load(bit_ptrs, mask=mask_bit_n, other=0.0)

    # grad_bit = sum(grad_out * noise, blockwise) *\
    #   alpha * tl.exp2(1-bit) * (-ln2)
    neg_ln2: tl.constexpr = -0.693147182464599609375
    grad_bit = accum * tl.exp2(1 - bit) * neg_ln2

    # store grad_bit
    grad_bit_ptrs = g_grad_bit + (offs_blk_n)
    tl.store(grad_bit_ptrs, grad_bit, mask=mask_bit_n)
    return


def gelu_bwd_bit(grad_out, out_alpha, out_bit, seed, block_size):
    """
    always activation
    - grad_out : [-, N] --(flatten)-> [M, N]
    - out_alpha : [cdiv(M, block_size), cdiv(N, block_size)]
    - out_bit : [cdiv(N, block_size)]
    returns
    - grad_bit : same as out_bit
    """
    if len(grad_out.shape) == 1:
        grad_out = grad_out.view((1, -1))
    flat = grad_out.flatten(0, -2)
    M, N = flat.shape

    noise = rand_4bit_packed(M * triton.cdiv(N, 8), seed)

    grid = lambda meta: (triton.cdiv(N, meta["BN"] * block_size),)
    grad_bit = torch.empty_like(out_bit)
    gelu_bwd_bit_kernel[grid](
        flat, out_alpha, out_bit, noise, grad_bit, M, N, block_size
    )
    return grad_bit


@triton.autotune(
    configs=get_linear_configs(),
    key=["N"],
)
@triton.jit
def gelu_bwd_x_kernel(
    g_grad_out,
    g_x,
    g_grad_x,
    N,
    BN: tl.constexpr,
):
    """
    grad_x = grad_out * (
        0.5 * (1 + erf(x / sqrt(2)))
        + x / sqrt(2 * pi) * exp(-x**2 * 0.5)
    )

    - grad_out : (BN,) of (N,)
    - x : (BN,) of (N,)
    """
    pid = tl.program_id(axis=0)

    # load grad_out, g_x
    offs_n = pid * BN + tl.arange(0, BN)
    grad_out_ptrs = g_grad_out + offs_n
    x_ptrs = g_x + offs_n
    mask_n = offs_n < N
    grad_out = tl.load(grad_out_ptrs, mask=mask_n, other=0.0)
    x = tl.load(x_ptrs, mask=mask_n, other=0.0)

    # 1 / sqrt(2) in fp32
    inv_sqrt2: tl.constexpr = 0.707106769084930419922
    # 1 / sqrt(2 * pi) in fp32
    inv_sqrt2pi: tl.constexpr = 0.398942291736602783203

    x32 = tl.cast(x, dtype=tl.float32)
    grad_x = (
        grad_out
        * 0.5
        * tl.cast(
            1
            + tl.cast(
                libdevice.erf(x32 * inv_sqrt2),
                dtype=x.dtype,
            )
            + x32 * inv_sqrt2pi * tl.exp(-(x32 * x32) * 0.5),
            dtype=x.dtype,
        )
    )

    # store grad_x
    grad_x_ptrs = g_grad_x + offs_n
    tl.store(grad_x_ptrs, grad_x, mask=mask_n)
    return


def gelu_bwd_x(grad_out, x):
    """
    always activation
    """
    N = x.numel()

    grid = lambda meta: (triton.cdiv(N, meta["BN"]),)
    grad_x = torch.empty_like(x)
    gelu_bwd_x_kernel[grid](grad_out, x, grad_x, N)
    return grad_x


if __name__ == "__main__":
    print("verify ...")
    # TODO verify
    print("benchmark ...")
    # TODO benchmark
