# CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 QT.py --master_port 29502 --tp_size 2 --batch_size 4
# CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 QT.py --master_port 29502 --tp_size 2 --batch_size 1
import os
import argparse
import time
import random
import torch
import torch.distributed as dist
import numpy as np

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin

from Policy import QT_policy as customPolicy
from transformers import GPT2Config, GPT2LMHeadModel

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--parallel", default='TP', type=str, help='parallel mode: DP or TP')
    parser.add_argument("--tp_size", default=1, type=int)
    parser.add_argument("--model_name", default='GPT2-L', type=str, help='Example: GPT2-B, GPT2-M, ...')
    parser.add_argument("--batch_size", default=2, type=int, help="batch size")
    parser.add_argument("--flash", action="store_true", help="Enable flash mode")
    parser.add_argument("--max_seqlength", default=1024, type=int)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--warmup_steps", default=2, type=int)
    parser.add_argument("--measure_steps", default=1, type=int)
    args = parser.parse_args()
    args.device = torch.device("cuda")
    return args

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def create_model(args):
    model_configs = {
        'GPT2-B':  {'num_layer': 12, 'num_head': 12, 'hidden_dim': 768},
        'GPT2-M':  {'num_layer': 24, 'num_head': 16, 'hidden_dim': 1024},
        'GPT2-L':  {'num_layer': 36, 'num_head': 16, 'hidden_dim': 1280},
        'GPT2-XL': {'num_layer': 48, 'num_head': 24, 'hidden_dim': 1584},
        'GPT2-2.5B': {'num_layer': 54, 'num_head': 24, 'hidden_dim': 1920},
        'GPT2-8.3B': {'num_layer': 72, 'num_head': 24, 'hidden_dim': 3072},
    }
    if args.model_name not in model_configs:
        raise ValueError(f"Unknown model_name {args.model_name}")

    cfg = model_configs[args.model_name]
    configuration = GPT2Config(
        n_positions=args.max_seqlength,
        n_embd=cfg['hidden_dim'],
        n_layer=cfg['num_layer'],
        n_head=cfg['num_head'],
        activation_function='gelu_new',
        resid_pdrop=0.0,
        attn_pdrop=0.0
    )
    model = GPT2LMHeadModel(configuration)
    return model

def measure_random_time(model, optimizer, booster, args):
    """
    - Only measure time for Forward+loss + Backward.
    - Measure communication time for Forward/Backward all_reduce seperately.
    - Print total op time and communication time across measure_steps after warmup.
    """
    global current_phase
    global fwd_all_reduce_time, fwd_all_reduce_calls
    global bwd_all_reduce_time, bwd_all_reduce_calls

    model.train()

    vocab_size = 50257
    batch_size = args.batch_size
    seq_len = args.max_seqlength

    warmup_steps = args.warmup_steps
    measure_steps = args.measure_steps
    total_steps = warmup_steps + measure_steps

    total_op_time = 0.0
    total_comm_time = 0.0

    for step in range(total_steps):
        fwd_all_reduce_calls = 0
        fwd_all_reduce_time = 0.0
        bwd_all_reduce_calls = 0
        bwd_all_reduce_time = 0.0

        input_ids = torch.randint(0, vocab_size, (batch_size, seq_len),
                                  dtype=torch.long, device=args.device)
        attention_mask = torch.ones_like(input_ids, device=args.device)
        labels = input_ids.clone()

        # --- 1) Train Forward + loss
        torch.cuda.synchronize()
        t0 = time.time()
        current_phase = 'fwd'
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels)
        loss = outputs.loss
        torch.cuda.synchronize()
        t1 = time.time()
        current_phase = None

        # --- 2) Backward
        torch.cuda.synchronize()
        t2 = time.time()
        current_phase = 'bwd'
        booster.backward(loss, optimizer)
        torch.cuda.synchronize()
        t3 = time.time()
        current_phase = None

        step_op = (t1 - t0) + (t3 - t2)
        if dist.get_rank() == 0:
            print(f"[Step {step}] Fwd+Bwd time = {step_op:.6f} sec")

        if step >= warmup_steps:
            total_op_time += step_op
            total_comm_time += (fwd_all_reduce_time + bwd_all_reduce_time)

    if dist.get_rank() == 0:
        measured = measure_steps
        avg_op_time = total_op_time / measured if measured > 0 else 0.0
        avg_comm_time = total_comm_time / measured if measured > 0 else 0.0
        print("\n==== All Steps Done ====")
        print(f"[Forward+Backward] total time: {total_op_time:.6f} sec over {measured} steps, average: {avg_op_time:.6f}")
        print(f"[All-Reduce] total time: {total_comm_time:.6f} sec over {measured} steps, average: {avg_comm_time:.6f}"
              f"(calls: {fwd_all_reduce_calls + bwd_all_reduce_calls})")


def main():
    args = parse_arguments()
    set_seed(args.seed)

    colossalai.launch_from_torch()
    args.world_size = int(os.environ.get('WORLD_SIZE', 1))

    if args.parallel == 'TP':
        plugin = HybridParallelPlugin(
            tp_size=args.tp_size,
            pp_size=1,
            precision='bf16',
            max_norm=1.0,
            zero_stage=0,
            custom_policy=customPolicy.GPT2LMHeadModelPolicy(),
            enable_flash_attention=args.flash
        )
    else:
        raise ValueError(f"Unknown parallel mode {args.parallel}")

    model = create_model(args)
    optimizer = torch.optim.SGD(model.parameters())
    booster = Booster(plugin=plugin)

    model, optimizer, _, _, _ = booster.boost(
        model=model,
        optimizer=optimizer,
        criterion=None,
        dataloader=None,
        lr_scheduler=None
    )
    model = model.to(args.device)

    if dist.get_rank() == 0:
        param_count = sum(p.numel() for p in model.parameters())
        print(f"Model Parameter Count: {param_count/1e6:.2f}M")

    measure_random_time(model, optimizer, booster, args)

if __name__ == "__main__":
    main()
