import os
import sys
import hashlib
with open(sys.argv[0], 'rb') as _f:
    _code_bytes = _f.read()
code_sha256 = hashlib.sha256(_code_bytes).hexdigest()
import uuid
import time
import copy
import glob
from dataclasses import dataclass, asdict
from functools import lru_cache
from pathlib import Path
import argparse # Keep argparse for --unet and potentially --optimizer_mode
import json
import random 
import numpy as np 

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
from torch import Tensor, nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.attention.flex_attention import BlockMask, flex_attention
sys.path.append("") # Add any necessary paths here
from optimizers.MUON_new import Muon
from utils.float_compute import mm_op, backward as mm_backward_custom, setup_context as mm_setup_context_custom # Renamed


# -----------------------------------------------------------------------------

mm_op.register_autograd(mm_backward_custom, setup_context=mm_setup_context_custom) # Use renamed imports

# -----------------------------------------------------------------------------
# Seeding Function
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print(f"PRINT: Set seed to {seed}", flush=True) # Print immediately for all ranks

# -----------------------------------------------------------------------------
def _load_data_shard(file: Path):
    header = torch.from_file(str(file), False, 256, dtype=torch.int32)
    assert header[0] == 20240520, "magic number mismatch in the data .bin file"
    assert header[1] == 1, "unsupported version"
    num_tokens = int(header[2])
    with file.open("rb", buffering=0) as f:
        tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
        f.seek(256 * 4)
        nbytes = f.readinto(tokens.numpy())
        assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
    return tokens

def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int):
    files = [Path(file) for file in sorted(glob.glob(filename_pattern))]
    assert batch_size % world_size == 0
    local_batch_size = batch_size // world_size
    file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
    tokens, pos = _load_data_shard(next(file_iter)), 0
    while True:
        if pos + batch_size + 1 >= len(tokens):
            tokens, pos = _load_data_shard(next(file_iter)), 0
        buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
        inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
        targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful.
        pos += batch_size
        yield inputs, targets

# ---- ADD: spectral metrics helper right after calculate_svd_entropy ----
def calculate_svd_metrics(matrix: torch.Tensor, *, topk: int = 10):
    """
    Returns dict with:
      - entropy_norm: normalized SVD entropy
      - erank: effective rank = exp(Shannon entropy of p)
      - topk_energy: sum of top-k p_i (energy fraction in the top-k singular values)
      - q75_q25: ratio of 75th to 25th percentile of eigenvalues (sigma^2)
    """
    with torch.no_grad():
        s = torch.linalg.svdvals(matrix.detach().to('cpu', torch.float32))
        s = s[s > 1e-9]
        n = s.numel()
        if n == 0:
            return dict(entropy_norm=0.0, erank=0.0, topk_energy=0.0, q75_q25=float('inf'))

        s2 = s * s
        S2_sum = float(torch.sum(s2))
        if S2_sum == 0.0:
            return dict(entropy_norm=0.0, erank=0.0, topk_energy=0.0, q75_q25=float('inf'))

        p = s2 / S2_sum  # energy distribution
        # Shannon entropy H (natural log)
        H = float(torch.sum(torch.special.entr(p)))
        entropy_norm = H / np.log(max(n, 2))  # same normalization as your SVD entropy
        erank = float(np.exp(H))

        k = min(topk, n)
        topk_energy = float(torch.topk(p, k).values.sum())

        
        q25 = float(torch.quantile(s2, 0.25))
        q75 = float(torch.quantile(s2, 0.75))
        q75_q25 = (q75 / q25) if q25 > 0 else float('inf')

        return dict(
            entropy_norm=entropy_norm,
            erank=erank,
            topk_energy=topk_energy,
            q75_q25=q75_q25,
        )


# -----------------------------------------------------------------------------
# int main
parser = argparse.ArgumentParser(description="NanoGPT Training Script with Muon")
parser.add_argument("--unet", action="store_true", help="Use U-net architecture")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
# --- MODIFICATION: Add optimizer_mode as a CLI argument ---
parser.add_argument("--optimizer_mode", type=int, default=0,
                    help="Defines how Muon is applied. "
                         "0: Muon(All Hidden Attn+MLP - original); "
                         "1: Muon(QK Attn)/Adam(VO Attn,MLP); "
                         "2: Muon(VO Attn)/Adam(QK Attn,MLP); "
                         "3: Muon(All Attn)/Adam(MLP); "
                         "4: Muon(MLP)/Adam(All Attn)"
                         "5: All Adam (No Muon, all applicable matrices to Adam)."
                         "6: Muon(W_2 MLP)/Adam(attn, W_1 MLP)."
                         "7: Muon(VO Attn, MLP)/Adam(QK Attn)."
                         "8: Muon(VO Attn, W_2 MLP)/Adam(QK Attn, W_1 MLP)."
                         )
parser.add_argument("--model_parameterization", type=str, default="whole",choices=["whole","qkvo", "norope", "gated"])
parser.add_argument("--adam_lr", type=float, default=0.008, help="Learning rate for Adam matrices")
parser.add_argument("--muon_lr", type=float, default=0.05, help="Learning rate for Muon matrices")
parser.add_argument("--base_dir", type=str, default="logs/gated", help="Base directory for logs")
exp_args = parser.parse_args()
set_seed(exp_args.seed)

# --- MODIFICATION: Import correct GPT model based on --unet flag ---
#if exp_args.unet:
#    print("Using U-net architecture")
#    from models.nano_GPT_unet import GPT
    
if exp_args.model_parameterization == "qkvo":
    print("Using architecture (models.nano_gpt_qkvo) with CausalSelfAttention having q_w, k_w, v_w")

    from models.nano_GPT_qkvo import GPT

elif exp_args.model_parameterization == "gated":
    print("Using architecture (models.nano_GPT_gated) with CausalSelfAttention having q_w, k_w, v_w")
    from models.nano_GPT_gated import GPT

#elif exp_args.model_parameterization == "whole":
#    print("Using original architecture")
#    from models.nano_GPT import GPT
    
@dataclass
class Hyperparameters:
    # data

    train_files = "" # fineweb dataset
    val_files = "" # fineweb dataset
    val_tokens = 10485760
    train_seq_len = 48*1024 # FlexAttention sequence length
    val_seq_len = 4*64*1024 # FlexAttention sequence length for validation

    # optimization
    num_iterations = 10000 #1770 # Original: 1770
    cooldown_frac = 0.4
    # architecture

    vocab_size = 50257

    # evaluation and logging
    val_loss_every = 200 
    save_checkpoint = False
args = Hyperparameters()

# DDP setup (KEEP AS IS, but ensure rank and world_size are correctly used)
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Used for device setting
world_size = int(os.environ.get("WORLD_SIZE", 1))

# print(f"[Rank {rank}] Global Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size}", flush=True) # Debug

assert torch.cuda.is_available()
device = torch.device("cuda", local_rank) # Use local_rank for device
torch.cuda.set_device(device)

if not dist.is_initialized(): # Ensure DDP is initialized only once
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) # Pass rank and world_size
dist.barrier()
master_process = (rank == 0)

logfile = None


logfile = None
run_dir_path_str = None 

base_log_dir = Path(exp_args.base_dir)

if master_process:
    # Set seed again specifically for master process for operations like dir creation, config saving
    set_seed(exp_args.seed)

    # Construct folder name based on config and seed
    run_folder_name = f"mode_{exp_args.optimizer_mode}_param_{exp_args.model_parameterization}_seed_{exp_args.seed}"
    run_dir_path = base_log_dir / run_folder_name
    run_dir_path.mkdir(parents=True, exist_ok=True)
    run_dir_path_str = str(run_dir_path)

    run_uuid = uuid.uuid4() 
    logfile = run_dir_path / f"training_log_{run_uuid}.txt" 
    print(f"Logging to: {logfile}")

    # Save configuration
    config_to_save = {
        "cli_args": vars(exp_args),
        "hyperparameters": {k: v for k, v in args.__class__.__dict__.items() if not k.startswith('__') and not callable(v)}, 
        "run_uuid_for_log": str(run_uuid),
        "code_sha256": code_sha256
    }
    config_file_path = run_dir_path / "config.json"
    with open(config_file_path, "w") as f:
        json.dump(config_to_save, f, indent=4)
    print(f"Saved configuration to: {config_file_path}")

def print0(s, console=False):
    if master_process:
        # Add timestamp and rank for better log readability
        timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        log_message = f"[{timestamp}] [Rank {rank}] {s}"
        
        # Print to console if requested or if it's a specific "PRINT:" message
        if console or s.startswith("PRINT:"):
            actual_s = s[6:] if s.startswith("PRINT:") else s
            print(actual_s) # Print to stdout for master process

        if logfile:
            with open(logfile, "a") as f:
                f.write(log_message + "\n")

        with open(logfile, "a") as f:
            f.write(log_message + "\n")


print0(f"PRINT: --- Script Start: {time.ctime()} ---", console=True) 
print0(f"PRINT: Parsed CLI args: {exp_args}", console=True)
print0(f"PRINT: Hyperparameters: {args}", console=True) 
print0(f"PRINT: Using fixed seed: {exp_args.seed}", console=True) 
if master_process: 
     print0(f"PRINT: Run directory: {run_dir_path_str}", console=True)
print0(f"PRINT: code_sha256={code_sha256}")
# ... (other initial logs)

########################################
#    Construct model and optimizer     #
########################################
print0("PRINT: Constructing model...", console=True)
model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=12, num_heads=6, model_dim=768,
                       max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda()
for m in model.modules():
    if isinstance(m, nn.Embedding):
        m.bfloat16()
print0("PRINT: Broadcasting model parameters...", console=True)
for param in model.parameters():
    dist.broadcast(param.detach(), 0)
print0("PRINT: Model constructed and broadcasted.", console=True)

# --- START MODIFIED PARAMETER COLLECTION AND OPTIMIZER SETUP ---
if exp_args.model_parameterization == "qkvo" or exp_args.model_parameterization == "norope":
    print0("PRINT: Collecting parameters for optimizers...", console=True)
    head_params = [model.lm_head.weight]
    embed_params = [model.embed.weight]

    # Granular collection for attention and MLP parts
    attn_q_params = []
    attn_k_params = []
    attn_v_params = []
    attn_o_params = [] # W_O from c_proj
    mlp_fc_params = []
    mlp_proj_params = []

    for block_module in model.blocks:
        if block_module.attn is not None:
            # These attributes (q_w, k_w, v_w) MUST exist in your CausalSelfAttention class
            if hasattr(block_module.attn, 'q_w'): attn_q_params.append(block_module.attn.q_w)
            else: print0(f"PRINT: Warning: q_w not found in attn module of a block.", console=True)
            if hasattr(block_module.attn, 'k_w'): attn_k_params.append(block_module.attn.k_w)
            else: print0(f"PRINT: Warning: k_w not found in attn module of a block.", console=True)
            if hasattr(block_module.attn, 'v_w'): attn_v_params.append(block_module.attn.v_w)
            else: print0(f"PRINT: Warning: v_w not found in attn module of a block.", console=True)
            attn_o_params.append(block_module.attn.c_proj.weight)
        if block_module.mlp is not None:
            mlp_fc_params.append(block_module.mlp.c_fc.weight)
            mlp_proj_params.append(block_module.mlp.c_proj.weight)

    # Combine into logical groups for experiments
    attn_qk_group = attn_q_params + attn_k_params
    attn_vo_group = attn_v_params + attn_o_params
    all_attn_matrices = attn_qk_group + attn_vo_group
    mlp_w1_group = mlp_fc_params
    mlp_w2_group = mlp_proj_params
    all_mlp_matrices = mlp_fc_params + mlp_proj_params

    # Scalar parameters (all others not explicitly grouped as matrices)
    matrix_params_for_scalar_check = set(head_params + embed_params + all_attn_matrices + all_mlp_matrices)
    scalar_params = [p for n, p in model.named_parameters() if p not in matrix_params_for_scalar_check]
    for p_scalar in scalar_params: # Sanity check
        if p_scalar.ndim >=2:
            print0(f"PRINT: Warning - Parameter {p_scalar.shape} ended up in scalar_params but has ndim >= 2. Check grouping.", console=True)


    # Determine parameter distribution based on optimizer_mode
    muon_params_target_list = []
    adam_matrix_target_list = [] # Matrices that Adam will handle specifically
    adam_matrix_lr = exp_args.adam_lr  # LR for matrices if Adam handles them (can be tuned)

    current_optimizer_mode = exp_args.optimizer_mode
    print0(f"PRINT: Configuring optimizers for EXPERIMENT_MODE = {current_optimizer_mode}", console=True)

    if current_optimizer_mode == 0: # Original behavior: Muon on all "hidden_matrix_params"
        print0(f"PRINT: Mode 0: Muon on ALL Attention (QKVO) and ALL MLP matrices.", console=True)
        muon_params_target_list = all_attn_matrices + all_mlp_matrices
        # Adam handles embeds, head, scalars by default. No extra matrices for Adam here.
    elif current_optimizer_mode == 1: # Muon on QK, Adam on VO and MLP
        print0(f"PRINT: Mode 1: Muon on QK Attn. Adam on VO Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_qk_group
        adam_matrix_target_list = attn_vo_group + all_mlp_matrices
    elif current_optimizer_mode == 2: # Muon on VO, Adam on QK and MLP
        print0(f"PRINT: Mode 2: Muon on VO Attn. Adam on QK Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group
        adam_matrix_target_list = attn_qk_group + all_mlp_matrices
    elif current_optimizer_mode == 3: # Muon on All Attn (QKVO), Adam on MLP
        print0(f"PRINT: Mode 3: Muon on ALL Attn (QKVO). Adam on MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = all_attn_matrices
        adam_matrix_target_list = all_mlp_matrices
    elif current_optimizer_mode == 4: # Muon on MLP, Adam on All Attn (QKVO)
        print0(f"PRINT: Mode 4: Muon on MLP. Adam on ALL Attn (QKVO) (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = all_mlp_matrices
        adam_matrix_target_list = all_attn_matrices
    elif current_optimizer_mode == 5: # NEW MODE 5 - All Adam
        print0(f"PRINT: Mode 5: All Adam. All Attn and MLP matrices to Adam (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = [] 
        adam_matrix_target_list = all_attn_matrices + all_mlp_matrices # All matrices to Adam
    elif current_optimizer_mode == 6: # Muon on W_2 MLP, Adam on attn, W_1 MLP
        print0(f"PRINT: Mode 6: Muon on W_2 MLP. Adam on attn, W_1 MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = mlp_w2_group
        adam_matrix_target_list = all_attn_matrices + mlp_w1_group
    elif current_optimizer_mode == 7: # Muon on VO Attn, MLP, Adam on QK Attn
        print0(f"PRINT: Mode 7: Muon on VO Attn, MLP. Adam on QK Attn (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group + all_mlp_matrices
        adam_matrix_target_list = attn_qk_group
    elif current_optimizer_mode == 8: # Muon on VO Attn, W_2 MLP, Adam on QK Attn, W_1 MLP
        print0(f"PRINT: Mode 8: Muon on VO Attn, W_2 MLP. Adam on QK Attn, W_1 MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group + mlp_w2_group
        adam_matrix_target_list = attn_qk_group + mlp_w1_group
    elif current_optimizer_mode == 9: # Muon on V Attn, MLP
        print0(f"PRINT: Mode 9: Muon on V Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_v_params + all_mlp_matrices
        adam_matrix_target_list = attn_o_params + attn_qk_group
    elif current_optimizer_mode == 10: # Muon on O Attn, MLP
        print0(f"PRINT: Mode 10: Muon on O Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_o_params + all_mlp_matrices
        adam_matrix_target_list = attn_v_params + attn_qk_group
    else:
        raise ValueError(f"Unsupported EXPERIMENT_MODE: {current_optimizer_mode}")

    # Adam optimizer setup
    adam_param_groups_config = [
        dict(params=head_params, lr=adam_matrix_lr),
        dict(params=embed_params, lr=adam_matrix_lr),
        dict(params=scalar_params, lr=adam_matrix_lr) # Scalar params always go to Adam
    ]
    # Add matrices specifically assigned to Adam for this experiment mode
    if adam_matrix_target_list:
        # Ensure adam_matrix_target_list is flat and contains Parameters
        flat_adam_matrices = [p for sublist_or_p in adam_matrix_target_list for p in (sublist_or_p if isinstance(sublist_or_p, list) else [sublist_or_p]) if p is not None]
        if flat_adam_matrices: # Only add group if there are params
            adam_param_groups_config.append(dict(params=flat_adam_matrices, lr=adam_matrix_lr))

    # Filter out any Adam groups that might be empty (e.g., if scalar_params was empty)
    adam_param_groups_config = [g for g in adam_param_groups_config if g['params']]
    optimizer1 = torch.optim.Adam(adam_param_groups_config, betas=(0.8, 0.95), eps=1e-10, fused=True)
    optimizers = [optimizer1] # Start with Adam

    # Muon optimizer setup
    if muon_params_target_list:
        # Ensure muon_params_target_list is flat, unique, and contains Parameters
        flat_unique_muon_params = []
        seen_muon_ids = set()
        for sublist_or_p in muon_params_target_list:
            for p in (sublist_or_p if isinstance(sublist_or_p, list) else [sublist_or_p]):
                if p is not None and id(p) not in seen_muon_ids:
                    flat_unique_muon_params.append(p)
                    seen_muon_ids.add(id(p))
        
        if flat_unique_muon_params: # Only create Muon if it has parameters
            optimizer2 = Muon(flat_unique_muon_params, lr=exp_args.muon_lr, momentum=0.95, weight_decay=0.0) # Pass nesterov, ns_steps
            optimizers.append(optimizer2)
        else:
            print0("PRINT: Muon optimizer not created as its target parameter list was empty.", console=True)
            optimizer2 = None # Explicitly set to None if not created
    else:
        print0("PRINT: Muon optimizer not created as muon_params_target_list was empty (e.g. mode where Adam handles all matrices).", console=True)
        optimizer2 = None # Explicitly set to None

    print0(f"PRINT: Optimizers configured. Total optimizers: {len(optimizers)}", console=True)
    if optimizer2:
        print0(f"PRINT: Muon optimizer is active with {len(flat_unique_muon_params)} parameters.", console=True)
    # --- END MODIFIED PARAMETER COLLECTION AND OPTIMIZER SETUP ---
elif exp_args.model_parameterization == "gated" :
    print0("PRINT: Collecting parameters for optimizers...", console=True)
    head_params = [model.lm_head.weight]
    embed_params = [model.embed.weight]

    # Granular collection for attention and MLP parts
    attn_q_params = []
    attn_k_params = []
    attn_v_params = []
    attn_o_params = [] # W_O from c_proj
    mlp_fc_params = []
    mlp_proj_params = []
    mlp_up_params = []

    for block_module in model.blocks:
        if block_module.attn is not None:
            if hasattr(block_module.attn, 'q_w'): attn_q_params.append(block_module.attn.q_w)
            else: print0(f"PRINT: Warning: q_w not found in attn module of a block.", console=True)
            if hasattr(block_module.attn, 'k_w'): attn_k_params.append(block_module.attn.k_w)
            else: print0(f"PRINT: Warning: k_w not found in attn module of a block.", console=True)
            if hasattr(block_module.attn, 'v_w'): attn_v_params.append(block_module.attn.v_w)
            else: print0(f"PRINT: Warning: v_w not found in attn module of a block.", console=True)
            attn_o_params.append(block_module.attn.c_proj.weight)
        if block_module.mlp is not None:
            mlp_fc_params.append(block_module.mlp.c_fc.weight)
            mlp_proj_params.append(block_module.mlp.c_proj.weight)
            mlp_up_params.append(block_module.mlp.c_up.weight)

    # Combine into logical groups for experiments
    attn_qk_group = attn_q_params + attn_k_params
    attn_vo_group = attn_v_params + attn_o_params
    all_attn_matrices = attn_qk_group + attn_vo_group
    mlp_w1_group = mlp_fc_params + mlp_up_params
    mlp_w2_group = mlp_proj_params
    all_mlp_matrices = mlp_fc_params + mlp_proj_params+ mlp_up_params

    # Scalar parameters (all others not explicitly grouped as matrices)
    matrix_params_for_scalar_check = set(head_params + embed_params + all_attn_matrices + all_mlp_matrices)
    scalar_params = [p for n, p in model.named_parameters() if p not in matrix_params_for_scalar_check]
    for p_scalar in scalar_params: # Sanity check
        if p_scalar.ndim >=2:
            print0(f"PRINT: Warning - Parameter {p_scalar.shape} ended up in scalar_params but has ndim >= 2. Check grouping.", console=True)


    # Determine parameter distribution based on optimizer_mode
    muon_params_target_list = []
    adam_matrix_target_list = [] # Matrices that Adam will handle specifically
    adam_matrix_lr = exp_args.adam_lr  # LR for matrices if Adam handles them (can be tuned)

    current_optimizer_mode = exp_args.optimizer_mode
    print0(f"PRINT: Configuring optimizers for EXPERIMENT_MODE = {current_optimizer_mode}", console=True)

    if current_optimizer_mode == 0: # Original behavior: Muon on all "hidden_matrix_params"
        print0(f"PRINT: Mode 0: Muon on ALL Attention (QKVO) and ALL MLP matrices.", console=True)
        muon_params_target_list = all_attn_matrices + all_mlp_matrices
        # Adam handles embeds, head, scalars by default. No extra matrices for Adam here.
    elif current_optimizer_mode == 1: # Muon on QK, Adam on VO and MLP
        print0(f"PRINT: Mode 1: Muon on QK Attn. Adam on VO Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_qk_group
        adam_matrix_target_list = attn_vo_group + all_mlp_matrices
    elif current_optimizer_mode == 2: # Muon on VO, Adam on QK and MLP
        print0(f"PRINT: Mode 2: Muon on VO Attn. Adam on QK Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group
        adam_matrix_target_list = attn_qk_group + all_mlp_matrices
    elif current_optimizer_mode == 3: # Muon on All Attn (QKVO), Adam on MLP
        print0(f"PRINT: Mode 3: Muon on ALL Attn (QKVO). Adam on MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = all_attn_matrices
        adam_matrix_target_list = all_mlp_matrices
    elif current_optimizer_mode == 4: # Muon on MLP, Adam on All Attn (QKVO)
        print0(f"PRINT: Mode 4: Muon on MLP. Adam on ALL Attn (QKVO) (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = all_mlp_matrices
        adam_matrix_target_list = all_attn_matrices
    elif current_optimizer_mode == 5: # NEW MODE 5 - All Adam
        print0(f"PRINT: Mode 5: All Adam. All Attn and MLP matrices to Adam (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = [] 
        adam_matrix_target_list = all_attn_matrices + all_mlp_matrices # All matrices to Adam
    elif current_optimizer_mode == 6: # Muon on W_2 MLP, Adam on attn, W_1 MLP
        print0(f"PRINT: Mode 6: Muon on W_2 MLP. Adam on attn, W_1 MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = mlp_w2_group
        adam_matrix_target_list = all_attn_matrices + mlp_w1_group
    elif current_optimizer_mode == 7: # Muon on VO Attn, MLP, Adam on QK Attn
        print0(f"PRINT: Mode 7: Muon on VO Attn, MLP. Adam on QK Attn (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group + all_mlp_matrices
        adam_matrix_target_list = attn_qk_group
    elif current_optimizer_mode == 8: # Muon on VO Attn, W_2 MLP, Adam on QK Attn, W_1 MLP
        print0(f"PRINT: Mode 8: Muon on VO Attn, W_2 MLP. Adam on QK Attn, W_1 MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group + mlp_w2_group
        adam_matrix_target_list = attn_qk_group + mlp_w1_group
    elif current_optimizer_mode == 9: # Muon on V Attn, MLP
        print0(f"PRINT: Mode 9: Muon on V Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_v_params + all_mlp_matrices
        adam_matrix_target_list = attn_o_params + attn_qk_group
    elif current_optimizer_mode == 10: # Muon on O Attn, MLP
        print0(f"PRINT: Mode 10: Muon on O Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_o_params + all_mlp_matrices
        adam_matrix_target_list = attn_v_params + attn_qk_group
    else:
        raise ValueError(f"Unsupported EXPERIMENT_MODE: {current_optimizer_mode}")

    # Adam optimizer setup
    adam_param_groups_config = [
        dict(params=head_params, lr=adam_matrix_lr),
        dict(params=embed_params, lr=adam_matrix_lr),
        dict(params=scalar_params, lr=adam_matrix_lr) # Scalar params always go to Adam
    ]
    # Add matrices specifically assigned to Adam for this experiment mode
    if adam_matrix_target_list:
        # Ensure adam_matrix_target_list is flat and contains Parameters
        flat_adam_matrices = [p for sublist_or_p in adam_matrix_target_list for p in (sublist_or_p if isinstance(sublist_or_p, list) else [sublist_or_p]) if p is not None]
        if flat_adam_matrices: # Only add group if there are params
            adam_param_groups_config.append(dict(params=flat_adam_matrices, lr=adam_matrix_lr))

    # Filter out any Adam groups that might be empty (e.g., if scalar_params was empty)
    adam_param_groups_config = [g for g in adam_param_groups_config if g['params']]
    optimizer1 = torch.optim.Adam(adam_param_groups_config, betas=(0.8, 0.95), eps=1e-10, fused=True)
    optimizers = [optimizer1] # Start with Adam

    # Muon optimizer setup
    if muon_params_target_list:
        # Ensure muon_params_target_list is flat, unique, and contains Parameters
        flat_unique_muon_params = []
        seen_muon_ids = set()
        for sublist_or_p in muon_params_target_list:
            for p in (sublist_or_p if isinstance(sublist_or_p, list) else [sublist_or_p]):
                if p is not None and id(p) not in seen_muon_ids:
                    flat_unique_muon_params.append(p)
                    seen_muon_ids.add(id(p))
        
        if flat_unique_muon_params: # Only create Muon if it has parameters
            optimizer2 = Muon(flat_unique_muon_params, lr=exp_args.muon_lr, momentum=0.95, weight_decay=0.0)
            optimizers.append(optimizer2)
        else:
            print0("PRINT: Muon optimizer not created as its target parameter list was empty.", console=True)
            optimizer2 = None # Explicitly set to None if not created
    else:
        print0("PRINT: Muon optimizer not created as muon_params_target_list was empty (e.g. mode where Adam handles all matrices).", console=True)
        optimizer2 = None # Explicitly set to None

    print0(f"PRINT: Optimizers configured. Total optimizers: {len(optimizers)}", console=True)
    if optimizer2:
        print0(f"PRINT: Muon optimizer is active with {len(flat_unique_muon_params)} parameters.", console=True)
    # --- END MODIFIED PARAMETER COLLECTION AND OPTIMIZER SETUP ---
elif exp_args.model_parameterization == "whole":
    hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
    embed_params = [p for n, p in model.named_parameters() if "embed" in n]
    scalar_params = [p for p in model.parameters() if p.ndim < 2]
    head_params = [model.lm_head.weight]

    # init the optimizer(s)
    adam_params = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
    optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), eps=1e-10, fused=True)
    optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size)
    optimizers = [optimizer1, optimizer2]

for opt in optimizers:
    for group in opt.param_groups:
        group["initial_lr"] = group["lr"]

# learning rate schedule: stable then decay (KEEP AS IS, but check assert)
def get_lr(step: int):
    x = step / args.num_iterations # progress in training
    # assert 0 <= x < 1 # Original assert, might fail on last step if step == num_iterations
    if not (0 <= x <= 1): # Allow x=1 for the last step
        x = min(max(x, 0.0), 1.0) # Clamp x if step goes beyond num_iterations
        # print0(f"LR schedule x = {x:.4f} (step={step}) was clamped.", console=False) # Optional log

    if x < 1 - args.cooldown_frac:
        return 1.0
    else:
        # Ensure cooldown_frac is not zero to avoid division by zero
        w = (1 - x) / max(args.cooldown_frac, 1e-9) 
        return w * 1.0 + (1 - w) * 0.1

# attention window size schedule (KEEP AS IS)
def next_multiple_of_n(v: float | int, *, n: int):
    return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
@lru_cache(1)
def get_window_size_blocks_helper(window_size: int):
    return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
def get_window_size_blocks(step: int):
    x = step / args.num_iterations # progress in training
    if not (0 <= x <= 1):
        x = min(max(x, 0.0), 1.0) # Clamp x

    # Ensure window_size is at least 128
    window_size = max(128, next_multiple_of_n(1728 * x, n=128))
    return get_window_size_blocks_helper(window_size)

print0("PRINT: Compiling model with TorchInductor...", console=True)
# Use 'model' for compilation, not 'model_compiled' before it's defined
model_compiled: nn.Module = torch.compile(model, dynamic=False, mode="max-autotune")
print0("PRINT: Model compilation complete.", console=True)

########################################
#            Warmup kernels            #
########################################
print0("PRINT: Starting warmup...", console=True)
warmup_steps = 10 
initial_state = dict(model=copy.deepcopy(model_compiled.state_dict()), # Use model_compiled
                     optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers])
for i in range(warmup_steps):
    # print0(f"Warmup step {i+1}/{warmup_steps}", console=False) # Less verbose
    inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda")
    loss = model_compiled(inputs.to(torch.int32), targets, get_window_size_blocks(0)) # Use model_compiled
    loss.backward()
    for param in model_compiled.parameters(): # Use model_compiled
        if param.grad is not None:
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
    for opt in optimizers:
        opt.step()
    model_compiled.zero_grad(set_to_none=True) # Use model_compiled
model_compiled.load_state_dict(initial_state["model"]) # Use model_compiled
for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
    opt.load_state_dict(opt_state)
del initial_state
print0("PRINT: Warmup complete.", console=True)
torch.cuda.synchronize()


params_to_analyze = []

if exp_args.model_parameterization == "whole":
    params_to_analyze = [p for p in hidden_matrix_params if p.requires_grad]
elif exp_args.model_parameterization == "qkvo" or exp_args.model_parameterization == "gated":
    params_to_analyze = all_attn_matrices + all_mlp_matrices
    matrix_groups_for_svd = {}
    if master_process:
        matrix_groups_for_svd = {
            "attn_qk": attn_qk_group,
            "attn_vo": attn_vo_group,
            "mlp_w1": mlp_w1_group,
            "mlp_w2": mlp_proj_params 
        }



########################################
#        Training and validation       #
########################################
print0("PRINT: Starting training...", console=True)
train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size)
training_time_ms = 0
torch.cuda.synchronize()
t0 = time.perf_counter()
train_steps = args.num_iterations

for step in range(train_steps + 1): # Loop up to num_iterations (inclusive for final validation)
    last_step = (step == train_steps)

    # --------------- VALIDATION SECTION -----------------
    # Validate at step 0 (after warmup), at specified intervals, and at the very last step
    if step == 0 or last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
        torch.cuda.synchronize()
        # Add time from previous segment only if t0 was set (i.e., not the first validation at step 0)
        if step > 0 : # For step 0, t0 hasn't started a training segment yet
             current_run_time = 1000 * (time.perf_counter() - t0)
             training_time_ms += current_run_time
        
        model_compiled.eval() # Use model_compiled
        val_batch_size = world_size * args.val_seq_len
        # Ensure val_tokens is divisible by val_batch_size, or handle remainder
        if args.val_tokens % val_batch_size != 0:
            print0(f"PRINT: Warning: val_tokens ({args.val_tokens}) not perfectly divisible by val_batch_size ({val_batch_size}). Some tokens might be missed.", console=True)
        val_num_steps = args.val_tokens // val_batch_size
        
        val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
        val_loss_sum = torch.zeros(1, device=device) # Accumulate loss on device
        actual_val_steps = 0
        with torch.no_grad():
            for val_i in range(val_num_steps):
                try:
                    inputs, targets = next(val_loader)
                    loss_val = model_compiled(inputs, targets, get_window_size_blocks(step)) # Use model_compiled
                    val_loss_sum += loss_val
                    actual_val_steps += 1
                except StopIteration:
                    print0(f"PRINT: Validation data loader for '{args.val_files}' exhausted early at val_step {val_i+1}/{val_num_steps}.", console=True)
                    break # Stop if data runs out
        
        if actual_val_steps > 0:
            val_loss_avg = val_loss_sum / actual_val_steps
        else: # Handle case where no validation steps were run (e.g., val_tokens too small or data loader issue)
            val_loss_avg = torch.tensor(float('nan'), device=device) 
            print0(f"PRINT: Warning: No validation steps were completed. val_loss is NaN.", console=True)

        del val_loader # Clean up
        dist.all_reduce(val_loss_avg, op=dist.ReduceOp.AVG) # Reduce average loss

        svd_log_str = ""
        if master_process and 'matrix_groups_for_svd' in locals() and matrix_groups_for_svd:
            TOPK = 10 
            svd_results_by_category = {}

            with torch.no_grad():
                # per-category metrics (average over matrices in the group)
                for name, group_params in matrix_groups_for_svd.items():
                    if not group_params:
                        continue
                    mets = [calculate_svd_metrics(p, topk=TOPK) for p in group_params]
                    if mets:
                        avg_entropy = float(np.mean([m['entropy_norm'] for m in mets]))
                        avg_erank   = float(np.mean([m['erank']        for m in mets]))
                        avg_topkE   = float(np.mean([m['topk_energy']  for m in mets]))
                        avg_qratio  = float(np.mean([m['q75_q25']      for m in mets]))
                        svd_results_by_category[name] = dict(
                            entropy=avg_entropy, erank=avg_erank, topkE=avg_topkE, q75_q25=avg_qratio
                        )

                # VO product as another category
                vo_mets = []
                num_layers = len(attn_v_params)
                for i in range(num_layers):
                    w_v = attn_v_params[i]
                    w_o = attn_o_params[i]
                    w_ov_product = torch.matmul(w_o, w_v)
                    vo_mets.append(calculate_svd_metrics(w_ov_product, topk=TOPK))
                if vo_mets:
                    svd_results_by_category['vo_prod'] = dict(
                        entropy=float(np.mean([m['entropy_norm'] for m in vo_mets])),
                        erank=float(np.mean([m['erank']        for m in vo_mets])),
                        topkE=float(np.mean([m['topk_energy']  for m in vo_mets])),
                        q75_q25=float(np.mean([m['q75_q25']    for m in vo_mets])),
                    )

            # format logging string (append metrics after entropy)
            svd_log_parts = []
            for name, vals in svd_results_by_category.items():
                svd_log_parts.append(
                    f"{name}:H={vals['entropy']:.4f},top{TOPK}E={vals['topkE']:.2f},eRank={vals['erank']:.1f},q75/q25={vals['q75_q25']:.2f}"
                )
            svd_log_str = " ".join(svd_log_parts)

        
        # For step 0, training_time_ms is 0. For subsequent steps, it's cumulative.
        avg_step_time = training_time_ms / max(step, 1) if step > 0 else 0
        print0(f"PRINT: step:{step}/{train_steps} val_loss:{val_loss_avg.item():.4f} svd_entropy: {svd_log_str} train_time:{training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms", console=True)
        
        model_compiled.train() # Switch back to train mode
        torch.cuda.synchronize()
        t0 = time.perf_counter() # Reset timer for the next training segment

    if last_step:
        if master_process and args.save_checkpoint:
            if run_dir_path_str: # Ensure run_dir_path_str is set by master process
                checkpoint_parent_dir = Path(run_dir_path_str) / "checkpoints"
                checkpoint_parent_dir.mkdir(parents=True, exist_ok=True) # Create checkpoints subdir
                checkpoint_path = checkpoint_parent_dir / f"state_step{step:06d}.pt"
                log_checkpoint = dict(step=step, code_sha256=code_sha256, model=model_compiled.state_dict(),
                                    optimizers=[opt.state_dict() for opt in optimizers])
                torch.save(log_checkpoint, str(checkpoint_path)) # Convert Path to str for torch.save
                print0(f"PRINT: Saved checkpoint to {checkpoint_path}", console=True)
            else:
                print0("PRINT: Warning - run_dir_path_str not set, cannot save checkpoint.", console=True)
        break

    # --------------- TRAINING SECTION -----------------
    try:
        inputs, targets = next(train_loader)
    except StopIteration:
        print0(f"PRINT: Training data loader for '{args.train_files}' exhausted. Ending training early at step {step}.", console=True)
        break # End if data runs out

    loss_train = model_compiled(inputs, targets, get_window_size_blocks(step)) # Use model_compiled
    loss_train.backward()
    
    for param in model_compiled.parameters(): # Use model_compiled
        if param.grad is not None: # Check if grad exists
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
    
    current_lr_val = get_lr(step)
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * current_lr_val
    
    # --- MODIFICATION: Muon momentum warmup only if optimizer2 (Muon) exists ---
    if optimizer2 is not None: # Check if Muon optimizer was created
        for group in optimizer2.param_groups:
            frac = min(step / 300, 1) # momentum warmup for muon
            group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
            
    for opt in optimizers:
        opt.step()
        
    model_compiled.zero_grad(set_to_none=True) # Use model_compiled
    
    # Logging (less frequent for training steps)
    if step > 0 and (step % 20 == 0 or step == train_steps -1) : # Avoid logging at step 0 before first val
        # This time is for the current segment since last validation / t0 reset
        current_segment_time_ms = 1000 * (time.perf_counter() - t0)
        # approx_training_time_ms is the total cumulative time
        approx_total_training_time_ms = training_time_ms + current_segment_time_ms
        
        total_tokens_in_batch = args.train_seq_len * world_size 
        train_loss_per_token = loss_train.item() / total_tokens_in_batch if total_tokens_in_batch > 0 else loss_train.item()

        print0(f"step:{step+1}/{train_steps} train_time:{approx_total_training_time_ms:.0f}ms step_avg:{approx_total_training_time_ms/max(1, step + 1):.2f}ms", console=True) # Log to console too

print0(f"PRINT: --- Training Finished: {time.ctime()} ---", console=True)
print0(f"PRINT: Peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
       f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)

if dist.is_initialized():
    dist.destroy_process_group()