# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import json
import torch
from model.ops.mh_router.mh_router import mh_router
from utils.benchmark import benchmark

torch.set_float32_matmul_precision("highest")

@torch.compile()
def func_base(inputs):
    X, R, auxfree_bias, K, USE_SIGMOID = inputs
    assert not USE_SIGMOID, "Sigmoid not supported"
    # X: (batch_size, num_token, num_head, head_size); bfloat16
    # R: (num_head, head_size, num_expert); float32
    # a: (num_head, num_expert); float32

    # Define variables
    batch_size, num_token, num_head, head_size = X.shape
    num_head, num_expert = auxfree_bias.shape
    # (batch_size, num_token, num_head, head_size); float32; contiguous
    X = X.float()
    # (batch_size * num_token, num_head, head_size); float32; contiguous
    X = X.view(batch_size * num_token, num_head, head_size)
    # (num_head, batch_size * num_token, head_size); float32; non-contiguous
    X = X.transpose(0, 1)
    # (num_head, batch_size * num_token, num_expert); float32; contiguous
    router_values = X @ R
    # (num_head, batch_size * num_token, num_expert); float32; contiguous
    topk_input = router_values + auxfree_bias.view(num_head, 1, num_expert)
    # (num_head, batch_size * num_token, num_expert_active); int64; contiguous; detached
    expert_assign = torch.topk(
        input=topk_input,
        k=K,
        dim=2,
        largest=True,
        sorted=False,
    ).indices.detach()
    del topk_input
    # (num_head, batch_size * num_token, num_expert_active); float32; contiguous
    router_values = torch.gather(
        input=router_values,
        dim=2,
        index=expert_assign,
    )
    return router_values


def func_ours(inputs):
    X, R, auxfree_bias, K, USE_SIGMOID = inputs
    router_values, _ = mh_router(X, R, auxfree_bias, K, USE_SIGMOID)
    return router_values


def get_inputs_factory(B, T, H, D, E, K):
    def get_inputs():
        X = torch.randn(B, T, H, D, dtype=torch.bfloat16, device="cuda", requires_grad=True)
        R = torch.randn(H, D, E, dtype=torch.float32, device="cuda", requires_grad=True)
        auxfree_bias = torch.randn(H, E, dtype=torch.float32, device="cuda")
        USE_SIGMOID = False
        return X, R, auxfree_bias, K, USE_SIGMOID
    return get_inputs


if __name__ == "__main__":
    torch._dynamo.config.recompile_limit = 800

    # Config
    B, T, H, D = 40, 2048, 8, 128

    num_expert_active_all = [2, 4, 8]
    num_expert_all = [384, 768, 1536, 3072, 6144]

    # JSON structure
    output = {
        "config": {
            "batch_size": B,
            "num_token": T,
            "num_head": H,
            "head_size": D,
        },
        "results": []
    }

    for K in num_expert_active_all:
        for E in num_expert_all:
            print(f"\n\nnum_expert = {E}")

            get_inputs = get_inputs_factory(B, T, H, D, E, K)
            fwd_ms_base, bwd_ms_base, mem_gib_base = benchmark(func_base, get_inputs)
            fwd_ms_ours, bwd_ms_ours, mem_gib_ours = benchmark(func_ours, get_inputs)

            output["results"].append({
                "num_expert_active": K,
                "num_expert": E,
                "fwd_ms_base": fwd_ms_base,
                "fwd_ms_ours": fwd_ms_ours,
                "bwd_ms_base": bwd_ms_base,
                "bwd_ms_ours": bwd_ms_ours,
                "mem_gib_base": mem_gib_base,
                "mem_gib_ours": mem_gib_ours,
            })

    # Save
    with open("benchmark_result_io_aware_routing.json", "w") as f:
        json.dump(output, f, indent=2)

    print("DONE")
