import torch
import triton
import triton.language as tl
import triton.language.core as tlc
import triton.language.math as math
import triton.profiler as proton
from triton.language.extra import libdevice

PHILOX_NROUND: tl.constexpr = 7


class RomuTrio32:
    PRIMITIVE = 32

    def rotate_left(x, lrot):
        lrot = lrot % RomuTrio32.PRIMITIVE
        return (x << lrot) & 0xFFFF_FFFF | (x >> (RomuTrio32.PRIMITIVE - lrot))

    def split_mix(state):
        # [splitmix32a](https://github.com/joelkp/ranoise)
        state = (state + 2654435769) & 0xFFFF_FFFF
        x = state
        x ^= x >> 15
        x = (x * 0x85EBCA6B) & 0xFFFF_FFFF
        x ^= x >> 13
        x = (x * 0xC2B2AE35) & 0xFFFF_FFFF
        x ^= x >> 16
        return (x, state)

    def __init__(self, seed):
        x, state = RomuTrio32.split_mix(seed)
        y, state = RomuTrio32.split_mix(state)
        z, _ = RomuTrio32.split_mix(state)

        self.state = [
            x if x != 0 else 1,
            y if y != 0 else 1,
            z if z != 0 else 1,
        ]

    def next(self, freeze=False):
        if freeze:
            return self.state[0]
        x = self.state[0]
        y = self.state[1]
        z = self.state[2]

        self.state[0] = (3323815723 * z) & 0xFFFF_FFFF
        self.state[1] = (y - x) & 0xFFFF_FFFF
        self.state[1] = RomuTrio32.rotate_left(self.state[1], 6)
        self.state[2] = (z - y) & 0xFFFF_FFFF
        self.state[2] = RomuTrio32.rotate_left(self.state[2], 22)
        return x


def get_prng_config():
    configs = []
    for num_stages in [3, 4, 5]:
        for num_warps in [2]:
            for m, n in [(32, 32)]:
                configs.append(
                    triton.Config(
                        {
                            "BM": m,
                            "BN": n,
                        },
                        num_warps=num_warps,
                        num_stages=num_stages,
                    )
                )
    return configs


@triton.jit
def pack_4bit(a, b, c, d):
    # pack
    a = a & (a >> 1)
    a = a & (a >> 2)
    a = a & (a >> 4)  # 0000_000x  : (x) 2**-8

    c = c & (c >> 1)
    c = c & (c >> 2)
    c = c & (c << 4)  # 000x_0000  : (x) 2**-8

    x1h = b | (b >> 1)  # -x-x_-x-x  : (x) 3/4
    x1q = x1h & (x1h >> 2)  # ---x_-x-x  : (x) 9/16

    x3h = d | (d >> 1)  # -x-x_-x-x  : (x) 3/4
    x3q = x3h & (x3h << 2)  # -x-x_-x--  : (x) 9/16

    # use all of a, c
    # use bit {1, 0} of b
    # use bit {5, 4} of d
    two = a & x1h & 0x0101_0101 | c & x3h & 0x1010_1010

    # use bit {6}, {5, 4, 3, 2} of b
    # use bit {6}, {3, 2, 1, 0} of d
    one = (b >> 6) & (x1q >> 2) & 0x0101_0101 | (d >> 2) & (
        x3q << 2
    ) & 0x1010_1010

    # suppress case of `3`
    one = one & ~two

    # use bit {7} of b
    # use bit {7} of d
    sign = (b >> 4) & 0x0808_0808 | d & 0x8080_8080

    packed = sign | (two << 1) | one
    return packed


@triton.jit
def uint_to_uniform_float(x):
    """
    Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
    """
    # TODO: fix frontend issues and cleanup
    # conditions can be simplified
    # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
    if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
        # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
        x = x.to(tl.int32, bitcast=True)
        scale = 4.6566127342e-10
    else:
        tl.static_assert(
            tl.constexpr(x.dtype == tl.uint64)
            or tl.constexpr(x.dtype == tl.int64)
        )
        x = x.to(tl.int64, bitcast=True)
        scale = 1.0842020432385337e-19
    x = tl.where(x < 0, -x - 1, x)
    return x * scale


@triton.jit
def pair_uniform_to_normal(u1, u2):
    """Box-Muller transform"""
    u1 = tl.maximum(1.0e-7, u1)
    th = 6.283185307179586 * u2
    r = math.sqrt(-2.0 * math.log(u1))
    return r * math.cos(th), r * math.sin(th)


@triton.jit
def pack_4bit_bm(a, b, c, d, e, f, g, h):
    u1 = uint_to_uniform_float(a)
    u2 = uint_to_uniform_float(b)
    u3 = uint_to_uniform_float(c)
    u4 = uint_to_uniform_float(d)
    u5 = uint_to_uniform_float(e)
    u6 = uint_to_uniform_float(f)
    u7 = uint_to_uniform_float(g)
    u8 = uint_to_uniform_float(h)

    n1, n2 = pair_uniform_to_normal(u1, u2)
    n3, n4 = pair_uniform_to_normal(u3, u4)
    n5, n6 = pair_uniform_to_normal(u5, u6)
    n7, n8 = pair_uniform_to_normal(u7, u8)

    rn1 = tl.cast(libdevice.rint(n1), tl.uint32, bitcast=True)
    rn2 = tl.cast(libdevice.rint(n2), tl.uint32, bitcast=True)
    rn3 = tl.cast(libdevice.rint(n3), tl.uint32, bitcast=True)
    rn4 = tl.cast(libdevice.rint(n4), tl.uint32, bitcast=True)
    rn5 = tl.cast(libdevice.rint(n5), tl.uint32, bitcast=True)
    rn6 = tl.cast(libdevice.rint(n6), tl.uint32, bitcast=True)
    rn7 = tl.cast(libdevice.rint(n7), tl.uint32, bitcast=True)
    rn8 = tl.cast(libdevice.rint(n8), tl.uint32, bitcast=True)

    FP32_SIGN_MASK: tl.constexpr = 0x8000_0000
    FP32_MAG_MASK: tl.constexpr = 0x6000_0000

    packed = (
        (rn1 & FP32_SIGN_MASK | ((rn1 & FP32_MAG_MASK) >> 1))
        | ((rn2 & FP32_SIGN_MASK | ((rn2 & FP32_MAG_MASK) >> 1)) >> 4)
        | ((rn3 & FP32_SIGN_MASK | ((rn3 & FP32_MAG_MASK) >> 1)) >> 8)
        | ((rn4 & FP32_SIGN_MASK | ((rn4 & FP32_MAG_MASK) >> 1)) >> 12)
        | ((rn5 & FP32_SIGN_MASK | ((rn5 & FP32_MAG_MASK) >> 1)) >> 16)
        | ((rn6 & FP32_SIGN_MASK | ((rn6 & FP32_MAG_MASK) >> 1)) >> 20)
        | ((rn7 & FP32_SIGN_MASK | ((rn7 & FP32_MAG_MASK) >> 1)) >> 24)
        | ((rn8 & FP32_SIGN_MASK | ((rn8 & FP32_MAG_MASK) >> 1)) >> 28)
    )

    return packed


@triton.jit
def split_mix32(state):
    # [splitmix32a](https://github.com/joelkp/ranoise)
    SPLIT_MIX_ADD: tl.constexpr = 0x9E3779B9
    SPLIT_MIX_MUL1: tl.constexpr = 0x85EBCA6B
    SPLIT_MIX_MUL2: tl.constexpr = 0xC2B2AE35

    state = tlc.add(state, SPLIT_MIX_ADD, sanitize_overflow=False)
    x = state
    x ^= x >> 15
    x = tlc.mul(x, SPLIT_MIX_MUL1, sanitize_overflow=False)
    x ^= x >> 13
    x = tlc.mul(x, SPLIT_MIX_MUL2, sanitize_overflow=False)
    x ^= x >> 16
    return (x, state)


@triton.jit
def init_seed_trio(seed):
    tl.static_assert(seed.dtype == tl.uint32)
    x, state = split_mix32(seed)
    y, state = split_mix32(state)
    z, _ = split_mix32(state)

    if x == 0:
        x = tlc.add(x, 1, sanitize_overflow=False)
    if y == 0:
        y = tlc.add(y, 1, sanitize_overflow=False)
    if z == 0:
        z = tlc.add(z, 1, sanitize_overflow=False)
    return x, y, z


@triton.jit
def rotate_left32(x, lrot):
    tl.static_assert(x.dtype == tl.uint32)
    # lrot = lrot % 32
    return (x << lrot) | (x >> tlc.sub(32, lrot, sanitize_overflow=False))


@triton.jit
def RomuTrio32_next(x, y, z, n_rounds):
    ROMU_TRIO_MUL: tl.constexpr = 0xC61D672B
    tl.static_assert(x.dtype == tl.uint32)
    tl.static_assert(y.dtype == tl.uint32)
    tl.static_assert(z.dtype == tl.uint32)
    for _ in tl.range(n_rounds):
        xp = x
        yp = y
        zp = z

        x = tlc.mul(ROMU_TRIO_MUL, zp, sanitize_overflow=False)
        y = tlc.sub(yp, xp, sanitize_overflow=False)
        y = rotate_left32(y, 6)
        z = tlc.sub(zp, yp, sanitize_overflow=False)
        z = rotate_left32(z, 22)
    return x, y, z


@triton.jit
def noise_4bit_gen(seed, offs):
    # returns i32 [offs.shape] with 8 elements each
    # RomuTrio32 version

    # initialize state
    INIT_ROUNDS: tl.constexpr = 6
    s0, s1, s2 = init_seed_trio(seed.to(tl.uint32))
    tl.static_assert(s0.dtype == tl.uint32)
    x = tlc.add(s0, offs, sanitize_overflow=False)
    y = tlc.add(s1, offs, sanitize_overflow=False)
    z = tlc.add(s2, offs, sanitize_overflow=False)

    # 6 initial runs
    x3, y, z = RomuTrio32_next(
        x,
        y,
        z,
        n_rounds=INIT_ROUNDS,
    )

    # generate
    x0, y, z = RomuTrio32_next(x3, y, z, n_rounds=1)
    x1, y, z = RomuTrio32_next(x0, y, z, n_rounds=1)
    x2, y, z = RomuTrio32_next(x1, y, z, n_rounds=1)
    x3, y, z = RomuTrio32_next(x2, y, z, n_rounds=1)

    return pack_4bit(x0, x1, x2, x3)


@triton.jit
def noise_4bit_gen_fp(seed, offs, M: tl.constexpr, N: tl.constexpr):
    # offs.shape == (M, N//8)
    # returns fp32 [M, N] in {-2, -1, 0, 1, 2}

    # generates i32 [offs.shape] == [M, N//8]
    packed = noise_4bit_gen(seed, offs)

    return noise_4bit_unpack_fp(packed, M, N // 8)


@triton.jit
def noise_4bit_unpack_fp(
    packed, M: tl.constexpr, N: tl.constexpr, is_uniform: tl.constexpr
):
    # input : i32 [M, N] packed with 8 elements each
    # returns fp32 [M, N * 8] in {-2, -1, 0, 1, 2}

    # generalizable until
    # * (-7, 7) for fp8e5
    # * (-15, 15) for fp8e4

    N8: tl.constexpr = N * 8

    # total ~5 ops per output element

    # shifted tensors reused once (used 2 times)
    # 8 ops =~ 1 op per output element
    packed_l4 = tl.cast(packed << 4, dtype=tl.int8)
    packed_l0 = tl.cast(packed, dtype=tl.int8)
    packed_r4 = tl.cast(packed >> 4, dtype=tl.int8)
    packed_r8 = tl.cast(packed >> 8, dtype=tl.int8)
    packed_r12 = tl.cast(packed >> 12, dtype=tl.int8)
    packed_r16 = tl.cast(packed >> 16, dtype=tl.int8)
    packed_r20 = tl.cast(packed >> 20, dtype=tl.int8)
    packed_r24 = tl.cast(packed >> 24, dtype=tl.int8)
    packed_r28 = tl.cast(packed >> 28, dtype=tl.int8)

    mask_sign: tl.constexpr = 0x80
    mask_mag: tl.constexpr = 0x03

    # 3 ops each
    sign_0 = packed_l4 & mask_sign
    mag_0 = packed_l0 & mask_mag
    val_0 = tl.cast(sign_0 | mag_0, dtype=tl.float8e5, bitcast=True)

    sign_1 = packed_l0 & mask_sign
    mag_1 = packed_r4 & mask_mag
    val_1 = tl.cast(sign_1 | mag_1, dtype=tl.float8e5, bitcast=True)

    sign_2 = packed_r4 & mask_sign
    mag_2 = packed_r8 & mask_mag
    val_2 = tl.cast(sign_2 | mag_2, dtype=tl.float8e5, bitcast=True)

    sign_3 = packed_r8 & mask_sign
    mag_3 = packed_r12 & mask_mag
    val_3 = tl.cast(sign_3 | mag_3, dtype=tl.float8e5, bitcast=True)

    sign_4 = packed_r12 & mask_sign
    mag_4 = packed_r16 & mask_mag
    val_4 = tl.cast(sign_4 | mag_4, dtype=tl.float8e5, bitcast=True)

    sign_5 = packed_r16 & mask_sign
    mag_5 = packed_r20 & mask_mag
    val_5 = tl.cast(sign_5 | mag_5, dtype=tl.float8e5, bitcast=True)

    sign_6 = packed_r20 & mask_sign
    mag_6 = packed_r24 & mask_mag
    val_6 = tl.cast(sign_6 | mag_6, dtype=tl.float8e5, bitcast=True)

    sign_7 = packed_r24 & mask_sign
    mag_7 = packed_r28 & mask_mag
    val_7 = tl.cast(sign_7 | mag_7, dtype=tl.float8e5, bitcast=True)

    val_01 = tl.join(val_0, val_1)
    val_23 = tl.join(val_2, val_3)
    val_45 = tl.join(val_4, val_5)
    val_67 = tl.join(val_6, val_7)

    val_03 = tl.join(val_01, val_23)
    val_47 = tl.join(val_45, val_67)

    val_07 = tl.reshape(tl.join(val_03, val_47), (M, N8), can_reorder=False)

    # recon_const: tl.constexpr = 2.0**16
    recon_const_h: tl.constexpr = 2.0**8
    # 1 op
    # alpha=0.154296875  for uniform with 7-points
    # alpha=0.2236328125 for uniform with 5-points
    # alpha=1.0          for normal
    alpha_scale: tl.constexpr = 0.154296875 if is_uniform else 1.0
    return val_07.to(tl.float16) * recon_const_h * recon_const_h * alpha_scale


def decode_4bit_torch(packed):
    M, N = packed.shape
    with proton.scope("unpack", metrics={"bytes": M * N * 8 * 4}):

        sign = torch.empty([M, N * 8], dtype=torch.int32, device=packed.device)
        mag = torch.empty([M, N * 8], dtype=torch.int32, device=packed.device)
        packed = packed.to(torch.int32)
        for i in range(8):
            sign[:, i * N : (i + 1) * N] = (
                (packed >> (2 + 4 * i)) & 0x0000_0002
            ) - 1
            mag[:, i * N : (i + 1) * N] = (packed >> (4 * i)) & 0x0000_0003

    return sign * mag


def _get_unpack_configs():
    configs = []
    for num_stages in [2, 3, 4, 5]:
        for num_warps in [2]:
            for m in [1]:
                for n in [1]:
                    configs.append(
                        triton.Config(
                            {
                                "BM": 32 * m,
                                "BN": 32 * n,
                            },
                            num_stages=num_stages,
                            num_warps=num_warps,
                        )
                    )
    return configs


@triton.autotune(
    configs=_get_unpack_configs(),
    key=["M", "N"],
)
@triton.jit
def rand_4bit_fp_kernel(
    g_packed,
    g_unpacked,
    M: tl.constexpr,
    N: tl.constexpr,
    is_uniform: tl.constexpr,
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    pid_m = pid // tl.cdiv(N, BN)
    pid_n = pid % tl.cdiv(N, BN)

    # load packed noise [BM, BN // 8] as part of [M, cdiv(N, 8)]
    stride_m = tl.cdiv(N, 8)
    stride_n = 1
    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN // 8 + tl.arange(0, BN // 8)
    packed_ptrs = g_packed + (
        offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
    )
    mask_m = offs_m[:, None] < M
    mask_n = offs_n[None, :] < tl.cdiv(N, 8)
    packed = tl.load(packed_ptrs, mask=mask_m & mask_n, other=0)

    # [BM, BN // 8] --> [BM, BN]
    unpacked = noise_4bit_unpack_fp(packed, BM, BN // 8, is_uniform=is_uniform)

    # store unpacked noise [BM, BN] as part of [M, N]
    stride_out_m = N
    stride_out_n = 1
    offs_out_m = pid_m * BM + tl.arange(0, BM)
    offs_out_n = pid_n * BN + tl.arange(0, BN)
    out_ptrs = g_unpacked + (
        offs_out_m[:, None] * stride_out_m + offs_out_n[None, :] * stride_out_n
    )
    mask_out_m = offs_out_m[:, None] < M
    mask_out_n = offs_out_n[None, :] < N
    tl.store(out_ptrs, unpacked, mask=mask_out_m & mask_out_n)
    return


def rand_4bit_fp(M, N, seed=42, is_uniform=False):
    # generate [M, N] fp32 with single element each

    # generate packed noise [M, cdiv(N, 8)]
    packed = rand_4bit_packed(
        M * triton.cdiv(N, 8), seed, is_uniform=is_uniform
    )

    # unpack noise
    noise = torch.empty((M, N), device="cuda")
    grid = lambda meta: (
        triton.cdiv(M, meta["BM"]) * triton.cdiv(N, meta["BN"]),
    )

    rand_4bit_fp_kernel[grid](
        packed.view((M, triton.cdiv(N, 8))), noise, M, N, is_uniform=is_uniform
    )

    return noise


def _get_romu_config():
    configs = []
    for num_stages in [2, 3, 4, 5]:
        for num_warps, data_per_sm, m in [
            (2, 4096, 512),  # ~88 Gelem/s
            (2, 1024, 256),  # ~87 Gelem/s
        ]:
            configs.append(
                triton.Config(
                    {
                        "BN": m,
                        "GN": data_per_sm // m,
                    },
                    num_stages=num_stages,
                    num_warps=num_warps,
                )
            )
    return configs


# implement RomuTrio32 in Triton
@triton.autotune(
    configs=_get_romu_config(),
    key=["N"],
)
@triton.jit
def randint_4bit_packed_kernel(
    # pointer i/o
    g_out,
    # pointer params
    N,
    # scalar params
    seed,
    # hyperparams
    BN: tl.constexpr,
    GN: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    offs_init = pid * BN * GN + tl.arange(0, BN)

    # initialize state
    INIT_ROUNDS: tl.constexpr = 6
    s0, s1, s2 = init_seed_trio(seed.to(tl.uint32))
    tl.static_assert(s0.dtype == tl.uint32)
    x = tlc.add(s0, offs_init, sanitize_overflow=False)
    y = tlc.add(s1, offs_init, sanitize_overflow=False)
    z = tlc.add(s2, offs_init, sanitize_overflow=False)
    x3, y, z = RomuTrio32_next(
        x,
        y,
        z,
        n_rounds=INIT_ROUNDS,
    )

    for i in tl.range(0, GN):
        # generate
        x0, y, z = RomuTrio32_next(x3, y, z, n_rounds=1)
        x1, y, z = RomuTrio32_next(x0, y, z, n_rounds=1)
        x2, y, z = RomuTrio32_next(x1, y, z, n_rounds=1)
        x3, y, z = RomuTrio32_next(x2, y, z, n_rounds=1)

        packed = pack_4bit(x0, x1, x2, x3)

        # store
        offs_n = pid * BN * GN + i * BN + tl.arange(0, BN)
        out_ptrs = g_out + offs_n
        mask_n = offs_n < N
        tl.store(out_ptrs, packed, mask=mask_n)
    return


# implement Triton default via Box-Muller
@triton.autotune(
    configs=_get_romu_config(),
    key=["N"],
)
@triton.jit
def randint_4bit_packed_bm_kernel(
    # pointer i/o
    g_out,
    # pointer params
    N,
    # scalar params
    seed,
    # hyperparams
    BN: tl.constexpr,
    GN: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    offs_init = pid * BN * GN + tl.arange(0, BN)

    # initialize state
    INIT_ROUNDS: tl.constexpr = 6
    s0, s1, s2 = init_seed_trio(seed.to(tl.uint32))
    tl.static_assert(s0.dtype == tl.uint32)
    x = tlc.add(s0, offs_init, sanitize_overflow=False)
    y = tlc.add(s1, offs_init, sanitize_overflow=False)
    z = tlc.add(s2, offs_init, sanitize_overflow=False)
    x7, y, z = RomuTrio32_next(
        x,
        y,
        z,
        n_rounds=INIT_ROUNDS,
    )

    for i in tl.range(0, GN):
        # generate
        x0, y, z = RomuTrio32_next(x7, y, z, n_rounds=1)
        x1, y, z = RomuTrio32_next(x0, y, z, n_rounds=1)
        x2, y, z = RomuTrio32_next(x1, y, z, n_rounds=1)
        x3, y, z = RomuTrio32_next(x2, y, z, n_rounds=1)
        x4, y, z = RomuTrio32_next(x3, y, z, n_rounds=1)
        x5, y, z = RomuTrio32_next(x4, y, z, n_rounds=1)
        x6, y, z = RomuTrio32_next(x5, y, z, n_rounds=1)
        x7, y, z = RomuTrio32_next(x6, y, z, n_rounds=1)

        packed = pack_4bit_bm(x0, x1, x2, x3, x4, x5, x6, x7)

        # store
        offs_n = pid * BN * GN + i * BN + tl.arange(0, BN)
        out_ptrs = g_out + offs_n
        mask_n = offs_n < N
        tl.store(out_ptrs, packed, mask=mask_n)
    return


# implement discrete uniform
@triton.autotune(
    configs=_get_romu_config(),
    key=["N"],
)
@triton.jit
def randint_4bit_packed_uniform_kernel(
    # pointer i/o
    g_out,
    # pointer params
    N,
    # scalar params
    seed,
    # hyperparams
    BN: tl.constexpr,
    GN: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    offs_init = pid * BN * GN + tl.arange(0, BN)

    # initialize state
    INIT_ROUNDS: tl.constexpr = 6
    s0, s1, s2 = init_seed_trio(seed.to(tl.uint32))
    tl.static_assert(s0.dtype == tl.uint32)
    x = tlc.add(s0, offs_init, sanitize_overflow=False)
    y = tlc.add(s1, offs_init, sanitize_overflow=False)
    z = tlc.add(s2, offs_init, sanitize_overflow=False)
    x, y, z = RomuTrio32_next(
        x,
        y,
        z,
        n_rounds=INIT_ROUNDS,
    )

    for i in tl.range(0, GN):
        # generate
        x, y, z = RomuTrio32_next(x, y, z, n_rounds=1)

        # case 1: {-3, -2, -1, -0, 0, 1, 2, 3} where alpha=0.154296875 (FP16)
        packed = x

        # case 2 : {-2, -1, -0, 0, 1, 2} where alpha=0.2236328125 (FP16)
        # three = x & (x >> 1) & 0x1111_1111
        # packed = x & (~three)

        # store
        offs_n = pid * BN * GN + i * BN + tl.arange(0, BN)
        out_ptrs = g_out + offs_n
        mask_n = offs_n < N
        tl.store(out_ptrs, packed, mask=mask_n)
    return


def rand_4bit_packed(N, seed, is_uniform):
    # generate [N] u32
    noise = torch.empty([N], dtype=torch.uint32, device="cuda")

    grid = lambda meta: (triton.cdiv(N, meta["BN"] * meta["GN"]),)
    if is_uniform:
        randint_4bit_packed_uniform_kernel[grid](noise, N, seed)
    else:
        # randint_4bit_packed_bm_kernel[grid](noise, N, seed)
        randint_4bit_packed_kernel[grid](noise, N, seed)
    return noise


def get_bench_configs():
    configs = []

    line_vals = [
        f"{qtype}-{num_bit}"
        for (qtype, num_bit) in [
            ("packed", "4bit"),
            ("unpacked", "4bit"),
        ]
    ]
    styles = [
        (color, style)
        for color in [
            "red",
            "cyan",
        ]
        for style in ["-", ":"]
    ]
    line_names = line_vals

    configs.append(
        triton.testing.Benchmark(
            x_names=[
                "M",
                "N",
            ],  # Argument names to use as an x-axis for the plot
            x_vals=[128 * i for i in range(1, 8)]
            + [256 * i for i in range(4, 8)]
            + [512 * i for i in range(4, 8)]
            + [1024 * i for i in range(4, 9)],
            y_log=True,
            x_log=False,
            line_arg="provider",
            line_vals=line_vals,
            line_names=line_names,
            styles=styles,
            ylabel="Gelem/s",  # Label name for the y-axis
            args={},
            plot_name="prng",
        )
    )
    return configs


@triton.testing.perf_report(get_bench_configs())
def benchmark(M, N, provider):
    q_part = provider.split("-")[0]
    b_part = provider.split("-")[-1]
    B = int(b_part[:1])

    print(f"(M={M}, N={N}, {B}-bit, {q_part}), {provider}")

    func = None
    if q_part == "unpacked" and B == 4:
        func = rand_4bit_fp
    elif q_part == "packed" and B == 4:
        func = rand_4bit_packed
    else:
        raise ValueError(q_part, B)

    quantiles = [0.5, 0.2, 0.8]
    ms, min_ms, max_ms = triton.testing.do_bench(
        lambda: func(
            M,
            N,
        ),
        quantiles=quantiles,
        rep=1000,
    )
    perf = lambda ms: M * N * 1e-9 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


if __name__ == "__main__":
    size = 4096
    print("verify ...")
    # TODO verify

    print("benchmark ...")
    # TODO benchmark
