import torch
import triton
import triton.language as tl
from q_config import QBN
from util import get_gemm_configs, dtype_tc2tn
from torch.nn import functional as F
from prng import (
    rand_4bit_fp,
    rand_4bit_packed,
    noise_4bit_unpack_fp,
)

# weight : (fp16 -> block-wise noise injection)
# activation : as-is
# gemm : (a @ w -> block-wise noise injection)


@triton.autotune(
    configs=get_gemm_configs(bs=QBN),
    key=["M", "N", "K", "block_size", "g_dtype"],
)
@triton.jit
def linear_fwd_kernel(
    # pointer i/o
    g_act,
    g_wgt,
    g_noise,
    g_out,
    g_out_a,
    g_out_b,
    g_bias,
    # pointer shapes
    M,
    N,
    K,
    block_size: tl.constexpr,
    # scalar parameters
    g_dtype: tl.constexpr,
    is_wgt_transposed: tl.constexpr,
    # hyperparams
    BM: tl.constexpr,
    BN: tl.constexpr,
    BK: tl.constexpr,
    GM: tl.constexpr,
):
    """
    - input
      - act in (M, K) as (BM * block_size, BK * block_size)
      - wgt in (K, N) as (BK * block_size, BN * block_size)
      - out_b in (1, cdiv(N, block_size)) as (1, BN)
    - output
      - out in (M, N) as (BM * block_size, BN * block_size)
      - out_a in (cdiv(M, block_size), cdiv(N, block_size)) as (BM, BN)
    """
    # assume inputs of shape (M, K) and (K, N)
    # and thus output of shape (M, N)
    # each kernel handles output of (BM * block_size, BN * block_size)
    # in which contains multiple(BM * block_size / m) batches
    pid = tl.program_id(axis=0)
    pid_per_m = tl.cdiv(M, BM * block_size)
    pid_per_n = tl.cdiv(N, BN * block_size)

    # grouped ordering, for L2 cache optimization
    # each group handles [GM * BM * block_size, N] in column-first order

    pid_per_group = GM * pid_per_n
    group_id = pid // pid_per_group
    pid_in_group = pid % pid_per_group

    pid_off_m = group_id * GM
    group_size_m = tl.minimum(pid_per_m - pid_off_m, GM)

    pid_m = pid_off_m + (pid_in_group % group_size_m)
    pid_n = pid_in_group // group_size_m

    stride_m = K
    stride_k1 = 1
    stride_k2 = N
    stride_n = 1

    offs_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_n = pid_n * BN * block_size + tl.arange(0, BN * block_size)

    mask_m = offs_m[:, None] < M
    mask_n = offs_n[None, :] < N

    accum = tl.zeros((BM * block_size, BN * block_size), dtype=tl.float32)

    for i in tl.range(0, tl.cdiv(K, BK * block_size)):
        offs_k = i * BK * block_size + tl.arange(0, BK * block_size)

        g_act_ptrs = g_act + (
            offs_m[:, None] * stride_m + offs_k[None, :] * stride_k1
        )

        mask_k1 = offs_k[None, :] < K
        act = tl.load(g_act_ptrs, mask=mask_m & mask_k1, other=0.0)

        if is_wgt_transposed:
            # wgt in [n, k]
            g_wgt_ptrs = g_wgt + (offs_n[:, None] * K + offs_k[None, :])
            mask_n2t = offs_n[:, None] < N
            mask_k2t = offs_k[None, :] < K
            wgt = tl.trans(
                tl.load(g_wgt_ptrs, mask=mask_n2t & mask_k2t, other=0.0)
            )
        else:
            g_wgt_ptrs = g_wgt + (
                offs_k[:, None] * stride_k2 + offs_n[None, :] * stride_n
            )
            mask_k2 = offs_k[:, None] < K
            wgt = tl.load(g_wgt_ptrs, mask=mask_k2 & mask_n, other=0.0)

        act_c = tl.cast(act, dtype=tl.float16)
        wgt_c = tl.cast(wgt, dtype=tl.float16)
        accum += tl.dot(act_c, wgt_c, out_dtype=tl.float32)

    # add bias
    if g_bias is not None:
        # load bias, shaped [BN * block_size]
        offs_bias = pid_n * BN * block_size + tl.arange(0, BN * block_size)
        g_bias_ptrs = g_bias + offs_bias
        mask_bias = offs_bias < N
        bias = tl.load(g_bias_ptrs, mask=mask_bias, other=0.0)

        # add bias [BN * block_size]
        # to output [BM * block_size, BN * block_size]
        accum += tl.broadcast_to(
            tl.reshape(bias, (1, BN * block_size), can_reorder=False),
            (BM * block_size, BN * block_size),
        )

    # load bit
    stride_b_m = tl.cdiv(N, block_size)
    stride_b_n = 1

    # (1, BN) from (1, cdiv(N, block_size))
    offs_b_n = pid_n * BN + tl.arange(0, BN)
    g_out_b_ptrs = g_out_b + (offs_b_n * stride_b_n)

    mask_b_n = offs_b_n < tl.cdiv(N, block_size)
    out_bit = tl.reshape(
        tl.load(g_out_b_ptrs, mask=mask_b_n, other=0.0),
        (1, BN),
        can_reorder=False,
    )

    # compute alpha
    # `alpha` is just-in-time, per-block
    # if num_channels == 1, alpha handles `block_size` batches per block
    alpha = tl.max(
        tl.max(
            tl.reshape(
                tl.abs(accum),
                (BM, block_size, BN, block_size),
                can_reorder=False,
            ),
            axis=1,
            keep_dims=False,
        ),
        axis=-1,
        keep_dims=False,
    )

    # store alpha
    # (BM, BN) out of (cdiv(M, block_size), cdiv(N, block_size))
    # stride_b_b = tl.cdiv(M, block_size) * tl.cdiv(N, block_size)
    offs_b_m = pid_m * BM + tl.arange(0, BM)
    g_out_a_ptrs = g_out_a + (
        offs_b_m[:, None] * stride_b_m + offs_b_n[None, :] * stride_b_n
    )
    mask_b_m = offs_b_m[:, None] < tl.cdiv(M, block_size)
    tl.store(g_out_a_ptrs, alpha, mask=mask_b_m & mask_b_n)

    # NOTE approximated alpha / (2**(bit-1) - 1)
    scale = alpha * tl.broadcast_to(tl.exp2(1 - out_bit), (BM, BN))
    scale_br = tl.reshape(
        tl.broadcast_to(
            tl.reshape(scale, (BM, 1, BN, 1), can_reorder=False),
            (BM, block_size, BN, block_size),
        ),
        (BM * block_size, BN * block_size),
        can_reorder=False,
    )

    # load noise
    stride_noise_m = tl.cdiv(N, 8)
    stride_noise_n = 1
    offs_noise_m = pid_m * BM * block_size + tl.arange(0, BM * block_size)
    offs_noise_n = pid_n * BN * block_size // 8 + tl.arange(
        0, BN * block_size // 8
    )
    mask_noise_m = offs_noise_m[:, None] < M
    mask_noise_n = offs_noise_n[None, :] < tl.cdiv(N, 8)
    noise_ptrs = g_noise + (
        offs_noise_m[:, None] * stride_noise_m
        + offs_noise_n[None, :] * stride_noise_n
    )
    packed = tl.load(noise_ptrs, mask=mask_noise_m & mask_noise_n, other=0)

    # unpack as (BM * block_size, BN * block_size)
    noise = noise_4bit_unpack_fp(packed, BM * block_size, BN * block_size // 8)

    # inject noise
    accum += noise * scale_br

    # store
    stride_out_m = N
    stride_out_n = 1
    g_out_ptrs = g_out + (
        offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n
    )
    tl.store(
        g_out_ptrs,
        tl.cast(accum, dtype=g_dtype),
        mask=mask_m & mask_n,
    )
    return


def linear_fwd(
    act, wgt, out_b, bias, block_size, seed, is_wgt_transposed=False
):
    orig_shape = act.shape

    if len(orig_shape) == 1:
        act = act.view((1, *orig_shape))

    flatten_act = act.flatten(0, -2)  # 2D

    assert len(wgt.shape) == 2
    assert len(out_b.shape) == 1
    if bias is not None:
        assert len(bias.shape) == 1
    M, K = flatten_act.shape
    if is_wgt_transposed:
        N, K2 = wgt.shape
    else:
        K2, N = wgt.shape
    assert K == K2, f"{K} != {K2}"
    assert len(out_b.shape) == 1
    assert triton.cdiv(N, block_size) == out_b.shape[0]

    # assume noise injected on `wgt`
    grid = lambda meta: (
        triton.cdiv(M, meta["BM"] * block_size)
        * triton.cdiv(N, meta["BN"] * block_size),
    )

    g_out = torch.empty((M, N), dtype=act.dtype, device=act.device)
    out_a = torch.empty(
        (triton.cdiv(M, block_size), triton.cdiv(N, block_size)),
        dtype=torch.float32,
        device=act.device,
    )
    g_dtype = dtype_tc2tn(act.dtype)

    # generate noise
    noise = rand_4bit_packed(M * triton.cdiv(N, 8), seed)

    linear_fwd_kernel[grid](
        act,
        wgt,
        noise.view((M, triton.cdiv(N, 8))),
        g_out,
        out_a,
        out_b,
        bias,
        M,
        N,
        K,
        block_size,
        g_dtype,
        is_wgt_transposed,
    )

    del noise
    if len(orig_shape) == 1:
        g_out = g_out.view((-1))
    else:
        g_out = g_out.view((*orig_shape[:-1], N))
    return g_out, out_a


if __name__ == "__main__":
    from prng import rand_4bit_fp, RomuTrio32

    romu = RomuTrio32(42)
    for _ in range(10):
        romu.next()
    torch.manual_seed(42)

    print("integrity check...")
    block_size = QBN
    num_bit = 4
    for M in [1, 32, 128]:
        for N in [32, 256, 1024]:
            for K in [32, 256, 1024]:
                for is_transposed in [False, True]:
                    seed = romu.next()
                    t1 = (
                        torch.randn(
                            [M, K],
                            dtype=torch.float16,
                            device="cuda",
                        )
                        * 8
                    )  # activation
                    t2 = (
                        torch.randn([K, N], dtype=torch.float16, device="cuda")
                        * 8
                    )  # weight
                    if is_transposed:
                        t2_c = torch.empty(
                            [N, K], dtype=torch.float16, device="cuda"
                        )
                        t2_c[:, :] = t2.T[:, :]
                        t2 = t2_c
                        baseline_mm = F.linear(t1, t2)
                    else:
                        baseline_mm = t1 @ t2
                    noise = rand_4bit_fp(M, N, seed)
                    NBM = triton.cdiv(M, block_size)
                    NBN = triton.cdiv(N, block_size)
                    for use_bias in [False, True]:
                        bias = None
                        if use_bias:
                            bias = torch.ones(
                                N, dtype=torch.float32, device="cuda"
                            )
                            baseline_mm += bias.view((1, -1))
                        out_b = (
                            torch.ones(
                                (triton.cdiv(N, block_size),),
                                dtype=torch.float32,
                                device="cuda",
                            )
                            * num_bit
                        )
                        out_b_calc = torch.exp2(1 - out_b)
                        out_b_re = torch.zeros_like(baseline_mm)
                        for m in range(NBM):
                            for n in range(NBN):
                                out_b_re[
                                    m * block_size : (m + 1) * block_size,
                                    n * block_size : (n + 1) * block_size,
                                ] = out_b_calc[n]
                        absmax = (
                            baseline_mm.abs()
                            .view(
                                (
                                    NBM,
                                    -1,  # for small M
                                    NBN,
                                    block_size,
                                )
                            )
                            .max(dim=1, keepdims=False)
                            .values.max(dim=-1, keepdims=False)
                            .values
                        )
                        alpha = torch.zeros_like(baseline_mm)
                        for m in range(NBM):
                            for n in range(NBN):
                                alpha[
                                    m * block_size : (m + 1) * block_size,
                                    n * block_size : (n + 1) * block_size,
                                ] = absmax[m, n]
                        baseline = baseline_mm + noise * alpha * out_b_re

                        print(
                            f"(M={M}, N={N}, K={K})",
                            end="",
                        )
                        print("+bias" if use_bias else "", end=" ")
                        print("+trans" if is_transposed else "", end=" ")
                        out, _ = linear_fwd(
                            t1,
                            t2,
                            out_b,
                            bias,
                            block_size=QBN,
                            seed=seed,
                            is_wgt_transposed=is_transposed,
                        )
                        l1_error = F.l1_loss(
                            baseline, out.to(torch.float16)
                        ).item()
                        print(f"L1 error = {l1_error}")
                        if (
                            torch.isnan(torch.tensor(l1_error))
                            or l1_error > 0.5
                        ):
                            print(baseline)
                            print(out)

    def get_bench_configs(plot_name="linear_fwd"):
        configs = []

        line_vals = [
            f"g{g_dtype}-t{t}" for g_dtype in ["float16"] for t in ["Yes", "No"]
        ]
        styles = [
            (color, style) for color in ["blue", "green"] for style in ["-"]
        ]
        line_names = line_vals

        configs.append(
            triton.testing.Benchmark(
                x_names=[
                    "M",
                    "N",
                    "K",
                ],  # Argument names to use as an x-axis for the plot
                x_vals=[128 * i for i in range(1, 8)]
                + [256 * i for i in range(4, 8)]
                + [512 * i for i in range(4, 8)]
                + [1024 * i for i in range(4, 9)],
                x_log=False,
                line_arg="provider",
                line_vals=line_vals,
                line_names=line_names,
                styles=styles,
                ylabel="TFLOP/s",  # Label name for the y-axis
                args={},
                plot_name=plot_name,
            )
        )
        return configs

    @triton.testing.perf_report(get_bench_configs())
    def benchmark(M, N, K, provider):
        block_size = QBN
        g_part = provider.split("-")[0]
        t_part = provider.split("-")[-1]
        num_bit = 8

        print(f"(M={M}, N={N}, K={K}), {provider}")

        g_dtype = g_part[1:]
        if g_dtype == "float16":
            g_dtype = torch.float16
        elif g_dtype == "float8_e5m2":
            g_dtype = torch.float8_e5m2
        elif g_dtype == "float8_e4m3fn":
            g_dtype = torch.float8_e4m3fn
        else:
            raise ValueError(g_part)

        is_transposed = t_part[1:]
        if is_transposed == "Yes":
            is_transposed = True
        elif is_transposed == "No":
            is_transposed = False
        else:
            raise ValueError(t_part)

        x1 = torch.randn((M, K), device="cuda", dtype=torch.float16).to(g_dtype)
        if is_transposed:
            x2 = torch.randn((N, K), device="cuda", dtype=torch.float16).to(
                g_dtype
            )
        else:
            x2 = torch.randn((K, N), device="cuda", dtype=torch.float16).to(
                g_dtype
            )

        out_b = (
            torch.ones(
                (triton.cdiv(N, block_size),),
                dtype=torch.float32,
                device="cuda",
            )
            * num_bit
        )

        bias = torch.zeros(N, dtype=torch.float32, device="cuda")

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: linear_fwd(
                x1,
                x2,
                out_b,
                bias,
                block_size=QBN,
                seed=42,
                is_wgt_transposed=is_transposed,
            ),
            quantiles=quantiles,
            rep=1000,
        )
        perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
        return perf(ms), perf(max_ms), perf(min_ms)

    benchmark.run(save_path="./")
