# colossalai run --nproc_per_node 1 Allreduce_GPT2.py --model_name "GPT2-L" --master_port 29502 --tp_size 1 --batch_size 4
import os
import argparse
import time
import random
import torch
import torch.distributed as dist
import numpy as np

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin

from transformers import GPT2Config, GPT2LMHeadModel

fwd_all_reduce_time = 0.0
fwd_all_reduce_calls = 0
bwd_all_reduce_time = 0.0
bwd_all_reduce_calls = 0

current_phase = None

_original_all_reduce = dist.all_reduce

def timed_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):
    """
    Measure the latency for every call of dist.all_reduce,
    then track total time/number of calls for Forward/Backward stage.
    """
    global fwd_all_reduce_time, fwd_all_reduce_calls
    global bwd_all_reduce_time, bwd_all_reduce_calls
    global current_phase

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    start_event.record()
    ret = _original_all_reduce(tensor, op=op, group=group, async_op=async_op)
    if ret is not None:
        ret.wait()
    end_event.record()
    torch.cuda.synchronize()

    elapsed = start_event.elapsed_time(end_event) / 1000.0  # millisecond -> second

    if current_phase == 'fwd':
        fwd_all_reduce_time += elapsed
        fwd_all_reduce_calls += 1
    elif current_phase == 'bwd':
        bwd_all_reduce_time += elapsed
        bwd_all_reduce_calls += 1

    return ret

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--parallel", default='TP', type=str, help='parallel mode: DP or TP')
    parser.add_argument("--tp_size", default=1, type=int)
    parser.add_argument("--model_name", default='GPT2-XL', type=str, help='Example: GPT2-B, GPT2-M, ...')
    parser.add_argument("--batch_size", default=2, type=int, help="batch size")
    parser.add_argument("--flash", action="store_true", help="Enable flash mode")
    parser.add_argument("--max_seqlength", default=1024, type=int)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--warmup_steps", default=5, type=int)
    parser.add_argument("--measure_steps", default=2, type=int)
    args = parser.parse_args()
    args.device = torch.device("cuda")
    return args

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def create_model(args):
    model_configs = {
        'GPT2-B':  {'num_layer': 12, 'num_head': 12, 'hidden_dim': 768},
        'GPT2-M':  {'num_layer': 24, 'num_head': 16, 'hidden_dim': 1024},
        'GPT2-L':  {'num_layer': 36, 'num_head': 16, 'hidden_dim': 1280},
        'GPT2-XL': {'num_layer': 48, 'num_head': 24, 'hidden_dim': 1584},
        'GPT2-2.5B': {'num_layer': 54, 'num_head': 24, 'hidden_dim': 1920},
        'GPT2-8.3B': {'num_layer': 72, 'num_head': 24, 'hidden_dim': 3072},
    }
    if args.model_name not in model_configs:
        raise ValueError(f"Unknown model_name {args.model_name}")

    cfg = model_configs[args.model_name]
    configuration = GPT2Config(
        n_positions=args.max_seqlength,
        n_embd=cfg['hidden_dim'],
        n_layer=cfg['num_layer'],
        n_head=cfg['num_head'],
        activation_function='gelu_new',
        resid_pdrop=0.0,
        attn_pdrop=0.0
    )
    model = GPT2LMHeadModel(configuration)
    return model

def measure_random_time(model, optimizer, booster, args):
    """
    - Only measure time for Forward+loss + Backward.
    - Measure communication time for Forward/Backward all_reduce seperately.
    - Print total op time and communication time across measure_steps after warmup.
    """
    global current_phase
    global fwd_all_reduce_time, fwd_all_reduce_calls
    global bwd_all_reduce_time, bwd_all_reduce_calls

    model.train()

    vocab_size = 50257
    batch_size = args.batch_size
    seq_len = args.max_seqlength

    warmup_steps = args.warmup_steps
    measure_steps = args.measure_steps
    total_steps = warmup_steps + measure_steps

    total_op_time = 0.0
    total_comm_time = 0.0

    for step in range(total_steps):
        fwd_all_reduce_calls = 0
        fwd_all_reduce_time = 0.0
        bwd_all_reduce_calls = 0
        bwd_all_reduce_time = 0.0

        input_ids = torch.randint(0, vocab_size, (batch_size, seq_len),
                                  dtype=torch.long, device=args.device)
        attention_mask = torch.ones_like(input_ids, device=args.device)
        labels = input_ids.clone()

        # --- 1) Train Forward + loss
        torch.cuda.synchronize()
        t0 = time.time()
        current_phase = 'fwd'
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels)
        loss = outputs.loss
        torch.cuda.synchronize()
        t1 = time.time()
        current_phase = None

        # --- 2) Backward
        torch.cuda.synchronize()
        t2 = time.time()
        current_phase = 'bwd'
        booster.backward(loss, optimizer)
        torch.cuda.synchronize()
        t3 = time.time()
        current_phase = None

        step_op = (t1 - t0) + (t3 - t2)
        if dist.get_rank() == 0:
            print(f"[Step {step}] Fwd+Bwd time = {step_op:.6f} sec")

        if step >= warmup_steps:
            total_op_time += step_op
            total_comm_time += (fwd_all_reduce_time + bwd_all_reduce_time)

    if dist.get_rank() == 0:
        measured = measure_steps
        avg_op_time = total_op_time / measured if measured > 0 else 0.0
        avg_comm_time = total_comm_time / measured if measured > 0 else 0.0
        print("\n==== All Steps Done ====")
        print(f"[Forward+Backward] total time: {total_op_time:.6f} sec over {measured} steps, average: {avg_op_time:.6f}")
        print(f"[All-Reduce] total time: {total_comm_time:.6f} sec over {measured} steps, average: {avg_comm_time:.6f}"
              f"(calls: {fwd_all_reduce_calls + bwd_all_reduce_calls})")

def main():
    args = parse_arguments()
    set_seed(args.seed)

    colossalai.launch_from_torch()
    args.world_size = int(os.environ.get('WORLD_SIZE', 1))

    dist.all_reduce = timed_all_reduce

    if args.parallel == 'TP':
        plugin = HybridParallelPlugin(
            tp_size=args.tp_size,
            pp_size=1,
            precision='bf16',
            max_norm=1.0,
            zero_stage=0,
            enable_flash_attention=args.flash
        )
    else:
        raise ValueError(f"Unknown parallel mode {args.parallel}")

    model = create_model(args)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    booster = Booster(plugin=plugin)

    model, optimizer, _, _, _ = booster.boost(
        model=model,
        optimizer=optimizer,
        criterion=None,
        dataloader=None,
        lr_scheduler=None
    )
    model = model.to(args.device)

    if dist.get_rank() == 0:
        param_count = sum(p.numel() for p in model.parameters())
        print(f"Model Parameter Count: {param_count/1e6:.2f}M")

    measure_random_time(model, optimizer, booster, args)


if __name__ == "__main__":
    main()
