from sparse_gemv import *
from compile_wrapper import *
import torch
import time
import numpy as np

use_cuda_graph = True

def benchmark_function_with_cuda_graph(func, *args, warmup=1000, repeat=1000, **kwargs):
    if use_cuda_graph == False:
        raise ValueError("")
    """
    Benchmark a function using CUDA Graph to eliminate kernel launch overhead.
    
    Args:
        func: Function to benchmark
        *args: Positional arguments to pass to the function
        warmup: Number of warmup iterations
        repeat: Number of timing iterations
        **kwargs: Keyword arguments to pass to the function
        
    Returns:
        mean_time: Mean execution time in milliseconds
        std_time: Standard deviation of execution time in milliseconds
    """
    result = func(*args, **kwargs)
    torch.cuda.synchronize()
    
    graph = torch.cuda.CUDAGraph()
    
    with torch.cuda.graph(graph):
        graph_result = func(*args, **kwargs)
    
    for _ in range(warmup):
        graph.replay()
        torch.cuda.synchronize()
    
    times = []
    for _ in range(repeat):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        graph.replay()
        end.record()
        
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))
    
    return np.mean(times), np.std(times), graph_result

def benchmark_function(func, *args, warmup=100, repeat=100, **kwargs):
    """
    Benchmark a function with warmup runs and multiple repeats.
    """
    # Warmup runs
    for _ in range(warmup):
        result = func(*args, **kwargs)
        torch.cuda.synchronize()
    
    # Timing runs
    times = []
    for _ in range(repeat):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        result = func(*args, **kwargs)
        end.record()
        
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))
    
    return np.mean(times), np.std(times), result

input_dim = 8192
output_dim = 28672

weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.float16)

x = torch.randn(1, 1, input_dim, device="cuda", dtype=torch.float16)
input_vec = x.contiguous()

x_norm_squared = torch.sum(torch.pow(x, 2))

print("=== Performance Benchmarking ===")

print("\n1. Benchmarking dense PyTorch matmul with CUDA Graph:")
dense_func = lambda x, weight: torch.nn.functional.linear(x, weight)
try:
    dense_time_mean, dense_time_std, dense_output = benchmark_function_with_cuda_graph(
        dense_func, x, weight
    )
    print(f"   Dense matmul (CUDA Graph): {dense_time_mean:.3f} ± {dense_time_std:.3f} ms")
except Exception as e:
    print(f"   CUDA Graph failed for dense matmul: {e}")
    dense_time_mean, dense_time_std, dense_output = benchmark_function(
        dense_func, x, weight
    )
    print(f"   Dense matmul (standard): {dense_time_mean:.3f} ± {dense_time_std:.3f} ms")

print("\n2. Benchmarking sparse GEMV implementation with CUDA Graph:")
weight = weight.t().contiguous().t()
assert weight.stride(0) == 1

sparse_gemv_kernel = SparseGEMV.initialize("sparse_gemv", weight.device.type)
sparse_gemv = sparse_gemv_kernel.operator(True)
splitk_sparse_gemv_compiled = torch.compile(sparse_gemv, fullgraph=True)

sparse_func = lambda x, weight, threshold_ratio, norm_squared: splitk_sparse_gemv_compiled(
    x=x, weight=weight, threshold_ratio=threshold_ratio, norm_squared=norm_squared
)

try:
    sparse_time_mean, sparse_time_std, sparse_output = benchmark_function_with_cuda_graph(
        sparse_func, x, weight, 0, x_norm_squared
    )
    print(f"   Sparse GEMV (CUDA Graph): {sparse_time_mean:.3f} ± {sparse_time_std:.3f} ms")
except Exception as e:
    print(f"   CUDA Graph failed for sparse GEMV: {e}")
    sparse_time_mean, sparse_time_std, sparse_output = benchmark_function(
        sparse_func, x, weight, 0, x_norm_squared
    )
    print(f"   Sparse GEMV (standard): {sparse_time_mean:.3f} ± {sparse_time_std:.3f} ms")

speedup = dense_time_mean / sparse_time_mean
print(f"\n3. Performance comparison:")
print(f"   Speedup: {speedup:.2f}x ({'-' if speedup < 1 else '+'}{abs(1-speedup)*100:.1f}%)")

print("\n4. Output verification:")
print(f"   Output shape: {sparse_output.shape}")
max_diff = (sparse_output - dense_output).abs().max().item()
print(f"   Max difference: {max_diff}")

print("\n5. Testing different sparsity thresholds with CUDA Graph:")
sparsities = [0.25, 0.5, 0.75, 1]
x_squared = torch.pow(x, 2) 
x_squared_sorted = torch.sort(x_squared.flatten())[0]
results = []

for sparsity in sparsities:
    k = max(1, int(sparsity * input_dim))
    threshold = (x_squared_sorted[k - 1] / x_norm_squared).item()
    
    try:
        sparse_time_mean, sparse_time_std, sparse_output = benchmark_function_with_cuda_graph(
            sparse_func, x, weight, threshold, x_norm_squared
        )
        method = "CUDA Graph"
    except Exception as e:
        sparse_time_mean, sparse_time_std, sparse_output = benchmark_function(
            sparse_func, x, weight, threshold, x_norm_squared
        )
        method = "standard"
    
    speedup = dense_time_mean / sparse_time_mean
    results.append((sparsity, threshold, sparse_time_mean, sparse_time_std, speedup, method))

for sparsity, threshold, sparse_time_mean, sparse_time_std, speedup, method in results:
    print(f"   Threshold {threshold:.8f}: {sparse_time_mean:.3f} ± {sparse_time_std:.3f} ms ({method}), "
          f"Sparsity: {sparsity*100:.1f}%, Speedup: {speedup:.2f}x")



