"""
Train model. From root directory of the project, run as:

python -m scripts.base_train.py

or distributed as:

torchrun --nproc_per_node=8 -m scripts.base_train.py

If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --sequence_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
"""

import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import warnings
# Silence FutureWarning about deprecated torch.cuda.amp.custom_fwd (used by mamba-ssm, causal-conv1d, mad, etc.)
# The warning is emitted by torch itself, so we can't filter by the offending module name.
warnings.filterwarnings("ignore", message=".*custom_fwd.*", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*custom_bwd.*", category=FutureWarning)
import argparse
import yaml
import time
import shutil
from tqdm import tqdm
from contextlib import nullcontext

import wandb
import torch
from torch.profiler import profile, record_function, ProfilerActivity

from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state, tokenized_distributed_data_loader, tokenized_distributed_data_loader_with_state
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.report import get_report
from scripts.base_eval import evaluate_model
print_banner()

def sample(model, orig_model, tokenizer):
        model.eval()
        prompts = [
            "The capital of France is",
            "The chemical symbol of gold is",
            "If yesterday was Friday, then tomorrow will be",
            "The opposite of hot is",
            "The planets of the solar system are:",
            "My favorite color is",
            "If 5*x + 3 = 13, then x is",
        ]
        engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
        for prompt in prompts:
            tokens = tokenizer(prompt, prepend="<|bos|>")
            with autocast_ctx:
                sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
            print0(tokenizer.decode(sample[0]))
        model.train()

# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Pretrain base model")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
# Model architecture
parser.add_argument("--config", type=str, default=None, help="path to yaml config file (overrides other defaults)")
parser.add_argument("--depth", type=int, default=12, help="depth of the Transformer model")
parser.add_argument("--aspect_ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
parser.add_argument("--model_dim", type=int, default=512, help="override model dimension (n_embd). If >0, ignores depth*aspect_ratio")
parser.add_argument("--head_dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--sequence_len", type=int, default=1024, help="max context length")
parser.add_argument("--vocab_size", type=int, default=265, help="vocabulary size")

parser.add_argument("--kla_blocks", nargs="+", type=int, default=[], help="kla blocks (e.g. [1,3,5]) (-1 = all blocks)")
parser.add_argument("--mamba_blocks", nargs="+", type=int, default=[], help="mamba blocks (e.g. [1,3,5]) (-1 = all blocks)")
parser.add_argument("--gdn_blocks", nargs="+", type=int, default=[], help="gated delta net blocks (e.g. [1,3,5]) (-1 = all blocks)")
parser.add_argument("--gla_blocks", nargs="+", type=int, default=[], help="gated linear attention blocks (e.g. [1,3,5]) (-1 = all blocks)")
parser.add_argument("--d_state", type=int, default=16, help="state dimension for SSM layers")
parser.add_argument("--mamba_params", type=bool, default=True, help="KLA only - whether to use mamba-like parametrisation")
parser.add_argument("--kla_kernel", type=bool, default=True, help="use kernel (new_kla)?")
parser.add_argument("--mimo_rank", type=int, default=1, help="MIMO rank for kla blocks")
parser.add_argument("--skip_around_kla", type=bool, default=True, help="residual connection around block?")
parser.add_argument("--decoder_mlp", type=bool, default=False, help="whether to use decoder MLP in kla blocks")
parser.add_argument("--use_reparametrisation_trick", type=bool, default=True, help="whether to use reparametrisation trick for kla blocks")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num_iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target_flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
parser.add_argument("--target_param_data_ratio", type=int, default=20, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
parser.add_argument("--data_dir", type=str, default="base_data", help="directory containing data files")
# Optimization
parser.add_argument("--device_batch_size", type=int, default=64, help="per-device batch size")
parser.add_argument("--total_batch_size", type=int, default=524288, help="total batch size in tokens")
parser.add_argument("--use_muon", type=bool, default=True, help="whether to use Muon optimizer for linear layers")
parser.add_argument("--embedding_lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--grad_clip", type=float, default=1.0, help="gradient clipping norm")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--adam_beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
parser.add_argument("--adam_beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown_ratio", type=float, default=0.4, help="ratio of iterations for LR warmdown")
parser.add_argument("--final_lr_frac", type=float, default=0.0, help="final LR as fraction of initial LR")
parser.add_argument("--resume_from_step", type=int, default=-1, help="resume training from this step (-1 = disable)")
# Evaluation
parser.add_argument("--eval_every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval_tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--core_metric_every", type=int, default=-1, help="evaluate CORE metric every N steps (-1 = disable)")
parser.add_argument("--core_metric_max_per_task", type=int, default=500, help="examples per task for CORE metric")
parser.add_argument("--sample_every", type=int, default=250, help="sample from model every N steps (-1 = disable)")
parser.add_argument("--save_every", type=int, default=1000, help="save checkpoints every N steps (-1 = only at end)")
parser.add_argument("--push_checkpoints_to_hub", type=bool, default=True, help="whether to push checkpoints to HuggingFace hub")
# Profiling
parser.add_argument("--use_profiler", type=bool, default=False, help="whether to use PyTorch profiler for training step")
parser.add_argument("--profile_step", type=int, default=2, help="which optimizer step to profile (-1 = disable)")
parser.add_argument("--profile_micro_step", type=int, default=0, help="which grad-accum micro-step to profile")
parser.add_argument("--memory_history_max_entries", type=int, default=10000, help="max CUDA alloc/free events to keep for memory snapshot")
# Output
parser.add_argument("--model_tag", type=str, default=None, help="override model tag for checkpoint directory name")
args, left_over_args = parser.parse_known_args()
if args.config:
    with open(f"configs/{args.config}.yaml", 'r') as f:
        config = yaml.safe_load(f)
    # This is the magic line: it sets the YAML values as the new defaults
    parser.set_defaults(**config)
args = parser.parse_args()
if args.model_tag is None:
    model_tag = "gpt"
    if len(args.kla_blocks) > 0:
        model_tag = "kla"
    if len(args.mamba_blocks) > 0:
        model_tag = "mamba"
    if len(args.gdn_blocks) > 0:
        model_tag = "gdn"
    if len(args.gla_blocks) > 0:
        model_tag = "gla"
    if args.data_dir == "tinystories_data":
        model_tag += "_ts"
        # args.total_batch_size = args.sequence_len * args.device_batch_size * 8  # assume 8 GPUs for TinyStories experiments
    elif args.data_dir == "mad_data":
        model_tag += "_mad"
        args.vocab_size = 16
        args.sequence_len = 256
        args.total_batch_size = args.sequence_len * args.device_batch_size * 1  # assume single GPU for MAD experiments
        args.eval_tokens = args.sequence_len * args.device_batch_size * 8
        args.target_param_data_ratio = 1000
    model_tag += f"_{args.depth}x{args.model_dim}"
    if args.use_reparametrisation_trick:
        model_tag += "_sample"
    if args.run == "dummy":
        model_tag = f"{model_tag}_dummy"
        
    args.model_tag = model_tag
args.run = args.model_tag # for wandb run name


user_config = vars(args).copy()  # for logging
# Detect if running on Kubernetes and add pod name to user config
if os.path.exists('/var/run/secrets/kubernetes.io/serviceaccount/token'):
    pod_name = os.environ.get('HOSTNAME')
    if pod_name:
        user_config['pod_name'] = pod_name
# -----------------------------------------------------------------------------

# Compute init
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
args.device_type = device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0

# wandb logging init
use_dummy_wandb = "dummy" in args.run or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config, entity="***")

# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer(vocab_size=args.vocab_size) if args.data_dir != "mad_data" else None
token_bytes = get_token_bytes(device=device, vocab_size=args.vocab_size) if args.data_dir != "mad_data" else torch.tensor([1] * 14 + [0, 0], dtype=torch.int32, device=device)
print0(f"Vocab size: {args.vocab_size:,}")

# Model kwargs are derived from the desired depth of the model
num_layers = args.depth
model_dim = args.model_dim if args.model_dim and args.model_dim > 0 else args.depth * args.aspect_ratio
def find_num_heads(model_dim, target_head_dim):
    # Find num_heads that divides model_dim evenly, with head_dim closest to target.
    ideal = max(1, round(model_dim / target_head_dim))
    for offset in range(model_dim):
        for candidate in [ideal + offset, ideal - offset]:
            if candidate > 0 and model_dim % candidate == 0:
                return candidate
    return 1
num_heads = find_num_heads(model_dim, args.head_dim)
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}")
print0(f"num_kv_heads: {num_kv_heads}")

# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = args.device_batch_size * args.sequence_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.sequence_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")

# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
batch_lr_scale = 1.0
reference_batch_size = 2**19
batch_ratio = args.total_batch_size / reference_batch_size
if batch_ratio != 1.0:
    # SGD: linear scaling with batch size is standard (not used in nanochat)
    # AdamW: sqrt scaling is standard
    # Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
    batch_lr_scale = batch_ratio ** 0.5
    print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")

# -----------------------------------------------------------------------------
# Initialize the Model

# Create a new model with random weights
model_config_kwargs = dict(**vars(args).copy(), n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
    # All tensors are created as meta tensors (they have shape/dtype but no data)
    model = GPT(GPTConfig(**model_config_kwargs))
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
model.init_weights() # All tensors get initialized

# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" # e.g. d12
# add 8char hash from wandb run id to output_dirname for uniqueness
if not use_dummy_wandb and args.resume_from_step == -1:
    output_dirname += f"_{wandb_run.id[:8]}"
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = args.resume_from_step != -1
if resuming:
    print0(f"Resuming optimization from step {args.resume_from_step}")
    model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank)
    model.load_state_dict(model_data, strict=True, assign=True)
    del model_data # free up this memory after the copy

orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
num_scaling_params = orig_model.num_scaling_params()
print0(f"\nNumber of parameters: {num_params:,} (scaling: {num_scaling_params:,})")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")

# Add profiling hooks for activations (only if profiling enabled)
if args.use_profiler:
    activation_norms = {}
    def activation_hook(module, input, output):
        if isinstance(output, torch.Tensor) and output.numel() > 0:
            key = module.__class__.__name__
            current = activation_norms.get(key, 0)
            activation_norms[key] = max(current, output.norm().item())

    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # leaf modules
            module.register_forward_hook(activation_hook)

# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
if args.num_iterations > 0:
    num_iterations = args.num_iterations
    print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif args.target_flops > 0:
    # calculate the number of iterations from the target flops
    num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
    print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif args.target_param_data_ratio > 0:
    # calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
    target_tokens = args.target_param_data_ratio * num_scaling_params
    num_iterations = target_tokens // args.total_batch_size
    print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
    raise ValueError("No training horizon specified")
# eval at least 5 times during training
args.eval_every = args.eval_every if args.eval_every > 0 and args.eval_every < num_iterations // 5 else (num_iterations // 5)  
total_tokens = args.total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")

report_data = [
    user_config, # CLI args
    { # stats about the training setup
        "Number of parameters": num_params,
        "Number of FLOPs per token": f"{num_flops_per_token:e}",
        "Calculated number of iterations": num_iterations,
        "Number of training tokens": total_tokens,
        "Tokens : Params ratio": args.total_batch_size * num_iterations / num_params,
        "DDP world size": ddp_world_size,
        "warmup_ratio": args.warmup_ratio,
        "warmdown_ratio": args.warmdown_ratio,
        "final_lr_frac": args.final_lr_frac,
    },
]

wandb_run.log({"report/" + k: v for k, v in report_data[1].items()})

# -----------------------------------------------------------------------------
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
adam_betas = (args.adam_beta1, args.adam_beta2)
optimizers = model.setup_optimizers(
    unembedding_lr=args.unembedding_lr * batch_lr_scale,
    embedding_lr=args.embedding_lr * batch_lr_scale,
    matrix_lr=args.matrix_lr * batch_lr_scale,
    weight_decay=args.weight_decay,
    adam_betas=adam_betas,
    use_muon=args.use_muon,
)
adamw_optimizer, muon_optimizer = optimizers

if resuming:
    for opt, dat in zip(optimizers, optimizer_data):
        opt.load_state_dict(dat)
    del optimizer_data # free up the memory

# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
tokens_dir = os.path.join(base_dir, "tokenized_data")
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
if args.data_dir == "mad_data":
    print0("Using MAD dataset for training")
    train_loader = tokenized_distributed_data_loader_with_state(args.device_batch_size, args.sequence_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict, data_path="mad_data")
    build_val_loader = lambda: tokenized_distributed_data_loader(args.device_batch_size, args.sequence_len, split="val", device=device, data_path="mad_data")
else:
    train_loader = tokenizing_distributed_data_loader_with_state(args.device_batch_size, args.sequence_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict, tokenizer_vocab_size=args.vocab_size, data_dir=args.data_dir)
    build_val_loader = lambda: tokenizing_distributed_data_loader(args.device_batch_size, args.sequence_len, split="val", device=device, tokenizer_vocab_size=args.vocab_size, data_dir=args.data_dir)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data

# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers

# Learning rate scheduler
def get_lr_multiplier(it):
    warmup_iters = round(args.warmup_ratio * num_iterations)
    warmdown_iters = round(args.warmdown_ratio * num_iterations)
    if it < warmup_iters:
        return (it + 1) / warmup_iters
    elif it <= num_iterations - warmdown_iters:
        return 1.0
    else:
        progress = (num_iterations - it) / warmdown_iters
        return progress * 1.0 + (1 - progress) * args.final_lr_frac

# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
    frac = min(it / 300, 1)
    momentum = (1 - frac) * 0.85 + frac * 0.95
    return momentum

# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)

if not resuming:
    step = 0
    val_bpb = None # will be set if eval_every > 0
    min_val_bpb = float("inf")
    smooth_train_loss = 0 # EMA of training loss
    total_training_time = 0 # total wall-clock time of training
else:
    step = meta_data["step"]
    loop_state = meta_data["loop_state"]
    val_bpb = meta_data["val_bpb"]
    min_val_bpb = loop_state["min_val_bpb"]
    smooth_train_loss = loop_state["smooth_train_loss"]
    total_training_time = loop_state["total_training_time"]


# -----------------------------------------------------------------------------
# Training loop

pbar = tqdm(total=num_iterations, desc="Training", disable=not master_process,
            bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")

sample(model, orig_model, tokenizer)

while True:
    last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
    flops_so_far = num_flops_per_token * args.total_batch_size * step

    # once in a while: evaluate the val bpb (all ranks participate)
    if args.eval_every > 0 and (last_step or step % args.eval_every == 0) and step > 0:
        model.eval()
        val_loader = build_val_loader()
        eval_steps = args.eval_tokens // (args.device_batch_size * args.sequence_len * ddp_world_size)
        with autocast_ctx:
            val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
        print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
        if val_bpb < min_val_bpb:
            min_val_bpb = val_bpb
        wandb_run.log({
            "step": step,
            "total_training_flops": flops_so_far,
            "total_training_time": total_training_time,
            "val/bpb": val_bpb,
        })
        model.train()

    # once in a while: estimate the CORE metric (all ranks participate)
    # use the original uncompiled model because the inputs keep changing shape
    results = {}
    if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)) and tokenizer is not None:
        model.eval()
        with autocast_ctx:
            results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
        print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
        wandb_run.log({
            "step": step,
            "total_training_flops": flops_so_far,
            "core_metric": results["core_metric"],
            "centered_results": results["centered_results"],
        })
        model.train()

    # once in a while: sample from the model (only on master process)
    # use the original uncompiled model because the inputs keep changing shape
    if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)) and tokenizer is not None:
        sample(model, orig_model, tokenizer)

    # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
    if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
        save_checkpoint(
            checkpoint_dir,
            step,
            orig_model.state_dict(), # model parameters
            [opt.state_dict() for opt in optimizers], # optimizer states
            { # metadata saved as json
                "step": step,
                "val_bpb": val_bpb, # loss at last step
                "model_config": model_config_kwargs,
                "user_config": user_config, # inputs to the training script
                "device_batch_size": args.device_batch_size,
                "sequence_len": args.sequence_len,
                "dataloader_state_dict": dataloader_state_dict,
                "loop_state": { # all loop state (other than step) so that we can resume training
                    "min_val_bpb": min_val_bpb,
                    "smooth_train_loss": smooth_train_loss,
                    "total_training_time": total_training_time,
                },
            },
            rank=ddp_rank,
        )

    # termination conditions (TODO: possibly also add loss explosions etc.)
    if last_step:
        break

    # -------------------------------------------------------------------------
    # single training step
    # evaluate the gradient
    synchronize()
    t0 = time.time()

    do_profile = (
        bool(args.use_profiler)
        and master_process
        and args.profile_step >= 0
        and step == args.profile_step
    )
    profile_dir = None
    if do_profile:
        profile_dir = os.path.join(checkpoint_dir, "profiles")
        print0(f"Profiler enabled for step={step}, micro_step={args.profile_micro_step} (fwd+bwd only)")

    for micro_step in range(grad_accum_steps):
        profiling_this_micro = bool(do_profile and micro_step == args.profile_micro_step)

################################################################################ profile block

        if profiling_this_micro:
            assert profile_dir is not None
            os.makedirs(profile_dir, exist_ok=True)
            # Use the PyTorch profiler trace suffix that W&B recognizes for interactive viewing.
            # (W&B writes its own trace files under wandb.run.dir/pytorch_traces/*.pt.trace.json.)
            trace_path = os.path.join(profile_dir, f"trace_step{step:06d}_micro{micro_step:02d}.pt.trace.json")
            snapshot_path = os.path.join(profile_dir, f"cuda_memory_step{step:06d}_micro{micro_step:02d}.pkl")

            if device_type == "cuda":
                # CUDA memory snapshot (pickle) - captures alloc/free history.
                torch.cuda.memory._record_memory_history(
                    enabled="all",
                    context="all",
                    stacks="all",
                    max_entries=args.memory_history_max_entries,
                    device=None,
                    clear_history=True,
                    compile_context=False,
                )

            activities = [ProfilerActivity.CPU]
            if device_type == "cuda":
                activities.append(ProfilerActivity.CUDA)

            on_trace_ready = None
            # W&B trace UI integration (TensorBoard trace handler writes into wandb.run.dir).
            if not use_dummy_wandb:
                on_trace_ready = wandb.profiler.torch_trace_handler()

            with profile(
                activities=activities,
                schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
                on_trace_ready=on_trace_ready,
                record_shapes=True,
                profile_memory=True,
                with_stack=False,
                with_flops=True,
            ) as prof:
                with record_function("forward"):
                    with autocast_ctx:
                        loss = model(x, y)
                train_loss = loss.detach() # for logging
                loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
                with record_function("backward"):
                    loss.backward()
                prof.step()

            # Ensure we have a JSON trace file saved under checkpoint_dir for convenient retrieval.
            if use_dummy_wandb:
                # No W&B handler => export a chrome trace ourselves.
                try:
                    prof.export_chrome_trace(trace_path)
                    print0(f"Wrote profiler trace: {trace_path}")
                except Exception as e:
                    print0(f"Failed to export profiler trace (ignored): {e}")
            else:
                # W&B handler writes traces under wandb.run.dir/pytorch_traces; copy the newest one.
                try:
                    trace_dir = os.path.join(wandb.run.dir, "pytorch_traces")  # type: ignore[attr-defined]
                    if os.path.isdir(trace_dir):
                        candidates = [
                            os.path.join(trace_dir, f)
                            for f in os.listdir(trace_dir)
                            if os.path.isfile(os.path.join(trace_dir, f))
                        ]
                        if candidates:
                            latest = max(candidates, key=os.path.getmtime)
                            shutil.copy2(latest, trace_path)
                            print0(f"Copied profiler trace: {latest} -> {trace_path}")
                        else:
                            print0(f"No profiler traces found in {trace_dir}")
                    else:
                        print0(f"Profiler trace dir missing: {trace_dir}")
                except Exception as e:
                    print0(f"Failed to copy W&B profiler trace (ignored): {e}")

            if device_type == "cuda":
                try:
                    torch.cuda.memory._dump_snapshot(snapshot_path)
                    print0(f"Wrote CUDA memory snapshot: {snapshot_path}")
                except Exception as e:
                    print0(f"Failed to dump CUDA memory snapshot (ignored): {e}")
                finally:
                    try:
                        torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
                    except Exception:
                        pass

            # Upload raw files to W&B as an artifact for easy download.
            if not use_dummy_wandb:
                try:
                    artifact = wandb.Artifact(
                        name=f"profile-{args.run}-step{step:06d}-rank{ddp_rank}",
                        type="profiler",
                    )
                    if os.path.exists(trace_path):
                        artifact.add_file(trace_path)
                    
                    wandb_run.log_artifact(artifact)
                except Exception as e:
                    print0(f"Failed to upload profiler outputs to W&B (ignored): {e}")
                try:
                    artifact = wandb.Artifact(
                        name=f"cuda-memory-{args.run}-step{step:06d}-rank{ddp_rank}",
                        type="cuda-memory-snapshot",
                    )
                    if os.path.exists(snapshot_path):
                        artifact.add_file(snapshot_path)
                    wandb_run.log_artifact(artifact)
                except Exception as e:
                    print0(f"Failed to upload CUDA memory snapshot to W&B (ignored): {e}")

            # Prefetch after profiling, but outside the profiler context.
            x, y, dataloader_state_dict = next(train_loader)
            continue

############################################################################## end of profile block

        else:
            with record_function("forward"):
                with autocast_ctx:
                    loss = model(x, y)
        train_loss = loss.detach() # for logging
        loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
        with record_function("backward"):
            loss.backward()
        x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward

############################################################################### profile block

    # Profiling logging every 10 steps (only if profiling enabled)
    if args.use_profiler and step % 10 == 0 and master_process:
        # Activations
        act_log = {f"act_norm_max/{k}": v for k, v in activation_norms.items()}
        wandb_run.log(act_log, step=step)
        activation_norms.clear()

        # Grads
        grad_norms_by_type = {}
        for name, param in model.named_parameters():
            if param.grad is not None:
                parts = name.split('.')
                module = model
                try:
                    for p in parts[:-1]:
                        if p.isdigit():
                            module = module[int(p)]
                        else:
                            module = getattr(module, p)
                    layer_type = module.__class__.__name__
                except:
                    layer_type = "Unknown"
                if layer_type not in grad_norms_by_type:
                    grad_norms_by_type[layer_type] = []
                grad_norms_by_type[layer_type].append(param.grad.norm().item())
        grad_log = {}
        for layer_type, norms in grad_norms_by_type.items():
            if norms:
                grad_log[f"grad_norm_max/{layer_type}"] = max(norms)
                grad_log[f"grad_norm_mean/{layer_type}"] = sum(norms)/len(norms)
        wandb_run.log(grad_log, step=step)

    # Weight norms every 100 steps (only if profiling enabled)
    if args.use_profiler and step % 100 == 0 and master_process:
        weight_norms_by_type = {}
        for name, param in model.named_parameters():
            parts = name.split('.')
            module = model
            try:
                for p in parts[:-1]:
                    if p.isdigit():
                        module = module[int(p)]
                    else:
                        module = getattr(module, p)
                layer_type = module.__class__.__name__
            except:
                layer_type = "Unknown"
            if layer_type not in weight_norms_by_type:
                weight_norms_by_type[layer_type] = []
            weight_norms_by_type[layer_type].append(param.norm().item())
        weight_log = {}
        for layer_type, norms in weight_norms_by_type.items():
            if norms:
                weight_log[f"weight_norm_max/{layer_type}"] = max(norms)
                weight_log[f"weight_norm_mean/{layer_type}"] = sum(norms)/len(norms)
        wandb_run.log(weight_log, step=step)

######################################################################### end of profile block

    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    # Detect gradient spikes and handle them
    grad_norm_value = grad_norm.item()
    spike_threshold = 15.0  # Adjust based on your model's typical norms (e.g., 5-20x the average)
    if grad_norm_value > spike_threshold:
        print0(f"Gradient spike detected at step {step}: norm={grad_norm_value:.2f}. Skipping optimizer step to prevent instability.")
        model.zero_grad(set_to_none=True)  # Reset grads without updating
        continue  # Skip to next iteration
    # step the optimizers
    lrm = get_lr_multiplier(step)
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * lrm
    muon_momentum = get_muon_momentum(step)
    for group in muon_optimizer.param_groups:
        group["momentum"] = muon_momentum
    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)
    synchronize()
    t1 = time.time()
    dt = t1 - t0
    # -------------------------------------------------------------------------

    # logging
    ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
    smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
    debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
    pct_done = 100 * step / num_iterations
    tok_per_sec = int(args.total_batch_size / dt)
    flops_per_sec = num_flops_per_token * args.total_batch_size / dt
    promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
    mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
    if step > 10:
        total_training_time += dt # only count the time after the first 10 steps
    # Calculate ETA based on average time per step (excluding first 10 steps)
    steps_done = step - 10
    if steps_done > 0:
        avg_time_per_step = total_training_time / steps_done
        remaining_steps = num_iterations - step
        eta_seconds = remaining_steps * avg_time_per_step
        eta_str = f" | eta: {eta_seconds/60:.1f}m"
    else:
        eta_str = ""
    
    # Update tqdm progress bar with metrics
    pbar.set_postfix({
        "loss": f"{debiased_smooth_loss:.4f}",
        "lrm": f"{lrm:.2f}",
        "tok/s": f"{tok_per_sec:,}",
        "mfu": f"{mfu:.1f}",
        "time": f"{total_training_time/60:.1f}m"
    })
    pbar.update(1)

    if step % (8/ddp_world_size) == 0:
        log_data = {
            "step": step,
            "total_training_flops": flops_so_far,
            "total_training_time": total_training_time,
            "train/loss": debiased_smooth_loss,
            "train/grad_norm": grad_norm.item(),
            "train/lrm": lrm,
            "train/dt": dt,
            "train/tok_per_sec": tok_per_sec,
            "train/mfu": mfu,
        }
        wandb_run.log(log_data)

    # state update
    step += 1

# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
if val_bpb is not None:
    print0(f"Minimum validation bpb: {min_val_bpb:.4f}")

report_data += [
    { # stats about training outcomes
        "Minimum validation bpb": min_val_bpb if val_bpb is not None else None,
        "Final validation bpb": val_bpb,
        "CORE metric estimate": results.get("core_metric", None),
        "MFU %": f"{mfu:.2f}%",
        "Total training flops": f"{flops_so_far:e}",
        "Total training time": f"{total_training_time/60:.2f}m",
        "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
    }
]

wandb_run.log({"report/" + k: v for k, v in report_data[2].items()})
    
# Log to report
get_report().log(section="Base model training", data=report_data)

# cleanup - force garbage collection and clear CUDA cache before destroying process group
import gc
wandb_run.finish() # wandb run finish
if device_type == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
gc.collect()
compute_cleanup()
