# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import json
import torch
import torch.nn.functional as F
import transformer_engine.pytorch as te
from torch.nn.attention.flex_attention import flex_attention
from utils.benchmark import benchmark
from model.modules.ffwd._get_block_mask import _get_block_mask
from model.modules.ffwd._score_mod_gelu import _score_mod_gelu


def factory_base(batch_size, num_token, num_head, head_size, num_expert, num_expert_active, expert_size):
    grouped_fc_1 = te.GroupedLinear(
        num_gemms=num_expert * num_head,
        in_features=head_size,
        out_features=expert_size,
        bias=False,
        params_dtype=torch.float32,
    ).cuda()
    grouped_fc_2 = te.GroupedLinear(
        num_gemms=num_expert * num_head,
        in_features=expert_size,
        out_features=head_size,
        bias=False,
        params_dtype=torch.float32,
    ).cuda()

    assert (batch_size * num_token * num_expert_active * num_head) % (num_expert * num_head) == 0
    group_size = (batch_size * num_token * num_expert_active * num_head) // (num_expert * num_head)
    group_sizes = [group_size] * (num_expert * num_head)

    @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True, cache_enabled=False)
    def func_base(x):
        x = grouped_fc_1(x, group_sizes)
        x = F.gelu(x)
        x = grouped_fc_2(x, group_sizes)
        return x

    def get_inputs_base():
        x = torch.randn((batch_size * num_token * num_expert_active * num_head, head_size), dtype=torch.bfloat16, device="cuda", requires_grad=True)
        return x
    return func_base, get_inputs_base







































@torch.compile()
def _stage_flex_attention_computation(q, k, v, score_mod, block_mask):
    """
    In:  (1, num_head, batch_size * num_token * num_expert_active, head_size); bfloat16; contiguous
         (1, num_head, num_expert * expert_size, head_size); float32; contiguous
         (1, num_head, num_expert * expert_size, head_size); float32; contiguous
         score_mod block_mask
    Out: (1, num_head, batch_size * num_token * num_expert_active, head_size); bfloat16; contiguous
         (1, num_head, batch_size * num_token * num_expert_active); float32; contiguous
    """
    o, lse = flex_attention(
        query=q,
        key=k,
        value=v,
        scale=1.0,
        block_mask=block_mask,
        score_mod=score_mod,
        return_lse=True,
    )
    return o, lse


@torch.compile()
def _stage_reversal_trick(o, lse, v, expert_assign, num_head, head_size, num_expert, expert_size):
    """
    In:  (1, num_head, batch_size * num_token * num_expert_active, head_size); bfloat16; contiguous
         (1, num_head, batch_size * num_token * num_expert_active); float32; contiguous
         (1, num_head, num_expert * expert_size, head_size); float32; contiguous
         (num_head, batch_size * num_token * num_expert_active); int64; contiguous; detached
         int int int int
    Out: (1, num_head, batch_size * num_token * num_expert_active, head_size); bfloat16; contiguous
    """
    # Stage 1
    lse = lse[:, :, :, None]
    o = o * lse.exp().to(torch.bfloat16)

    # Stage 2
    offsets = v.view(1, num_head, num_expert, expert_size, head_size)
    offsets = offsets.sum(dim=3, keepdim=False)
    offsets = offsets.to(torch.bfloat16)

    expert_assign = expert_assign[None, :, :, None]
    expert_assign = expert_assign.expand(-1, -1, -1, head_size)
    offsets = torch.gather(input=offsets, dim=2, index=expert_assign)
    o = o - offsets
    return o


def factory_ours(batch_size, num_token, num_head, head_size, num_expert, num_expert_active, expert_size):
    """
    Create our IO-aware expert computation function and input factory.
    Ours uses FlexAttention with reversal trick.
    """
    block_size = 128

    def ours_fn(inputs):
        q, k, v = inputs
        num_block_q = (batch_size * num_token * num_expert_active) // block_size
        block_level_expert_assign = torch.randint(0, num_expert, (num_head, num_block_q), dtype=torch.int64, device="cuda")
        expert_assign = block_level_expert_assign.repeat_interleave(block_size, dim=1)

        block_mask = _get_block_mask(
            block_level_expert_assign=block_level_expert_assign,
            num_expert=num_expert,
            expert_size=expert_size,
            block_size=block_size,
        )
        o, lse = _stage_flex_attention_computation(
            q, k, v,
            score_mod=_score_mod_gelu,
            block_mask=block_mask,
        )

        o = _stage_reversal_trick(
            o=o, lse=lse, v=v, expert_assign=expert_assign,
            num_head=num_head, head_size=head_size, num_expert=num_expert, expert_size=expert_size,
        )
        return o

    def make_inputs():
        q = torch.randn(1, num_head, batch_size * num_token * num_expert_active, head_size, dtype=torch.bfloat16, device="cuda", requires_grad=True)
        k = torch.randn(1, num_head, num_expert * expert_size, head_size, dtype=torch.float32, device="cuda", requires_grad=True)
        v = torch.randn(1, num_head, num_expert * expert_size, head_size, dtype=torch.float32, device="cuda", requires_grad=True)

        k = k.to(torch.bfloat16)  # Explict cast to enable FP32 accumulation
        v = v.to(torch.bfloat16)
        return q, k, v

    return ours_fn, make_inputs


if __name__ == "__main__":
    torch._dynamo.config.recompile_limit = 800

    # Config
    batch_size = 4
    num_token = 512
    num_head = 8
    head_size = 128
    num_expert_active = 4

    # Ideally we want to showcase 1024 experts
    expert_size_all = [128, 256, 512]
    num_expert_all  = [8, 16, 32, 64, 128, 256, 512]  # 2048  # Required: (batch_size * num_token * num_expert_active) % num_expert == 0


    # JSON structure
    output = {
        "config": {
            "batch_size": batch_size,
            "num_token": num_token,
            "num_head": num_head,
            "head_size": head_size,
            "num_expert_active": num_expert_active,
        },
        "results": []
    }

    for expert_size in expert_size_all:
        for num_expert in num_expert_all:
            print(f"\n\nexpert_size = {expert_size}, num_expert = {num_expert}")

            func_base, get_inputs_base = factory_base(batch_size, num_token, num_head, head_size, num_expert, num_expert_active, expert_size)
            func_ours, get_inputs_ours = factory_ours(batch_size, num_token, num_head, head_size, num_expert, num_expert_active, expert_size)

            fwd_ms_base, bwd_ms_base, mem_gib_base = benchmark(func_base, get_inputs_base)
            fwd_ms_ours, bwd_ms_ours, mem_gib_ours = benchmark(func_ours, get_inputs_ours)

            output["results"].append({
                "expert_size": expert_size,
                "num_expert": num_expert,
                "fwd_ms_base": fwd_ms_base,
                "fwd_ms_ours": fwd_ms_ours,
                "bwd_ms_base": bwd_ms_base,
                "bwd_ms_ours": bwd_ms_ours,
                "mem_gib_base": mem_gib_base,
                "mem_gib_ours": mem_gib_ours,
            })

    # Save
    with open("benchmark_result_io_aware_expert_computation.json", "w") as f:
        json.dump(output, f, indent=2)

    print("DONE")
