import torch
import triton
import triton.language as tl
from prng import noise_4bit_unpack_fp, rand_4bit_packed


def get_bwd_bit_configs():
    configs = []
    for bm in [1, 2, 4, 8, 16, 32]:
        for bn in [1]:
            configs.append(triton.Config({"BM": bm, "BN": bn}))
    return configs


@triton.autotune(
    configs=get_bwd_bit_configs(),
    key=["M", "N", "block_size"],
)
@triton.jit
def linear_bwd_bit_kernel(
    # pointer i/o
    g_grad_out,
    g_noise,
    g_out_a,
    g_out_b,
    g_grad_bit,
    # pointer info
    M: tl.constexpr,
    N: tl.constexpr,
    block_size: tl.constexpr,
    # hyperparam
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    """
    - input
      - grad_out in (M, N) as (BM * block_size, BN * block_size)
      - out_a in (cdiv(M, block_size), cdiv(N, block_size)) as (BM, BN)
      - out_b in (cdiv(N, block_size), ) as (BN, )
    - output
      - grad_bit in (cdiv(N, block_size), )
    - compute
      - blockwise = sum_block(grad_out * rn)  # (BM, BN)
      - batchwise = sum_batch(alpha * blockwise)  # (BN, )
      - result = -ln2 * 2**(1-bit) * batchwise  # (BN, )
    """
    pid = tl.program_id(axis=0)

    # load grad_out
    stride_m = N
    stride_n = 1
    stride_b_m = tl.cdiv(N, block_size)
    stride_b_n = 1

    offs_n = pid * BN * block_size + tl.arange(0, BN * block_size)
    offs_b_n = pid * BN + tl.arange(0, BN)
    offs_noise_n = pid * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )

    mask_n = offs_n[None, :] < N
    mask_b_n = offs_b_n[None, :] < tl.cdiv(N, block_size)
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)

    accum = tl.zeros((BN,), dtype=tl.float32)
    for i in tl.range(0, tl.cdiv(M, BM * block_size)):
        offs_m = i * BM * block_size + tl.arange(0, BM * block_size)
        mask_m = offs_m[:, None] < M

        g_grad_out_ptrs = g_grad_out + (
            offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
        )
        grad_out = tl.load(g_grad_out_ptrs, mask=mask_m & mask_n, other=0.0)

        stride_noise_m = tl.cdiv(N, 8)
        stride_noise_n = 1
        offs_noise_m = i * BM * block_size + tl.arange(0, BM * block_size)
        mask_noise_m = offs_noise_m[:, None] < M
        noise_ptrs = g_noise + (
            offs_noise_m[:, None] * stride_noise_m
            + offs_noise_n[None, :] * stride_noise_n
        )
        packed = tl.load(noise_ptrs, mask=mask_noise_m & mask_noise_n, other=0)

        # unpack as (BM * block_size, BN * block_size)
        noise = noise_4bit_unpack_fp(
            packed, BM * block_size, BN * block_size // 8
        )

        # reduce gradient over block
        blockwise = tl.sum(
            tl.sum(
                tl.reshape(
                    grad_out * noise,
                    (BM, block_size, BN, block_size),
                    can_reorder=False,
                ),
                axis=1,
            ),
            axis=-1,
        )

        # load alpha
        offs_b_m = i * BM + tl.arange(0, BM)
        g_out_a_ptrs = g_out_a + (
            offs_b_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
        )
        mask_b_m = offs_b_m[:, None] < tl.cdiv(M, block_size)
        out_a = tl.load(g_out_a_ptrs, mask=mask_b_m & mask_b_n, other=0.0)

        # reduce gradient over batch
        accum += tl.sum(blockwise * out_a, axis=0)  # (BN,)

    # load bit
    g_out_b_ptrs = g_out_b + offs_b_n
    # (1, BN) from (cdiv(M, block_size), cdiv(N, block_size))
    mask_bit_n = offs_b_n < tl.cdiv(N, block_size)
    out_b = tl.load(g_out_b_ptrs, mask=mask_bit_n, other=0.0)
    neg_ln2: tl.constexpr = -0.693147182464599609375

    grad_bit = accum * tl.exp2(1 - out_b) * neg_ln2

    # store grad_bit
    g_grad_bit_ptrs = g_grad_bit + offs_b_n
    tl.store(g_grad_bit_ptrs, grad_bit, mask=mask_bit_n)
    return


def linear_bwd_bit(grad_out, out_a, out_b, seed, block_size):
    assert len(out_b.shape) == 1

    grad_out_flatten = grad_out.flatten(0, -2)
    M, N = grad_out_flatten.shape
    BM, BN = out_a.shape
    assert triton.cdiv(M, block_size) == BM
    assert triton.cdiv(N, block_size) == BN

    # generate noise here
    noise = rand_4bit_packed(M * triton.cdiv(N, 8), seed)

    grid = lambda meta: (triton.cdiv(N, meta["BN"] * block_size),)

    grad_b = torch.empty_like(out_b)
    linear_bwd_bit_kernel[grid](
        grad_out_flatten, noise, out_a, out_b, grad_b, M, N, block_size
    )
    del noise
    return grad_b


if __name__ == "__main__":
    # TODO integrity check
    print("integrity check ...")
    raise NotImplementedError()
