
# Modified from: https://github.com/KellerJordan/modded-nanogpt/blob/master/records/101724_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt
import os
import sys
import uuid
import time
import wandb
import torch
import torch.distributed as dist
import torch._inductor.config as config

from argparse import ArgumentParser
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from utils.gpt import GPT, GPTConfig
from optimizers.scion import Scion
from utils.seeder import seed_everything
from utils.dataloader import DistributedDataLoader
from utils.configs import load_config

# read the code of this file ASAP, for logging
with open(sys.argv[0]) as f:
    code = f.read() 

# -----------------------------------------------------------------------------
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
parser.add_argument("--ckpt_in", type=str, default=None, help="Checkpoint to load")
parser.add_argument("--ckpt_out", type=str, required=True, help="Checkpoint to save after phase")
parser.add_argument("--save_step", type=int, required=True, help="Checkpoint to save everytime")
cli_args = parser.parse_args()
args = load_config(cli_args.config)

# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])

# Example usage
BASE_SEED = getattr(args, "seed", 42)
seed_everything(BASE_SEED, ddp_rank)

device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.

if master_process:
    print("======== Arguments ========")
    print(args)
    print("===========================")
    
# convenience variables
B, T = args.device_batch_size, args.sequence_length
# calculate the number of steps to take in the val loop.
# assert args.val_tokens % (B * T * ddp_world_size) == 0
val_steps = args.val_tokens // (B * T * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (B * ddp_world_size) == 0
train_accumulation_steps = args.batch_size // (B * ddp_world_size)
tokens_per_step = args.batch_size * args.sequence_length

# load tokens
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
if master_process:
    print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
    print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
x, y = train_loader.next_batch()

# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd))
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
# raw_model = model.module # always contains the "raw" unwrapped model
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

start_step = 1

# init the optimizer(s)
optim_groups = [{
    'params': model.module.transformer.h.parameters(), 
    'norm': 'Spectral', 
    'norm_kwargs': {'steps': 5}, 
    'scale': args.scale_embed,
    'weight_decay': args.weight_decay,
}, {
    'params': model.module.lm_head.parameters(), 
    'norm': 'Sign', 
    'norm_kwargs': {}, 
    'scale': args.scale_matrix,
    'weight_decay': args.weight_decay,
}]
optimizer1 = Scion(optim_groups, lr=args.lr_embed, 
                   momentum=args.momentum, unconstrained=args.unconstrained)
optimizers = [optimizer1]

# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
    assert it <= args.num_iterations
    # 1) linear warmup
    if args.warmup_iters > 0 and it < args.warmup_iters:
        ratio = (it+1) / args.warmup_iters
    # 2) constant
    elif it < args.num_iterations - args.warmdown_iters:
        ratio = 1.0
    # 3) linear warmdown
    else:
        ratio = (args.num_iterations - it) / args.warmdown_iters
    # clamp ratio so final LR = 1e-8
    min_ratio = 1e-8 / args.lr_embed
    return max(ratio, min_ratio)
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# Load checkpoint (model-only)
seen_tokens = 0
if cli_args.ckpt_in is not None:
    map_location = {"cuda:%d" % 0: "cuda:%d" % ddp_rank}
    ckpt = torch.load(cli_args.ckpt_in, map_location=map_location)
    model.module.load_state_dict(ckpt["model"], strict=True)
    # ---- optimizers ----
    for opt, opt_state in zip(optimizers, ckpt["optimizers"]):
        opt.load_state_dict(opt_state)
    # ---- schedulers ----
    for sch, sch_state in zip(schedulers, ckpt["schedulers"]):
        sch.load_state_dict(sch_state)
    # ---- dataloaders ----
    seen_tokens = ckpt["tokens"] # * ddp_world_size * train_accumulation_steps # TODO: This had to be fixed
    train_loader.skip_tokens(seen_tokens)
    # ---- RNG ----
    start_step = ckpt["step"] + 1
    if master_process:
        print(f"Loaded checkpoint from {cli_args.ckpt_in}")

global_step = start_step - 1

# begin logging
if master_process:
    run_id = str(uuid.uuid4())
    wandb.init(
        project=args.project, 
        name=args.run, 
        config=vars(args)
    )
    logdir = 'logs/%s/' % run_id
    os.makedirs(logdir, exist_ok=True)
    logfile = 'logs/%s.txt' % run_id
    # create the log file
    with open(logfile, "w") as f:
        # begin the log by printing this file (the Python code)
        f.write('='*100 + '\n')
        f.write(code)
        f.write('='*100 + '\n')
        # log information about the hardware/software environment this is running on
        # and print the full `nvidia-smi` to file
        f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
        import subprocess
        result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        f.write(f'{result.stdout}\n')
        f.write('='*100 + '\n')

# ----------------------------
# start global timer (optional, total training)
torch.cuda.synchronize()
global_start_time = time.time()
train_iter_start = time.time()  # timer for training iterations only

for step in tqdm(range(global_step, args.num_iterations + 1)):
    last_step = (step == args.num_iterations)

    # --------------- EVALUATION -----------------
    val_loss = None
    if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
        model.eval()
        val_loader.reset()
        val_loss = 0.0
        with torch.no_grad():
            for _ in range(val_steps):
                x_val, y_val = val_loader.next_batch()
                with ctx:
                    _, loss = model(x_val, y_val, return_logits=False)
                    val_loss += loss.detach()
                    del loss
        dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
        val_loss /= val_steps
        # validation does NOT affect training timer

    if last_step:
        break

    # --------------- TRAINING SECTION -----------------.
    train_iter_start = time.time()
    model.train()
    batch_tokens = B * T * ddp_world_size * train_accumulation_steps

    for i in range(1, train_accumulation_steps+1):
        with ctx:
            _, loss = model(x, y, return_logits=False)
            train_loss = loss.detach()
        x, y = train_loader.next_batch()
        if i < train_accumulation_steps:
            with model.no_sync():
                loss.backward()
        else:
            loss.backward()
    for p in model.parameters():
        p.grad /= train_accumulation_steps
    for opt, sched in zip(optimizers, schedulers):
        opt_logs = opt.step()
        sched.step()
    model.zero_grad(set_to_none=True)

    seen_tokens += tokens_per_step

    # ----------------- TRAIN TIMING -----------------
    torch.cuda.synchronize()
    train_iter_end = time.time()
    iter_elapsed_sec = (train_iter_end - train_iter_start)
    tokens_per_sec = batch_tokens / iter_elapsed_sec

    # ----------------- LOGGING -----------------
    if master_process:
        lr = sched.get_last_lr()[0]
        wandb.log({
            "train_loss": train_loss.item(),
            "val_loss": val_loss.item() if val_loss else None,
            "learning_rate": lr,
            "tokens/sec": tokens_per_sec,
            **opt_logs
        })

    # ----------------- SAVING -----------------
    if master_process and step % cli_args.save_step == 0:
        ckpt = {
            "model": model.module.state_dict(),
            "optimizers": [opt.state_dict() for opt in optimizers],
            "schedulers": [sch.state_dict() for sch in schedulers],
            "step": step,
            "tokens": seen_tokens,
            "config": vars(args),
        }
        torch.save(
            ckpt,
            f"{cli_args.ckpt_out}_step{step}.pt",
        )
        print(f"Checkpoint saved to {cli_args.ckpt_out} at step {step}")

if master_process:
    wandb.log({
        "train_loss": train_loss.item(),
        "val_loss": val_loss.item() if val_loss else None,
        "learning_rate": lr,
        "tokens/sec": tokens_per_sec,
        **opt_logs
    })
    wandb.finish()
    print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()