# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import os
import json
import tqdm
import time
import numpy as np
import torch
import torch.distributed as dist
from datetime import timedelta
from model.ops.all_to_all import all_to_all
from utils.synchronize import synchronize

def _trim_outliers(arr, pct=5):
    lower, upper = np.percentile(arr, [pct, 100 - pct])
    return arr[(arr >= lower) & (arr <= upper)]


# ----- #
# Configuration
# ----- #
# Common
BATCH_SIZE     = 40
NUM_TOKEN      = 2048
WORLD_SIZE     = 4
EMB_SIZE       = 1024
BYTES_PER_ELEM = 2  # BF16
# MoE
NUM_EXPERT = 768
assert NUM_EXPERT % WORLD_SIZE == 0
NUM_EXPERT_PER_RANK = NUM_EXPERT // WORLD_SIZE
# HP-specific
NUM_HEAD  = 8
HEAD_SIZE = 128
# Benchmark
K_VALUES = [2, 4, 8]
SKEWNESS_VALUES = np.linspace(0.0, 2.0, 10).tolist()
WARMUP_ITERS = 20
TIMED_ITERS  = 100
# Misc
metadata_a2a_split_size = [1] * WORLD_SIZE
# ----- #


# ----- #
# Setup `torch.distributed`
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl",
                        device_id=torch.device(int(os.environ["LOCAL_RANK"])),
                        timeout=timedelta(minutes=30))
rank, world_size = dist.get_rank(), dist.get_world_size()
# ----- #


all_results = []


for SKEW in [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]:
    for NUM_EXPERT_ACTIVE in [2, 4, 8, 16]:
        # Precompute Zipf distribution probabilities
        ranks = np.arange(1, NUM_EXPERT + 1)
        probs = 1.0 / np.power(ranks, SKEW)
        probs = probs / probs.sum()
        p_gpu0 = probs[:NUM_EXPERT_PER_RANK].sum()
        if rank == 0:
            print("P(GPU_0):", p_gpu0)

        ep_latency_all = list()
        bytes_to_gpu0_all = list()
        for idx_global in tqdm.tqdm(range(120)):

            # (batch_size, num_token, num_expert_active,); int64
            expert_assign_np = np.random.choice(
                NUM_EXPERT,
                size=(BATCH_SIZE, NUM_TOKEN, NUM_EXPERT_ACTIVE),
                p=probs
            )
            expert_assign = torch.from_numpy(expert_assign_np).to(dtype=torch.int64, device="cuda")

            # (num_expert,); int64; contiguous; detached
            expert_bincount = torch.bincount(expert_assign.view(-1), minlength=NUM_EXPERT)


            # (batch_size * num_token * num_expert_active, emb_size); bfloat16; contiguous
            x = torch.randn(
                size=(BATCH_SIZE * NUM_TOKEN * NUM_EXPERT_ACTIVE, EMB_SIZE),
                dtype=torch.bfloat16,
                device="cuda",
            )

            # Prepare metadata
            # (world_size, num_expert_per_rank); int64; contiguous; detached
            chunk_sizes_distribute = expert_bincount.view(WORLD_SIZE, NUM_EXPERT_PER_RANK)
            # Get `chunk_sizes_collect` through all-to-all
            # (world_size, num_expert_per_rank); int64; contiguous; detached
            chunk_sizes_collect = torch.empty_like(chunk_sizes_distribute)
            dist.all_to_all_single(
                output=chunk_sizes_collect,
                input=chunk_sizes_distribute,
                output_split_sizes=metadata_a2a_split_size,
                input_split_sizes=metadata_a2a_split_size,
            )


            synchronize()
            t_start_ep = time.perf_counter()

            # Do the token all-to-all
            # (world_size,); python list of integers
            input_splits = chunk_sizes_distribute.sum(dim=1).tolist()
            # (world_size,); python list of integers
            output_splits = chunk_sizes_collect.sum(dim=1).tolist()
            # (dyn_pool_size, emb_size); bfloat16; contiguous
            x, _ = all_to_all(input=x, input_splits=input_splits, output_splits=output_splits)

            synchronize()
            t_stop_ep = time.perf_counter()

            ep_latency_all.append((t_stop_ep - t_start_ep) * 1000)

            # (1) Calculate bytes sent to GPU0
            if rank == 0:
                tokens_to_gpu0 = sum(output_splits)
                bytes_to_gpu0 = tokens_to_gpu0 * EMB_SIZE * BYTES_PER_ELEM
                bytes_to_gpu0_all.append(bytes_to_gpu0)



        if rank == 0:
            ep_latency_all = np.array(ep_latency_all)
            ep_latency_all = ep_latency_all[-100:]
            ep_latency_median = np.median(ep_latency_all)
            ep_latency_all = _trim_outliers(ep_latency_all)
            ep_latency_mean = ep_latency_all.mean()
            ep_latency_std  = ep_latency_all.std(ddof=1)
            print(f"NUM_EXPERT_ACTIVE={NUM_EXPERT_ACTIVE}, SKEW={SKEW}")
            print("ep_latency_median:", round(ep_latency_median, 2), "ms")
            print("ep_latency_mean:", round(ep_latency_mean, 2), "ms")
            print("ep_latency_std:", round(ep_latency_std, 2), "ms")

            # Report bytes sent to GPU0
            bytes_to_gpu0_all = np.array(bytes_to_gpu0_all)
            bytes_to_gpu0_all = bytes_to_gpu0_all[-100:]
            mib_to_gpu0_all = bytes_to_gpu0_all / (1024 ** 2)
            mib_to_gpu0_median = np.median(mib_to_gpu0_all)
            mib_to_gpu0_mean = mib_to_gpu0_all.mean()
            mib_to_gpu0_std = mib_to_gpu0_all.std(ddof=1)
            print("bytes_to_gpu0_median:", round(mib_to_gpu0_median, 2), "MiB")
            print("bytes_to_gpu0_mean:", round(mib_to_gpu0_mean, 2), "MiB")
            print("bytes_to_gpu0_std:", round(mib_to_gpu0_std, 2), "MiB")

            all_results.append({
                "num_expert_active": NUM_EXPERT_ACTIVE,
                "skew": SKEW,
                "p_gpu0": round(p_gpu0, 4),
                "ep_latency_median_ms": round(ep_latency_median, 2),
                "ep_latency_mean_ms": round(ep_latency_mean, 2),
                "ep_latency_std_ms": round(ep_latency_std, 2),
                "bytes_to_gpu0_median_mib": round(mib_to_gpu0_median, 2),
                "bytes_to_gpu0_mean_mib": round(mib_to_gpu0_mean, 2),
                "bytes_to_gpu0_std_mib": round(mib_to_gpu0_std, 2),
            })

synchronize()

if rank == 0:
    with open("benchmark_result_ep.json", "w") as f:
        json.dump(all_results, f, indent=2)

dist.destroy_process_group()
