
import sys
import numpy as np
import torch
import os
import triton
from scipy.linalg import hadamard

import qutlass
from qutlass import matmul_nvf4_bf16_tn, fusedQuantizeNv
from qutlass.utils import to_blocked

import flashinfer
from flashinfer import (
    SfLayout,
    autotune,
    mm_fp4,
    nvfp4_quantize,
    mxfp4_quantize,
)

PROVIDER_CFGS = {
    "torch-bf16":    dict(                   enabled=True),
    "nvfp4-cudnn":   dict(backend="cudnn",   enabled=True),
    "nvfp4-cutlass": dict(backend="cutlass", enabled=True),
}

_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]


def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
    return torch.tensor(
        hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
    )

def _quant_weight_nvfp4(b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, global_scale: torch.Tensor, device: str):
    weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(b, forward_hadamard_matrix, global_scale)
    weight_hf_scale_block = to_blocked(weight_hf_e8m0, True)
    return weight_hf_e2m1, weight_hf_scale_block

def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
    m, k = a.shape
    n = b.shape[0]
    auto_tuning = True
    use_128x4_sf_layout = False
    global_scale = torch.tensor([1.]).to("cuda")
    weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(b, forward_hadamard_matrix, global_scale, device)
    alpha = torch.Tensor([1.]).to("cuda")
    out = torch.empty([m, n], device="cuda", dtype=torch.bfloat16)

    if cfg["backend"] == "cudnn":
        tmp_in_e2m1, tmp_in_e4m3 = fusedQuantizeNv(a, forward_hadamard_matrix, global_scale)
        tmp_in_scale_block = to_blocked(tmp_in_e4m3, True)
        with autotune(True):
            mm_fp4(
                tmp_in_e2m1,
                weight_hf_e2m1.T,
                tmp_in_scale_block.view(-1, k//16),
                weight_hf_scale_block.view(-1, k//16).T,
                alpha,
                torch.bfloat16,
                out,
                block_size=16,
                use_8x4_sf_layout=False,
                backend="cudnn",
                use_nvfp4=True
            )
        torch.cuda.synchronize()

        def run():
            input_hf_e2m1, input_hf_e4m3 = fusedQuantizeNv(a, forward_hadamard_matrix, global_scale)
            input_hf_scale_block = to_blocked(input_hf_e4m3, True)
            mm_fp4(
                input_hf_e2m1,
                weight_hf_e2m1.T,
                input_hf_scale_block.view(-1, k//16),
                weight_hf_scale_block.view(-1, k//16).T,
                alpha,
                torch.bfloat16,
                out,
                block_size=16,
                use_8x4_sf_layout=False,
                backend="cudnn",
                use_nvfp4=True
            )

        return run

    def run():
        input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(a, forward_hadamard_matrix, global_scale)
        input_hf_scale_block = to_blocked(input_hf_e8m0, True)
        return qutlass.matmul_nvf4_bf16_tn(
                input_hf_e2m1, weight_hf_e2m1, input_hf_scale_block, weight_hf_scale_block, alpha
        )

    return run

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
        x_vals=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 24576, 32768],
        x_log=False,
        line_arg="provider",
        line_vals=_enabled,
        line_names=_enabled,
        ylabel="TFLOP/s (larger is better)",
        plot_name="BF16 vs NVFP4 GEMMs",
        args={},
    )
)
def benchmark(batch_size, provider, N, K, had_size):
    M = batch_size
    device = "cuda"
    dtype = torch.bfloat16

    a = torch.randn((M, K), device=device, dtype=dtype)
    b = torch.randn((N, K), device=device, dtype=dtype)
    forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)

    quantiles = [0.5, 0.2, 0.8]

    if provider == "torch-bf16":
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
            lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
        )
    else:
        cfg = PROVIDER_CFGS[provider]
        run_quant = build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device)
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
            lambda: run_quant(), rep=200, quantiles=quantiles
        )

    to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
    return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)


MODELS = {
    #'Llama7B': [
    #    (4096, 3 * 4096),
    #    (4096, 4096),
    #    (4096, 2 * 10752),
    #    (10752, 4096)
    #],
    #'Llama13B': [
    #    (5120, 3 * 5120),
    #    (5120, 5120),
    #    (5120, 2 * 13568),
    #    (13568, 5120)
    #],
    #'Llama33B': [
    #    (6656, 3 * 6656),
    #    (6656, 6656),
    #    (6656, 2 * 17664),
    #    (17664, 6656)
    #],
    #'Llama65B': [
    #    (8192, 3 * 8192),
    #    (8192, 8192),
    #    (8192, 2 * 21760),
    #    (21760, 8192)
    #],
    #'Qwen3-0.6B': ((1024, 2048), (1024, 1024), (1024, 6144), (3072, 1024)),
    #'Qwen3-1.7B': ((2048, 4096), (2048, 2048), (2048, 12288), (6144, 2048)),
    #'Qwen3-4B': ((2560, 2560), (2560, 2560), (2560, 19456), (9728, 2560)),
    #'Qwen3-8B': ((4096, 4096), (4096, 4096), (4096, 24576), (12288, 4096)),
    #'Qwen3-14B': [(5120, 5120), (5120, 5120), (5120, 34816), (17408, 5120)],
    'Qwen3-32B': [(5120, 5120), (5120, 51200), (25600, 5120)],
    'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)]
}

for model, layers in MODELS.items():
    for K, N in layers:
        for had_size in [16,32,64,128]:
            print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
            save_path = f"bench_nvfp4_res_n{N}_k{K}_had{had_size}"
            os.makedirs(save_path, exist_ok=True)
            benchmark.run(
                print_data=True,
                show_plots=True,
                save_path=save_path,
                N=N,
                K=K,
                had_size=had_size
            )