import os
import argparse
import time
import random
import torch
import torch.distributed as dist
import numpy as np

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin

from transformers import GPT2Config, GPT2LMHeadModel


fwd_all_reduce_time = 0.0
fwd_all_reduce_calls = 0
bwd_all_reduce_time = 0.0
bwd_all_reduce_calls = 0

current_phase = None


_original_all_reduce = dist.all_reduce

def timed_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):
    """
    Measure the latency for every call of dist.all_reduce,
    then track total time/number of calls for Forward/Backward stage.
    """
    global fwd_all_reduce_time, fwd_all_reduce_calls
    global bwd_all_reduce_time, bwd_all_reduce_calls
    global current_phase

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    start_event.record()
    ret = _original_all_reduce(tensor, op=op, group=group, async_op=async_op)
    if ret is not None:
        ret.wait()
    end_event.record()
    torch.cuda.synchronize()

    elapsed = start_event.elapsed_time(end_event) / 1000.0  # millisecond -> second

    if current_phase == 'fwd':
        fwd_all_reduce_time += elapsed
        fwd_all_reduce_calls += 1
    elif current_phase == 'bwd':
        bwd_all_reduce_time += elapsed
        bwd_all_reduce_calls += 1

    return ret


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--parallel", default='TP', type=str, help='parallel mode: DP or TP')
    parser.add_argument("--tp_size", default=1, type=int)
    parser.add_argument("--model_name", default='GPT2-XL', type=str, help='Example: GPT2-B, GPT2-M, ...')
    parser.add_argument("--batch_size", default=2, type=int, help="batch size")
    parser.add_argument("--flash", action="store_true", help="Enable flash mode")
    parser.add_argument("--max_seqlength", default=1024, type=int)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--warmup_steps", default=5, type=int)
    parser.add_argument("--measure_steps", default=2, type=int)
    args = parser.parse_args()
    args.device = torch.device("cuda")
    return args

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def create_model(args):
    model_configs = {
        'GPT2-B':  {'num_layer': 12, 'num_head': 12, 'hidden_dim': 768},
        'GPT2-M':  {'num_layer': 24, 'num_head': 16, 'hidden_dim': 1024},
        'GPT2-L':  {'num_layer': 36, 'num_head': 16, 'hidden_dim': 1280},
        'GPT2-XL': {'num_layer': 48, 'num_head': 24, 'hidden_dim': 1584},
        'GPT2-2.5B': {'num_layer': 54, 'num_head': 24, 'hidden_dim': 1920},
        'GPT2-8.3B': {'num_layer': 72, 'num_head': 24, 'hidden_dim': 3072},
    }
    if args.model_name not in model_configs:
        raise ValueError(f"Unknown model_name {args.model_name}")

    cfg = model_configs[args.model_name]
    configuration = GPT2Config(
        n_positions=args.max_seqlength,
        n_embd=cfg['hidden_dim'],
        n_layer=cfg['num_layer'],
        n_head=cfg['num_head'],
        activation_function='gelu_new',
        resid_pdrop=0.0,
        attn_pdrop=0.0
    )
    model = GPT2LMHeadModel(configuration)
    return model


def measure_random_time(model, optimizer, booster, args):
    """
    - Measure for Pure Forward(no grad) / Actual Forward+loss.
    - Compute/print average commumication time of forward all_reduce per step.
    """
    global current_phase, fwd_all_reduce_calls, fwd_all_reduce_time
    model.eval()

    vocab_size = 50257
    batch_size = args.batch_size
    seq_len = args.max_seqlength

    warmup_steps = args.warmup_steps
    measure_steps = args.measure_steps
    total_steps = warmup_steps + measure_steps

    total_fwd_comm = 0.0
    total_pure_fwd_time = 0.0

    for step in range(total_steps):
        fwd_all_reduce_calls = 0
        fwd_all_reduce_time = 0.0

        input_ids = torch.randint(0, vocab_size, (batch_size, seq_len),
                                  dtype=torch.long, device=args.device)
        attention_mask = torch.ones_like(input_ids, device=args.device)
        labels = input_ids.clone()

        with torch.no_grad():
            torch.cuda.synchronize()
            t0 = time.time()
            current_phase = 'fwd'
            _ = model(input_ids=input_ids,
                      attention_mask=attention_mask,
                      labels=labels)
            torch.cuda.synchronize()
            t1 = time.time()
            current_phase = None

        pure_fwd_time = t1 - t0

        if dist.get_rank() == 0:
            print(f"[Step {step} / Pure-Fwd] time = {pure_fwd_time:.6f} sec")


        if step >= warmup_steps:
            total_fwd_comm += fwd_all_reduce_time
            total_pure_fwd_time += pure_fwd_time

    if dist.get_rank() == 0:
        measured = measure_steps
        avg_fwd_comm = total_fwd_comm / measured if measured > 0 else 0.0
        avg_pure_fwd = total_pure_fwd_time / measured if measured > 0 else 0.0
        print("\n==== All Steps Done ====")
        print(f"[Forward all_reduce] total time: {total_fwd_comm:.6f} sec over {measured} steps → "
              f"avg {avg_fwd_comm:.6f} sec/step (calls: {fwd_all_reduce_calls})")
        print(f"[Pure Forward] total time: {total_pure_fwd_time:.6f} sec over {measured} steps → "
              f"avg {avg_pure_fwd:.6f} sec/step")

def main():
    args = parse_arguments()
    set_seed(args.seed)

    colossalai.launch_from_torch()
    args.world_size = int(os.environ.get('WORLD_SIZE', 1))

    dist.all_reduce = timed_all_reduce

    if args.parallel == 'TP':
        plugin = HybridParallelPlugin(
            tp_size=args.tp_size,
            pp_size=1,
            precision='bf16',
            max_norm=1.0,
            zero_stage=0,
            enable_flash_attention=args.flash
        )
    else:
        raise ValueError(f"Unknown parallel mode {args.parallel}")

    model = create_model(args)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    booster = Booster(plugin=plugin)

    model, optimizer, _, _, _ = booster.boost(
        model=model,
        optimizer=optimizer,
        criterion=None,
        dataloader=None,
        lr_scheduler=None
    )
    model = model.to(args.device)

    if dist.get_rank() == 0:
        param_count = sum(p.numel() for p in model.parameters())
        print(f"Model Parameter Count: {param_count/1e6:.2f}M")

    measure_random_time(model, optimizer, booster, args)

if __name__ == "__main__":
    main()
