# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import time
import torch
import numpy as np

def _benchmark_once(func, get_inputs, grad):
    inputs = get_inputs()

    # mem - start
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    mem_0 = torch.cuda.memory_allocated()

    torch.cuda.synchronize()  # fwd - start
    t0 = time.perf_counter()
    y = func(inputs)
    torch.cuda.synchronize()  # fwd - end
    t1 = time.perf_counter()

    torch.cuda.synchronize()  # bwd - start
    t2 = time.perf_counter()
    y.backward(grad)
    torch.cuda.synchronize()  # bwd - end
    t3 = time.perf_counter()

    # mem - end
    torch.cuda.synchronize()
    mem_1 = torch.cuda.max_memory_allocated()

    mem_gib = (mem_1 - mem_0) / (2 ** 30)
    fwd_ms = (t1 - t0) * 1000
    bwd_ms = (t3 - t2) * 1000
    return fwd_ms, bwd_ms, mem_gib


def benchmark(func, get_inputs):
    # Get `grad`
    with torch.no_grad():
        inputs = get_inputs()
        y = func(inputs)
        grad = torch.ones_like(y)
        del inputs
        del y

    # Start
    fwd_list = []
    bwd_list = []
    mem_list = []
    for idx in range(120):
        fwd_ms, bwd_ms, mem_gib = _benchmark_once(func, get_inputs, grad)
        fwd_list.append(fwd_ms)
        bwd_list.append(bwd_ms)
        mem_list.append(mem_gib)

    # Post-processing
    def _f(x):
        return np.sort(np.array(x)[20:])[5:-5].mean()
    fwd_avg = _f(fwd_list)
    bwd_avg = _f(bwd_list)
    mem_avg = _f(mem_list)
    return fwd_avg, bwd_avg, mem_avg


if __name__ == "__main__":
    import torch.nn.functional as F

    @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True, cache_enabled=False)
    @torch.compile()
    def naive_attn(inputs):
        q, k, v = inputs
        scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
        return torch.matmul(F.softmax(scores, dim=-1), v)

    @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True, cache_enabled=False)
    @torch.compile()
    def sdpa_attn(inputs):
        q, k, v = inputs
        return F.scaled_dot_product_attention(q, k, v)

    def get_inputs():
        batch_size = 4
        num_token = 1024
        num_head = 8
        head_size = 128
        shape = (batch_size, num_head, num_token, head_size)
        q = torch.randn(shape, dtype=torch.float32, device="cuda", requires_grad=True)
        k = torch.randn(shape, dtype=torch.float32, device="cuda", requires_grad=True)
        v = torch.randn(shape, dtype=torch.float32, device="cuda", requires_grad=True)
        return q, k, v


    fwd_ms, bwd_ms, mem_gib = benchmark(naive_attn, get_inputs)
    print("\n\nNaive:")
    print(f"fwd: {fwd_ms:.3f} ms")
    print(f"bwd: {bwd_ms:.3f} ms")
    print(f"mem: {mem_gib:.3f} GiB")


    fwd_ms, bwd_ms, mem_gib = benchmark(sdpa_attn, get_inputs)
    print("\n\nSDPA:")
    print(f"fwd: {fwd_ms:.3f} ms")
    print(f"bwd: {bwd_ms:.3f} ms")
    print(f"mem: {mem_gib:.3f} GiB")
