import time

import torch

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


@torch.inference_mode()
def main(num_tokens: int,
         hidden_size: int,
         static_scale: bool,
         quant_dtype: torch.dtype,
         dtype: torch.dtype,
         seed: int = 0,
         do_profile: bool = False,
         num_warmup_iters: int = 5,
         num_iters: int = 100) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device("cuda")

    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
    scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None

    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

        for _ in range(num_iters):
            if quant_dtype == torch.int8:
                ops.scaled_int8_quant(x, scale)
            else:
                ops.scaled_fp8_quant(x, scale)
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
    run_benchmark = run_cuda_benchmark
    run_benchmark(num_iters=num_warmup_iters, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=num_iters, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':

    def to_torch_dtype(dt):
        if dt == "int8":
            return torch.int8
        if dt == "fp8":
            return torch.float8_e4m3fn
        raise ValueError(f"Unsupported dtype: {dt}")

    parser = FlexibleArgumentParser(
        description="Benchmark the quantization (fp8 or int8) kernel.")
    parser.add_argument("--num-tokens", type=int, default=4096)
    parser.add_argument("--hidden-size", type=int, default=8192)
    parser.add_argument("--static-scale", action="store_true")
    parser.add_argument("--quant-dtype",
                        type=str,
                        choices=["fp8", "int8"],
                        default="int8")
    parser.add_argument("--dtype",
                        type=str,
                        choices=["half", "bfloat16", "float"],
                        default="half")

    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
    parser.add_argument("--num-warmup-iters", type=int, default=5)
    parser.add_argument("--num-iters",
                        type=int,
                        default=100,
                        help="Number of benchmark iterations. "
                        "If --profile is set, this number is ignored")

    args = parser.parse_args()
    print(args)

    main(num_tokens=args.num_tokens,
         hidden_size=args.hidden_size,
         static_scale=args.static_scale,
         quant_dtype=to_torch_dtype(args.quant_dtype),
         dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
         seed=args.seed,
         do_profile=args.profile,
         num_warmup_iters=args.num_warmup_iters,
         num_iters=args.num_iters)
