import os
import sys
import hashlib
with open(sys.argv[0], 'rb') as _f:
    _code_bytes = _f.read()
code_sha256 = hashlib.sha256(_code_bytes).hexdigest()
import uuid
import time
import copy
import glob
import math
from dataclasses import dataclass, asdict
from functools import lru_cache
from pathlib import Path
import argparse # Keep argparse for --unet and potentially --optimizer_mode
import json
import random 
import numpy as np 
import itertools
from itertools import cycle
from transformers import GPT2Tokenizer
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from tqdm import tqdm
import re


# 

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
from torch import Tensor, nn
import torch.nn.functional as F
import torch.distributed as dist
# use of FlexAttention
from torch.nn.attention.flex_attention import BlockMask, flex_attention
#sys.path.append("") 
from optimizers.MUON_new import Muon
from utils.float_compute import mm_op, backward as mm_backward_custom, setup_context as mm_setup_context_custom # Renamed


# -----------------------------------------------------------------------------

mm_op.register_autograd(mm_backward_custom, setup_context=mm_setup_context_custom) # Use renamed imports

# -----------------------------------------------------------------------------
# Seeding Function
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print(f"PRINT: Set seed to {seed}", flush=True) # Print immediately for all ranks



# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader (KEEP AS IS)
def _load_data_shard(file: Path):
    header = torch.from_file(str(file), False, 256, dtype=torch.int32)
    assert header[0] == 20240520, "magic number mismatch in the data .bin file"
    assert header[1] == 1, "unsupported version"
    num_tokens = int(header[2])
    with file.open("rb", buffering=0) as f:
        tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
        f.seek(256 * 4)
        nbytes = f.readinto(tokens.numpy())
        assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
    return tokens

def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int):
    files = [Path(file) for file in sorted(glob.glob(filename_pattern))]
    assert batch_size % world_size == 0
    local_batch_size = batch_size // world_size
    file_iter = cycle(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
    tokens, pos = _load_data_shard(next(file_iter)), 0
    while True:
        if pos + batch_size + 1 >= len(tokens):
            tokens, pos = _load_data_shard(next(file_iter)), 0
        buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
        inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
        targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful.
        pos += batch_size
        yield inputs, targets





# -----------------------------------------------------------------------------
# int main
parser = argparse.ArgumentParser(description="NanoGPT Training Script with Muon")
parser.add_argument("--unet", action="store_true", help="Use U-net architecture")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
# --- MODIFICATION: Add optimizer_mode as a CLI argument ---
parser.add_argument("--optimizer_mode", type=int, default=0,
                    help="Defines how Muon is applied. "
                         "0: Muon(All Hidden Attn+MLP - original); "
                         "1: Muon(QK Attn)/Adam(VO Attn,MLP); "
                         "2: Muon(VO Attn)/Adam(QK Attn,MLP); "
                         "3: Muon(All Attn)/Adam(MLP); "
                         "4: Muon(MLP)/Adam(All Attn)"
                         "5: All Adam (No Muon, all applicable matrices to Adam)."
                         "6: Muon(W_2 MLP)/Adam(attn, W_1 MLP)."
                         "7: Muon(VO Attn, MLP)/Adam(QK Attn)."
                         "8: Muon(VO Attn, W_2 MLP)/Adam(QK Attn, W_1 MLP)."
                         )
parser.add_argument("--model_parameterization", type=str, default="whole",choices=["whole","qkvo"])
parser.add_argument("--per_group_k", type=int, default=100, help="Number of samples per group")
parser.add_argument("--muon_lr", type=float, default=0.01, help="Learning rate for Muon optimizer.")
parser.add_argument("--adam_lr", type=float, default=1e-3, help="Base learning rate for Adam optimizer groups.")
parser.add_argument("--sgd_lr", type=float, default=0.01, help="Learning rate for SGD optimizer (used in mode 9).")
parser.add_argument("--m_val", type=int, default=15,
                    help="Power-law exponent m used by the dataset generator.")
parser.add_argument("--qa_jsonl_path", type=str,
                    default="./data/qa_tail_m15.jsonl",
                    help="Path to the QA jsonl used for evaluation (fixed eval set).")


exp_args = parser.parse_args()
set_seed(exp_args.seed)

M_FOR_POWERLAW: int = exp_args.m_val
QA_JSONL_PATH: str = exp_args.qa_jsonl_path
PER_GROUP_K: int = exp_args.per_group_k

# --- MODIFICATION: Import correct GPT model based on --unet flag ---
#if exp_args.unet:
#    print("Using U-net architecture")
#    from models.nano_GPT_unet import GPT
if exp_args.model_parameterization == "qkvo":
    print("Using architecture (models.nano_gpt_qkvo) with CausalSelfAttention having q_w, k_w, v_w")
    from models.nano_GPT_qkvo import GPT
#elif exp_args.model_parameterization == "whole":
#    print("Using original architecture")
#    from models.nano_GPT import GPT
    
@dataclass
class Hyperparameters:
    # data
    #train_files = "" # synthetic qa dataset
    #val_files = ""   # synthetic qa dataset
    val_tokens = 1966080 
    val_tokens = 10485760
    train_seq_len = 12*1024
    val_seq_len = 4*16*1024

    # optimization
    num_iterations = 5000 #1770 # Original: 1770
    cooldown_frac = 0.8
    # architecture
    vocab_size = 50257
    # evaluation and logging
    val_loss_every = 500 # Original: 125
    save_checkpoint = False # Original: False
args = Hyperparameters()

# DDP setup (KEEP AS IS, but ensure rank and world_size are correctly used)
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Used for device setting
world_size = int(os.environ.get("WORLD_SIZE", 1))

# print(f"[Rank {rank}] Global Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size}", flush=True) # Debug

assert torch.cuda.is_available()
device = torch.device("cuda", local_rank) # Use local_rank for device
torch.cuda.set_device(device)

if not dist.is_initialized(): # Ensure DDP is initialized only once
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) # Pass rank and world_size
dist.barrier()
master_process = (rank == 0)

# Logging setup (KEEP AS IS, but maybe add optimizer_mode to filename)
logfile = None


logfile = None
run_dir_path_str = f"./logs_bios/qa/mode_{exp_args.optimizer_mode}_param_{exp_args.model_parameterization}_lr_{exp_args.adam_lr}_seed_{exp_args.seed}"
run_dir_path = Path(run_dir_path_str)

base_log_dir = Path("./logs_bios/qa") # Base log directory for bioS mixed training

if master_process:
    # Set seed again specifically for master process for operations like dir creation, config saving
    set_seed(exp_args.seed)

    # Construct folder name based on config and seed
    run_folder_name = f"mode_{exp_args.optimizer_mode}_param_{exp_args.model_parameterization}_lr_{exp_args.adam_lr}_seed_{exp_args.seed}"
    run_dir_path = base_log_dir / run_folder_name
    run_dir_path.mkdir(parents=True, exist_ok=True)
    run_dir_path_str = str(run_dir_path)

    run_uuid = uuid.uuid4() 
    logfile = run_dir_path / f"training_log_{run_uuid}.txt" 
    print(f"Logging to: {logfile}")

    # Save configuration (anonymized; store only hash, not source)
    config_to_save = {
        "cli_args": vars(exp_args),
        "hyperparameters": {k: v for k, v in args.__class__.__dict__.items() if not k.startswith('__') and not callable(v)}, 
        "run_uuid_for_log": str(run_uuid),
        "code_sha256": code_sha256
    }
    config_file_path = run_dir_path / "config.json"
    with open(config_file_path, "w") as f:
        json.dump(config_to_save, f, indent=4)
    print(f"Saved configuration to: {config_file_path}")

def print0(s, console=False):
    if master_process:
        # Add timestamp and rank for better log readability
        timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        log_message = f"[{timestamp}] [Rank {rank}] {s}"
        
        # Print to console if requested or if it's a specific "PRINT:" message
        if console or s.startswith("PRINT:"):
            actual_s = s[6:] if s.startswith("PRINT:") else s
            print(actual_s) # Print to stdout for master process

        if logfile:
            with open(logfile, "a") as f:
                f.write(log_message + "\n")

        with open(logfile, "a") as f:
            f.write(log_message + "\n")


print0(f"PRINT: --- Script Start: {time.ctime()} ---", console=True)
print0(f"PRINT: Parsed CLI args: {exp_args}", console=True)
print0(f"PRINT: Hyperparameters: {args}", console=True)
print0(f"PRINT: Using fixed seed: {exp_args.seed}", console=True)
if master_process:
    print0(f"PRINT: Run directory: {run_dir_path_str}", console=True)
print0(f"PRINT: code_sha256={code_sha256}", console=True)
# ... (other initial logs)



# -----------------------------------------------------------------------------

def generate_powerlaw_selection_counts(m: int):
    """Construct class sample counts to match the paper's distribution."""
    selection_counts = {}
    class_groups = []
    class_id = 0
    for group_id in range(m + 1):
        if group_id == 0: num_classes = 1
        else: num_classes = 2 ** (group_id - 1)
        samples_per_class = 2 ** (m - group_id)
        if samples_per_class < 1: continue
        for _ in range(num_classes):
            selection_counts[class_id] = samples_per_class
            class_groups.append(group_id)
            class_id += 1
    return selection_counts, class_groups


def run_detailed_evaluation(model, tokenizer, qa_data_path, device, m_val, class_to_group_map, fixed_indices=None):
    """
    In a single evaluation, compute Per-Class Loss, Per-Class FTA, Total Loss, and Total FTA.
    """
    print0("\n--- Starting Detailed Evaluation (Loss & FTA) ---", console=True)
    model.eval()

    qa_data = []
    if fixed_indices is not None:
        needed = set()
        for arr in fixed_indices.values():
            needed.update(arr)
        with open(qa_data_path, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx in needed:
                    try:
                        qa_data.append(json.loads(line))
                    except Exception:
                        continue
        print0(f"PRINT: Fixed-eval set loaded with {len(qa_data)} samples.", console=True)
    else:
        with open(qa_data_path, 'r', encoding='utf-8') as f:
            qa_data = [json.loads(line) for line in f]
        print0(f"PRINT: WARNING: fixed_indices is None; using all {len(qa_data)} samples.", console=True)


    # 2. Initialize counters
    group_losses = defaultdict(float)
    group_loss_counts = defaultdict(int)  # For loss sample count
    group_correct = defaultdict(int)
    group_total_fta = defaultdict(int)    # For FTA sample count

    # 3. Evaluation loop
    with torch.no_grad():
        for item in tqdm(qa_data, desc="Detailed Evaluation", disable=(not master_process)):
            if not item or 'text' not in item or not item['text']: continue
            
            group_id = class_to_group_map.get(item['class_id'])
            if group_id is None: continue

            # --- Data prep for Loss ---
            tokens = tokenizer.encode(item['text'], add_special_tokens=False)
            tokens.append(tokenizer.eos_token_id)
            original_len = len(tokens)
            if original_len < 2: continue
            
            BLOCK_SIZE = 128
            padded_len = ((original_len + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
            max_eval_len = 4096
            padded_len = min(padded_len, max_eval_len)
            
            final_tokens = tokens[:padded_len]
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            padded_input = final_tokens + [pad_token_id] * (padded_len - len(final_tokens))
            input_seq = torch.tensor(padded_input, dtype=torch.long, device=device)
            
            target_seq_list = (tokens[1:] + [pad_token_id])[:padded_len]
            target_seq_list += [-100] * (padded_len - len(target_seq_list))
            target_seq = torch.tensor(target_seq_list, dtype=torch.long, device=device)
            
            window_blocks = torch.tensor(padded_len // BLOCK_SIZE, device=device, dtype=torch.int32)

            # --- Data prep for FTA ---
            match = re.search(r'^(.*?\?)\s*Answer\s*:\s*(.*)$', item['text'], re.IGNORECASE)
            if not match: continue
            prompt, answer = match.groups()
            prompt, answer = prompt.strip(), answer.strip()
            if not answer: continue
            
            try:
                expected_token = tokenizer.encode(' ' + answer, add_special_tokens=False)[0]
            except IndexError:
                continue

            # --- Model call (once only) ---
            logits = model(input_seq, target_seq=None, sliding_window_num_blocks=window_blocks)
            if isinstance(logits, tuple): logits = logits[0]

            # --- Compute Loss ---
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), ignore_index=-100)
            if not torch.isnan(loss):
                group_losses[group_id] += loss.item()
                group_loss_counts[group_id] += 1
            
            # --- Compute FTA ---
            prompt_tokens_len = len(tokenizer.encode(prompt, add_special_tokens=False))
            if prompt_tokens_len > 0 and prompt_tokens_len <= padded_len:
                last_token_logits = logits.squeeze(0)[prompt_tokens_len - 1, :]
                predicted_token = torch.argmax(last_token_logits).item()
                
                if predicted_token == expected_token:
                    group_correct[group_id] += 1
                group_total_fta[group_id] += 1

    # 4. Aggregate results
    avg_group_loss = {str(g): group_losses[g] / group_loss_counts[g] for g in group_loss_counts if group_loss_counts[g] > 0}
    avg_group_acc = {str(g): group_correct[g] / group_total_fta[g] for g in group_total_fta if group_total_fta[g] > 0}
    
    total_loss = sum(group_losses.values()) / sum(group_loss_counts.values()) if sum(group_loss_counts.values()) > 0 else 0
    
    # Two methods for calculating total accuracy
    total_acc_weighted = sum(group_correct.values()) / sum(group_total_fta.values()) if sum(group_total_fta.values()) > 0 else 0  # Original method: weighted by samples
    total_acc_unweighted = sum(avg_group_acc.values()) / len(avg_group_acc) if avg_group_acc else 0  # New method: simple average across groups
    
    print0("--- Detailed Evaluation Complete ---", console=True)
    return {
        'per_class_loss': avg_group_loss,
        'per_class_acc': avg_group_acc,
        'total_loss': total_loss,
        'total_acc_weighted': total_acc_weighted,      # Sample-weighted total accuracy
        'total_acc_unweighted': total_acc_unweighted,  # Simple average total accuracy across groups
        'total_acc': total_acc_unweighted              # Primarily use simple average method
    }

def plot_curves(history, output_path, title, y_label, y_lim=None):
    """Generic plotting function"""
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(8, 6))
    if not history:
        print0(f"Warning: No history data for {y_label}, cannot plot.", console=True)
        plt.close()
        return
    
    is_per_class = isinstance(next(iter(history.values())), dict)
    
    if is_per_class:
        group_ids = sorted([int(g) for g in history.keys()])
        cmap = plt.get_cmap("viridis")
        norm = Normalize(vmin=min(group_ids) if group_ids else 0, vmax=max(group_ids) if group_ids else 1)
        for group_id_int in group_ids:
            group_id_str = str(group_id_int)
            epoch_data = history[group_id_str]
            epochs = sorted([int(e) for e in epoch_data.keys()])
            values = [epoch_data[str(e)] for e in epochs]
            ax.plot(epochs, values, color=cmap(norm(group_id_int)), linewidth=2.0, label=f'Group {group_id_int}')
        ax.legend(title="Class Group", bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        epochs = sorted([int(e) for e in history.keys()])
        values = [history[str(e)] for e in epochs]
        ax.plot(epochs, values, linewidth=2.5)

    ax.set_xlabel("Epoch", fontsize=14)
    ax.set_ylabel(y_label, fontsize=14)
    ax.set_title(title, fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=12)
    
    if y_lim:
        ax.set_ylim(y_lim)
    else:
        all_values = []
        if is_per_class:
            for group_data in history.values(): all_values.extend(group_data.values())
        else:
            all_values = list(history.values())
        if all_values:
            min_val, max_val = min(all_values), max(all_values)
            ax.set_ylim(min_val * 0.95, max_val * 1.05)
            
    ax.grid(True)
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print0(f"{title} curve updated and saved to: {output_path}", console=True)
    plt.close()



def evaluate_per_class_loss(model, tokenizer, qa_data_path, device, m_val, num_samples=None):
    
    model.eval()

    with open(qa_data_path, 'r', encoding='utf-8') as f:
        qa_data = [json.loads(line) for line in f]
    
    if num_samples is not None and num_samples > 0 and len(qa_data) > num_samples:
        print0(f"Using stratified sampling to extract ~{num_samples} samples for evaluation...", console=True)
        data_by_class = defaultdict(list)
        for item in qa_data:
            data_by_class[item['class_id']].append(item)
        sample_ratio = num_samples / len(qa_data)
        stratified_sample_data = []
        for class_id, items in data_by_class.items():
            num_to_sample = max(1, int(len(items) * sample_ratio))
            sampled_items = random.sample(items, min(len(items), num_to_sample))
            stratified_sample_data.extend(sampled_items)
        qa_data = stratified_sample_data
        print0(f"Evaluation set size after sampling: {len(qa_data)}", console=True)
    # =================================================================

    # 3. Create mapping
    selection_counts, class_groups = generate_powerlaw_selection_counts(m_val)
    class_to_group_map = {class_id: group_id for class_id, group_id in zip(selection_counts.keys(), class_groups)}

    group_losses = defaultdict(float)
    group_counts = defaultdict(int)

    with torch.no_grad():
        for item in tqdm(qa_data, desc="Detailed Evaluation", disable=not master_process):
            if not item or 'text' not in item or not item['text']: continue
            group_id = class_to_group_map.get(item['class_id'])
            if group_id is None: continue

            tokens = tokenizer.encode(item['text'], add_special_tokens=False)
            tokens.append(tokenizer.eos_token_id)
            
            original_len = len(tokens)
            if original_len < 2: continue

            BLOCK_SIZE = 128
            padded_len = ((original_len + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE
            max_eval_len = 4096
            padded_len = min(padded_len, max_eval_len)
            
            final_tokens = tokens[:padded_len]
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            padded_input = final_tokens + [pad_token_id] * (padded_len - len(final_tokens))
            
            input_seq = torch.tensor(padded_input, dtype=torch.long, device=device)
            
            target_seq_list = (tokens[1:] + [pad_token_id])[:padded_len]
            target_seq_list += [-100] * (padded_len - len(target_seq_list))
            target_seq = torch.tensor(target_seq_list, dtype=torch.long, device=device)

            window_blocks = torch.tensor(padded_len // BLOCK_SIZE, device=device, dtype=torch.int32)

            loss = model(input_seq, target_seq, window_blocks)
            
            if loss is not None and not torch.isnan(loss):
                group_losses[group_id] += loss.item()
                group_counts[group_id] += 1

    avg_group_losses = {str(group): group_losses[group] / group_counts[group]
                        for group in group_losses if group_counts[group] > 0}
    
    print0("--- Per-Class Loss Evaluation Complete ---", console=True)
    return avg_group_losses

def plot_loss_curves(loss_history, output_path, plot_title="Per-Class Loss"):
    """Plot loss curve from aggregated history data"""
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(8, 6))
    if not loss_history:
        print0("Warning: Loss history is empty. Cannot plot.", console=True)
        plt.close()
        return
    group_ids = sorted([int(g) for g in loss_history.keys()])
    cmap = plt.get_cmap("viridis")
    norm = Normalize(vmin=min(group_ids) if group_ids else 0, vmax=max(group_ids) if group_ids else 1)
    for group_id_int in group_ids:
        group_id_str = str(group_id_int)
        epoch_data = loss_history[group_id_str]
        epochs = sorted([int(e) for e in epoch_data.keys()])
        losses = [epoch_data[str(e)] for e in epochs]
        ax.plot(epochs, losses, color=cmap(norm(group_id_int)), linewidth=2.0, label=f'Group {group_id_int}')
    ax.set_xlabel("Step", fontsize=14)
    ax.set_ylabel("Per-Class Loss", fontsize=14)
    ax.set_title(plot_title, fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=12)
    all_losses = [loss for group_data in loss_history.values() for loss in group_data.values()]
    if all_losses:
        min_loss, max_loss = min(all_losses), max(all_losses)
        ax.set_ylim(min_loss * 0.95, max_loss * 1.05)
    ax.legend(title="Class Group")
    ax.grid(True)
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print0(f"Per-Class Loss curve updated and saved to: {output_path}", console=True)
    plt.close()



########################################
#    Construct model and optimizer     #
########################################

print0("PRINT: Constructing model...", console=True)
model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=12, num_heads=6, model_dim=768,
                       max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda()
for m in model.modules():
    if isinstance(m, nn.Embedding):
        m.bfloat16()
print0("PRINT: Broadcasting model parameters...", console=True)
for param in model.parameters():
    dist.broadcast(param.detach(), 0)
print0("PRINT: Model constructed and broadcasted.", console=True)


if master_process:
    print0("PRINT: Testing model forward function:", console=True)
    try:
        test_input = torch.randint(0, 1000, (128,), device=device, dtype=torch.int32)
        test_blocks = torch.tensor(1, device=device)
        model.eval()
        with torch.no_grad():
            result = model(test_input, None, test_blocks)
        model.train()
        
        print0(f"PRINT: Model test - Result type: {type(result)}", console=True)
        if isinstance(result, tuple):
            print0(f"PRINT: Model test - Tuple length: {len(result)}", console=True)
            if len(result) >= 2:
                print0(f"PRINT: Model test - First element (loss): {result[0]}", console=True)
                print0(f"PRINT: Model test - Second element shape (logits): {result[1].shape if hasattr(result[1], 'shape') else 'No shape'}", console=True)
        else:
            print0(f"PRINT: Model test - Single result shape: {result.shape if hasattr(result, 'shape') else 'No shape'}", console=True)
    except Exception as e:
        print0(f"PRINT: Model test failed: {e}", console=True)


model_for_inference = model  
print0("PRINT: Saved original model reference for inference.", console=True)


if master_process:
    print0("PRINT: Testing model with target_seq=None...", console=True)
    try:
        test_input = torch.randint(0, 1000, (128,), device=device, dtype=torch.int32)
        test_blocks = torch.tensor(1, device=device)
        model.eval()
        with torch.no_grad():
            result = model(test_input, None, test_blocks)  # target_seq=None
        model.train()
        
        if isinstance(result, tuple) and len(result) == 2:
            loss, logits = result
            print0(f"PRINT: SUCCESS! Model returns (loss={loss}, logits.shape={logits.shape})", console=True)
        else:
            print0(f"PRINT: Model returns: {type(result)}", console=True)
    except Exception as e:
        print0(f"PRINT: Model test still fails: {e}", console=True)



# --- START MODIFIED PARAMETER COLLECTION AND OPTIMIZER SETUP ---
if exp_args.model_parameterization == "qkvo":
    print0("PRINT: Collecting parameters for optimizers...", console=True)
    head_params = [model.lm_head.weight]
    embed_params = [model.embed.weight]

    # Granular collection for attention and MLP parts
    attn_q_params = []
    attn_k_params = []
    attn_v_params = []
    attn_o_params = [] # W_O from c_proj
    mlp_fc_params = []
    mlp_proj_params = []

    for block_module in model.blocks:
        if block_module.attn is not None:
            # These attributes (q_w, k_w, v_w) MUST exist in your CausalSelfAttention class
            if hasattr(block_module.attn, 'q_w'): attn_q_params.append(block_module.attn.q_w)
            else: print0(f"PRINT: Warning: q_w not found in attn module of a block.", console=True)
            if hasattr(block_module.attn, 'k_w'): attn_k_params.append(block_module.attn.k_w)
            else: print0(f"PRINT: Warning: k_w not found in attn module of a block.", console=True)
            if hasattr(block_module.attn, 'v_w'): attn_v_params.append(block_module.attn.v_w)
            else: print0(f"PRINT: Warning: v_w not found in attn module of a block.", console=True)
            attn_o_params.append(block_module.attn.c_proj.weight)
        if block_module.mlp is not None:
            mlp_fc_params.append(block_module.mlp.c_fc.weight)
            mlp_proj_params.append(block_module.mlp.c_proj.weight)

    # Combine into logical groups for experiments
    attn_qk_group = attn_q_params + attn_k_params
    attn_vo_group = attn_v_params + attn_o_params
    all_attn_matrices = attn_qk_group + attn_vo_group
    mlp_w1_group = mlp_fc_params
    mlp_w2_group = mlp_proj_params
    all_mlp_matrices = mlp_fc_params + mlp_proj_params

    # Scalar parameters (all others not explicitly grouped as matrices)
    matrix_params_for_scalar_check = set(head_params + embed_params + all_attn_matrices + all_mlp_matrices)
    scalar_params = [p for n, p in model.named_parameters() if p not in matrix_params_for_scalar_check]
    for p_scalar in scalar_params: # Sanity check
        if p_scalar.ndim >=2:
            print0(f"PRINT: Warning - Parameter {p_scalar.shape} ended up in scalar_params but has ndim >= 2. Check grouping.", console=True)


    # Determine parameter distribution based on optimizer_mode
    muon_params_target_list = []
    adam_matrix_target_list = [] # Matrices that Adam will handle specifically
    adam_matrix_lr = exp_args.adam_lr  # LR for matrices if Adam handles them (can be tuned)
    muon_lr = exp_args.muon_lr

    current_optimizer_mode = exp_args.optimizer_mode
    print0(f"PRINT: Configuring optimizers for EXPERIMENT_MODE = {current_optimizer_mode}", console=True)

    if current_optimizer_mode == 0: # Original behavior: Muon on all "hidden_matrix_params"
        print0(f"PRINT: Mode 0: Muon on ALL Attention (QKVO) and ALL MLP matrices.", console=True)
        muon_params_target_list = all_attn_matrices + all_mlp_matrices
        # Adam handles embeds, head, scalars by default. No extra matrices for Adam here.
    elif current_optimizer_mode == 1: # Muon on QK, Adam on VO and MLP
        print0(f"PRINT: Mode 1: Muon on QK Attn. Adam on VO Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_qk_group
        adam_matrix_target_list = attn_vo_group + all_mlp_matrices
    elif current_optimizer_mode == 2: # Muon on VO, Adam on QK and MLP
        print0(f"PRINT: Mode 2: Muon on VO Attn. Adam on QK Attn, MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group
        adam_matrix_target_list = attn_qk_group + all_mlp_matrices
    elif current_optimizer_mode == 3: # Muon on All Attn (QKVO), Adam on MLP
        print0(f"PRINT: Mode 3: Muon on ALL Attn (QKVO). Adam on MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = all_attn_matrices
        adam_matrix_target_list = all_mlp_matrices
    elif current_optimizer_mode == 4: # Muon on MLP, Adam on All Attn (QKVO)
        print0(f"PRINT: Mode 4: Muon on MLP. Adam on ALL Attn (QKVO) (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = all_mlp_matrices
        adam_matrix_target_list = all_attn_matrices
    elif current_optimizer_mode == 5: # NEW MODE 5 - All Adam
        print0(f"PRINT: Mode 5: All Adam. All Attn and MLP matrices to Adam (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = [] 
        adam_matrix_target_list = all_attn_matrices + all_mlp_matrices # All matrices to Adam
    elif current_optimizer_mode == 6: # Muon on W_2 MLP, Adam on attn, W_1 MLP
        print0(f"PRINT: Mode 6: Muon on W_2 MLP. Adam on attn, W_1 MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = mlp_w2_group
        adam_matrix_target_list = all_attn_matrices + mlp_w1_group
    elif current_optimizer_mode == 7: # Muon on VO Attn, MLP, Adam on QK Attn
        print0(f"PRINT: Mode 7: Muon on VO Attn, MLP. Adam on QK Attn (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group + all_mlp_matrices
        adam_matrix_target_list = attn_qk_group
    elif current_optimizer_mode == 8: # Muon on VO Attn, W_2 MLP, Adam on QK Attn, W_1 MLP
        print0(f"PRINT: Mode 8: Muon on VO Attn, W_2 MLP. Adam on QK Attn, W_1 MLP (Adam LR: {adam_matrix_lr}).", console=True)
        muon_params_target_list = attn_vo_group + mlp_w2_group
        adam_matrix_target_list = attn_qk_group + mlp_w1_group
    elif current_optimizer_mode == 9: # sgd + momentum
        # This mode uses SGD with momentum for all parameters, no Muon or Adam
        print0(f"PRINT: Mode 9: Using pure SGD+Momentum (lr={exp_args.sgd_lr}).", console=True)
        all_params = list(model.parameters())
        sgd_lr = exp_args.sgd_lr  # Use learning rate from command line argument
        optimizer1 = torch.optim.SGD(all_params, lr=sgd_lr, momentum=0.9, weight_decay=1e-4)
        optimizer2 = None
        optimizers = [optimizer1]
        
        print0(f"PRINT: SGD optimizer configured with lr={sgd_lr}, momentum=0.9, weight_decay=1e-4", console=True)
    else:
        raise ValueError(f"Unsupported EXPERIMENT_MODE: {current_optimizer_mode}")

    # Skip Adam and Muon setup for SGD mode (9)
    if current_optimizer_mode != 9:
        # Adam optimizer setup
        adam_param_groups_config = [
            #dict(params=head_params, lr=0.22),
            #dict(params=embed_params, lr=0.6),
            #dict(params=scalar_params, lr=0.04) # Scalar params always go to Adam
            dict(params=head_params, lr=exp_args.adam_lr ),
            dict(params=embed_params, lr=exp_args.adam_lr ),
            dict(params=scalar_params, lr=exp_args.adam_lr ) # Scalar params always go to Adam
        ]
        # Add matrices specifically assigned to Adam for this experiment mode
        if adam_matrix_target_list:
            # Ensure adam_matrix_target_list is flat and contains Parameters
            flat_adam_matrices = [p for sublist_or_p in adam_matrix_target_list for p in (sublist_or_p if isinstance(sublist_or_p, list) else [sublist_or_p]) if p is not None]
            if flat_adam_matrices: # Only add group if there are params
                adam_param_groups_config.append(dict(params=flat_adam_matrices, lr=adam_matrix_lr))

        # Filter out any Adam groups that might be empty (e.g., if scalar_params was empty)
        adam_param_groups_config = [g for g in adam_param_groups_config if g['params']]
        optimizer1 = torch.optim.Adam(adam_param_groups_config, betas=(0.8, 0.95), eps=1e-10, fused=True)#add weight_decay=0.01 to Adam
        optimizers = [optimizer1] # Start with Adam

        # Muon optimizer setup
        if muon_params_target_list:
            # Ensure muon_params_target_list is flat, unique, and contains Parameters
            flat_unique_muon_params = []
            seen_muon_ids = set()
            for sublist_or_p in muon_params_target_list:
                for p in (sublist_or_p if isinstance(sublist_or_p, list) else [sublist_or_p]):
                    if p is not None and id(p) not in seen_muon_ids:
                        flat_unique_muon_params.append(p)
                        seen_muon_ids.add(id(p))
            
            if flat_unique_muon_params: # Only create Muon if it has parameters
                optimizer2 = Muon(flat_unique_muon_params, lr=muon_lr, momentum=0.95, nesterov=False, ns_steps=5, rank=rank, world_size=world_size) # Pass nesterov, ns_steps
                optimizers.append(optimizer2)
            else:
                print0("PRINT: Muon optimizer not created as its target parameter list was empty.", console=True)
                optimizer2 = None # Explicitly set to None if not created
        else:
            print0("PRINT: Muon optimizer not created as muon_params_target_list was empty (e.g. mode where Adam handles all matrices).", console=True)
            optimizer2 = None # Explicitly set to None

    print0(f"PRINT: Optimizers configured. Total optimizers: {len(optimizers)}", console=True)
    if optimizer2:
        print0(f"PRINT: Muon optimizer is active with {len(flat_unique_muon_params)} parameters.", console=True)
    # --- END MODIFIED PARAMETER COLLECTION AND OPTIMIZER SETUP ---
elif exp_args.model_parameterization == "whole":
    hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
    embed_params = [p for n, p in model.named_parameters() if "embed" in n]
    scalar_params = [p for p in model.parameters() if p.ndim < 2]
    head_params = [model.lm_head.weight]

    # init the optimizer(s)
    adam_params = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
    # small adam epsilon for fixing the world_size dependence
    optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), eps=1e-10, fused=True)
    optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size)
    optimizers = [optimizer1, optimizer2]

for opt in optimizers:
    for group in opt.param_groups:
        group["initial_lr"] = group["lr"]

# learning rate schedule: stable then decay (KEEP AS IS, but check assert)
def get_lr(step: int):
    x = step / args.num_iterations # progress in training
    # assert 0 <= x < 1 # Original assert, might fail on last step if step == num_iterations
    # --- MODIFICATION: Adjust assert for LR schedule ---
    if not (0 <= x <= 1): # Allow x=1 for the last step
        x = min(max(x, 0.0), 1.0) # Clamp x if step goes beyond num_iterations
        # print0(f"LR schedule x = {x:.4f} (step={step}) was clamped.", console=False) # Optional log

    if x < 1 - args.cooldown_frac:
        return 1.0
    else:
        # Ensure cooldown_frac is not zero to avoid division by zero
        w = (1 - x) / max(args.cooldown_frac, 1e-9) 
        return w * 1.0 + (1 - w) * 0.1


# attention window size schedule (KEEP AS IS)
def next_multiple_of_n(v: float | int, *, n: int):
    return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
@lru_cache(1)
def get_window_size_blocks_helper(window_size: int):
    return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
def get_window_size_blocks(step: int):
    x = step / args.num_iterations # progress in training
    # --- MODIFICATION: Adjust assert for window size schedule ---
    if not (0 <= x <= 1):
        x = min(max(x, 0.0), 1.0) # Clamp x

    # Ensure window_size is at least 128
    window_size = max(128, next_multiple_of_n(1728 * x, n=128))
    return get_window_size_blocks_helper(window_size)

print0("PRINT: Compiling model with TorchInductor...", console=True)
# Use 'model' for compilation, not 'model_compiled' before it's defined

model_compiled: nn.Module = torch.compile(model, dynamic=False, mode="max-autotune")
print0("PRINT: Model compilation complete.", console=True)

########################################
# Warmup kernels
########################################
print0("PRINT: Starting warmup...", console=True)
warmup_steps = 10
initial_state = dict(
    model=copy.deepcopy(model_compiled.state_dict()),
    optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]
)

for i in range(warmup_steps):
    inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda")
    loss = model_compiled(inputs.to(torch.int32), targets, get_window_size_blocks(0))
    loss.backward()
    for param in model_compiled.parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
    # Add gradient clipping for SGD mode in warmup too
    if exp_args.optimizer_mode == 9:
        torch.nn.utils.clip_grad_norm_(model_compiled.parameters(), max_norm=1.0)
    for opt in optimizers:
        opt.step()
    model_compiled.zero_grad(set_to_none=True)
    model_compiled.load_state_dict(initial_state["model"])
    for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
        opt.load_state_dict(opt_state)

del initial_state
print0("PRINT: Warmup complete.", console=True)
torch.cuda.synchronize()

########################################
# Training and validation
########################################
print0("PRINT: Starting training...", console=True)
train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size)
train_loss_sum = torch.zeros(1, device=device)
train_step_count = torch.zeros(1, device=device)
training_time_ms = 0
torch.cuda.synchronize()
t0 = time.perf_counter()
train_steps = args.num_iterations



if master_process:
    tokenizer_for_eval = GPT2Tokenizer.from_pretrained('gpt2')
    
    history = {
        'per_class_loss': defaultdict(dict),
        'per_class_acc': defaultdict(dict),
        'total_loss': {},
        'total_acc': {}
    }

    
    FIXED_VAL_INDEX_PATH = run_dir_path / "fixed_eval_indices.json"
    #PER_GROUP_K = 100  # Number of samples per group

    def _is_valid_qa_text_for_fta(text: str) -> bool:
        # Quick filtering for building fixed eval set, ensure parseable "?" + "Answer:"
        if not isinstance(text, str): 
            return False
        return re.search(r'^(.*?\?)\s*Answer\s*:\s*(.+)$', text, re.IGNORECASE) is not None

    def build_fixed_eval_indices(jsonl_path, class_to_group_map, per_group_k, seed=2025):
        rng = random.Random(seed)
        
        buckets = defaultdict(list)  
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                try:
                    item = json.loads(line)
                except Exception:
                    continue
                gid = class_to_group_map.get(item.get("class_id"))
                if gid is None:
                    continue
                if not _is_valid_qa_text_for_fta(item.get("text", "")):
                    continue
                buckets[gid].append(i)

        fixed = {}
        for gid, arr in buckets.items():
            if len(arr) <= per_group_k:
                fixed[str(gid)] = arr[:]  # Take all if fewer than K samples
            else:
                fixed[str(gid)] = rng.sample(arr, per_group_k)
        return fixed


    selection_counts, class_groups_list = generate_powerlaw_selection_counts(M_FOR_POWERLAW)
    class_to_group_map_global = {cid: gid for cid, gid in zip(selection_counts.keys(), class_groups_list)}

    if not FIXED_VAL_INDEX_PATH.exists():
        fixed_idx = build_fixed_eval_indices(QA_JSONL_PATH, class_to_group_map_global, PER_GROUP_K)
        with open(FIXED_VAL_INDEX_PATH, "w") as f:
            json.dump(fixed_idx, f)
        print0(f"PRINT: Built fixed eval set. Saved to {FIXED_VAL_INDEX_PATH}", console=True)
    else:
        print0(f"PRINT: Using existing fixed eval set: {FIXED_VAL_INDEX_PATH}", console=True)
        # --- FIX: Load the indices if the file already exists ---
        with open(FIXED_VAL_INDEX_PATH, "r") as f:
            fixed_idx = json.load(f)

    


for step in range(train_steps + 1):
    last_step = (step == train_steps)

    # --------- VALIDATION SECTION ---------
    if step == 0 or last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
        torch.cuda.synchronize()
        if step > 0:
            current_run_time = 1000 * (time.perf_counter() - t0)
            training_time_ms += current_run_time

        model_compiled.eval()
        val_batch_size = world_size * args.val_seq_len
        if args.val_tokens % val_batch_size != 0:
            print0(f"PRINT: Warning: val_tokens ({args.val_tokens}) not perfectly divisible by val_batch_size ({val_batch_size}). Some tokens might be missed.", console=True)

        val_num_steps = args.val_tokens // val_batch_size
        val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
        val_loss_sum = torch.zeros(1, device=device)
        actual_val_steps = 0

        with torch.no_grad():
            for val_i in range(val_num_steps):
                try:
                    inputs, targets = next(val_loader)
                    loss_val = model_compiled(inputs, targets, get_window_size_blocks(step))
                    val_loss_sum += loss_val
                    actual_val_steps += 1
                except StopIteration:
                    print0(f"PRINT: Validation data loader for '{args.val_files}' exhausted early at val_step {val_i+1}/{val_num_steps}.", console=True)
                    break

        if actual_val_steps > 0:
            val_loss_avg = val_loss_sum / actual_val_steps
        else:
            val_loss_avg = torch.tensor(float('nan'), device=device)
            print0(f"PRINT: Warning: No validation steps were completed. val_loss is NaN.", console=True)

        del val_loader
        dist.all_reduce(val_loss_avg, op=dist.ReduceOp.AVG)

        if train_step_count > 0:
            avg_train_loss = train_loss_sum / train_step_count
            dist.all_reduce(avg_train_loss, op=dist.ReduceOp.AVG)
            avg_train_loss = avg_train_loss.item()
        else:
            avg_train_loss = float('nan')

        avg_step_time = training_time_ms / max(step, 1) if step > 0 else 0



        avg_train_loss = float(avg_train_loss)
        if step == 0:
            print0(f"PRINT: step:{step}/{train_steps} val_loss:{val_loss_avg.item():.4f}  train_time:{training_time_ms:.0f}ms", console=True)
        else:
            print0(f"PRINT: step:{step}/{train_steps} train_loss:{avg_train_loss:.4f} val_loss:{val_loss_avg.item():.4f} train_time:{training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms", console=True)

            if master_process and step > 0:
                selection_counts, class_groups_list = generate_powerlaw_selection_counts(M_FOR_POWERLAW)
                class_to_group_map = {cid: gid for cid, gid in zip(selection_counts.keys(), class_groups_list)}
            
                model_for_inference.load_state_dict(model.state_dict())

                
                eval_results = run_detailed_evaluation(
                    model=model_for_inference,
                    tokenizer=tokenizer_for_eval,
                    qa_data_path=QA_JSONL_PATH,
                    device=device,
                    m_val=M_FOR_POWERLAW,
                    class_to_group_map=class_to_group_map,
                    #num_samples=NUM_SAMPLES_FOR_DETAIL_EVAL
                    fixed_indices=fixed_idx
                )

                #


                print0("--- Detailed Evaluation Results (This Step) ---", console=True)
                print0(f"  Total Loss: {eval_results['total_loss']:.4f}", console=True)
                print0(f"  Total FTA (Unweighted): {eval_results['total_acc_unweighted']:.4f}", console=True)
                print0(f"  Total FTA (Weighted):   {eval_results['total_acc_weighted']:.4f}", console=True)
                for group_id, loss in sorted(eval_results['per_class_loss'].items(), key=lambda item: int(item[0])):
                    print0(f"  Group {group_id} Loss: {loss:.4f}", console=True)
                for group_id, acc in sorted(eval_results['per_class_acc'].items(), key=lambda item: int(item[0])):
                    print0(f"  Group {group_id} FTA: {acc:.4f}", console=True)
                
                
                current_step_str = str(step)
                history['total_loss'][current_step_str] = eval_results['total_loss']
                history['total_acc'][current_step_str] = eval_results['total_acc_unweighted']  # Use simple average method
                for group_id, loss in eval_results['per_class_loss'].items():
                    history['per_class_loss'][group_id][current_step_str] = loss
                for group_id, acc in eval_results['per_class_acc'].items():
                    history['per_class_acc'][group_id][current_step_str] = acc
                
               
                plot_curves(history['per_class_loss'], run_dir_path / "per_class_loss_curves.png", "Per-Class Loss", "Loss")
                plot_curves(history['per_class_acc'], run_dir_path / "per_class_acc_curves.png", "Per-Class FTA", "Accuracy", y_lim=[0, 1])
                plot_curves(history['total_loss'], run_dir_path / "total_loss_curve.png", "Total Detailed Loss", "Loss")
                plot_curves(history['total_acc'], run_dir_path / "total_acc_curve.png", "Total Detailed FTA", "Accuracy", y_lim=[0, 1])

            if world_size > 1:
                dist.barrier()


        if master_process and args.save_checkpoint and step > 0:
            if run_dir_path_str:
                
                checkpoint_parent_dir = Path(run_dir_path_str) / "checkpoints"
                checkpoint_parent_dir.mkdir(parents=True, exist_ok=True)
                
                
                checkpoint_path = checkpoint_parent_dir / f"ckpt_epoch_{step}.pt"
                
                log_checkpoint = dict(
                    step=step,
                    code_sha256=code_sha256,
                    model=model_compiled.state_dict(),
                    optimizers=[opt.state_dict() for opt in optimizers]
                )
                
                torch.save(log_checkpoint, str(checkpoint_path))
                print0(f"PRINT: Saved checkpoint to {checkpoint_path}", console=True)
            else:
                print0("PRINT: Warning - run_dir_path_str not set, cannot save checkpoint.", console=True)

        train_loss_sum = torch.zeros(1, device=device)
        train_step_count = torch.zeros(1, device=device)
        model_compiled.train()
        torch.cuda.synchronize()
        t0 = time.perf_counter()

    # --------- TRAINING SECTION ---------
    try:
        inputs, targets = next(train_loader)
    except StopIteration:
        
        print0(f"PRINT: Training data loader for '{args.train_files}' exhausted. Ending training early at step {step}.", console=True)
        break

    loss_train = model_compiled(inputs, targets, get_window_size_blocks(step))
    loss_train.backward()
    train_loss_sum += loss_train.detach()/ args.train_seq_len
    train_step_count += 1

    for param in model_compiled.parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)

    # Add gradient clipping for SGD mode to prevent gradient explosion
    if exp_args.optimizer_mode == 9:
        torch.nn.utils.clip_grad_norm_(model_compiled.parameters(), max_norm=1.0)

    current_lr_val = get_lr(step)
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * current_lr_val

    if optimizer2 is not None:
        for group in optimizer2.param_groups:
            frac = min(step / 300, 1)
            group["momentum"] = (1 - frac) * 0.85 + frac * 0.95

    for opt in optimizers:
        opt.step()

    model_compiled.zero_grad(set_to_none=True)

    if step > 0 and (step % 20 == 0 or step == train_steps - 1):
        current_segment_time_ms = 1000 * (time.perf_counter() - t0)
        approx_total_training_time_ms = training_time_ms + current_segment_time_ms
        total_tokens_in_batch = args.train_seq_len * world_size
        train_loss_per_token = loss_train.item() / total_tokens_in_batch if total_tokens_in_batch > 0 else loss_train.item()
        print0(f"step:{step+1}/{train_steps} train_time:{approx_total_training_time_ms:.0f}ms step_avg:{approx_total_training_time_ms/max(1, step + 1):.2f}ms", console=True)

print0(f"PRINT: --- Training Finished: {time.ctime()} ---", console=True)
print0(f"PRINT: Peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
       f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)

if dist.is_initialized():
    dist.destroy_process_group()
