#!/usr/bin/env python3
"""
Unified Training Script for Transformer-Graph Models

This script provides a unified interface for training any model type
(RoBERTa, Looped Transformer, Disentangled Transformer) on various graph datasets.

Usage:
    python train.py --model_type roberta --dataset two_chains --num_nodes 32
    python train.py --model_type looped_transformer --dataset erdos_renyi --num_nodes 16
    python train.py --model_type disentangled_transformer --dataset erdos_renyi two_chains
"""

import argparse
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
import math
import os
import json
import wandb
import re
import glob
from wandb import Settings
from tqdm import tqdm
import random
import sys
from datetime import datetime

# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from models import create_model
from data import create_dataset, create_mixed_dataset
from utils import (
    FocalLoss,
    compute_path_length_accuracy,
    aggregate_path_length_accuracies,
    create_path_length_wandb_log,
    get_dataset_params,
    visualize_dataset_samples,
    evaluate_model_with_perm_metrics,
)

# Set environment variables for wandb
os.environ["WANDB_INSECURE"] = "true"
import certifi

os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()


def find_latest_checkpoint_dir(base_ckpt_path: str, model_type_folder: str) -> str:
    """
    Find the latest checkpoint directory based on date suffix.

    Args:
        base_ckpt_path: The base checkpoint path without date suffix
        model_type_folder: The model type folder (e.g., "roberta_models")

    Returns:
        Path to the latest checkpoint directory, or the original path if no dated directories found
    """
    checkpoints_base = os.path.join("checkpoints", model_type_folder)

    if not os.path.exists(checkpoints_base):
        return os.path.join(checkpoints_base, base_ckpt_path)

    # Look for directories matching the pattern: base_ckpt_path_YYYYMMDD_HHMMSS
    pattern = f"{base_ckpt_path}_*"
    matching_dirs = []

    for item in os.listdir(checkpoints_base):
        item_path = os.path.join(checkpoints_base, item)
        if os.path.isdir(item_path) and item.startswith(f"{base_ckpt_path}_"):
            # Check if it matches the date pattern
            suffix = item[len(base_ckpt_path) + 1 :]  # +1 for the underscore
            if re.match(r"\d{8}_\d{6}$", suffix):
                matching_dirs.append((suffix, item_path))

    if matching_dirs:
        # Sort by date suffix (most recent first)
        matching_dirs.sort(key=lambda x: x[0], reverse=True)
        latest_dir = matching_dirs[0][1]
        print(f"Found latest checkpoint directory: {latest_dir}")
        return latest_dir
    else:
        # No dated directories found, return original path
        return os.path.join(checkpoints_base, base_ckpt_path)


def should_save_checkpoint(global_step, save_every):
    """
    Determine if we should save a checkpoint at the given global step.

    Saves at powers of 10 (10^k) up to save_every, then switches to regular save_every intervals.

    Examples with save_every=1000:
    - Steps 1, 10, 100: Save (power-of-10 phase)
    - Step 500: Don't save (not power of 10)
    - Steps 1000, 2000, 3000: Save (regular interval phase)

    Args:
        global_step: Current training step
        save_every: Regular save interval

    Returns:
        bool: True if we should save a checkpoint
    """
    if save_every is None:
        return False

    # Check if we're in the power-of-10 phase (before reaching save_every)
    if global_step < save_every:
        # Check if current step is a power of 10
        if global_step > 0:
            # Check if global_step is exactly 10^k for some integer k
            log_step = math.log10(global_step)
            if (
                abs(log_step - round(log_step)) < 1e-10
            ):  # Allow for floating point precision
                return True
    else:
        # Switch to regular save_every mode
        if global_step % save_every == 0:
            return True

    return False


def train_epoch(
    args,
    model,
    loader,
    optimizer,
    criterion,
    epoch,
    acc_threshold,
    global_step,
    saved_models_dir,
    device,
):
    """
    Performs a regular training epoch.
    Returns the average loss, accuracy, and updated global step.
    """
    correct_preds = 0
    total_preds = 0
    total_samples = 0
    total_loss = 0
    train_bar = tqdm(loader, desc=f"epoch {epoch+1}/{args.num_epochs}")
    total_graphs = 0
    completely_correct_pred = 0

    # Path length accuracy tracking
    num_nodes = None  # Will be set from first batch
    max_path_length = None
    path_length_accuracies = []
    path_length_counts = []
    step_count = 0

    model.model.train()  # Access the underlying model
    for batch_idx, (adj_matrix, connectivity_matrix) in enumerate(train_bar):
        adj_matrix = adj_matrix.float().to(device)
        connectivity_matrix = connectivity_matrix.float().to(device)

        optimizer.zero_grad()
        pred_connectivity = model.forward(adj_matrix)

        loss = criterion(pred_connectivity, connectivity_matrix)
        if args.l1_lambda > 0:
            l1_reg = torch.tensor(0.0).to(device)
            for param in model.model.parameters():
                l1_reg += torch.norm(param, 1)
            loss += args.l1_lambda * l1_reg
        loss.backward()
        optimizer.step()

        global_step += 1

        # Save checkpoint using power-of-10 logic up to save_every, then regular intervals
        if should_save_checkpoint(global_step, args.save_every):
            checkpoint_path = os.path.join(
                saved_models_dir, f"model_step_{global_step}.pt"
            )
            # Handle DataParallel wrapper properly
            model_state = (
                model.model.module.state_dict()
                if isinstance(model.model, nn.DataParallel)
                else model.model.state_dict()
            )
            checkpoint_data = {
                "model_state_dict": model_state,
                "global_step": global_step,
                "epoch": epoch,
            }
            torch.save(checkpoint_data, checkpoint_path)

            # Determine checkpoint type for logging
            checkpoint_type = (
                "power-of-10" if global_step < args.save_every else "regular"
            )
            print(
                f"\n--- Saved {checkpoint_type} checkpoint at step {global_step}: {checkpoint_path} ---"
            )

        total_loss += loss.item()
        pred = (pred_connectivity > acc_threshold).float()

        acc = ((pred == connectivity_matrix)).sum().item()
        correct_preds += acc
        total_samples += 1
        total_preds += connectivity_matrix.numel()
        current_acc = correct_preds / total_preds

        # Initialize path length tracking from first batch
        if num_nodes is None:
            num_nodes = adj_matrix.shape[1]
            max_path_length = num_nodes // 2 - 1

        # Compute path length accuracy for this batch
        batch_path_acc, batch_path_counts = compute_path_length_accuracy(
            pred_connectivity,
            connectivity_matrix,
            adj_matrix,
            max_path_length,
            acc_threshold,
        )
        path_length_accuracies.append(batch_path_acc.cpu())
        path_length_counts.append(batch_path_counts.cpu())
        step_count += 1

        # Calculate all_correct for current batch only
        batch_all_correct = 0
        batch_total_graphs = 0
        for pred_graph, ans_graph in zip(pred, connectivity_matrix):
            if (pred_graph == ans_graph).all():
                completely_correct_pred += 1
                batch_all_correct += 1
            total_graphs += 1
            batch_total_graphs += 1

        # Calculate batch-level all_correct accuracy
        batch_all_correct_acc = (
            batch_all_correct / batch_total_graphs if batch_total_graphs > 0 else 0
        )

        # Calculate epoch-level all_correct accuracy (cumulative)
        all_correct_acc = (
            completely_correct_pred / total_graphs if total_graphs > 0 else 0
        )

        train_bar.set_description(
            f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item():.4f}, Acc: {current_acc:.4f}, AllCorrectAcc: {all_correct_acc:.4f}, Step: {global_step}"
        )

        # Log path length accuracy every N steps
        if (
            step_count % args.path_length_accuracy_steps == 0
            and len(path_length_accuracies) > 0
        ):
            # Aggregate path length accuracies over the last N steps
            recent_accuracies = path_length_accuracies[
                -args.path_length_accuracy_steps :
            ]
            recent_counts = path_length_counts[-args.path_length_accuracy_steps :]

            # Compute weighted average accuracy for each path length
            aggregated_accuracies, total_counts = aggregate_path_length_accuracies(
                recent_accuracies, recent_counts
            )

            if aggregated_accuracies is not None:
                # Create and log wandb data
                path_length_log = create_path_length_wandb_log(
                    aggregated_accuracies,
                    total_counts,
                    "path_length_accuracy",
                    global_step=global_step,
                )
                wandb.log(path_length_log)

        # wandb logging for every batch
        wandb.log(
            {
                "batch/loss": loss.item(),
                "batch/accuracy": acc / connectivity_matrix.numel(),
                "batch/all_correct_accuracy": batch_all_correct_acc,
                "batch/global_step": global_step,
            }
        )

    # Final epoch-level path length accuracy aggregation
    if len(path_length_accuracies) > 0 and max_path_length is not None:
        # Aggregate all path length accuracies for the epoch
        epoch_accuracies, epoch_total_counts = aggregate_path_length_accuracies(
            path_length_accuracies, path_length_counts
        )

        if epoch_accuracies is not None:
            # Create and log epoch-level wandb data
            epoch_path_length_log = create_path_length_wandb_log(
                epoch_accuracies,
                epoch_total_counts,
                "epoch_path_length_accuracy",
                epoch=epoch,
            )
            wandb.log(epoch_path_length_log)

    avg_loss = total_loss / total_samples if total_samples > 0 else 0
    avg_accuracy = correct_preds / total_preds if total_preds > 0 else 0
    avg_all_correct = completely_correct_pred / total_graphs if total_graphs > 0 else 0
    return avg_loss, avg_accuracy, avg_all_correct, global_step


def create_train_dataset(args):
    """Create training dataset based on args"""
    # Convert single dataset to list for uniform handling
    datasets = args.dataset if isinstance(args.dataset, list) else [args.dataset]

    # Determine on_the_fly setting
    on_the_fly = args.on_the_fly and not args.no_on_the_fly

    if len(datasets) == 1:
        # Single dataset
        dataset_type = datasets[0]
        dataset_params = get_dataset_params(args, dataset_type)
        # Only restrict_diam for ER training data
        if dataset_type == "erdos_renyi":
            rd = args.restrict_diam
            # rd can be False, True, or int
            if rd is False:
                dataset_params["restrict_diam"] = None
            elif rd is True:
                dataset_params["restrict_diam"] = int(3**args.num_layers)
            else:
                # assume integer
                try:
                    dataset_params["restrict_diam"] = int(rd)
                except Exception:
                    dataset_params["restrict_diam"] = None
            print(
                f"Using restrict_diam = {dataset_params['restrict_diam']} for training data"
            )
        return create_dataset(
            dataset_type,
            num_samples=args.num_samples,
            num_nodes=args.num_nodes,
            on_the_fly=on_the_fly,
            **dataset_params,
        )
    else:
        # Mixed dataset with multiple types
        dataset_configs = []
        samples_per_dataset = args.num_samples // len(datasets)

        for dataset_type in datasets:
            dataset_configs.append(
                (
                    dataset_type,
                    {
                        "num_samples": samples_per_dataset,
                        "num_nodes": args.num_nodes,
                        "on_the_fly": on_the_fly,
                        **get_dataset_params(args, dataset_type),
                    },
                )
            )

        return create_mixed_dataset(dataset_configs, shuffle=True)


def create_eval_datasets(args):
    """Create evaluation datasets"""
    eval_datasets = {}

    # Determine which datasets to use for evaluation
    if args.eval_dataset is not None:
        # Use explicitly specified eval datasets
        eval_dataset_list = (
            args.eval_dataset
            if isinstance(args.eval_dataset, list)
            else [args.eval_dataset]
        )
        # Handle "all" option
        if "all" in eval_dataset_list:
            valid_datasets = [
                "erdos_renyi",
                "two_chains",
                "two_variable_chains",
                "two_trees",
                "two_stars",
                "sbm",
                # "erdos_renyi_two_graphs",
                # "erdos_renyi_medium",
                "erdos_renyi_hard",
                "tree_forest",
                "star_forest",
                "two_cliques",
                "caveman",
                "one_circle",
                "two_degree_3_chains",
            ]
            # If "all" is specified, use all datasets regardless of other specifications
            eval_dataset_list = valid_datasets
    else:
        # Use same as training datasets
        eval_dataset_list = (
            args.dataset if isinstance(args.dataset, list) else [args.dataset]
        )

    # Create separate evaluation datasets for each specified type
    # For evaluation, we typically want smaller datasets that can be pre-generated
    for dataset_type in eval_dataset_list:
        dataset_params = get_dataset_params(args, dataset_type)
        # Do NOT restrict_diam for eval data
        if dataset_type == "erdos_renyi" and "restrict_diam" in dataset_params:
            dataset_params["restrict_diam"] = None
        eval_datasets[f"eval_{dataset_type}"] = create_dataset(
            dataset_type,
            num_samples=1024,
            num_nodes=args.num_nodes,
            on_the_fly=False,  # Use pre-generated for evaluation for consistent results
            **dataset_params,
        )

    return eval_datasets


def parse_restrict_diam(val):
    """Parse --restrict_diam which can be passed as a boolean-like string or an integer.
    Returns:
        - False (bool) if disabled
        - True (bool) if user passed a truthy value but no integer (use default later)
        - int value when an integer is provided
    """
    if val is None:
        return False
    if isinstance(val, bool):
        return val
    s = str(val).strip()
    if s.lower() in ("true", "t", "1", "yes", "y"):
        return True
    if s.lower() in ("false", "f", "0", "no", "n"):
        return False
    try:
        return int(s)
    except Exception:
        # Fallback to False if parsing fails
        return False


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Train any transformer model on graph data with unified interface."
    )

    # Model arguments
    parser.add_argument(
        "--model_type",
        type=str,
        choices=["roberta", "looped_transformer", "disentangled_transformer"],
        default="roberta",
        help="Type of model to train",
    )

    # Dataset arguments
    parser.add_argument(
        "--dataset",
        type=str,
        nargs="+",  # Allow multiple datasets
        default=["two_chains"],
        help="Dataset type(s) to train on. Single string or list of strings. "
        "If multiple datasets provided, creates mixed dataset. "
        "Choices: erdos_renyi, two_chains, sbm, erdos_renyi_two_graphs, erdos_renyi_medium, erdos_renyi_hard, tree_forest, star_forest, two_cliques, one_circle, two_degree_3_chains",
    )
    parser.add_argument(
        "--eval_dataset",
        type=str,
        nargs="+",  # Allow multiple datasets
        default=None,
        help="Dataset type(s) to evaluate on. If not specified, uses same as --dataset. "
        "Single string or list of strings. If multiple datasets provided, creates mixed dataset. "
        "Use 'all' to evaluate on all available dataset types. "
        "Choices: erdos_renyi, two_chains, sbm, erdos_renyi_two_graphs, erdos_renyi_medium, erdos_renyi_hard, tree_forest, star_forest, two_cliques, one_circle, two_degree_3_chains, all",
    )

    parser.add_argument(
        "--num_samples", type=int, default=1_000_000, help="Number of graphs in dataset"
    )
    parser.add_argument(
        "--num_nodes", type=int, default=16, help="Number of nodes per graph"
    )
    parser.add_argument(
        "--on_the_fly",
        action="store_true",
        help="Generate graphs on-the-fly during training (memory efficient)",
    )
    parser.add_argument(
        "--no_on_the_fly",
        action="store_true",
        help="Pre-generate all graphs and store in memory (faster access)",
    )

    # Erdos-Renyi specific parameters
    parser.add_argument(
        "--fixed_p", type=float, default=None, help="Fixed edge creation probability"
    )
    parser.add_argument(
        "--sample_p",
        action="store_true",
        help="Sample p instead of using a fixed value",
    )
    parser.add_argument(
        "--p_range",
        type=float,
        nargs=2,
        default=(0.02, 0.2),
        help="Range for sampling p if sample_p is True",
    )

    # Two chains specific parameters
    parser.add_argument(
        "--k", type=int, default=None, help="Chain length for two_chains dataset"
    )
    parser.add_argument(
        "--add_isolated_nodes",
        action="store_true",
        help="Add isolated nodes to two_chains dataset (default: False, use all nodes for chains)",
    )

    # SBM specific parameters
    parser.add_argument(
        "--p_intra",
        type=float,
        default=0.4,
        help="Intra-community connection probability for SBM",
    )
    parser.add_argument(
        "--p_inter",
        type=float,
        default=0.05,
        help="Inter-community connection probability for SBM",
    )
    parser.add_argument(
        "--num_communities", type=int, default=4, help="Number of communities for SBM"
    )

    # Model-specific parameters
    parser.add_argument(
        "--num_attention_heads",
        type=int,
        default=1,
        help="Number of attention heads per layer",
    )
    parser.add_argument(
        "--num_layers",
        type=int,
        default=None,
        help="Number of transformer layers (default: log3(num_nodes))",
    )
    parser.add_argument(
        "--hidden_size",
        type=int,
        default=128,
        help="Hidden dimension of the transformer model",
    )

    # RoBERTa specific parameters
    parser.add_argument(
        "--roberta_type",
        type=str,
        default="relu",
        choices=["relu", "softmax", "tie_qk"],
        help="Type of RoBERTa model to use",
    )
    parser.add_argument(
        "--layer_norm_type",
        type=str,
        default="post",
        choices=["pre", "post"],
        help="Type of layer normalization to use in the model",
    )
    parser.add_argument(
        "--roberta_attention_only",
        action="store_true",
        help="Use attention-only RoBERTa models (removes feed-forward layers, only supports layer_norm_type='pre')",
    )

    # Looped Transformer specific parameters
    parser.add_argument(
        "--read_in_method",
        type=str,
        default="linear",
        choices=["linear", "zero_pad"],
        help="Read-in method for looped transformer",
    )
    parser.add_argument(
        "--tie_qk",
        action="store_true",
        help="Tie query and key weights",
    )

    # Disentangled Transformer specific parameters
    parser.add_argument(
        "--heads",
        type=int,
        nargs="+",
        default=None,
        help="Number of heads per layer for disentangled transformer (e.g., --heads 8 4 2)",
    )
    parser.add_argument(
        "--init_type",
        type=str,
        default="randn",
        choices=["randn", "zeros", "eye", "psd", "sym"],
        help="Initialization type for disentangled transformer",
    )
    parser.add_argument(
        "--readout_type",
        type=str,
        default="linear",
        choices=["linear", "sum", "last", "sum_clamp_to_prob"],
        help="Readout type for disentangled transformer",
    )
    parser.add_argument(
        "--disentangled_final_activation",
        type=str,
        default="tanh",
        choices=["tanh", "relu", "sigmoid", "clamp_to_prob", "phi_logits", "none"],
        help="Final activation function for disentangled transformer",
    )

    # Training arguments
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument(
        "--learning_rate", type=float, default=1e-2, help="Learning rate"
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        choices=["SGD", "Adam", "AdamW"],
        default="AdamW",
        help="Optimizer to use",
    )
    parser.add_argument(
        "--weight_decay", type=float, default=1e-4, help="Weight decay for optimizer"
    )
    parser.add_argument(
        "--l1_lambda", type=float, default=0.0, help="L1 regularization lambda"
    )
    parser.add_argument(
        "--criterion_type",
        type=str,
        choices=["bce", "mse", "focal"],
        default="bce",
        help="Type of loss criterion to use (bce: BCE loss, mse: MSELoss, focal: FocalLoss). Use --preds_are_probs to control probabilities vs logits for bce and focal.",
    )
    parser.add_argument(
        "--focal_alpha",
        type=float,
        default=0.25,
        help="Alpha parameter for focal loss (weighting factor for rare class)",
    )
    parser.add_argument(
        "--focal_gamma",
        type=float,
        default=2.0,
        help="Gamma parameter for focal loss (focusing parameter for hard examples)",
    )
    parser.add_argument(
        "--preds_are_probs",
        action="store_true",
        help="Whether model predictions are probabilities (default: False, assumes logits). Applies to both BCE and focal loss.",
    )
    parser.add_argument(
        "--num_epochs", type=int, default=100, help="Number of training epochs"
    )
    parser.add_argument(
        "--ckpt_path", type=str, default=None, help="Path to save checkpoints"
    )
    parser.add_argument(
        "--wandb_run_name", type=str, default=None, help="Name of the wandb run"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    parser.add_argument(
        "--save_every",
        type=int,
        default=None,
        help="Save model checkpoint every N steps (default: None, save only at epoch end). "
        "Uses power-of-10 checkpoints (10^k) up to N, then regular N-step intervals.",
    )
    parser.add_argument(
        "--eval_before_train",
        action="store_true",
        help="Evaluate model before training starts (useful to see initial performance)",
    )
    parser.add_argument(
        "--path_length_accuracy_steps",
        type=int,
        default=20,
        help="Number of steps to aggregate path length accuracy over (default: 20)",
    )
    parser.add_argument(
        "--restrict_diam",
        type=parse_restrict_diam,
        default=False,
        help=(
            "Restrict ER training graphs to diameter <= restrict_diam. "
            "Accepts an integer or boolean-like value. "
            "If True is passed, uses 3^{num_layers}. Default: False (no restriction)."
        ),
    )

    parser.add_argument(
        "--visualize_samples",
        action="store_true",
        help="Visualize and save sample graphs from training and evaluation datasets",
    )

    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume training from existing checkpoint if available (default: False, start fresh)",
    )

    parser.add_argument(
        "--steps_per_epoch_estimate",
        type=int,
        default=None,
        help="Estimated steps per epoch for resume calculation from step-based checkpoints (default: None, auto-calculate from num_samples/batch_size)",
    )

    # Multi-GPU training arguments
    parser.add_argument(
        "--multi_gpu",
        action="store_true",
        help="Enable multi-GPU training using DataParallel (default: False, single GPU)",
    )
    parser.add_argument(
        "--gpu_ids",
        type=int,
        nargs="+",
        default=None,
        help="List of GPU IDs to use for training (e.g., --gpu_ids 0 1 2). If not specified, uses all available GPUs when --multi_gpu is enabled.",
    )

    parser.add_argument(
        "--num_workers",
        type=int,
        default=None,
        help="Number of DataLoader workers (default: auto; lower this if hitting host RAM limits, especially with --multi_gpu).",
    )

    parser.add_argument(
        "--model_type_suffix",
        type=str,
        default="",
        help="Suffix to add to the model type folder name (e.g., '_experiment1'). Model folder becomes {model_type}_models{suffix}",
    )

    return parser.parse_args()


def evaluate_model(model, dataloader, criterion, acc_threshold=0.0, device=None):
    if dataloader is None:
        return None, None, None

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.model.eval()
    total_loss, correct_preds, total_preds, all_correct, graph_count = 0, 0, 0, 0, 0

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in dataloader:
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)
            pred_connectivity = model.forward(adj_matrix)
            loss = criterion(pred_connectivity, connectivity_matrix)
            total_loss += loss.item()
            pred = (pred_connectivity > acc_threshold).float()
            correct_preds += ((pred == connectivity_matrix)).sum().item()
            total_preds += connectivity_matrix.numel()
            for pred_graph, ans_graph in zip(pred, connectivity_matrix):
                if (pred_graph == ans_graph).all():
                    all_correct += 1
                graph_count += 1

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0
    avg_accuracy = correct_preds / total_preds if total_preds > 0 else 0
    avg_all_correct = all_correct / graph_count if graph_count > 0 else 0
    return avg_loss, avg_accuracy, avg_all_correct


def main():
    args = parse_arguments()

    # Set default probability values for different node counts
    p_dict = {
        4: 0.5,
        8: 0.2,
        9: 0.182,
        10: 0.165,
        11: 0.153,
        12: 0.141,
        13: 0.131,
        14: 0.122,
        15: 0.114,
        16: 0.108,
        17: 0.102,
        18: 0.097,
        19: 0.092,
        20: 0.088,
        21: 0.083,
        22: 0.0795,
        23: 0.0765,
        24: 0.073,
        25: 0.0705,
        26: 0.0675,
        27: 0.065,
        28: 0.0625,
        29: 0.061,
        30: 0.059,
        31: 0.0575,
        32: 0.056,
        48: 0.040,
        54: 0.036,
        56: 0.032,
        64: 0.0278,
        80: 0.0225,
        96: 0.0183,
        128: 0.0138,
    }

    # Convert single dataset to list for uniform handling
    datasets = args.dataset if isinstance(args.dataset, list) else [args.dataset]

    # Handle eval datasets
    if args.eval_dataset is not None:
        eval_datasets_list = (
            args.eval_dataset
            if isinstance(args.eval_dataset, list)
            else [args.eval_dataset]
        )
        # Handle "all" option
        if "all" in eval_datasets_list:
            valid_datasets = [
                "erdos_renyi",
                "two_chains",
                "two_variable_chains",
                "two_trees",
                "two_stars",
                "sbm",
                # "erdos_renyi_two_graphs",
                # "erdos_renyi_medium",
                "erdos_renyi_hard",
                "tree_forest",
                "star_forest",
                "two_cliques",
                "caveman",
                "one_circle",
                "two_degree_3_chains",
            ]
            # If "all" is specified, use all datasets regardless of other specifications
            eval_datasets_list = valid_datasets
    else:
        eval_datasets_list = datasets  # Use same as training

    # Validate dataset types
    valid_datasets = [
        "erdos_renyi",
        "two_chains",
        "two_variable_chains",
        "two_trees",
        "two_stars",
        "sbm",
        "erdos_renyi_two_graphs",
        "erdos_renyi_medium",
        "erdos_renyi_hard",
        "tree_forest",
        "star_forest",
        "two_cliques",
        "one_circle",
        "two_degree_3_chains",
    ]

    # Validate training datasets
    for dataset in datasets:
        if dataset not in valid_datasets:
            raise ValueError(
                f"Invalid training dataset: {dataset}. Valid options: {valid_datasets}"
            )

    # Validate evaluation datasets (check original args before expansion)
    if args.eval_dataset is not None:
        original_eval_list = (
            args.eval_dataset
            if isinstance(args.eval_dataset, list)
            else [args.eval_dataset]
        )
        for dataset in original_eval_list:
            if dataset not in valid_datasets and dataset != "all":
                raise ValueError(
                    f"Invalid evaluation dataset: {dataset}. Valid options: {valid_datasets + ['all']}"
                )

    # Validate attention_only constraints for RoBERTa
    if (
        args.model_type == "roberta"
        and args.roberta_attention_only
        and args.layer_norm_type != "pre"
    ):
        raise ValueError(
            "roberta_attention_only=True only supports layer_norm_type='pre'"
        )

    # Check if we need fixed_p for any datasets (training or evaluation)
    all_datasets = datasets + eval_datasets_list
    if args.fixed_p is None and any(
        dataset
        in [
            "erdos_renyi",
            "erdos_renyi_two_graphs",
            "erdos_renyi_medium",
            "erdos_renyi_hard",
        ]
        for dataset in all_datasets
    ):
        if args.num_nodes in p_dict:
            args.fixed_p = p_dict[args.num_nodes]
            print(
                f"Using fixed edge creation probability p = {args.fixed_p} for {args.num_nodes} nodes"
            )
        else:
            raise ValueError(
                f"Please provide a fixed edge creation probability for {args.num_nodes} nodes or use --sample_p."
            )

    # Set random seeds
    if args.seed is None or args.seed <= 0:
        args.seed = random.randint(0, 10000)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Using seed {args.seed} for reproducibility")

    # Setup device and multi-GPU configuration
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. This script requires GPU support.")

    device = torch.device("cuda")

    if args.multi_gpu:
        if args.gpu_ids is not None:
            # Use specified GPU IDs
            available_gpus = args.gpu_ids
            if not all(gpu_id < torch.cuda.device_count() for gpu_id in available_gpus):
                raise ValueError(
                    f"Invalid GPU IDs: {args.gpu_ids}. Available GPUs: 0-{torch.cuda.device_count()-1}"
                )
        else:
            # Use all available GPUs
            available_gpus = list(range(torch.cuda.device_count()))

        print(f"Multi-GPU training enabled on GPUs: {available_gpus}")
        print(f"Number of GPUs: {len(available_gpus)}")

        # Set the primary GPU
        torch.cuda.set_device(available_gpus[0])
        device = torch.device(f"cuda:{available_gpus[0]}")

        # Adjust batch size for multi-GPU training
        total_batch_size = args.batch_size
        per_gpu_batch_size = args.batch_size // len(available_gpus)
        if per_gpu_batch_size * len(available_gpus) != total_batch_size:
            per_gpu_batch_size += 1
            print(
                f"Adjusting batch size: total={total_batch_size}, per-GPU={per_gpu_batch_size}"
            )
        args.batch_size = per_gpu_batch_size
    else:
        print(f"Single GPU training on GPU 0")
        torch.cuda.set_device(0)
        available_gpus = [0]

    # Set default number of layers
    if args.num_layers is None:
        if args.model_type == "disentangled_transformer":
            args.num_layers = len(args.heads)
        else:
            args.num_layers = math.ceil(math.log(args.num_nodes, 3))

    # Set default heads for disentangled transformer
    if args.model_type == "disentangled_transformer" and args.heads is None:
        args.heads = [8, 4, 2]  # Default architecture

    # Set default k for two_chains
    if args.k is None:
        if getattr(args, "add_isolated_nodes", False):
            # With isolated nodes: reserve 2 nodes as isolated
            args.k = (args.num_nodes - 2) // 2
        else:
            # Without isolated nodes: use all nodes for chains
            args.k = args.num_nodes // 2

    # Setup checkpoint paths
    model_suffix = {
        "roberta": f"roberta_{args.roberta_type}_{args.layer_norm_type}{'_attn_only' if args.roberta_attention_only else ''}",
        "looped_transformer": f"looped_{args.read_in_method}_{args.layer_norm_type}",
        "disentangled_transformer": f"disentangled_{args.init_type}_{args.readout_type}",
    }[args.model_type]

    # Create dataset name for paths
    dataset_name = "_".join(datasets) if len(datasets) > 1 else datasets[0]

    # Add eval dataset to path if different from training
    if eval_datasets_list != datasets:
        # Handle "all" case specially
        if args.eval_dataset is not None and "all" in args.eval_dataset:
            eval_dataset_name = "all"
        else:
            eval_dataset_name = (
                "_".join(eval_datasets_list)
                if len(eval_datasets_list) > 1
                else eval_datasets_list[0]
            )
        dataset_name = f"{dataset_name}_eval_{eval_dataset_name}"

    if args.ckpt_path is None:
        restrict_diam_str = (
            f"_restrictdiam={args.restrict_diam}"
            if getattr(args, "restrict_diam", False)
            else ""
        )
        if args.model_type == "roberta":
            args.ckpt_path = f"{model_suffix}_{dataset_name}_{args.num_layers}layers_h{args.num_attention_heads}_d{args.hidden_size}_n={args.num_nodes}_seed={args.seed}{restrict_diam_str}"
        else:
            args.ckpt_path = f"{model_suffix}_{dataset_name}_{args.num_layers}layers_n={args.num_nodes}_seed={args.seed}{restrict_diam_str}"
    # Create outer folder for model type and models, e.g., disentangled-transformer_models
    model_type_folder = f"{args.model_type}_models{args.model_type_suffix}"

    # Add date suffix to checkpoint path for organization
    date_suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
    ckpt_path_with_date = f"{args.ckpt_path}_{date_suffix}"

    saved_models_dir = os.path.join(
        "checkpoints", model_type_folder, ckpt_path_with_date
    )
    os.makedirs(saved_models_dir, exist_ok=True)

    # Save configuration
    config_path = os.path.join(saved_models_dir, "config.json")
    with open(config_path, "w") as config_file:
        json.dump(vars(args), config_file, indent=2)
    print(f"Configuration saved to {config_path}")

    # Initialize wandb with project name as {model-type}-graphs
    wandb_run_name = (
        args.wandb_run_name if args.wandb_run_name is not None else args.ckpt_path
    )
    wandb_project_name = f"{args.model_type.replace('_', '-')}-graphs"
    wandb_resume = "allow" if getattr(args, "resume", False) else None
    wandb_init_kwargs = {
        "project": wandb_project_name,
        "name": wandb_run_name,
        "config": vars(args),
        "id": wandb_run_name,
    }
    if wandb_resume:
        wandb_init_kwargs["resume"] = wandb_resume
    wandb.init(**wandb_init_kwargs)

    # Create datasets
    dataset_desc = " + ".join(datasets) if len(datasets) > 1 else datasets[0]
    eval_dataset_desc = (
        " + ".join(eval_datasets_list)
        if len(eval_datasets_list) > 1
        else eval_datasets_list[0]
    )

    print(f"Creating training dataset: {dataset_desc} with {args.num_nodes} nodes...")
    if eval_datasets_list != datasets:
        print(
            f"Creating evaluation dataset: {eval_dataset_desc} with {args.num_nodes} nodes..."
        )

    train_dataset = create_train_dataset(args)
    eval_datasets = create_eval_datasets(args)

    # Auto-calculate steps_per_epoch_estimate if not provided
    if args.steps_per_epoch_estimate is None:
        # Calculate based on dataset size and batch size
        try:
            dataset_size = len(train_dataset)
            args.steps_per_epoch_estimate = math.ceil(dataset_size / args.batch_size)
            print(
                f"Auto-calculated steps_per_epoch_estimate: {args.steps_per_epoch_estimate} (dataset_size={dataset_size}, batch_size={args.batch_size})"
            )
        except (TypeError, AttributeError):
            # Fallback for on-the-fly datasets or datasets without len()
            # Use num_samples as approximation since that's what was used to create the dataset
            args.steps_per_epoch_estimate = math.ceil(
                args.num_samples / args.batch_size
            )
            print(
                f"Auto-calculated steps_per_epoch_estimate: {args.steps_per_epoch_estimate} (using num_samples={args.num_samples}, batch_size={args.batch_size})"
            )
            print(
                "  Note: Dataset length not available (likely on-the-fly), using num_samples as approximation"
            )
    else:
        print(
            f"Using provided steps_per_epoch_estimate: {args.steps_per_epoch_estimate}"
        )

    # Visualize dataset samples
    print("\n" + "=" * 60)
    print("VISUALIZING DATASET SAMPLES")
    print("=" * 60)

    # Visualize training dataset samples
    if args.visualize_samples:
        print("Visualizing training dataset samples...")
        if isinstance(train_dataset, torch.utils.data.Dataset):
            # For single or mixed datasets, we need to handle them differently
            try:
                # Try to get the dataset types from the training dataset
                if hasattr(train_dataset, "dataset_types"):
                    # Mixed dataset case
                    visualize_dataset_samples(
                        {"mixed_train": train_dataset}, "train", saved_models_dir
                    )
                else:
                    # Single dataset case - use the dataset name from args
                    dataset_name = (
                        "_".join(datasets) if len(datasets) > 1 else datasets[0]
                    )
                    visualize_dataset_samples(
                        {dataset_name: train_dataset}, "train", saved_models_dir
                    )
            except Exception as e:
                print(f"Warning: Could not visualize training dataset: {e}")

        # Visualize evaluation dataset samples
        try:
            visualize_dataset_samples(eval_datasets, "eval", saved_models_dir)
        except Exception as e:
            print(f"Warning: Could not visualize evaluation datasets: {e}")

        print("=" * 60)
        print("DATASET VISUALIZATION COMPLETE")
        print("=" * 60 + "\n")

    # Create data loaders
    auto_workers = min(16, os.cpu_count() - 4)  # at most 16 workers to avoid RAM issues
    num_workers = auto_workers if args.num_workers is None else args.num_workers
    if num_workers < 0:
        num_workers = 0
    print(
        f"Creating DataLoader for training dataset with {num_workers} workers (auto={auto_workers}, override={'yes' if args.num_workers is not None else 'no'})..."
    )
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,  # Enable for faster CPU→GPU transfers
    )
    # Use a smaller evaluation batch size to avoid empty eval splits when train batch_size is large.
    # Also ensure drop_last=False so we don't drop the final smaller batch.
    eval_batch_size = min(args.batch_size, 1024)
    eval_dataloaders = {
        name: DataLoader(
            dataset,
            batch_size=eval_batch_size,
            shuffle=False,
            pin_memory=True,  # Enable for faster CPU→GPU transfers
            drop_last=False,
        )
        for name, dataset in eval_datasets.items()
    }

    # Create model
    print(f"Creating {args.model_type} model...")
    model_params = {
        "num_nodes": args.num_nodes,
        "hidden_size": args.hidden_size,
        "num_layers": args.num_layers,
        "num_attention_heads": args.num_attention_heads,
    }

    if args.model_type == "roberta":
        model_params.update(
            {
                "roberta_type": args.roberta_type,
                "layer_norm_type": args.layer_norm_type,
                "attention_only": args.roberta_attention_only,
            }
        )
    elif args.model_type == "looped_transformer":
        model_params.update(
            {
                "read_in_method": args.read_in_method,
                "layer_norm_type": args.layer_norm_type,
                "tie_qk": args.tie_qk,
            }
        )
    elif args.model_type == "disentangled_transformer":
        # Handle "none" string to None conversion for final_activation
        final_activation = (
            args.disentangled_final_activation
            if args.disentangled_final_activation != "none"
            else None
        )
        model_params = {
            "num_nodes": args.num_nodes,
            "heads": args.heads,
            "init_type": args.init_type,
            "readout_type": args.readout_type,
            "final_activation": final_activation,
        }

    model = create_model(args.model_type, **model_params)
    print(f"Model created: {type(model.model).__name__}")

    # Load checkpoint if resume flag is set and checkpoints are available
    offset = 0
    latest_checkpoint = None
    global_step = 0

    if args.resume:
        # When resuming, look for the latest checkpoint directory by date
        resume_checkpoint_dir = find_latest_checkpoint_dir(
            args.ckpt_path, model_type_folder
        )

        if os.path.exists(resume_checkpoint_dir):
            print("Resume flag is set. Searching for existing checkpoints...")
            print(f"Checkpoint directory: {resume_checkpoint_dir}")

            # Update saved_models_dir to point to the resume directory for loading
            resume_saved_models_dir = resume_checkpoint_dir

            # Look for epoch-based checkpoints first
            epoch_files = glob.glob(
                os.path.join(resume_saved_models_dir, "model_[0-9][0-9][0-9].pt")
            )
            step_files = glob.glob(
                os.path.join(resume_saved_models_dir, "model_step_*.pt")
            )

            print(f"Found {len(epoch_files)} epoch-based checkpoints")
            print(f"Found {len(step_files)} step-based checkpoints")

            if epoch_files:
                epochs = sorted(
                    [
                        int(re.search(r"model_(\d+)\.pt", path).group(1))
                        for path in epoch_files
                    ]
                )
                latest_epoch = epochs[-1]
                latest_checkpoint = os.path.join(
                    resume_saved_models_dir, f"model_{latest_epoch:03d}.pt"
                )
                offset = latest_epoch + 1
                print(f"Found epoch-based checkpoint: {latest_checkpoint}")
                print(f"Will resume training from epoch {latest_epoch + 1}")
            elif step_files:
                steps = sorted(
                    [
                        int(re.search(r"model_step_(\d+)\.pt", path).group(1))
                        for path in step_files
                    ]
                )
                latest_step = steps[-1]
                latest_checkpoint = os.path.join(
                    resume_saved_models_dir, f"model_step_{latest_step}.pt"
                )
                print(f"Found step-based checkpoint: {latest_checkpoint}")
                print(f"Will resume training from step {latest_step}")

                # Try to load checkpoint to get epoch information
                try:
                    temp_checkpoint = torch.load(latest_checkpoint)
                    if isinstance(temp_checkpoint, dict) and "epoch" in temp_checkpoint:
                        offset = temp_checkpoint["epoch"] + 1
                        print(
                            f"Found epoch info in checkpoint: resuming from epoch {offset}"
                        )
                    else:
                        # Estimate epoch from step count if no epoch info available
                        # Use the calculated estimate for steps per epoch
                        estimated_epoch = latest_step // args.steps_per_epoch_estimate
                        offset = estimated_epoch
                        print(f"No epoch info in checkpoint, estimated epoch: {offset}")
                        print(
                            f"  Calculation: step {latest_step} ÷ {args.steps_per_epoch_estimate} steps/epoch = epoch {estimated_epoch}"
                        )
                        # Compare with auto-calculated value if possible
                        try:
                            auto_calculated = math.ceil(
                                len(train_dataset) / args.batch_size
                            )
                            if args.steps_per_epoch_estimate != auto_calculated:
                                print(
                                    f"  Note: Using provided/calculated estimate. Direct dataset calculation would be {auto_calculated}"
                                )
                        except (TypeError, AttributeError):
                            # Dataset doesn't have len(), likely on-the-fly
                            auto_calculated = math.ceil(
                                args.num_samples / args.batch_size
                            )
                            if args.steps_per_epoch_estimate != auto_calculated:
                                print(
                                    f"  Note: Using provided estimate. num_samples-based calculation would be {auto_calculated}"
                                )
                except Exception as e:
                    print(
                        f"Warning: Could not load checkpoint to check epoch info: {e}"
                    )
                    offset = 0

        if not latest_checkpoint:
            print("No valid checkpoints found. Starting fresh training.")
    elif args.resume:
        print(
            "Resume flag is set, but no saved models directory found. Starting fresh training."
        )
    else:
        print(
            "Starting fresh training (resume=False)."
        )  # Setup optimizer and scheduler
    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(
            model.model.parameters(),
            lr=args.learning_rate,
            momentum=0.9,
            nesterov=True,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "Adam":
        optimizer = torch.optim.Adam(
            model.model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
        )
    else:  # AdamW
        optimizer = torch.optim.AdamW(
            model.model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
        )

    scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs, eta_min=1e-8)

    # Now load checkpoint state if we found one
    if latest_checkpoint:
        print(f"Loading checkpoint: {latest_checkpoint}")
        checkpoint_data = torch.load(latest_checkpoint)

        # Handle both old format (just state_dict) and new format (dict with multiple states)
        if isinstance(checkpoint_data, dict) and "model_state_dict" in checkpoint_data:
            model.model.load_state_dict(checkpoint_data["model_state_dict"])
            print("✓ Loaded model state from checkpoint")

            # Also load optimizer and scheduler if available
            if "optimizer_state_dict" in checkpoint_data:
                optimizer.load_state_dict(checkpoint_data["optimizer_state_dict"])
                print("✓ Loaded optimizer state from checkpoint")
            else:
                print("⚠ No optimizer state in checkpoint - using fresh optimizer")

            if "scheduler_state_dict" in checkpoint_data:
                scheduler.load_state_dict(checkpoint_data["scheduler_state_dict"])
                print("✓ Loaded scheduler state from checkpoint")
            else:
                print(
                    "⚠ No scheduler state in checkpoint - adjusting scheduler for current epoch"
                )
                # Adjust scheduler to current epoch if no scheduler state is saved
                if offset > 0:
                    for _ in range(offset):
                        scheduler.step()
                    print(f"✓ Advanced scheduler to epoch {offset}")

            if "global_step" in checkpoint_data:
                global_step = checkpoint_data["global_step"]
                print(f"✓ Resuming from global step: {global_step}")
            else:
                print("⚠ No global step in checkpoint - starting from 0")

            # Verify epoch consistency
            if "epoch" in checkpoint_data:
                checkpoint_epoch = checkpoint_data["epoch"]
                if offset != checkpoint_epoch + 1:
                    print(
                        f"⚠ Epoch mismatch: calculated offset={offset}, checkpoint epoch={checkpoint_epoch}"
                    )
                    print(
                        f"  Using checkpoint epoch information: setting offset to {checkpoint_epoch + 1}"
                    )
                    offset = checkpoint_epoch + 1
        else:
            # Old format - just model state dict
            model.model.load_state_dict(checkpoint_data)
            print("✓ Loaded model state from checkpoint (old format)")
            print("⚠ Using fresh optimizer and scheduler states")
            global_step = 0  # Reset global step for old checkpoints

        print(f"\n=== RESUME SUMMARY ===")
        print(f"Checkpoint: {latest_checkpoint}")
        print(f"Starting epoch: {offset}")
        print(f"Global step: {global_step}")
        print(f"Total epochs to train: {args.num_epochs}")
        print(f"Final epoch will be: {args.num_epochs + offset - 1}")
        print(f"======================\n")

    # Set up loss criterion
    if args.criterion_type == "bce":
        if not args.preds_are_probs:
            criterion = torch.nn.BCEWithLogitsLoss()
        else:
            criterion = torch.nn.BCELoss()
    elif args.criterion_type == "mse":
        criterion = torch.nn.MSELoss()
    elif args.criterion_type == "focal":
        criterion = FocalLoss(
            alpha=args.focal_alpha,
            gamma=args.focal_gamma,
            preds_are_probs=args.preds_are_probs,
        )
    else:
        raise ValueError(f"Invalid criterion type: {args.criterion_type}")

    acc_threshold = 0.0  # Default threshold for accuracy calculation
    if args.preds_are_probs:
        # If predictions are probabilities, use 0.5 as threshold
        acc_threshold = 0.5

    # Move model to GPU and setup multi-GPU if enabled
    model.model.to(device)

    if args.multi_gpu and len(available_gpus) > 1:
        print(f"Wrapping model with DataParallel for GPUs: {available_gpus}")
        model.model = nn.DataParallel(model.model, device_ids=available_gpus)
        print(f"Model is now using {len(available_gpus)} GPUs")
    else:
        print(f"Model moved to device: {device}")

    # Save initialization checkpoint (before any training or loading of pretrained weights)
    if not args.resume:  # Only save initialization if not resuming from checkpoint
        init_checkpoint_path = os.path.join(saved_models_dir, "model_init.pt")
        # Handle DataParallel wrapper properly
        model_state = (
            model.model.module.state_dict()
            if isinstance(model.model, nn.DataParallel)
            else model.model.state_dict()
        )
        checkpoint_data = {
            "model_state_dict": model_state,
            "epoch": -1,  # Use -1 to indicate initialization
            "global_step": 0,
        }
        torch.save(checkpoint_data, init_checkpoint_path)
        print(f"Saved initialization checkpoint: {init_checkpoint_path}")

    # Evaluate before training if requested
    if args.eval_before_train:
        print(">>> [EVAL BEFORE TRAIN] Evaluating initial model performance...")
        model.model.eval()  # Set to eval mode

        eval_results = {}
        for name, dataloader in eval_dataloaders.items():
            results = evaluate_model_with_perm_metrics(
                model,
                dataloader,
                criterion,
                acc_threshold,
                device,
            )
            eval_results[name] = results

        # Print initial results
        eval_summary_lines = []
        for name, results in eval_results.items():
            line = f"{name}: accuracy {results['accuracy']:.4f} | all correct {results['all_correct']:.4f} | loss {results['loss']:.4f}"
            line += f" | perm_frob_cosine_similarity {results.get('perm_frob_cosine_similarity', 0):.4f}"
            line += f" | perm_l1_distance {results.get('perm_l1_distance', 0):.4f}"
            line += f" | spectral_dist {results.get('spectral_distance', 0):.4f}"
            line += f" | degree_corr {results.get('degree_correlation', 0):.4f}"
            eval_summary_lines.append(line)

        eval_summary = "\n".join(eval_summary_lines)
        print(f">>> [INIT] Initial metrics: \n{eval_summary}")

        # Log initial results to wandb
        initial_wandb_log = {
            "epoch/epoch": 0,  # Epoch 0 for initial evaluation
            "epoch/global_step": 0,
        }

        for name, results in eval_results.items():
            initial_wandb_log.update(
                {
                    f"{name}/loss": results["loss"],
                    f"{name}/acc": results["accuracy"],
                    f"{name}/all_corr": results["all_correct"],
                    f"{name}/perm_frob_cosine_similarity": results.get(
                        "perm_frob_cosine_similarity", 0
                    ),
                    f"{name}/perm_l1_distance": results.get("perm_l1_distance", 0),
                    f"{name}/perm_variance": results.get("perm_variance", 0),
                    f"{name}/spectral_distance": results.get("spectral_distance", 0),
                    f"{name}/degree_correlation": results.get("degree_correlation", 0),
                    f"{name}/triangle_count_error": results.get(
                        "triangle_count_error", 0
                    ),
                    f"{name}/graph_similarity": results.get("graph_similarity", 0),
                }
            )

        wandb.log(initial_wandb_log)

        # Log initial results to CSV (include eval dataset losses)
        history_path = os.path.join(saved_models_dir, "history.csv")
        with open(history_path, "w") as f:  # Create/overwrite file
            # Build header: epoch, train metrics, then for each eval dataset include loss and acc,
            # followed by permutation/spectral/degree metrics per dataset
            header_parts = ["epoch", "train_loss", "train_acc", "train_all_correct"]
            for name in eval_results.keys():
                header_parts.append(f"{name}_loss")
                header_parts.append(f"{name}_acc")

            # Permutation/frob similarity fields
            for name in eval_results.keys():
                header_parts.append(f"{name}_perm_frob_cosine_similarity")

            # Spectral distance and degree correlation fields
            for name in eval_results.keys():
                header_parts.append(f"{name}_spectral_dist")
            for name in eval_results.keys():
                header_parts.append(f"{name}_degree_corr")

            header = ",".join(header_parts)
            f.write(header + "\n")

            # Write initial row (epoch 0, no training loss/acc yet but include eval losses/accs)
            row_parts = ["0", "0.0", "0.0", "0.0"]
            for results in eval_results.values():
                row_parts.append(f"{results.get('loss', 0.0):.4f}")
                row_parts.append(f"{results['accuracy']:.4f}")

            for results in eval_results.values():
                row_parts.append(f"{results.get('perm_frob_cosine_similarity', 0):.4f}")

            for results in eval_results.values():
                row_parts.append(f"{results.get('spectral_distance', 0):.4f}")
            for results in eval_results.values():
                row_parts.append(f"{results.get('degree_correlation', 0):.4f}")

            f.write(",".join(row_parts) + "\n")

    # Training loop
    for epoch in range(offset, args.num_epochs + offset):
        avg_loss, avg_accuracy, avg_all_correct, global_step = train_epoch(
            args,
            model,
            train_dataloader,
            optimizer,
            criterion,
            epoch,
            acc_threshold,
            global_step,
            saved_models_dir,
            device,
        )

        print(
            f">>> [TRAIN] epoch {epoch}/{args.num_epochs + offset - 1} (step {global_step}) - loss: {avg_loss:.4f}, accuracy: {avg_accuracy:.4f}, all_correct: {avg_all_correct:.4f}"
        )

        # Evaluation
        print(
            f">>> [EVAL]  epoch {epoch}/{args.num_epochs + offset - 1} evaluation... ",
            end="",
        )

        eval_results = {}
        for name, dataloader in eval_dataloaders.items():
            results = evaluate_model_with_perm_metrics(
                model, dataloader, criterion, acc_threshold, device
            )
            eval_results[name] = results

        # Print results
        eval_summary_lines = []
        for name, results in eval_results.items():
            line = f"{name}: acc {results['accuracy']:.4f} | all_correct {results['all_correct']:.4f}"
            line += (
                f" | perm_frob_cos {results.get('perm_frob_cosine_similarity', 0):.3f}"
            )
            line += f" | perm_l1 {results.get('perm_l1_distance', 0):.3f}"
            line += f" | spec_dist {results.get('spectral_distance', 0):.3f}"
            line += f" | deg_corr {results.get('degree_correlation', 0):.3f}"
            eval_summary_lines.append(line)

        eval_summary = "\n".join(eval_summary_lines)
        print(f"metrics: {eval_summary}")

        # Logging to CSV
        history_path = os.path.join(saved_models_dir, "history.csv")

        # Check if this is the first epoch and file doesn't exist yet
        file_needs_header = (
            not os.path.exists(history_path) or os.path.getsize(history_path) == 0
        )

        with open(history_path, "a") as f:
            if file_needs_header:
                # Build header matching initial write: epoch, train metrics, then per-eval loss/acc and metrics
                header_parts = ["epoch", "train_loss", "train_acc", "train_all_correct"]
                for name in eval_results.keys():
                    header_parts.append(f"{name}_loss")
                    header_parts.append(f"{name}_acc")
                for name in eval_results.keys():
                    header_parts.append(f"{name}_perm_frob_cosine_similarity")
                for name in eval_results.keys():
                    header_parts.append(f"{name}_spectral_dist")
                for name in eval_results.keys():
                    header_parts.append(f"{name}_degree_corr")
                f.write(",".join(header_parts) + "\n")

            # Build row: epoch and train metrics
            row_parts = [
                f"{epoch+1}",
                f"{avg_loss:.4f}",
                f"{avg_accuracy:.4f}",
                f"{avg_all_correct:.4f}",
            ]
            # Append per-eval dataset loss and accuracy
            for results in eval_results.values():
                row_parts.append(f"{results.get('loss', 0.0):.4f}")
                row_parts.append(f"{results['accuracy']:.4f}")

            # Append permutation/spectral/degree metrics
            for results in eval_results.values():
                row_parts.append(f"{results.get('perm_frob_cosine_similarity', 0):.4f}")
            for results in eval_results.values():
                row_parts.append(f"{results.get('spectral_distance', 0):.4f}")
            for results in eval_results.values():
                row_parts.append(f"{results.get('degree_correlation', 0):.4f}")

            f.write(",".join(row_parts) + "\n")

        # Wandb logging
        wandb_log = {
            "epoch/epoch": epoch + 1,
            "epoch/train_loss": avg_loss,
            "epoch/train_accuracy": avg_accuracy,
            "epoch/train_all_correct": avg_all_correct,
            "epoch/learning_rate": scheduler.get_last_lr()[0],
            "epoch/global_step": global_step,
        }

        for name, results in eval_results.items():
            wandb_log.update(
                {
                    f"{name}/epoch": epoch + 1,
                    f"{name}/loss": results["loss"],
                    f"{name}/acc": results["accuracy"],
                    f"{name}/all_corr": results["all_correct"],
                    f"{name}/perm_frob_cosine_similarity": results.get(
                        "perm_frob_cosine_similarity", 0
                    ),
                    f"{name}/perm_variance": results.get("perm_variance", 0),
                    f"{name}/spectral_distance": results.get("spectral_distance", 0),
                    f"{name}/degree_correlation": results.get("degree_correlation", 0),
                    f"{name}/triangle_count_error": results.get(
                        "triangle_count_error", 0
                    ),
                    f"{name}/graph_similarity": results.get("graph_similarity", 0),
                }
            )

        wandb.log(wandb_log)

        # Save epoch-based checkpoint (if save_every is None, save at every epoch)
        if args.save_every is None:
            epoch_checkpoint_path = os.path.join(
                saved_models_dir, f"model_{epoch:03d}.pt"
            )
            # Handle DataParallel wrapper properly
            model_state = (
                model.model.module.state_dict()
                if isinstance(model.model, nn.DataParallel)
                else model.model.state_dict()
            )
            checkpoint_data = {
                "model_state_dict": model_state,
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
            }
            torch.save(checkpoint_data, epoch_checkpoint_path)
            print(f"Saved epoch checkpoint: {epoch_checkpoint_path}")

        # Step the learning rate scheduler at the end of each epoch
        scheduler.step()

    wandb.finish()
    # save final model
    final_model_path = os.path.join(saved_models_dir, "final_model.pt")
    # Handle DataParallel wrapper properly
    model_state = (
        model.model.module.state_dict()
        if isinstance(model.model, nn.DataParallel)
        else model.model.state_dict()
    )
    torch.save(model_state, final_model_path)
    print(f"Final model saved to {final_model_path}")


if __name__ == "__main__":
    main()
