# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import torch
import time
import itertools
from model.ops.mh_router.mh_router_fwd import mh_router_fwd
from model.ops.mh_router.mh_router_bwd import mh_router_bwd


# Profiles: (num_heads, head_size)
PROFILES = {
    64: (16, 64),
    128: (8, 128),
    256: (4, 256),
}

# Common parameters
BATCH_SIZE = 8
NUM_TOKENS = 128
NUM_EXPERTS = 1024
K = 4
USE_SIGMOID = False

# Search spaces
FWD_SEARCH_SPACE = {
    'N': [16, 32, 64],
    'M': [32, 64, 128],
    'num_warps': [2, 4, 8],
    'num_stages': [2, 3, 4],
}

BWD_SEARCH_SPACE = {
    'N': [16, 32, 64],
    'num_warps': [2, 4, 8],
    'num_stages': [2, 3, 4],
}

WARMUP_ITERS = 10
BENCH_ITERS = 100


def create_tensors(num_heads, head_size, device='cuda'):
    """Create input tensors for benchmarking."""
    X = torch.randn(
        BATCH_SIZE, NUM_TOKENS, num_heads, head_size,
        dtype=torch.bfloat16, device=device
    ).contiguous()
    R = torch.randn(
        num_heads, head_size, NUM_EXPERTS,
        dtype=torch.float32, device=device
    ).contiguous()
    auxfree_bias = torch.randn(
        num_heads, NUM_EXPERTS,
        dtype=torch.float32, device=device
    ).contiguous()
    return X, R, auxfree_bias


def bench_forward(X, R, auxfree_bias, N, M, num_warps, num_stages, warmup=WARMUP_ITERS, iters=BENCH_ITERS):
    """Benchmark forward kernel with given hyperparameters."""
    batch_size, num_tokens, num_heads, head_size = X.shape
    _, _, num_experts = R.shape
    device = X.device

    # Check constraints
    if num_tokens % N != 0:
        return None, "num_tokens not divisible by N"
    if num_experts % M != 0:
        return None, "num_experts not divisible by M"
    if K > M:
        return None, "K > M"
    if num_experts < M:
        return None, "num_experts < M"

    # Allocate outputs
    top_logit = torch.empty(
        (batch_size, num_tokens, num_heads, K),
        dtype=torch.float32, device=device
    )
    top_idx = torch.empty(
        (batch_size, num_tokens, num_heads, K),
        dtype=torch.int64, device=device
    )

    grid = (batch_size, num_tokens // N, num_heads)

    # Warmup
    for _ in range(warmup):
        mh_router_fwd[grid](
            X, R, auxfree_bias, top_logit, top_idx,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            R.stride(0), R.stride(1), R.stride(2),
            auxfree_bias.stride(0), auxfree_bias.stride(1),
            top_logit.stride(0), top_logit.stride(1), top_logit.stride(2), top_logit.stride(3),
            top_idx.stride(0), top_idx.stride(1), top_idx.stride(2), top_idx.stride(3),
            head_size=head_size,
            number_of_experts=num_experts,
            number_of_tokens=num_tokens,
            USE_SIGMOID=USE_SIGMOID,
            K=K,
            BLOCK_N=N,
            BLOCK_M=M,
            num_warps=num_warps,
            num_stages=num_stages,
        )

    torch.cuda.synchronize()

    # Benchmark
    start = time.perf_counter()
    for _ in range(iters):
        mh_router_fwd[grid](
            X, R, auxfree_bias, top_logit, top_idx,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            R.stride(0), R.stride(1), R.stride(2),
            auxfree_bias.stride(0), auxfree_bias.stride(1),
            top_logit.stride(0), top_logit.stride(1), top_logit.stride(2), top_logit.stride(3),
            top_idx.stride(0), top_idx.stride(1), top_idx.stride(2), top_idx.stride(3),
            head_size=head_size,
            number_of_experts=num_experts,
            number_of_tokens=num_tokens,
            USE_SIGMOID=USE_SIGMOID,
            K=K,
            BLOCK_N=N,
            BLOCK_M=M,
            num_warps=num_warps,
            num_stages=num_stages,
        )
    torch.cuda.synchronize()
    end = time.perf_counter()

    return (end - start) / iters * 1000, None  # ms


def bench_backward(X, R, top_idx, d_top_logit, N, num_warps, num_stages, warmup=WARMUP_ITERS, iters=BENCH_ITERS):
    """Benchmark backward kernel with given hyperparameters."""
    batch_size, num_tokens, num_heads, head_size = X.shape
    _, _, num_experts = R.shape
    device = X.device

    # Check constraints
    if num_tokens % N != 0:
        return None, "num_tokens not divisible by N"

    # Allocate outputs
    dX = torch.empty(
        (batch_size, num_tokens, num_heads, head_size),
        dtype=torch.bfloat16, device=device
    )
    dR = torch.zeros(
        (num_heads, head_size, num_experts),
        dtype=torch.float32, device=device
    )

    grid = (batch_size, num_tokens // N, num_heads)

    # Warmup
    for _ in range(warmup):
        mh_router_bwd[grid](
            X, R, top_idx, d_top_logit, dX, dR,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            R.stride(0), R.stride(1), R.stride(2),
            top_idx.stride(0), top_idx.stride(1), top_idx.stride(2), top_idx.stride(3),
            d_top_logit.stride(0), d_top_logit.stride(1), d_top_logit.stride(2), d_top_logit.stride(3),
            dX.stride(0), dX.stride(1), dX.stride(2), dX.stride(3),
            dR.stride(0), dR.stride(1), dR.stride(2),
            head_size=head_size,
            number_of_experts=num_experts,
            number_of_tokens=num_tokens,
            K=K,
            BLOCK_N=N,
            num_warps=num_warps,
            num_stages=num_stages,
        )

    torch.cuda.synchronize()

    # Benchmark
    start = time.perf_counter()
    for _ in range(iters):
        mh_router_bwd[grid](
            X, R, top_idx, d_top_logit, dX, dR,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            R.stride(0), R.stride(1), R.stride(2),
            top_idx.stride(0), top_idx.stride(1), top_idx.stride(2), top_idx.stride(3),
            d_top_logit.stride(0), d_top_logit.stride(1), d_top_logit.stride(2), d_top_logit.stride(3),
            dX.stride(0), dX.stride(1), dX.stride(2), dX.stride(3),
            dR.stride(0), dR.stride(1), dR.stride(2),
            head_size=head_size,
            number_of_experts=num_experts,
            number_of_tokens=num_tokens,
            K=K,
            BLOCK_N=N,
            num_warps=num_warps,
            num_stages=num_stages,
        )
    torch.cuda.synchronize()
    end = time.perf_counter()

    return (end - start) / iters * 1000, None  # ms


def sweep_forward(profile_name, num_heads, head_size):
    """Sweep forward kernel hyperparameters for a given profile."""
    print(f"\n{'='*70}")
    print(f"Forward Sweep - Profile {profile_name}: heads={num_heads}, head_size={head_size}")
    print(f"{'='*70}")

    X, R, auxfree_bias = create_tensors(num_heads, head_size)

    results = []
    total_configs = (
        len(FWD_SEARCH_SPACE['N']) *
        len(FWD_SEARCH_SPACE['M']) *
        len(FWD_SEARCH_SPACE['num_warps']) *
        len(FWD_SEARCH_SPACE['num_stages'])
    )
    config_idx = 0

    for N, M, num_warps, num_stages in itertools.product(
        FWD_SEARCH_SPACE['N'],
        FWD_SEARCH_SPACE['M'],
        FWD_SEARCH_SPACE['num_warps'],
        FWD_SEARCH_SPACE['num_stages'],
    ):
        config_idx += 1
        config_str = f"N={N:3d}, M={M:3d}, warps={num_warps}, stages={num_stages}"
        try:
            time_ms, skip_reason = bench_forward(X, R, auxfree_bias, N, M, num_warps, num_stages)
            if time_ms is not None:
                results.append({
                    'N': N,
                    'M': M,
                    'num_warps': num_warps,
                    'num_stages': num_stages,
                    'time_ms': time_ms,
                })
                print(f"[{config_idx:3d}/{total_configs}] {config_str} -> {time_ms:8.4f} ms")
            else:
                print(f"[{config_idx:3d}/{total_configs}] {config_str} -> SKIP ({skip_reason})")
        except Exception as e:
            print(f"[{config_idx:3d}/{total_configs}] {config_str} -> FAIL ({e})")

    if results:
        best = min(results, key=lambda x: x['time_ms'])
        print(f"\n*** Best Forward: N={best['N']}, M={best['M']}, "
              f"warps={best['num_warps']}, stages={best['num_stages']} -> {best['time_ms']:.4f} ms")

    return results


def sweep_backward(profile_name, num_heads, head_size):
    """Sweep backward kernel hyperparameters for a given profile."""
    print(f"\n{'='*70}")
    print(f"Backward Sweep - Profile {profile_name}: heads={num_heads}, head_size={head_size}")
    print(f"{'='*70}")

    X, R, auxfree_bias = create_tensors(num_heads, head_size)

    # Create dummy top_idx and d_top_logit
    top_idx = torch.randint(
        0, NUM_EXPERTS,
        (BATCH_SIZE, NUM_TOKENS, num_heads, K),
        dtype=torch.int64, device='cuda'
    )
    d_top_logit = torch.randn(
        BATCH_SIZE, NUM_TOKENS, num_heads, K,
        dtype=torch.float32, device='cuda'
    ).contiguous()

    results = []
    total_configs = (
        len(BWD_SEARCH_SPACE['N']) *
        len(BWD_SEARCH_SPACE['num_warps']) *
        len(BWD_SEARCH_SPACE['num_stages'])
    )
    config_idx = 0

    for N, num_warps, num_stages in itertools.product(
        BWD_SEARCH_SPACE['N'],
        BWD_SEARCH_SPACE['num_warps'],
        BWD_SEARCH_SPACE['num_stages'],
    ):
        config_idx += 1
        config_str = f"N={N:3d}, warps={num_warps}, stages={num_stages}"
        try:
            time_ms, skip_reason = bench_backward(X, R, top_idx, d_top_logit, N, num_warps, num_stages)
            if time_ms is not None:
                results.append({
                    'N': N,
                    'num_warps': num_warps,
                    'num_stages': num_stages,
                    'time_ms': time_ms,
                })
                print(f"[{config_idx:3d}/{total_configs}] {config_str} -> {time_ms:8.4f} ms")
            else:
                print(f"[{config_idx:3d}/{total_configs}] {config_str} -> SKIP ({skip_reason})")
        except Exception as e:
            print(f"[{config_idx:3d}/{total_configs}] {config_str} -> FAIL ({e})")

    if results:
        best = min(results, key=lambda x: x['time_ms'])
        print(f"\n*** Best Backward: N={best['N']}, "
              f"warps={best['num_warps']}, stages={best['num_stages']} -> {best['time_ms']:.4f} ms")

    return results


def main():
    print("=" * 70)
    print("MH Router Hyperparameter Sweep")
    print("=" * 70)
    print(f"Batch size:    {BATCH_SIZE}")
    print(f"Num tokens:    {NUM_TOKENS}")
    print(f"Num experts:   {NUM_EXPERTS}")
    print(f"K (top-k):     {K}")
    print(f"USE_SIGMOID:   {USE_SIGMOID}")
    print(f"Warmup iters:  {WARMUP_ITERS}")
    print(f"Bench iters:   {BENCH_ITERS}")
    print(f"Device:        {torch.cuda.get_device_name(0)}")

    all_results = {}

    for profile_name, (num_heads, head_size) in PROFILES.items():
        fwd_results = sweep_forward(profile_name, num_heads, head_size)
        bwd_results = sweep_backward(profile_name, num_heads, head_size)
        all_results[profile_name] = {
            'forward': fwd_results,
            'backward': bwd_results,
        }

    # Summary
    print("\n")
    print("=" * 70)
    print("SUMMARY")
    print("=" * 70)

    for profile_name in sorted(all_results.keys()):
        results = all_results[profile_name]
        num_heads, head_size = PROFILES[profile_name]
        print(f"\nProfile {profile_name} (heads={num_heads}, head_size={head_size}):")

        if results['forward']:
            best_fwd = min(results['forward'], key=lambda x: x['time_ms'])
            print(f"  Forward:  N={best_fwd['N']:3d}, M={best_fwd['M']:3d}, "
                  f"warps={best_fwd['num_warps']}, stages={best_fwd['num_stages']} "
                  f"-> {best_fwd['time_ms']:.4f} ms")
        else:
            print(f"  Forward:  No valid configurations")

        if results['backward']:
            best_bwd = min(results['backward'], key=lambda x: x['time_ms'])
            print(f"  Backward: N={best_bwd['N']:3d}, "
                  f"warps={best_bwd['num_warps']}, stages={best_bwd['num_stages']} "
                  f"-> {best_bwd['time_ms']:.4f} ms")
        else:
            print(f"  Backward: No valid configurations")

    # Print recommended configs in code format
    print("\n")
    print("=" * 70)
    print("RECOMMENDED CONFIGURATIONS (copy-paste ready)")
    print("=" * 70)

    for profile_name in sorted(all_results.keys()):
        results = all_results[profile_name]
        num_heads, head_size = PROFILES[profile_name]
        print(f"\n# Profile {profile_name}: (num_head={num_heads}, head_size={head_size})")

        if results['forward']:
            best_fwd = min(results['forward'], key=lambda x: x['time_ms'])
            print(f"# Forward: {best_fwd['time_ms']:.4f} ms")
            print(f"N = {best_fwd['N']}")
            print(f"M = {best_fwd['M']}")
            print(f"num_of_warps = {best_fwd['num_warps']}")
            print(f"num_of_stages = {best_fwd['num_stages']}")

        if results['backward']:
            best_bwd = min(results['backward'], key=lambda x: x['time_ms'])
            print(f"# Backward: {best_bwd['time_ms']:.4f} ms")
            print(f"N = {best_bwd['N']}")
            print(f"num_of_warps = {best_bwd['num_warps']}")
            print(f"num_of_stages = {best_bwd['num_stages']}")


if __name__ == "__main__":
    main()
