# =========================
# GPT training with curriculum-defined iterations
# =========================
import os
import torch
import torch.distributed as dist
import wandb
import gc
import random
import numpy as np

from utils.seeder import seed_everything
from argparse import ArgumentParser
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from utils.gpt import GPT, GPTConfig
from optimizers.scion import Scion
from utils.dataloader import DistributedDataLoader
from utils.configs import load_config

import torch._dynamo
torch._dynamo.config.suppress_errors = True

# -------------------------------------------------------------------------
# Seeding
def seed_everything(seed: int, rank: int):
    seed = seed + rank  # IMPORTANT: different seed per rank

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# -------------------------------------------------------------------------
# Parse args
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
cli_args = parser.parse_args()
args = load_config(cli_args.config)

# -------------------------------------------------------------------------
# DDP setup
assert torch.cuda.is_available()
dist.init_process_group("nccl")

rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
master = rank == 0
seed_everything(42, rank)

# -------------------------------------------------------------------------
# Model
num_vocab = 50304
model = GPT(GPTConfig(
    vocab_size=num_vocab,
    n_layer=args.n_layer,
    n_head=args.n_head,
    n_embd=args.n_embd,
)).cuda()
model = torch.compile(model)
model = DDP(model, device_ids=[local_rank])
ctx = torch.amp.autocast("cuda", torch.bfloat16)

# -------------------------------------------------------------------------
# Step-based curriculum manager (owns optimizer)
class StepCurriculumDynamic:
    def __init__(self, curriculum, device_batch, world_size, rank, scale_embed, scale_matrix, weight_decay, momentum, unconstrained):
        self.curriculum = curriculum
        self.device_batch = device_batch
        self.world_size = world_size
        self.rank = rank
        self.scale_embed = scale_embed
        self.scale_matrix = scale_matrix
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.unconstrained = unconstrained

        self.phase_idx = -1
        self.current_phase = None  # store dynamically

    def next_phase(self, model, seen_tokens):
        # Clear Phase
        if self.current_phase is not None:
            del self.current_phase
            self.current_phase = None
            gc.collect()           
            torch._dynamo.reset()
            torch.compiler.reset()
            torch.cuda.empty_cache()

        self.phase_idx += 1
        if self.phase_idx >= len(self.curriculum):
            return None

        phase_cfg = self.curriculum[self.phase_idx]

        # Compute accum steps
        global_batch = phase_cfg["batch_size"]
        seq_len = phase_cfg["sequence_length"]
        assert global_batch % (self.device_batch * self.world_size) == 0
        accum_steps = global_batch // (self.device_batch * self.world_size)
        tokens_per_step = global_batch * seq_len

        # Build DataLoaders dynamically
        train_loader = DistributedDataLoader(
            args.input_bin,
            self.device_batch,
            seq_len,
            self.rank,
            self.world_size,
        )
        val_loader = DistributedDataLoader(
            args.input_val_bin,
            self.device_batch,
            seq_len,
            self.rank,
            self.world_size,
        )

        # Skip tokens (to start where we left off)
        batches_to_skip = seen_tokens // (self.device_batch * world_size * seq_len)
        print(f"Starting training from: {seen_tokens} ({batches_to_skip} batches)")
        for _ in range(batches_to_skip):
            train_loader.next_batch()

        # Build optimizer dynamically
        optim_groups = [
            {
                "params": model.transformer.h.parameters(),
                "norm": "Spectral",
                "norm_kwargs": {"steps": 5},
                "scale": self.scale_embed,
                "weight_decay": self.weight_decay,
            },
            {
                "params": model.lm_head.parameters(),
                "norm": "Sign",
                "norm_kwargs": {},
                "scale": self.scale_matrix,
                "weight_decay": self.weight_decay,
            },
        ]
        optimizer = Scion(
            optim_groups,
            lr=phase_cfg["lr_embed"],
            momentum=self.momentum,
            unconstrained=self.unconstrained,
        )
        # learning rate decay scheduler (linear warmup and warmdown)
        def get_lr(it):
            if phase_cfg['warmup_iters'] < 0 and it <= phase_cfg['warmup_iters']:
                ratio = it / phase_cfg['warmup_iters']
            elif it <= phase_cfg['step_until'] - phase_cfg['warmdown_iters']:
                ratio = 1.0
            else:
                ratio = (phase_cfg['step_until'] - it) / phase_cfg['warmdown_iters']
            min_ratio = 1e-5 / phase_cfg['lr_embed']
            return max(ratio, min_ratio)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)

        # Store current phase info
        self.current_phase = {
            "phase_cfg": phase_cfg,
            "train_loader": train_loader,
            "val_loader": val_loader,
            "accum_steps": accum_steps,
            "tokens_per_step": tokens_per_step,
            "optimizer": optimizer,
            "global_batch": global_batch,
            "scheduler": scheduler
        }
        return self.current_phase

# -------------------------------------------------------------------------
# Initialize curriculum
curriculum = StepCurriculumDynamic(
    args.curriculum,
    device_batch=args.device_batch_size,
    world_size=world_size,
    rank=rank,
    scale_embed=args.scale_embed,
    scale_matrix=args.scale_matrix,
    weight_decay=args.weight_decay,
    momentum=args.momentum,
    unconstrained=args.unconstrained,
)
# -------------------------------------------------------------------------
# WandB logging
if master:
    wandb.init(project=args.project, name=args.run, config=vars(args))

# -------------------------------------------------------------------------
seen_tokens = 0

while True:
    phase = curriculum.next_phase(model.module, seen_tokens)
    if phase is None:
        break  # done

    phase_cfg = phase["phase_cfg"]
    train_loader = phase["train_loader"]
    val_loader = phase["val_loader"]
    accum_steps = phase["accum_steps"]
    tokens_per_step = phase["tokens_per_step"]
    optimizer = phase["optimizer"]
    scheduler = phase["scheduler"]
    global_batch = phase["global_batch"]

    if master:
        print(f"\n=== PHASE {curriculum.phase_idx} ===")
        print(f"Batch size : {global_batch:,}, Seq length : {phase_cfg['sequence_length']}")

    for step in tqdm(range(1, phase_cfg["step_until"] + 1)):
        val_steps = args.val_tokens // (args.device_batch_size * phase['phase_cfg']['sequence_length'] * world_size)
        last_step = (step == phase_cfg["step_until"])

        # --- Training + validation code here (same as before) ---
        # Validation
        val_loss = None
        if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
            model.eval()
            val_loss = None
            val_loader.reset()
            val_loss = 0.0

            with torch.no_grad():
                for _ in range(val_steps):
                    x_val, y_val = val_loader.next_batch()
                    with ctx:
                        _, loss = model(x_val, y_val, return_logits=False)
                        val_loss += loss.detach()

            dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
            val_loss /= val_steps
            model.train()

        # ----------------- Training -----------------
        x, y = train_loader.next_batch()
        for i in range(accum_steps):
            with ctx:
                _, loss = model(x, y, return_logits=False)
            if i < accum_steps - 1:
                with model.no_sync():
                    loss.backward()
            else:
                loss.backward()
            if i < accum_steps - 1:
                x, y = train_loader.next_batch()

        # Scale gradients
        for p in model.parameters():
            if p.grad is not None:
                p.grad /= accum_steps

        optimizer.step()
        scheduler.step()
        # optimizer.zero_grad(set_to_none=True)
        model.zero_grad(set_to_none=True)
        seen_tokens += tokens_per_step

        # Logging
        if master:
            wandb.log({
                "train_loss": loss.item(),
                "val_loss": val_loss.item() if val_loss is not None else None,
                "lr_embed": optimizer.param_groups[0]['lr'],
                "lr_matrix": optimizer.param_groups[1]['lr'],
                "tokens": seen_tokens,
                "batch_size": global_batch,
                "seq_len": phase_cfg['sequence_length'],
            })

# -------------------------------------------------------------------------
# Cleanup
if master:
    wandb.finish()
    print("Training completed successfully.")

dist.destroy_process_group()
