# =========================
# GPT training with curriculum-defined iterations
# =========================
import os
import gc
import random
import numpy as np
import torch
import torch.distributed as dist
import wandb

from argparse import ArgumentParser
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

from utils.gpt import GPT, GPTConfig
from optimizers.scion import Scion
from utils.dataloader import DistributedDataLoader
from utils.configs import load_config
from utils.seeder import seed_everything

# -------------------------------------------------------------------------
# Checkpoint helpers
def save_checkpoint(path, model, optimizer, scheduler, step, seen_tokens, curriculum):
    torch.save({
        "model": model.module.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "step": step,
        "seen_tokens": seen_tokens,
        "phase_idx": curriculum.phase_idx,
        "seed": 42,
    }, path)

# -------------------------------------------------------------------------
# Parse args
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--save_every", type=int, default=0)
parser.add_argument("--ckpt_out", type=str, default="checkpoints")
parser.add_argument("--resume", type=str, default=None)
cli_args = parser.parse_args()
args = load_config(cli_args.config)

# -------------------------------------------------------------------------
# DDP setup
assert torch.cuda.is_available()
dist.init_process_group("nccl")

rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
master = rank == 0

seed_everything(42, rank)

# -------------------------------------------------------------------------
# Model
num_vocab = 50304
model = GPT(GPTConfig(
    vocab_size=num_vocab,
    n_layer=args.n_layer,
    n_head=args.n_head,
    n_embd=args.n_embd,
)).cuda()

model = torch.compile(model)
model = DDP(model, device_ids=[local_rank])
ctx = torch.amp.autocast("cuda", torch.bfloat16)

# -------------------------------------------------------------------------
# Curriculum
class StepCurriculumDynamic:
    def __init__(self, curriculum, device_batch, world_size, rank,
                 scale_embed, scale_matrix, weight_decay, momentum, unconstrained):
        self.curriculum = curriculum
        self.device_batch = device_batch
        self.world_size = world_size
        self.rank = rank
        self.scale_embed = scale_embed
        self.scale_matrix = scale_matrix
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.unconstrained = unconstrained

        self.phase_idx = -1
        self.current_phase = None

    def next_phase(self, model, seen_tokens):
        if self.current_phase is not None:
            del self.current_phase
            gc.collect()
            torch._dynamo.reset()
            torch.compiler.reset()
            torch.cuda.empty_cache()

        self.phase_idx += 1
        if self.phase_idx >= len(self.curriculum):
            return None

        phase_cfg = self.curriculum[self.phase_idx]

        global_batch = phase_cfg["batch_size"]
        seq_len = phase_cfg["sequence_length"]
        assert global_batch % (self.device_batch * self.world_size) == 0
        accum_steps = global_batch // (self.device_batch * self.world_size)
        tokens_per_step = global_batch * seq_len

        train_loader = DistributedDataLoader(
            args.input_bin,
            self.device_batch,
            seq_len,
            self.rank,
            self.world_size,
        )
        val_loader = DistributedDataLoader(
            args.input_val_bin,
            self.device_batch,
            seq_len,
            self.rank,
            self.world_size,
        )

        train_loader.skip_tokens(seen_tokens)
        optim_groups = [
            {
                "params": model.transformer.h.parameters(),
                "norm": "Spectral",
                "norm_kwargs": {"steps": 5},
                "scale": self.scale_embed,
                "weight_decay": self.weight_decay,
            },
            {
                "params": model.lm_head.parameters(),
                "norm": "Sign",
                "norm_kwargs": {},
                "scale": self.scale_matrix,
                "weight_decay": self.weight_decay,
            },
        ]

        optimizer = Scion(
            optim_groups,
            lr=phase_cfg["lr_embed"],
            momentum=self.momentum,
            unconstrained=self.unconstrained,
        )

        def get_lr(it):
            if it <= phase_cfg["warmup_iters"]:
                ratio = it / max(1, phase_cfg["warmup_iters"])
            elif it <= phase_cfg["step_until"] - phase_cfg["warmdown_iters"]:
                ratio = 1.0
            else:
                ratio = (phase_cfg["step_until"] - it) / phase_cfg["warmdown_iters"]
            return max(ratio, 1e-8 / phase_cfg["lr_embed"])

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)

        self.current_phase = {
            "phase_cfg": phase_cfg,
            "train_loader": train_loader,
            "val_loader": val_loader,
            "accum_steps": accum_steps,
            "tokens_per_step": tokens_per_step,
            "optimizer": optimizer,
            "scheduler": scheduler,
            "global_batch": global_batch,
        }
        return self.current_phase

# -------------------------------------------------------------------------
curriculum = StepCurriculumDynamic(
    args.curriculum,
    args.device_batch_size,
    world_size,
    rank,
    args.scale_embed,
    args.scale_matrix,
    args.weight_decay,
    args.momentum,
    args.unconstrained,
)

# -------------------------------------------------------------------------
# WandB
if master:
    if cli_args.ckpt_out is not None:
        os.makedirs(cli_args.ckpt_out, exist_ok=True)
    wandb.init(project=args.project, name=args.run, config=vars(args))

# -------------------------------------------------------------------------
# Resume handling
seen_tokens = 0
resume_phase_idx = -1
resume_step = 0
resume_ckpt = None

if cli_args.resume is not None:
    if master:
        print(f"Resuming from {cli_args.resume}")
    map_location = {"cuda:%d" % 0: "cuda:%d" % local_rank}
    ckpt = torch.load(cli_args.resume, map_location=map_location)
    model.module.load_state_dict(ckpt["model"], strict=True)
    seen_tokens = ckpt["seen_tokens"]
    resume_phase_idx = ckpt["phase_idx"]
    resume_step = ckpt["step"]
    # Phase is complete => Skip it TODO: Automate this
    restart_from_next_phase = False
    if restart_from_next_phase:
        curriculum.phase_idx = resume_phase_idx
        resume_step = 0
    # Resume inside phase
    else:
        curriculum.phase_idx = resume_phase_idx - 1
    resume_ckpt = ckpt

seen_tokens = torch.tensor(seen_tokens, device=device)
dist.broadcast(seen_tokens, src=0)
seen_tokens = int(seen_tokens.item())

# -------------------------------------------------------------------------
# Training loop
while True:
    phase = curriculum.next_phase(model.module, seen_tokens)
    if phase is None:
        break

    phase_cfg = phase["phase_cfg"]
    train_loader = phase["train_loader"]
    val_loader = phase["val_loader"]
    optimizer = phase["optimizer"]
    scheduler = phase["scheduler"]
    accum_steps = phase["accum_steps"]
    tokens_per_step = phase["tokens_per_step"]
    global_batch = phase["global_batch"]

    if resume_ckpt is not None:
        if not restart_from_next_phase:
            optimizer.load_state_dict(resume_ckpt["optimizer"])
            scheduler.load_state_dict(resume_ckpt["scheduler"])
        start_step = resume_step + 1
        resume_ckpt = None
    else:
        start_step = 1

    for step in tqdm(range(start_step, phase_cfg["step_until"] + 1)):
        # Validation
        val_loss = None
        last_step = (step == phase_cfg["step_until"])
        # ---------------- Validation ----------------
        if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
            model.eval()
            val_loader.reset()

            val_steps = args.val_tokens // (args.device_batch_size * phase_cfg["sequence_length"] * world_size)
            val_loss = 0.0

            with torch.no_grad():
                for _ in range(val_steps):
                    x_val, y_val = val_loader.next_batch()
                    with torch.no_grad():
                        with ctx:
                            _, loss = model(x_val, y_val, return_logits=False)
                            val_loss += loss.detach()

            dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
            val_loss /= val_steps
            model.train()

        # ---------------- Training ----------------
        x, y = train_loader.next_batch()
        for i in range(accum_steps):
            with ctx:
                _, loss = model(x, y, return_logits=False)
            if i < accum_steps - 1:
                with model.no_sync():
                    loss.backward()
                x, y = train_loader.next_batch()
            else:
                loss.backward()

        for p in model.parameters():
            if p.grad is not None:
                p.grad /= accum_steps

        optimizer.step()
        scheduler.step()
        model.zero_grad(set_to_none=True)

        seen_tokens += tokens_per_step

        if cli_args.save_every > 0 and (step % cli_args.save_every == 0 or last_step) and master:
            save_checkpoint(
                os.path.join(cli_args.ckpt_out, f"ckpt_phase{curriculum.phase_idx}_step{step}.pt"),
                model, optimizer, scheduler, step, seen_tokens, curriculum
            )

        if master:
            wandb.log({
                "val_loss": val_loss.item() if val_loss is not None else None,
                "train_loss": loss.item(),
                "lr": optimizer.param_groups[0]["lr"],
                "tokens": seen_tokens,
            })

# -------------------------------------------------------------------------
if master:
    wandb.finish()
    print("Training completed successfully.")

dist.destroy_process_group()
