# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import os
import tqdm
import time
import numpy as np
import torch
import torch.distributed as dist
from datetime import timedelta
from model.ops.all_to_all import all_to_all
from utils.synchronize import synchronize

def _trim_outliers(arr, pct=5):
    lower, upper = np.percentile(arr, [pct, 100 - pct])
    return arr[(arr >= lower) & (arr <= upper)]


# ----- #
# Configuration
# ----- #
# Common
BATCH_SIZE     = 40
NUM_TOKEN      = 2048
WORLD_SIZE     = 4
EMB_SIZE       = 1024
BYTES_PER_ELEM = 2  # BF16
# HP-specific
NUM_HEAD  = 8
HEAD_SIZE = 128
assert NUM_HEAD % WORLD_SIZE == 0
NUM_HEAD_PER_RANK = NUM_HEAD // WORLD_SIZE
# Benchmark
WARMUP_ITERS = 20
TIMED_ITERS  = 100
# ----- #


# ----- #
# Setup `torch.distributed`
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl",
                        device_id=torch.device(int(os.environ["LOCAL_RANK"])),
                        timeout=timedelta(minutes=30))
rank, world_size = dist.get_rank(), dist.get_world_size()
# ----- #


hp_split_size = BATCH_SIZE * NUM_TOKEN
hp_input_splits = [hp_split_size] * WORLD_SIZE
hp_output_splits = [hp_split_size] * WORLD_SIZE

# bath_size for HP (total tokens after all-to-all on each rank)
bath_size = BATCH_SIZE * NUM_TOKEN * WORLD_SIZE

if rank == 0:
    print("HP Benchmark Configuration:")
    print(f"  BATCH_SIZE={BATCH_SIZE}, NUM_TOKEN={NUM_TOKEN}, WORLD_SIZE={WORLD_SIZE}")
    print(f"  NUM_HEAD={NUM_HEAD}, HEAD_SIZE={HEAD_SIZE}, NUM_HEAD_PER_RANK={NUM_HEAD_PER_RANK}")
    print(f"  hp_split_size={hp_split_size}, bath_size={bath_size}")
    print()

hp_latency_all = list()
bytes_to_gpu0_all = list()

for idx_global in tqdm.tqdm(range(WARMUP_ITERS + TIMED_ITERS)):

    x = torch.randn(
        size=(WORLD_SIZE * BATCH_SIZE * NUM_TOKEN, NUM_HEAD_PER_RANK * HEAD_SIZE),
        dtype=torch.bfloat16,
        device="cuda",
    )

    synchronize()
    t_start_hp = time.perf_counter()

    # Do the HP all-to-all
    x, _ = all_to_all(input=x, input_splits=hp_input_splits, output_splits=hp_output_splits)

    synchronize()
    t_stop_hp = time.perf_counter()

    hp_latency_all.append((t_stop_hp - t_start_hp) * 1000)

    # Calculate bytes sent to GPU0
    if rank == 0:
        tokens_to_gpu0 = sum(hp_output_splits)
        bytes_to_gpu0 = tokens_to_gpu0 * NUM_HEAD_PER_RANK * HEAD_SIZE * BYTES_PER_ELEM
        bytes_to_gpu0_all.append(bytes_to_gpu0)


if rank == 0:
    hp_latency_all = np.array(hp_latency_all)
    hp_latency_all = hp_latency_all[-TIMED_ITERS:]
    hp_latency_median = np.median(hp_latency_all)
    hp_latency_all = _trim_outliers(hp_latency_all)
    hp_latency_mean = hp_latency_all.mean()
    hp_latency_std  = hp_latency_all.std(ddof=1)
    print("HP Benchmark Results:")
    print("hp_latency_median:", round(hp_latency_median, 2), "ms")
    print("hp_latency_mean:", round(hp_latency_mean, 2), "ms")
    print("hp_latency_std:", round(hp_latency_std, 2), "ms")

    # Report bytes sent to GPU0
    bytes_to_gpu0_all = np.array(bytes_to_gpu0_all)
    bytes_to_gpu0_all = bytes_to_gpu0_all[-TIMED_ITERS:]
    mib_to_gpu0_all = bytes_to_gpu0_all / (1024 ** 2)
    mib_to_gpu0_median = np.median(mib_to_gpu0_all)
    mib_to_gpu0_mean = mib_to_gpu0_all.mean()
    mib_to_gpu0_std = mib_to_gpu0_all.std(ddof=1)
    print("bytes_to_gpu0_median:", round(mib_to_gpu0_median, 2), "MiB")
    print("bytes_to_gpu0_mean:", round(mib_to_gpu0_mean, 2), "MiB")
    print("bytes_to_gpu0_std:", round(mib_to_gpu0_std, 2), "MiB")

synchronize()
dist.destroy_process_group()
