#!/usr/bin/env python3
"""
Unified Evaluation Script for Transformer-Graph Models

This script provides a unified interface for evaluating any model type
(RoBERTa, Looped Transformer, Disentangled Transformer) on various graph datasets.

Usage:
    python eval.py --ckpt_path saved_models/roberta_relu_post_two_chains_5layers_n=16_seed=42
    python eval.py --ckpt_path saved_models/looped_transformer_linear_pre_erdos_renyi_3layers_n=32_seed=42 --num_vis_examples 20
"""

import argparse
import os
import json
import torch
from torch.utils.data import DataLoader
import matplotlib

matplotlib.use("Agg")  # Use non-interactive backend to save memory
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import networkx as nx
from matplotlib.colors import TwoSlopeNorm
from tqdm import tqdm
import glob
import re
import sys
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

# Configure matplotlib for better memory management
plt.rcParams["figure.max_open_warning"] = 50  # Increase warning threshold
plt.rcParams["axes.grid"] = False  # Disable grid by default to save memory

# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from models import create_model
from data import create_dataset
from utils import compute_permutation_equivariant_metrics


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Evaluate a transformer model on graph data with unified interface."
    )

    # Evaluation arguments
    parser.add_argument(
        "--ckpt_path",
        type=str,
        required=True,
        help="Path to load model checkpoints and config",
    )
    parser.add_argument(
        "--num_vis_examples",
        type=int,
        default=10,
        help="Number of examples to visualize hidden states",
    )
    parser.add_argument(
        "--num_eval_examples",
        type=int,
        default=1000,
        help="Number of examples to evaluate (also used in --iclr_only mode for finding worst graphs)",
    )
    parser.add_argument(
        "--config_path",
        type=str,
        default=None,
        help="Path to config file, if different from checkpoint path",
    )
    parser.add_argument(
        "--ckpt_id",
        type=int,
        default=None,
        help="Specific checkpoint ID to load (e.g., 50 for model_050.pt)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        nargs="+",
        default=None,
        help="Dataset type(s) to evaluate on. If not specified, uses config datasets. "
        "Choices: erdos_renyi, two_chains, two_trees, two_stars, sbm, erdos_renyi_two_graphs, erdos_renyi_hard, two_variable_chains",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size for evaluation",
    )
    parser.add_argument(
        "--iclr_only",
        action="store_true",
        help="Special mode for ICLR: load model, analyze training dynamics and/or find worst performing graphs with equivariant consistency analysis",
    )
    parser.add_argument(
        "--iclr_plot_mode",
        type=str,
        choices=["all", "behavior", "graphs"],
        default="all",
        help="Control what to plot in ICLR analysis: 'all' (both model behavior and graph analysis), 'behavior' (only model_behavior.pdf, requires --dataset), 'graphs' (only graph analysis using Erdos-Renyi graphs)",
    )
    parser.add_argument(
        "--eval_edge_prob",
        type=float,
        default=None,
        help="Edge probability for Erdos-Renyi type datasets during evaluation. If not specified, uses config fixed_p value.",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Enable verbose output for equivariance metrics computation (shows detailed progress per graph)",
    )
    parser.add_argument(
        "--recompute",
        action="store_true",
        help="Force recomputation of training dynamics even if training_dynamics.json already exists",
    )

    parser.add_argument(
        "--right_panel_mode",
        type=str,
        choices=["per-layer", "last-layer-preds"],
        default="per-layer",
        help="Right panel style: 'per-layer' shows per-layer value cosine lines; 'last-layer-preds' shows mean last-layer predictions (sigmoid) across the dataset",
    )

    parser.add_argument(
        "--log_scale",
        action="store_true",
        help="Use log scale on the x-axis for both left and right panels in model_behavior.pdf. If enabled and step 0 is present, it will be plotted at a small positive value.",
    )

    parser.add_argument(
        "--consistency_xlim",
        type=float,
        default=0.5,
    )
    return parser.parse_args()


def load_config(ckpt_path, config_path=None):
    """Load configuration from checkpoint path"""
    config_file = config_path or os.path.join(ckpt_path, "config.json")
    if os.path.exists(config_file):
        with open(config_file, "r") as f:
            config = json.load(f)
        print(f"Configuration loaded from {config_file}")
        return config
    else:
        raise FileNotFoundError(f"Config file not found at {config_file}")


def find_latest_checkpoint(ckpt_path, ckpt_id=None):
    """Find the latest checkpoint in the given path"""
    if ckpt_id is not None:
        checkpoint_file = os.path.join(ckpt_path, f"model_{ckpt_id:03d}.pt")
        if os.path.exists(checkpoint_file):
            return checkpoint_file
        else:
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_file}")

    # Look for epoch-based checkpoints first
    epoch_files = glob.glob(os.path.join(ckpt_path, "model_[0-9][0-9][0-9].pt"))
    step_files = glob.glob(os.path.join(ckpt_path, "model_step_*.pt"))

    if epoch_files:
        epochs = sorted(
            [
                int(re.search(r"model_(\d+)\.pt", os.path.basename(path)).group(1))
                for path in epoch_files
            ]
        )
        latest_epoch = epochs[-1]
        return os.path.join(ckpt_path, f"model_{latest_epoch:03d}.pt")
    elif step_files:
        steps = sorted(
            [
                int(re.search(r"model_step_(\d+)\.pt", os.path.basename(path)).group(1))
                for path in step_files
            ]
        )
        latest_step = steps[-1]
        return os.path.join(ckpt_path, f"model_step_{latest_step}.pt")
    else:
        # Try model_last.pt as fallback
        last_file = os.path.join(ckpt_path, "final_model.pt")
        if os.path.exists(last_file):
            return last_file
        else:
            raise FileNotFoundError(f"No checkpoint found in {ckpt_path}")


def get_dataset_params(config, dataset_type, eval_edge_prob=None):
    """Get dataset-specific parameters based on dataset type and config"""
    params = {}

    if dataset_type == "erdos_renyi":
        # Use eval_edge_prob if provided, otherwise fall back to config
        edge_prob = (
            eval_edge_prob if eval_edge_prob is not None else config.get("fixed_p")
        )
        if eval_edge_prob is not None:
            print(f"Using eval_edge_prob={edge_prob} for {dataset_type} dataset")
        params.update(
            {
                "p": edge_prob,
                "sample_p": config.get("sample_p", False),
                "p_range": config.get("p_range", (0.02, 0.2)),
            }
        )

    elif dataset_type == "two_chains":
        params["k"] = config.get("k", config["num_nodes"] // 2)

    elif dataset_type == "two_trees":
        params["k"] = config.get("k", config["num_nodes"] // 2)

    elif dataset_type == "two_stars":
        params["k"] = config.get("k", config["num_nodes"] // 2)

    elif dataset_type == "sbm":
        params.update(
            {
                "p_intra": config.get("p_intra", 0.4),
                "p_inter": config.get("p_inter", 0.05),
                "num_communities": config.get("num_communities", 4),
            }
        )

    elif dataset_type == "erdos_renyi_two_graphs":
        # Use eval_edge_prob if provided, otherwise fall back to config
        edge_prob = (
            eval_edge_prob if eval_edge_prob is not None else config.get("fixed_p")
        )
        if eval_edge_prob is not None:
            print(f"Using eval_edge_prob={edge_prob} for {dataset_type} dataset")
        params.update(
            {
                "p": edge_prob,
                "connect_prob": config.get("connect_prob", 0.5),
                "min_component_size_frac": config.get("min_component_size_frac", 0.3),
            }
        )

    elif dataset_type == "erdos_renyi_hard":
        # Use eval_edge_prob if provided, otherwise fall back to config
        edge_prob = (
            eval_edge_prob if eval_edge_prob is not None else config.get("fixed_p")
        )
        if eval_edge_prob is not None:
            print(f"Using eval_edge_prob={edge_prob} for {dataset_type} dataset")
        params.update(
            {
                "p": edge_prob,
                "min_component_size_frac": config.get("min_component_size_frac", 0.3),
            }
        )

    return params


def get_eval_device(model_type: str) -> torch.device:
    """Select device based on model type.

    - disentangled_transformer: force CPU
    - roberta: prefer CUDA (fallback to CPU if unavailable)
    - others: prefer CUDA if available, else CPU
    """
    if model_type == "disentangled_transformer":
        return torch.device("cpu")
    if model_type == "roberta":
        if torch.cuda.is_available():
            return torch.device("cuda")
        print("Warning: CUDA not available; falling back to CPU for RoBERTa.")
        return torch.device("cpu")
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def create_eval_datasets(config, dataset_types, num_samples, eval_edge_prob=None):
    """Create evaluation datasets based on specified types"""
    datasets = {}

    for dataset_type in dataset_types:
        params = get_dataset_params(config, dataset_type, eval_edge_prob)
        if dataset_type == "two_chains":
            params["k"] = config["num_nodes"] // 2
        datasets[dataset_type] = create_dataset(
            dataset_type,
            num_samples=num_samples,
            num_nodes=config["num_nodes"],
            **params,
        )

    return datasets


def compute_clustering_permutation(adj_matrix):
    """
    Compute permutation based on clustering of adjacency matrix.

    Args:
        adj_matrix: numpy array of shape (num_nodes, num_nodes)

    Returns:
        permutation: numpy array of indices to reorder nodes
    """
    # Convert to sparse matrix for connected components
    sparse_adj = csr_matrix(adj_matrix)

    # Find connected components
    n_components, labels = connected_components(sparse_adj, directed=False)

    # Group nodes by component and sort components by size (largest first)
    components = {}
    for node_idx, component_id in enumerate(labels):
        if component_id not in components:
            components[component_id] = []
        components[component_id].append(node_idx)

    # Sort components by size (largest first)
    sorted_components = sorted(components.values(), key=len, reverse=True)

    # Create permutation by concatenating sorted components
    permutation = []
    for component in sorted_components:
        # Sort nodes within each component by their degree (highest degree first)
        component_degrees = [(node, np.sum(adj_matrix[node])) for node in component]
        component_degrees.sort(key=lambda x: x[1], reverse=True)
        permutation.extend([node for node, _ in component_degrees])

    return np.array(permutation)


def plot_hidden_states(
    adj_matrix, hidden_states, example_idx, save_dir, config, model=None, device=None
):
    """Plot hidden states analysis: H, H^T H (or H H^T), and readout(H)"""
    num_nodes = config["num_nodes"]
    model_type = config["model_type"]

    # Get the adjacency matrix as numpy array
    if isinstance(adj_matrix, torch.Tensor):
        adj_np = adj_matrix[0].cpu().numpy()
    else:
        adj_np = adj_matrix[0]

    # Compute clustering-based permutation
    permutation = compute_clustering_permutation(adj_np)

    # Create permutation matrix
    perm_matrix = np.zeros((num_nodes, num_nodes))
    perm_matrix[np.arange(num_nodes), permutation] = 1

    # Apply permutation to adjacency matrix: P @ A @ P^T
    adj_permuted = perm_matrix @ adj_np @ perm_matrix.T

    # Convert to torch tensor for model input
    adj_matrix_permuted = torch.tensor(adj_permuted, dtype=torch.float32).unsqueeze(0)
    if device is not None:
        adj_matrix_permuted = adj_matrix_permuted.to(device)

    # Get hidden states for permuted input
    with torch.no_grad():
        hidden_states_permuted = model.get_hidden_states(adj_matrix_permuted)

    # Debug: print permutation info
    print(f"Permutation for example {example_idx}: {permutation}")
    print(f"Original adjacency matrix shape: {adj_np.shape}")

    # Plot hidden states analysis for each layer
    for layer_idx, states_perm in enumerate(hidden_states_permuted):
        # Get hidden states
        if isinstance(states_perm[0], torch.Tensor):
            H = states_perm[0].cpu().numpy()  # Shape: (num_nodes, hidden_dim)
        else:
            H = states_perm[0]

        # No thresholding for any model type - show raw hidden states

        # Different layout for disentangled vs other models
        if model_type == "disentangled_transformer":
            # Only 2 subplots for disentangled (no readout)
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        else:
            # 3 subplots for other models (H, H^T H, readout)
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # 1. Plot H (hidden states)
        if model_type == "disentangled_transformer":
            im1 = axes[0].imshow(H, cmap="viridis")
        else:
            im1 = axes[0].imshow(H, cmap="RdBu_r", norm=TwoSlopeNorm(0))
        axes[0].set_title(f"Layer {layer_idx} - H")
        axes[0].set_xlabel("Hidden Dimension")
        axes[0].set_ylabel("Nodes")
        plt.colorbar(im1, ax=axes[0])

        # 2. Plot H^T H or H H^T (choose the one that gives n x n)
        if H.shape[0] == num_nodes:  # H is n x d
            # Use H @ H^T to get n x n
            H_gram = H @ H.T  # Shape: (num_nodes, num_nodes)
            gram_title = "H H^T"
        else:  # H is d x n (less likely but handle it)
            # Use H^T @ H to get n x n
            H_gram = H.T @ H  # Shape: (num_nodes, num_nodes)
            gram_title = "H^T H"

        if model_type == "disentangled_transformer":
            im2 = axes[1].imshow(H_gram, cmap="viridis")
        else:
            im2 = axes[1].imshow(H_gram, cmap="RdBu_r", norm=TwoSlopeNorm(0))
        axes[1].set_title(f"Layer {layer_idx} - {gram_title}")
        axes[1].set_xlabel("Nodes")
        axes[1].set_ylabel("Nodes")
        plt.colorbar(im2, ax=axes[1])

        # 3. Plot readout(H) - only for non-disentangled models
        if model_type != "disentangled_transformer":
            # Apply the model's readout layer to get the final output
            with torch.no_grad():
                # Convert H back to tensor and apply readout
                H_tensor = torch.tensor(H, dtype=torch.float32).unsqueeze(0)
                if device is not None:
                    H_tensor = H_tensor.to(device)

                if model_type in ["roberta", "looped_transformer"]:
                    # For RoBERTa and looped transformer, apply the readout layer
                    if hasattr(model.model, "_read_out"):
                        readout_output = (
                            model.model._read_out(H_tensor).cpu().numpy()[0]
                        )
                    elif hasattr(model.model, "read_out"):
                        readout_output = model.model.read_out(H_tensor).cpu().numpy()[0]
                    else:
                        # Fallback: use the full forward pass
                        readout_output = (
                            model.forward(adj_matrix_permuted).cpu().numpy()[0]
                        )
                    final_output = readout_output
                else:
                    # Fallback: use the full forward pass
                    final_output = model.forward(adj_matrix_permuted).cpu().numpy()[0]

            im3 = axes[2].imshow(final_output, cmap="RdBu_r", norm=TwoSlopeNorm(0))
            axes[2].set_title(f"Layer {layer_idx} - readout(H)")
            axes[2].set_xlabel("Nodes")
            axes[2].set_ylabel("Nodes")
            plt.colorbar(im3, ax=axes[2])

        # Save the figure
        save_path = os.path.join(
            save_dir, f"example_{example_idx}/layer_{layer_idx}.png"
        )
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        plt.savefig(save_path, bbox_inches="tight", pad_inches=0, dpi=300)
        plt.close(fig)  # Close the figure to free memory

    # Plot adjacency matrix powers and predictions
    num_pows = min(num_nodes, 10)  # Cap at num_nodes
    fig = plt.figure(figsize=(num_pows * 4, 8))

    # First row: permuted adjacency matrix powers
    for k in range(num_pows):
        mat_pow = np.linalg.matrix_power(adj_permuted, k) > 0
        ax = plt.subplot(2, num_pows, k + 1)
        ax.imshow(mat_pow, cmap="viridis")
        ax.set_title(f"A^{k}", fontsize=10)
        ax.axis("off")

    # Second row: model predictions on permuted input
    if model is not None:
        with torch.no_grad():
            pred_connectivity_perm = torch.sigmoid(
                model.forward(adj_matrix_permuted)
            ).cpu()
            thresholds = np.linspace(0.05, 0.95, num_pows)
            for i, thresh in enumerate(thresholds):
                ax = plt.subplot(2, num_pows, num_pows + i + 1)
                pred_matrix_perm = (pred_connectivity_perm[0] > thresh).float().numpy()
                ax.imshow(pred_matrix_perm, cmap="viridis")
                ax.set_title(f"Pred (t={thresh:.1f})", fontsize=10)
                ax.axis("off")

    plt.savefig(
        os.path.join(save_dir, f"example_{example_idx}/matrix_analysis.png"),
        bbox_inches="tight",
        pad_inches=0,
        dpi=300,
    )
    plt.close(fig)  # Close the figure to free memory


def find_graph_with_both_errors(
    model, dataloader, criterion, threshold=0.5, device=None
):
    """Find a graph that has both false positives and false negatives"""
    model.model.eval()
    best_candidate = None
    best_score = -1  # We want to maximize this score

    with torch.no_grad():
        for batch_idx, (adj_matrix, connectivity_matrix) in enumerate(
            tqdm(dataloader, desc="Finding graph with both FP and FN")
        ):
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            # Get predictions for each graph in the batch
            pred_connectivity = model.forward(adj_matrix)
            pred_probs = torch.sigmoid(pred_connectivity)

            # Compute loss for each graph individually
            for i in range(adj_matrix.shape[0]):
                single_adj = adj_matrix[i : i + 1]
                single_conn = connectivity_matrix[i : i + 1]
                single_pred = pred_connectivity[i : i + 1]
                single_probs = pred_probs[i : i + 1]

                # Convert predictions to binary
                pred_binary = (single_probs > threshold).float()

                # Calculate false positives and false negatives
                # FP: predicted 1, actual 0
                false_positives = ((pred_binary == 1) & (single_conn == 0)).sum().item()
                # FN: predicted 0, actual 1
                false_negatives = ((pred_binary == 0) & (single_conn == 1)).sum().item()

                # Only consider graphs that have both types of errors
                if false_positives > 0 and false_negatives > 0:
                    # Score based on balance of errors and total loss
                    loss = criterion(single_pred, single_conn).item()
                    # Balanced error score: prefer graphs with both types of errors
                    error_balance = min(false_positives, false_negatives) / max(
                        false_positives, false_negatives
                    )
                    total_errors = false_positives + false_negatives

                    # Combined score: balance * total_errors * loss
                    score = error_balance * total_errors * loss

                    if score > best_score:
                        best_score = score
                        graph_data = (single_adj.cpu(), single_conn.cpu())
                        pred_data = single_pred.cpu()
                        error_info = {
                            "false_positives": false_positives,
                            "false_negatives": false_negatives,
                            "loss": loss,
                            "error_balance": error_balance,
                            "total_errors": total_errors,
                        }
                        best_candidate = (graph_data, pred_data, error_info)

    return best_candidate


def find_worst_performing_graphs(model, dataloader, criterion, top_k=5, device=None):
    """Find the top-k graphs with the worst prediction performance"""
    model.model.eval()
    worst_graphs = []  # List of (loss, graph_data, pred) tuples

    with torch.no_grad():
        for batch_idx, (adj_matrix, connectivity_matrix) in enumerate(
            tqdm(dataloader, desc=f"Finding top {top_k} worst graphs")
        ):
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            # Get predictions for each graph in the batch
            pred_connectivity = model.forward(adj_matrix)

            # Compute loss for each graph individually
            for i in range(adj_matrix.shape[0]):
                single_adj = adj_matrix[i : i + 1]
                single_conn = connectivity_matrix[i : i + 1]
                single_pred = pred_connectivity[i : i + 1]

                loss = criterion(single_pred, single_conn).item()

                graph_data = (single_adj.cpu(), single_conn.cpu())
                pred_data = single_pred.cpu()

                # Add to worst_graphs list
                worst_graphs.append((loss, graph_data, pred_data))

                # Keep only top k worst graphs
                worst_graphs.sort(
                    key=lambda x: x[0], reverse=True
                )  # Sort by loss descending
                if len(worst_graphs) > top_k:
                    worst_graphs = worst_graphs[:top_k]

    return worst_graphs


def create_component_permutation(adj_matrix):
    """
    Create permutation that orders nodes by component size (largest first).

    Args:
        adj_matrix: numpy array of shape (num_nodes, num_nodes)

    Returns:
        permutation: numpy array of node indices in component-ordered arrangement
    """
    from scipy.sparse import csr_matrix
    from scipy.sparse.csgraph import connected_components

    # Convert to sparse matrix for connected components
    sparse_adj = csr_matrix(adj_matrix)

    # Find connected components
    n_components, labels = connected_components(sparse_adj, directed=False)

    # Group nodes by component
    components = {}
    for node_idx, component_id in enumerate(labels):
        if component_id not in components:
            components[component_id] = []
        components[component_id].append(node_idx)

    # Sort components by size (largest first)
    sorted_components = sorted(components.values(), key=len, reverse=True)

    # Create permutation by concatenating sorted components
    permutation = []
    for component in sorted_components:
        # Sort nodes within each component by their degree (highest degree first)
        component_degrees = [(node, np.sum(adj_matrix[node])) for node in component]
        component_degrees.sort(key=lambda x: x[1], reverse=True)
        permutation.extend([node for node, _ in component_degrees])

    return np.array(permutation)


def visualize_iclr_graph(adj_matrix, connectivity_matrix, pred_matrix, save_path):
    """
    Create ICLR visualization: 2x2 plot with NetworkX graph, adjacency matrix,
    connectivity matrix, and prediction matrix, all ordered by components.
    """
    # Convert to numpy if needed
    if isinstance(adj_matrix, torch.Tensor):
        adj_np = adj_matrix.squeeze().cpu().numpy()
    else:
        adj_np = adj_matrix.squeeze()

    if isinstance(connectivity_matrix, torch.Tensor):
        conn_np = connectivity_matrix.squeeze().cpu().numpy()
    else:
        conn_np = connectivity_matrix.squeeze()

    if isinstance(pred_matrix, torch.Tensor):
        pred_np = torch.sigmoid(pred_matrix).squeeze().cpu().numpy()
    else:
        pred_np = pred_matrix.squeeze()

    # Convert predictions to binary (threshold at 0.5)
    pred_np = (pred_np > 0.5).astype(float)

    # Create component-based permutation
    permutation = create_component_permutation(adj_np)

    # Apply permutation to all matrices
    perm_matrix = np.zeros((len(permutation), len(permutation)))
    perm_matrix[np.arange(len(permutation)), permutation] = 1

    adj_permuted = perm_matrix @ adj_np @ perm_matrix.T
    conn_permuted = perm_matrix @ conn_np @ perm_matrix.T
    pred_permuted = perm_matrix @ pred_np @ perm_matrix.T

    # Create the 2x2 visualization with modern styling
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    fig.patch.set_facecolor("#ffffff")

    # Define Anthropic-inspired modern color palette
    colors = {
        "primary": "#FF6B35",  # Vibrant orange
        "secondary": "#004E89",  # Deep blue
        "accent": "#00C896",  # Teal green
        "neutral": "#6B7280",  # Modern gray
        "background": "#F8FAFC",  # Light background
        "dark": "#1F2937",  # Dark text
        "prediction": "#da7756",  # Custom color for model predictions
    }

    # Top left: NetworkX visualization (remove self-loops)
    ax_graph = axes[0, 0]
    ax_graph.set_facecolor(colors["background"])

    # Remove self-loops for graph visualization
    adj_no_self_loops = adj_permuted.copy()
    np.fill_diagonal(adj_no_self_loops, 0)

    G = nx.from_numpy_array(adj_no_self_loops)

    # Debug information
    print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

    # Use uniform node colors for all nodes
    node_colors = [colors["primary"]] * len(permutation)
    # Use a more contrasting dark color for edges
    edge_colors = ["#1F2937"] * G.number_of_edges()  # Dark gray for better visibility

    # Create component-aware layout for better visualization
    # First, identify components in the permuted graph
    from scipy.sparse import csr_matrix
    from scipy.sparse.csgraph import connected_components

    sparse_adj = csr_matrix(adj_no_self_loops)
    n_components, component_labels = connected_components(sparse_adj, directed=False)

    # Group nodes by component, separating real components from isolated nodes
    components = {}
    isolated_nodes = []

    for node_idx, component_id in enumerate(component_labels):
        if component_id not in components:
            components[component_id] = []
        components[component_id].append(node_idx)

    # Separate real components (size > 1) from isolated nodes (size = 1)
    real_components = []
    for component in components.values():
        if len(component) > 1:
            real_components.append(component)
        else:
            isolated_nodes.extend(component)

    # Sort real components by size (largest first)
    sorted_real_components = sorted(real_components, key=len, reverse=True)

    # Create optimized component-aware layout with minimal white space
    pos = {}

    # Calculate component centers using intelligent packing
    import math

    def calculate_component_size(component):
        """Estimate the space needed for a component (only for real components)"""
        if len(component) <= 3:
            return 0.8  # Larger space for tiny components
        elif len(component) <= 6:
            return 1.2  # More space for medium components
        else:
            return 1.6  # Much larger space for big components

    def pack_components_efficiently(components):
        """Pack components in a compact arrangement"""
        if len(components) == 0:
            return []
        elif len(components) == 1:
            return [(0, 0)]
        elif len(components) == 2:
            # Place side by side with appropriate gap for larger components
            size1 = calculate_component_size(components[0])
            size2 = calculate_component_size(components[1])
            gap = 0.4  # Larger gap to accommodate bigger components
            return [(-size1 / 2 - gap / 2, 0), (size2 / 2 + gap / 2, 0)]
        elif len(components) <= 4:
            # Arrange in a 2x2 grid with more spacing for larger components
            positions = []
            sizes = [calculate_component_size(comp) for comp in components]
            max_size = max(sizes) if sizes else 0.8
            spacing = (
                max_size + 0.8
            )  # Increased spacing for better component separation
            grid_positions = [
                (-spacing / 2, spacing / 2),
                (spacing / 2, spacing / 2),
                (-spacing / 2, -spacing / 2),
                (spacing / 2, -spacing / 2),
            ]
            return grid_positions[: len(components)]
        else:
            # For many components, use a circle with better spacing for larger components
            positions = []
            max_size = max(calculate_component_size(comp) for comp in components)
            # Calculate radius to fit larger components without overlap
            radius = max(
                0.8, max_size * len(components) / (2 * math.pi) * 2.0
            )  # Increased multiplier for better spacing
            for i in range(len(components)):
                angle = 2 * math.pi * i / len(components)
                x = radius * math.cos(angle)
                y = radius * math.sin(angle)
                positions.append((x, y))
            return positions

    # Layout real components first
    component_centers = pack_components_efficiently(sorted_real_components)

    # Layout each real component separately with adaptive scaling
    for i, component in enumerate(sorted_real_components):
        # Multiple nodes - use spring layout within component
        subgraph = G.subgraph(component)
        # Use appropriate k value based on component size
        k_val = max(0.3, 0.8 / math.sqrt(len(component)))
        component_pos = nx.spring_layout(subgraph, seed=42, k=k_val, iterations=50)

        # Calculate adaptive scaling based on component size
        component_size = calculate_component_size(component)
        scale = component_size * 1.0  # Use full allocated space for larger components

        # Translate to component center
        center_x, center_y = component_centers[i]
        for node in component:
            x, y = component_pos[node]
            # Use adaptive scaling that uses allocated space efficiently
            pos[node] = (center_x + x * scale, center_y + y * scale)

    # Handle isolated nodes separately - place them around the periphery
    if isolated_nodes:
        # Find the bounding box of the existing components
        if pos:
            existing_x = [p[0] for p in pos.values()]
            existing_y = [p[1] for p in pos.values()]
            max_radius = (
                max(math.sqrt(x**2 + y**2) for x, y in pos.values()) if pos else 0
            )
        else:
            max_radius = 0

        # Place isolated nodes in a circle around the components
        isolation_radius = max_radius + 0.5  # Some distance from components
        for i, node in enumerate(isolated_nodes):
            if len(isolated_nodes) == 1:
                # Single isolated node - place it to the right
                pos[node] = (isolation_radius, 0)
            else:
                # Multiple isolated nodes - arrange in a circle
                angle = 2 * math.pi * i / len(isolated_nodes)
                x = isolation_radius * math.cos(angle)
                y = isolation_radius * math.sin(angle)
                pos[node] = (x, y)

    # Draw the graph with enhanced edge visibility
    nx.draw(
        G,
        pos,
        ax=ax_graph,
        node_color=node_colors,
        edge_color=edge_colors,
        node_size=300,  # Slightly smaller nodes to show edges better
        width=3.0,  # Thicker edges for better visibility
        with_labels=True,
        font_size=14,  # Increased font size for better readability
        font_weight="bold",
        font_color="white",
    )

    # Set consistent bounds and aspect ratio to match heatmaps
    if pos:
        all_x = [p[0] for p in pos.values()]
        all_y = [p[1] for p in pos.values()]
        margin = 0.3  # Small margin around the graph

        # Calculate bounds to maintain square aspect ratio like heatmaps
        x_range = max(all_x) - min(all_x)
        y_range = max(all_y) - min(all_y)
        max_range = max(x_range, y_range) + 2 * margin

        # Center the view and make it square
        center_x = (max(all_x) + min(all_x)) / 2
        center_y = (max(all_y) + min(all_y)) / 2

        ax_graph.set_xlim(center_x - max_range / 2, center_x + max_range / 2)
        ax_graph.set_ylim(center_y - max_range / 2, center_y + max_range / 2)
    else:
        # Default square bounds if no nodes
        ax_graph.set_xlim(-1, 1)
        ax_graph.set_ylim(-1, 1)

    # Remove axis labels and ticks for consistency with heatmaps
    ax_graph.set_xticks([])
    ax_graph.set_yticks([])
    ax_graph.set_aspect("equal")  # Make aspect ratio same as heatmaps

    # Set title with same parameters as other subplots for alignment
    ax_graph.set_title(
        "Graph Structure",
        fontsize=18,  # Increased font size for better readability
        fontweight="bold",
        color=colors["dark"],
        pad=20,
    )

    # Top right: Adjacency matrix with modern colormap
    ax_adj = axes[0, 1]
    # Create custom colormap from white to primary color
    adj_colors = ["#ffffff", colors["secondary"]]
    adj_cmap = LinearSegmentedColormap.from_list("adj", adj_colors)

    # Get number of nodes for tick labels
    num_nodes = adj_permuted.shape[0]

    im_adj = ax_adj.imshow(adj_permuted, cmap=adj_cmap, vmin=0, vmax=1)
    ax_adj.set_title(
        "Adjacency Matrix (Self-Loop Augmented)",
        fontsize=18,  # Increased font size for better readability
        fontweight="bold",
        color=colors["dark"],
        pad=20,
    )
    # Add node indices as ticks but remove axis labels (every 5 nodes)
    tick_indices = list(range(0, num_nodes, 5))
    ax_adj.set_xticks(tick_indices)
    ax_adj.set_yticks(tick_indices)
    ax_adj.set_xticklabels(tick_indices)
    ax_adj.set_yticklabels(tick_indices)
    ax_adj.tick_params(
        colors=colors["neutral"], labelsize=16
    )  # Increased font size for better readability
    # Make tick labels bold
    for label in ax_adj.get_xticklabels() + ax_adj.get_yticklabels():
        label.set_fontweight("bold")

    # Bottom left: Connectivity matrix (target) with modern colormap
    ax_conn = axes[1, 0]
    conn_colors = ["#ffffff", colors["accent"]]
    conn_cmap = LinearSegmentedColormap.from_list("conn", conn_colors)

    im_conn = ax_conn.imshow(conn_permuted, cmap=conn_cmap, vmin=0, vmax=1)
    ax_conn.set_title(
        "Target Connectivity Matrix",
        fontsize=18,  # Increased font size for better readability
        fontweight="bold",
        color=colors["dark"],
        pad=20,
    )
    # Add node indices as ticks but remove axis labels (every 5 nodes)
    ax_conn.set_xticks(tick_indices)
    ax_conn.set_yticks(tick_indices)
    ax_conn.set_xticklabels(tick_indices)
    ax_conn.set_yticklabels(tick_indices)
    ax_conn.tick_params(
        colors=colors["neutral"], labelsize=16
    )  # Increased font size for better readability
    # Make tick labels bold
    for label in ax_conn.get_xticklabels() + ax_conn.get_yticklabels():
        label.set_fontweight("bold")

    # Bottom right: Model prediction with modern colormap
    ax_pred = axes[1, 1]
    pred_colors = ["#ffffff", colors["prediction"]]
    pred_cmap = LinearSegmentedColormap.from_list("pred", pred_colors)

    im_pred = ax_pred.imshow(pred_permuted, cmap=pred_cmap, vmin=0, vmax=1)
    ax_pred.set_title(
        "Model Prediction",
        fontsize=18,  # Increased font size for better readability
        fontweight="bold",
        color=colors["dark"],
        pad=20,
    )
    # Add node indices as ticks but remove axis labels (every 5 nodes)
    ax_pred.set_xticks(tick_indices)
    ax_pred.set_yticks(tick_indices)
    ax_pred.set_xticklabels(tick_indices)
    ax_pred.set_yticklabels(tick_indices)
    ax_pred.tick_params(
        colors=colors["neutral"], labelsize=16
    )  # Increased font size for better readability
    # Make tick labels bold
    for label in ax_pred.get_xticklabels() + ax_pred.get_yticklabels():
        label.set_fontweight("bold")

    # Apply modern styling to all subplots
    for ax in axes.flat:
        ax.grid(False)
        ax.spines["top"].set_color(colors["neutral"])
        ax.spines["right"].set_color(colors["neutral"])
        ax.spines["bottom"].set_color(colors["neutral"])
        ax.spines["left"].set_color(colors["neutral"])
        ax.spines["top"].set_linewidth(0.5)
        ax.spines["right"].set_linewidth(0.5)
        ax.spines["bottom"].set_linewidth(0.5)
        ax.spines["left"].set_linewidth(0.5)

    plt.tight_layout(pad=3.0)
    plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white")
    plt.close(fig)

    return permutation


def analyze_training_dynamics(config, args):
    """Analyze equivariance metrics across all step checkpoints and create dynamics plot"""
    print("Analyzing training dynamics across all checkpoints...")

    # Check if training_dynamics.json already exists and --recompute is not specified
    dynamics_path = os.path.join(args.ckpt_path, "training_dynamics.json")
    diameter_dynamics_path = os.path.join(args.ckpt_path, "diameter_dynamics.json")

    if os.path.exists(dynamics_path) and not args.recompute:
        print(f"Found existing training_dynamics.json. Loading from {dynamics_path}")

        with open(dynamics_path, "r") as f:
            dynamics_data = json.load(f)

        # Try to load diameter dynamics separately
        diameter_dynamics_loaded = False
        if os.path.exists(diameter_dynamics_path):
            print(
                f"Found existing diameter_dynamics.json. Loading from {diameter_dynamics_path}"
            )
            with open(diameter_dynamics_path, "r") as f:
                diameter_dynamics_data = json.load(f)
            # Merge diameter data back into dynamics_data for compatibility
            dynamics_data["path_length_accuracies"] = diameter_dynamics_data.get(
                "path_length_accuracies", []
            )
            dynamics_data["max_diameter"] = diameter_dynamics_data.get(
                "max_diameter", None
            )
            dynamics_data["num_layers"] = diameter_dynamics_data.get("num_layers", None)
            diameter_dynamics_loaded = True

        # Check if path length data exists (either in main file or separate file)
        if not diameter_dynamics_loaded and (
            "path_length_accuracies" not in dynamics_data
            or not dynamics_data.get("path_length_accuracies")
        ):
            print(
                "Path length accuracy data not found in existing dynamics. Recomputing..."
            )
        else:
            print("Use --recompute to force recomputation.")

            # Ensure num_layers is available for capacity line plotting
            if "num_layers" not in dynamics_data or dynamics_data["num_layers"] is None:
                dynamics_data["num_layers"] = config["num_layers"]

            # Still create the plot with the loaded data
            dataset_type = (
                args.dataset[0] if isinstance(args.dataset, list) else args.dataset
            )
            create_modern_dynamics_plot(
                dynamics_data,
                args.ckpt_path,
                dataset_type,
                args.right_panel_mode,
                args.log_scale,
            )
            create_path_length_heatmap(dynamics_data, args.ckpt_path, dataset_type)

            return dynamics_data

    # Continue with the original computation if recompute is True or file doesn't exist
    if os.path.exists(dynamics_path) and args.recompute:
        print(
            f"Existing training_dynamics.json found but --recompute specified. Recomputing..."
        )

    # Find all step checkpoints
    step_files = glob.glob(os.path.join(args.ckpt_path, "model_step_*.pt"))
    if not step_files:
        print("No step checkpoints found. Looking for epoch checkpoints...")
        epoch_files = glob.glob(
            os.path.join(args.ckpt_path, "model_[0-9][0-9][0-9].pt")
        )
        if epoch_files:
            step_files = epoch_files
            print(f"Found {len(step_files)} epoch checkpoints")
        else:
            print("No checkpoints found for dynamics analysis")
            return
    else:
        print(f"Found {len(step_files)} step checkpoints")

    # Extract step numbers and sort
    checkpoint_data = []
    for file_path in step_files:
        filename = os.path.basename(file_path)
        if "step_" in filename:
            step_num = int(re.search(r"model_step_(\d+)\.pt", filename).group(1))
            checkpoint_data.append((step_num, file_path))
        else:
            # Epoch checkpoints
            epoch_num = int(re.search(r"model_(\d+)\.pt", filename).group(1))
            checkpoint_data.append((epoch_num, file_path))

    checkpoint_data.sort(key=lambda x: x[0])

    # Check for final model and estimate its step number
    final_model_path = os.path.join(args.ckpt_path, "final_model.pt")
    if os.path.exists(final_model_path):
        print("Found final_model.pt")
        if len(checkpoint_data) >= 2:
            # Estimate final step as: last_step + (last_step - second_to_last_step)
            last_step = checkpoint_data[-1][0]
            second_to_last_step = checkpoint_data[-2][0]
            step_increment = last_step - second_to_last_step
            estimated_final_step = last_step + step_increment
            checkpoint_data.append((estimated_final_step, final_model_path))
            print(
                f"Estimated final model step: {estimated_final_step} (based on increment of {step_increment})"
            )
        elif len(checkpoint_data) == 1:
            # If only one checkpoint, assume the same step increment (or double the step)
            last_step = checkpoint_data[-1][0]
            estimated_final_step = last_step * 2  # Simple heuristic
            checkpoint_data.append((estimated_final_step, final_model_path))
            print(
                f"Estimated final model step: {estimated_final_step} (doubled from single checkpoint)"
            )
        else:
            # If no other checkpoints, assign a default large step number
            estimated_final_step = 10000
            checkpoint_data.append((estimated_final_step, final_model_path))
            print(f"Estimated final model step: {estimated_final_step} (default)")
    else:
        print("No final_model.pt found")

    # Apply ICLR step limiting after final model estimation
    if hasattr(args, "iclr_only") and args.iclr_only:
        original_count = len(checkpoint_data)
        checkpoint_data = [
            (step, path) for step, path in checkpoint_data if step <= 1e6
        ]
        filtered_count = len(checkpoint_data)
        if filtered_count < original_count:
            print(
                f"ICLR mode: Filtered {original_count - filtered_count} checkpoints beyond 1e6 steps (including final model if applicable)"
            )

    print(
        f"Will analyze {len(checkpoint_data)} checkpoints (including final model if present)"
    )

    # Prepare model parameters
    model_params = {
        "num_nodes": config["num_nodes"],
        "hidden_size": config["hidden_size"],
        "num_layers": config["num_layers"],
        "num_attention_heads": config["num_attention_heads"],
    }

    if config["model_type"] == "roberta":
        model_params.update(
            {
                "roberta_type": config["roberta_type"],
                "layer_norm_type": config["layer_norm_type"],
                "attention_only": config.get("roberta_attention_only", False),
            }
        )
    elif config["model_type"] == "looped_transformer":
        model_params.update(
            {
                "read_in_method": config["read_in_method"],
                "layer_norm_type": config["layer_norm_type"],
                "tie_qk": config.get("tie_qk", False),
            }
        )
    elif config["model_type"] == "disentangled_transformer":
        model_params = {
            "num_nodes": config["num_nodes"],
            "heads": config["heads"],
            "init_type": config["init_type"],
            "readout_type": config["readout_type"],
        }

    # Prepare dataset for evaluation
    if args.dataset is None:
        raise ValueError(
            "Training dynamics analysis requires --dataset to be specified"
        )

    dataset_type = args.dataset[0] if isinstance(args.dataset, list) else args.dataset
    params = get_dataset_params(config, dataset_type, args.eval_edge_prob)
    if dataset_type == "two_chains":
        params["k"] = config["num_nodes"] // 2  # Ensure full chain length
    # Create all datasets once before checkpoint analysis
    print("Creating evaluation datasets...")

    # Main dataset for dynamics analysis (smaller)
    dynamics_examples = min(50, args.num_eval_examples)
    dataset = create_dataset(
        dataset_type,
        num_samples=dynamics_examples,
        num_nodes=config["num_nodes"],
        **params,
    )
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)

    # Larger dataset for accuracy computation
    eval_dataset = create_dataset(
        dataset_type,
        num_samples=args.num_eval_examples,  # Use full evaluation examples for accuracy metrics
        num_nodes=config["num_nodes"],
        **params,
    )
    eval_dataloader = DataLoader(
        eval_dataset, batch_size=32, shuffle=False, drop_last=False
    )

    # Create multi-dataset accuracy evaluation datasets once
    accuracy_datasets = {}
    dataset_configs = {
        "erdos_renyi_train": {
            "type": "erdos_renyi",
            "params": get_dataset_params(config, "erdos_renyi", None),
        },
        "two_chains": {
            "type": "two_chains",
            "params": get_dataset_params(config, "two_chains", None),
        },
        "two_cliques": {
            "type": "two_cliques",
            "params": get_dataset_params(config, "two_cliques", None),
        },
    }

    # Create equivariance dataset using the provided --dataset type
    print(f"Creating {dataset_type} dataset for equivariance metrics...")
    equivariance_params = dict(params)
    equivariance_dataset = create_dataset(
        dataset_type,
        num_samples=dynamics_examples,  # Use same small size as dynamics dataset
        num_nodes=config["num_nodes"],
        **equivariance_params,
    )
    equivariance_dataloader = DataLoader(
        equivariance_dataset, batch_size=32, shuffle=False, drop_last=False
    )
    print(
        f"Created {dataset_type} equivariance dataset with {dynamics_examples} samples"
    )

    for dataset_name, dataset_config in dataset_configs.items():
        try:
            accuracy_dataset = create_dataset(
                dataset_config["type"],
                num_samples=args.num_eval_examples,
                num_nodes=config["num_nodes"],
                **dataset_config["params"],
            )
            accuracy_dataloader = DataLoader(
                accuracy_dataset, batch_size=32, shuffle=False, drop_last=False
            )
            accuracy_datasets[dataset_name] = accuracy_dataloader
            print(
                f"Created {dataset_name} dataset with {args.num_eval_examples} samples"
            )
        except Exception as e:
            print(f"Warning: Could not create {dataset_name} dataset: {e}")
            accuracy_datasets[dataset_name] = None

    # Determine max diameter for path length analysis
    if args.dataset == "two_chains":
        max_diameter = config["num_nodes"] // 2 - 1  # Chain graph diameter
    else:
        max_diameter = min(
            config["num_nodes"] - 1, 3 ** config["num_layers"] + 5
        )  # Reasonable default for most graph types

    # Storage for dynamics data (value-level cosine only)
    dynamics_data = {
        "steps": [],
        "value_frob_cosine": [],
        "value_per_layer_cosine": [],
        "mean_pred_prob": [],
        "losses": [],
        "all_correct_accuracy": [],
        # Edge-level precision/recall metrics
        "edge_precision": [],
        "edge_recall": [],
        "edge_f1_score": [],
        "edge_accuracy": [],
        # Multiple accuracy datasets
        "erdos_renyi_train_accuracy": [],
        "two_chains_accuracy": [],
        "two_cliques_accuracy": [],
        # F1 scores for multiple datasets
        "erdos_renyi_train_f1": [],
        "two_chains_f1": [],
        "two_cliques_f1": [],
        # Path length (diameter) accuracies - will be populated as list of arrays
        "path_length_accuracies": [],
        "max_diameter": max_diameter,
        "num_layers": config["num_layers"],
    }

    # Always attempt to load model_init.pt as step 0 if available; otherwise compute random init as step 0
    init_checkpoint_path = os.path.join(args.ckpt_path, "model_init.pt")
    print("Preparing initialization (step 0) metrics...")
    init_model = create_model(config["model_type"], **model_params)
    device = get_eval_device(config["model_type"])  # select device per rule
    init_model.model.to(device)
    init_model.model.eval()

    if os.path.exists(init_checkpoint_path):
        try:
            init_checkpoint = torch.load(init_checkpoint_path, map_location="cpu")
            if "model_state_dict" in init_checkpoint:
                init_state_dict = init_checkpoint["model_state_dict"]
            elif "state_dict" in init_checkpoint:
                init_state_dict = init_checkpoint["state_dict"]
            else:
                init_state_dict = init_checkpoint
            init_model.model.load_state_dict(init_state_dict)
            print(f"Loaded initialization checkpoint: {init_checkpoint_path}")
        except Exception as e:
            print(f"Warning: Failed to load model_init.pt ({e}). Using random init.")
    else:
        # Optional: set seed for reproducibility if using random init
        if "seed" in config:
            torch.manual_seed(config["seed"])
            np.random.seed(config["seed"])
            if torch.cuda.is_available():
                torch.cuda.manual_seed(config["seed"])

    # Compute metrics for initialization (step 0)
    init_metrics = compute_checkpoint_metrics(
        init_model,
        dataloader,
        args.verbose,
        compute_path_length=True,
        max_diameter=max_diameter,
        device=device,
    )

    # Compute equivariance metrics on selected dataset for initialization
    init_equivariance_metrics = compute_equivariance_metrics_only(
        init_model,
        equivariance_dataloader,
        args.verbose,
        device=device,
    )

    # Compute multi-dataset accuracy and F1 for initialization
    init_accuracy_metrics, init_f1_metrics = compute_multi_dataset_accuracy_and_f1(
        init_model, accuracy_datasets, device=device
    )

    dynamics_data["steps"].append(0)
    dynamics_data["value_frob_cosine"].append(
        init_equivariance_metrics["value_frob_cosine"]
    )
    dynamics_data["value_per_layer_cosine"].append(
        init_equivariance_metrics.get("per_layer_cosine", [])
    )
    dynamics_data["losses"].append(init_metrics["loss"])
    dynamics_data["mean_pred_prob"].append(init_metrics.get("mean_pred_prob", 0.0))
    dynamics_data["all_correct_accuracy"].append(init_metrics["all_correct_accuracy"])
    # Store edge-level precision/recall metrics
    dynamics_data["edge_precision"].append(init_metrics["edge_precision"])
    dynamics_data["edge_recall"].append(init_metrics["edge_recall"])
    dynamics_data["edge_f1_score"].append(init_metrics["edge_f1_score"])
    dynamics_data["edge_accuracy"].append(init_metrics["edge_accuracy"])
    dynamics_data["erdos_renyi_train_accuracy"].append(
        init_accuracy_metrics["erdos_renyi_train"]
    )
    dynamics_data["two_chains_accuracy"].append(init_accuracy_metrics["two_chains"])
    dynamics_data["two_cliques_accuracy"].append(init_accuracy_metrics["two_cliques"])
    # Store F1 scores for multiple datasets
    dynamics_data["erdos_renyi_train_f1"].append(init_f1_metrics["erdos_renyi_train"])
    dynamics_data["two_chains_f1"].append(init_f1_metrics["two_chains"])
    dynamics_data["two_cliques_f1"].append(init_f1_metrics["two_cliques"])
    # Store path length accuracies
    if (
        "path_length_accuracies" in init_metrics
        and init_metrics["path_length_accuracies"] is not None
    ):
        dynamics_data["path_length_accuracies"].append(
            init_metrics["path_length_accuracies"].tolist()
        )
    else:
        dynamics_data["path_length_accuracies"].append([0.0] * (max_diameter + 1))

    print(
        f"Initialization (step 0) - Value Cosine: {init_equivariance_metrics['value_frob_cosine']:.4f}, "
        f"Precision: {init_metrics['edge_precision']:.4f}, Recall: {init_metrics['edge_recall']:.4f}, "
        f"F1: {init_metrics['edge_f1_score']:.4f}"
    )

    # Analyze each checkpoint
    for step_num, checkpoint_path in tqdm(
        checkpoint_data, desc="Analyzing checkpoints"
    ):
        is_final_model = "final_model.pt" in checkpoint_path
        checkpoint_name = "final model" if is_final_model else f"step {step_num}"
        print(f"Analyzing checkpoint at {checkpoint_name}...")

        # Create and load model
        model = create_model(config["model_type"], **model_params)
        try:
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
            if "model_state_dict" in checkpoint:
                state_dict = checkpoint["model_state_dict"]
            elif "state_dict" in checkpoint:
                state_dict = checkpoint["state_dict"]
            else:
                state_dict = checkpoint
            model.model.load_state_dict(state_dict)
            model.model.eval()
            device = get_eval_device(config["model_type"])  # select device per rule
            model.model.to(device)
        except Exception as e:
            print(f"Error loading checkpoint {checkpoint_path}: {e}")
            print(f"Skipping {checkpoint_name}")
            continue

        # Compute metrics
        metrics = compute_checkpoint_metrics(
            model,
            dataloader,
            args.verbose,
            compute_path_length=True,
            max_diameter=max_diameter,
            device=device,
        )

        # Compute equivariance metrics on selected dataset
        equivariance_metrics = compute_equivariance_metrics_only(
            model,
            equivariance_dataloader,
            args.verbose,
            device=device,
        )

        # Compute multi-dataset accuracy and F1
        accuracy_metrics, f1_metrics = compute_multi_dataset_accuracy_and_f1(
            model, accuracy_datasets, device=device
        )

        dynamics_data["steps"].append(step_num)
        dynamics_data["value_frob_cosine"].append(
            equivariance_metrics["value_frob_cosine"]
        )
        dynamics_data["value_per_layer_cosine"].append(
            equivariance_metrics.get("per_layer_cosine", [])
        )
        dynamics_data["losses"].append(metrics["loss"])
        dynamics_data["mean_pred_prob"].append(metrics.get("mean_pred_prob", 0.0))
        dynamics_data["all_correct_accuracy"].append(metrics["all_correct_accuracy"])
        dynamics_data["edge_precision"].append(metrics["edge_precision"])
        dynamics_data["edge_recall"].append(metrics["edge_recall"])
        dynamics_data["edge_f1_score"].append(metrics["edge_f1_score"])
        dynamics_data["edge_accuracy"].append(metrics["edge_accuracy"])
        dynamics_data["erdos_renyi_train_accuracy"].append(
            accuracy_metrics["erdos_renyi_train"]
        )
        dynamics_data["two_chains_accuracy"].append(accuracy_metrics["two_chains"])
        dynamics_data["two_cliques_accuracy"].append(accuracy_metrics["two_cliques"])
        dynamics_data["erdos_renyi_train_f1"].append(f1_metrics["erdos_renyi_train"])
        dynamics_data["two_chains_f1"].append(f1_metrics["two_chains"])
        dynamics_data["two_cliques_f1"].append(f1_metrics["two_cliques"])

        if (
            "path_length_accuracies" in metrics
            and metrics["path_length_accuracies"] is not None
        ):
            dynamics_data["path_length_accuracies"].append(
                metrics["path_length_accuracies"].tolist()
            )
        else:
            dynamics_data["path_length_accuracies"].append([0.0] * (max_diameter + 1))

        print(
            f"{checkpoint_name} - Value Cosine: {equivariance_metrics['value_frob_cosine']:.4f}, "
            f"Loss: {metrics['loss']:.4f}, Precision: {metrics['edge_precision']:.4f}, "
            f"Recall: {metrics['edge_recall']:.4f}, F1: {metrics['edge_f1_score']:.4f}"
        )

        # Clean up memory
        del model
        if torch.cuda.is_available() and str(device).startswith("cuda"):
            torch.cuda.empty_cache()

    # Create modern dynamics plot
    create_modern_dynamics_plot(
        dynamics_data,
        args.ckpt_path,
        dataset_type,
        args.right_panel_mode,
        args.log_scale,
    )

    # Create path length heatmap
    create_path_length_heatmap(dynamics_data, args.ckpt_path, dataset_type)

    # Save path length (diameter) dynamics data separately
    if (
        "path_length_accuracies" in dynamics_data
        and dynamics_data["path_length_accuracies"]
    ):
        diameter_dynamics_data = {
            "steps": dynamics_data["steps"],
            "path_length_accuracies": dynamics_data["path_length_accuracies"],
            "max_diameter": dynamics_data.get("max_diameter", None),
            "num_layers": dynamics_data.get(
                "num_layers", None
            ),  # Include for capacity line
            "dataset_type": dataset_type,
            "ckpt_path": args.ckpt_path,
        }
        diameter_dynamics_path = os.path.join(args.ckpt_path, "diameter_dynamics.json")
        with open(diameter_dynamics_path, "w") as f:
            json.dump(diameter_dynamics_data, f, indent=2)
        print(f"Diameter dynamics data saved to: {diameter_dynamics_path}")

        # Remove path length data from main dynamics to keep it separate
        dynamics_data_copy = dynamics_data.copy()
        if "path_length_accuracies" in dynamics_data_copy:
            del dynamics_data_copy["path_length_accuracies"]
        if "max_diameter" in dynamics_data_copy:
            del dynamics_data_copy["max_diameter"]
    else:
        dynamics_data_copy = dynamics_data

    # Save main dynamics data (without path length data)
    dynamics_path = os.path.join(args.ckpt_path, "training_dynamics.json")
    with open(dynamics_path, "w") as f:
        json.dump(dynamics_data_copy, f, indent=2)
    print(f"Training dynamics data saved to: {dynamics_path}")

    return dynamics_data


def compute_equivariance_metrics_only(model, dataloader, verbose=False, device=None):
    """Compute value-level cosine equivariance metrics (per-layer + average)."""
    model.model.eval()

    per_graph_values = []
    per_graph_layers = []

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in tqdm(
            dataloader, desc="Computing equivariance metrics", leave=False
        ):
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            # We still run a forward pass to keep parity but metrics function will
            # derive per-layer outputs via model.get_hidden_states internally.
            pred_connectivity = model.forward(adj_matrix)

            for i in range(adj_matrix.shape[0]):
                single_adj = adj_matrix[i : i + 1]
                single_conn = connectivity_matrix[i : i + 1]
                single_pred = pred_connectivity[i : i + 1]

                metrics = compute_permutation_equivariant_metrics(
                    single_pred,
                    single_conn,
                    single_adj,
                    num_permutations=128,
                    type="value",
                    model=model,
                    verbose=verbose,
                )

                per_graph_values.append(metrics["perm_frob_cosine_similarity"])
                per_graph_layers.append(metrics.get("per_layer_cosine", []))

    avg_value = float(np.mean(per_graph_values)) if per_graph_values else 0.0

    # Average per-layer across graphs by padding shorter lists with NaNs then nanmean
    max_L = max((len(x) for x in per_graph_layers), default=0)
    if max_L > 0:
        arr = np.full((len(per_graph_layers), max_L), np.nan, dtype=float)
        for r, lst in enumerate(per_graph_layers):
            arr[r, : len(lst)] = lst
        per_layer_avg = np.nanmean(arr, axis=0).tolist()
    else:
        per_layer_avg = []

    return {
        "value_frob_cosine": avg_value,
        "per_layer_cosine": per_layer_avg,
    }


def compute_checkpoint_metrics(
    model,
    dataloader,
    verbose=False,
    compute_path_length=False,
    max_diameter=None,
    device=None,
):
    """Compute loss, accuracy, path length metrics, and edge-level precision/recall for a given model checkpoint"""
    model.model.eval()
    criterion = torch.nn.BCEWithLogitsLoss()

    total_loss = 0
    completely_correct_pred = 0
    total_graphs = 0

    # Edge-level precision/recall tracking
    total_true_positives = 0
    total_false_positives = 0
    total_false_negatives = 0
    total_true_negatives = 0

    # Path length tracking
    if compute_path_length and max_diameter is not None:
        from utils import compute_path_length_accuracy, aggregate_path_length_accuracies

        path_length_accuracies = []
        path_length_counts = []

    # Track mean last-layer prediction (sigmoid probabilities) across dataset
    sum_probs = 0.0
    total_elems = 0

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in tqdm(
            dataloader, desc="Computing metrics", leave=False
        ):
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            # Compute loss
            pred_connectivity = model.forward(adj_matrix)
            loss = criterion(pred_connectivity, connectivity_matrix)
            total_loss += loss.item()

            # Accumulate mean prediction before thresholding (use sigmoid)
            sum_probs += torch.sigmoid(pred_connectivity).sum().item()
            total_elems += connectivity_matrix.numel()

            # Compute all_correct accuracy
            pred = (pred_connectivity > 0.0).float()
            for pred_graph, ans_graph in zip(pred, connectivity_matrix):
                if (pred_graph == ans_graph).all():
                    completely_correct_pred += 1
                total_graphs += 1

                # Compute edge-level precision/recall metrics
                # Flatten to 1D for edge-level analysis
                pred_flat = pred_graph.view(-1)
                ans_flat = ans_graph.view(-1)

                # Compute confusion matrix elements
                true_positives = ((pred_flat == 1) & (ans_flat == 1)).sum().item()
                false_positives = ((pred_flat == 1) & (ans_flat == 0)).sum().item()
                false_negatives = ((pred_flat == 0) & (ans_flat == 1)).sum().item()
                true_negatives = ((pred_flat == 0) & (ans_flat == 0)).sum().item()

                total_true_positives += true_positives
                total_false_positives += false_positives
                total_false_negatives += false_negatives
                total_true_negatives += true_negatives

            # Compute path length accuracy if requested
            if compute_path_length and max_diameter is not None:
                batch_path_acc, batch_path_counts = compute_path_length_accuracy(
                    pred_connectivity,
                    connectivity_matrix,
                    adj_matrix,
                    max_diameter,
                    acc_threshold=0.0,
                )
                path_length_accuracies.append(batch_path_acc.cpu())
                path_length_counts.append(batch_path_counts.cpu())

    # Average metrics across all graphs
    avg_loss = total_loss / len(dataloader)
    avg_all_correct = completely_correct_pred / total_graphs if total_graphs > 0 else 0

    # Compute precision, recall, and F1 score
    precision = (
        total_true_positives / (total_true_positives + total_false_positives)
        if (total_true_positives + total_false_positives) > 0
        else 0.0
    )
    recall = (
        total_true_positives / (total_true_positives + total_false_negatives)
        if (total_true_positives + total_false_negatives) > 0
        else 0.0
    )
    f1_score = (
        2 * (precision * recall) / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )
    edge_accuracy = (
        (total_true_positives + total_true_negatives)
        / (
            total_true_positives
            + total_false_positives
            + total_false_negatives
            + total_true_negatives
        )
        if (
            total_true_positives
            + total_false_positives
            + total_false_negatives
            + total_true_negatives
        )
        > 0
        else 0.0
    )

    result = {
        "loss": avg_loss,
        "all_correct_accuracy": avg_all_correct,
        "edge_precision": precision,
        "edge_recall": recall,
        "edge_f1_score": f1_score,
        "edge_accuracy": edge_accuracy,
    }

    # Add mean prediction probability
    if total_elems > 0:
        result["mean_pred_prob"] = sum_probs / total_elems
    else:
        result["mean_pred_prob"] = 0.0

    # Add path length accuracies if computed
    if compute_path_length and max_diameter is not None and path_length_accuracies:
        aggregated_accuracies, total_counts = aggregate_path_length_accuracies(
            path_length_accuracies, path_length_counts
        )
        result["path_length_accuracies"] = (
            aggregated_accuracies.numpy() if aggregated_accuracies is not None else None
        )
        result["path_length_counts"] = (
            total_counts.numpy() if total_counts is not None else None
        )

    return result


def compute_accuracy_and_f1(model, dataloader, device=None):
    """Compute both exact match accuracy and edge-level F1 score for a given model and dataset"""
    model.model.eval()
    completely_correct_pred = 0
    total_graphs = 0

    # Edge-level precision/recall tracking
    total_true_positives = 0
    total_false_positives = 0
    total_false_negatives = 0

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in dataloader:
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            # Compute predictions
            pred_connectivity = model.forward(adj_matrix)
            pred = (pred_connectivity > 0.0).float()

            # Count exact matches
            for pred_graph, ans_graph in zip(pred, connectivity_matrix):
                if (pred_graph == ans_graph).all():
                    completely_correct_pred += 1
                total_graphs += 1

                # Compute edge-level precision/recall metrics
                # Flatten to 1D for edge-level analysis
                pred_flat = pred_graph.view(-1)
                ans_flat = ans_graph.view(-1)

                # Compute confusion matrix elements
                true_positives = ((pred_flat == 1) & (ans_flat == 1)).sum().item()
                false_positives = ((pred_flat == 1) & (ans_flat == 0)).sum().item()
                false_negatives = ((pred_flat == 0) & (ans_flat == 1)).sum().item()

                total_true_positives += true_positives
                total_false_positives += false_positives
                total_false_negatives += false_negatives

    accuracy = completely_correct_pred / total_graphs if total_graphs > 0 else 0

    # Compute F1 score
    precision = (
        total_true_positives / (total_true_positives + total_false_positives)
        if (total_true_positives + total_false_positives) > 0
        else 0.0
    )
    recall = (
        total_true_positives / (total_true_positives + total_false_negatives)
        if (total_true_positives + total_false_negatives) > 0
        else 0.0
    )
    f1_score = (
        2 * (precision * recall) / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )

    return accuracy, f1_score


def compute_accuracy_only(model, dataloader, device=None):
    """Compute only the exact match accuracy for a given model and dataset"""
    model.model.eval()
    completely_correct_pred = 0
    total_graphs = 0

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in dataloader:
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            # Compute predictions
            pred_connectivity = model.forward(adj_matrix)
            pred = (pred_connectivity > 0.0).float()

            # Count exact matches
            for pred_graph, ans_graph in zip(pred, connectivity_matrix):
                if (pred_graph == ans_graph).all():
                    completely_correct_pred += 1
                total_graphs += 1

    return completely_correct_pred / total_graphs if total_graphs > 0 else 0


def compute_multi_dataset_accuracy_and_f1(model, accuracy_datasets, device=None):
    """Compute accuracy and F1 score for multiple specific datasets using pre-created datasets"""
    accuracy_results = {}
    f1_results = {}

    for dataset_name, dataloader in accuracy_datasets.items():
        if dataloader is not None:
            try:
                # Compute both accuracy and F1 score
                accuracy, f1_score = compute_accuracy_and_f1(
                    model, dataloader, device=device
                )
                accuracy_results[dataset_name] = accuracy
                f1_results[dataset_name] = f1_score
            except Exception as e:
                print(f"Warning: Could not compute metrics for {dataset_name}: {e}")
                accuracy_results[dataset_name] = 0.0
                f1_results[dataset_name] = 0.0
        else:
            accuracy_results[dataset_name] = 0.0
            f1_results[dataset_name] = 0.0

    return accuracy_results, f1_results


def compute_multi_dataset_accuracy(model, accuracy_datasets, device=None):
    """Compute accuracy for multiple specific datasets using pre-created datasets"""
    accuracy_results = {}

    for dataset_name, dataloader in accuracy_datasets.items():
        if dataloader is not None:
            try:
                # Compute accuracy
                accuracy = compute_accuracy_only(model, dataloader, device=device)
                accuracy_results[dataset_name] = accuracy
            except Exception as e:
                print(f"Warning: Could not compute accuracy for {dataset_name}: {e}")
                accuracy_results[dataset_name] = 0.0
        else:
            accuracy_results[dataset_name] = 0.0

    return accuracy_results


def create_modern_dynamics_plot(
    dynamics_data,
    save_dir,
    dataset_type,
    right_panel_mode="per-layer",
    log_scale=False,
):
    """Create model_behavior.pdf with two plots: accuracy and value-level consistency.

    Right panel either shows per-layer value cosine lines (one line per layer index) using a modern palette,
    or shows the mean last-layer predictions (sigmoid) across the dataset, controlled by right_panel_mode.
    """
    # Stack panels vertically and keep each panel flatter (wider than tall)
    fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(16, 8))
    fig.patch.set_facecolor("white")

    colors = {
        "accent": "#FF6B35",  # orange
        "blue": "#004E89",
        "neutral": "#6b7280",
        "background": "#f8fafc",
    }

    steps = (
        np.array(dynamics_data["steps"])
        if len(dynamics_data["steps"]) > 0
        else np.array([1])
    )
    # If log scale requested and step 0 exists, compute a small epsilon once and use for plotting positions only
    plot_steps = steps.copy()
    epsilon_value = None
    if log_scale and np.any(plot_steps == 0):
        # Replace 0 with a small positive value to enable log scaling without math errors
        min_pos = np.min(plot_steps[plot_steps > 0]) if np.any(plot_steps > 0) else 1.0
        epsilon_value = max(min_pos / 10.0, 1e-2)
        plot_steps = np.where(plot_steps == 0, epsilon_value, plot_steps)

    # Human-readable dataset label
    dataset_label = (
        dataset_type.replace("_", " ").title()
        if isinstance(dataset_type, str)
        else str(dataset_type)
    )

    # Left: exact match accuracy for three datasets
    # Expect keys: erdos_renyi_train_accuracy, two_chains_accuracy, two_cliques_accuracy
    if "erdos_renyi_train_accuracy" in dynamics_data:
        ax_left.plot(
            plot_steps if log_scale else steps,
            dynamics_data["erdos_renyi_train_accuracy"],
            color="#2563eb",  # Indigo
            linewidth=3.0,
            marker="o",
            markersize=5,
            label="Erdos-Renyi (train)",
        )
    if "two_chains_accuracy" in dynamics_data:
        ax_left.plot(
            plot_steps if log_scale else steps,
            dynamics_data["two_chains_accuracy"],
            color="#16a34a",  # Green
            linewidth=3.0,
            marker="s",
            markersize=5,
            label="Two Chains",
        )
    if "two_cliques_accuracy" in dynamics_data:
        ax_left.plot(
            plot_steps if log_scale else steps,
            dynamics_data["two_cliques_accuracy"],
            color="#f59e0b",  # Amber
            linewidth=3.0,
            marker="^",
            markersize=5,
            label="Two Cliques",
        )
    ax_left.set_title("Exact Match Accuracy", fontsize=20, fontweight="bold")
    ax_left.set_xlabel("Training Step", fontsize=16, fontweight="bold")
    ax_left.set_ylabel("Accuracy", fontsize=16, fontweight="bold")
    if log_scale:
        ax_left.set_xscale("log")
    ax_left.set_facecolor("white")
    ax_left.set_xlim(0, 1.02e6)
    ax_left.set_ylim(-0.02, 1.02)
    ax_left.grid(True, linestyle="--", alpha=0.3)
    leg_left = ax_left.legend(frameon=False, fontsize=18, loc="upper left", ncol=1)
    for txt in leg_left.get_texts():
        txt.set_fontweight("bold")

    # Right: either per-layer value cosine or mean last-layer predictions
    if right_panel_mode == "per-layer":
        per_layer_series = dynamics_data.get("value_per_layer_cosine", [])
        max_L = max((len(x) for x in per_layer_series), default=0)
        if max_L > 0 and len(per_layer_series) == len(steps):
            layer_mat = np.full((len(steps), max_L), np.nan, dtype=float)
            for t, layer_vals in enumerate(per_layer_series):
                layer_mat[t, : len(layer_vals)] = layer_vals

            modern_palette = [
                "#2563eb",
                "#16a34a",
                "#f59e0b",
                "#ef4444",
                "#8b5cf6",
                "#06b6d4",
                "#f472b6",
                "#10b981",
                "#f97316",
                "#94a3b8",
                "#22c55e",
                "#0ea5e9",
                "#e11d48",
                "#7c3aed",
                "#14b8a6",
                "#a855f7",
                "#60a5fa",
                "#f43f5e",
                "#6366f1",
                "#eab308",
            ]
            # Drop layer 0; plot layers starting from index 1
            plotted_indices = list(range(1, max_L))
            for l in plotted_indices:
                color = modern_palette[l % len(modern_palette)]
                ax_right.plot(
                    plot_steps if log_scale else steps,
                    layer_mat[:, l],
                    color=color,
                    linewidth=3.5,
                    alpha=0.95,
                    marker="o",
                    markersize=5,
                    label=f"Layer {l}",
                )

            # CSV export
            try:
                import csv

                csv_path = os.path.join(save_dir, "value_cosine_per_layer.csv")
                with open(csv_path, "w", newline="") as f:
                    writer = csv.writer(f)
                    header = ["step"] + [f"layer_{i}" for i in plotted_indices]
                    writer.writerow(header)
                    for idx, step in enumerate(steps):
                        row_values = [
                            (
                                None
                                if np.isnan(layer_mat[idx, j])
                                else float(layer_mat[idx, j])
                            )
                            for j in plotted_indices
                        ]
                        row = [int(step)] + row_values
                        writer.writerow(row)
                print(f"Per-layer value cosine CSV saved to: {csv_path}")
            except Exception as e:
                print(f"Warning: could not save per-layer CSV: {e}")

            ax_right.set_title(
                "Model Equivariance Consistency",
                fontsize=20,
                fontweight="bold",
            )
            # Show legend for layer indices (no global title)
            try:
                num_lines = len(plotted_indices)
                if num_lines > 0:
                    leg = ax_right.legend(
                        frameon=False,
                        fontsize=18,
                        ncol=1,
                        loc="lower right",
                    )
                    # Make legend texts bold
                    for txt in leg.get_texts():
                        txt.set_fontweight("bold")
            except Exception:
                leg = ax_right.legend(
                    frameon=False, ncol=1, loc="lower right", fontsize=18
                )
                for txt in leg.get_texts():
                    txt.set_fontweight("bold")
    else:  # last-layer-preds
        mean_probs = dynamics_data.get("mean_pred_prob", [])
        if len(mean_probs) == len(steps) and len(mean_probs) > 0:
            ax_right.plot(
                plot_steps if log_scale else steps,
                mean_probs,
                color="#2563eb",
                linewidth=3.5,
                marker="o",
                markersize=6,
                label="Mean Sigmoid Prediction",
            )
        ax_right.set_title(
            "Model Equivariance Consistency",
            fontsize=20,
            fontweight="bold",
        )

    ax_right.set_xlabel("Training Step", fontsize=16, fontweight="bold")
    if right_panel_mode == "per-layer":
        ax_right.set_ylabel("Cosine Similarity", fontsize=16, fontweight="bold")
    else:
        ax_right.set_ylabel("Mean Sigmoid Prediction", fontsize=16, fontweight="bold")
    # X-axis scaling for right panel
    if log_scale:
        ax_right.set_xscale("log")
    ax_right.set_facecolor("white")
    ax_right.grid(True, linestyle="--", alpha=0.3)
    ax_right.set_xlim(0, 1.02e6)
    # Force right panel y-limits to [0.5, 1.0]
    try:
        ylim_min = min(0.5, np.min(layer_mat[:, 1:]) - 0.02)
        ax_right.set_ylim(ylim_min, 1.02)
    except Exception:
        pass

    # If log scale is enabled and we plotted an epsilon for step 0, add a tick labeled '0'
    if log_scale and "plot_steps" in locals():
        # We computed epsilon_value where steps==0 were replaced by epsilon_value
        zero_mask = steps == 0
        if np.any(zero_mask) and epsilon_value is not None:
            for ax in (ax_left, ax_right):
                ticks = list(ax.get_xticks())
                # Add epsilon tick if missing
                if not any(
                    np.isclose(t, epsilon_value, rtol=1e-6, atol=1e-12) for t in ticks
                ):
                    ticks.append(epsilon_value)
                    ticks = sorted(ticks)
                    ax.set_xticks(ticks)
                # Replace label for epsilon with '0'
                labels = ax.get_xticklabels()
                for i, t in enumerate(ax.get_xticks()):
                    if np.isclose(t, epsilon_value, rtol=1e-6, atol=1e-12) and i < len(
                        labels
                    ):
                        labels[i].set_text("0")

    # Explicitly set ticks/labels so epsilon (step 0) shows as '0' on both panels when using log scale
    if log_scale and "plot_steps" in locals():
        zero_mask = steps == 0
        if np.any(zero_mask) and epsilon_value is not None:
            for ax in (ax_left, ax_right):
                ticks = list(ax.get_xticks())
                if not any(
                    np.isclose(t, epsilon_value, rtol=1e-6, atol=1e-12) for t in ticks
                ):
                    ticks.append(epsilon_value)
                ticks = sorted(ticks)
                ax.set_xticks(ticks)

                # Build mathtext labels 10^x for non-epsilon ticks
                labels = []
                for t in ticks:
                    if np.isclose(t, epsilon_value, rtol=1e-6, atol=1e-12):
                        labels.append("0")
                    else:
                        exp = np.log10(t) if t > 0 else 0.0
                        if np.isclose(exp, round(exp)):
                            labels.append(f"$\\mathbf{{10^{{{int(round(exp))}}}}}$")
                        else:
                            labels.append(f"$\\mathbf{{10^{{{exp:.1f}}}}}$")
                text_objs = ax.set_xticklabels(labels)
                for txt in text_objs:
                    txt.set_fontweight("bold")

    # Enforce final x-limits to include full range up to 1.02e6 after tick handling
    x_right_max = 1.02e6
    for ax in (ax_left, ax_right):
        if log_scale:
            left = (
                epsilon_value
                if epsilon_value is not None
                else (np.min(steps[steps > 0]) * 0.9 if np.any(steps > 0) else 1e-2)
            )
            try:
                ax.set_xlim(left=left, right=x_right_max)
            except Exception:
                pass
        else:
            ax.set_xlim(left=0, right=x_right_max)

    # Legend only for last-layer-preds mode
    if right_panel_mode == "last-layer-preds":
        leg = ax_right.legend(frameon=True, ncol=1, loc="lower right", fontsize=24)
        for txt in leg.get_texts():
            txt.set_fontweight("bold")

    # Common styling
    for ax in (ax_left, ax_right):
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_color(colors["neutral"])
        ax.spines["bottom"].set_color(colors["neutral"])
        for spine in ax.spines.values():
            spine.set_linewidth(1.5)
        ax.tick_params(colors=colors["neutral"], labelsize=14)
        for label in ax.get_xticklabels() + ax.get_yticklabels():
            label.set_fontweight("bold")

    # Reduce whitespace between panels and apply tight layout
    try:
        fig.subplots_adjust(hspace=0.25)
    except Exception:
        pass
    plt.tight_layout(pad=0.8)
    save_path = os.path.join(save_dir, "model_behavior.pdf")
    plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white")
    plt.close(fig)
    print(f"Training dynamics plot saved to: {save_path}")
    return save_path


def create_path_length_heatmap(dynamics_data, save_dir, dataset_type):
    """Create heatmap and threshold plot showing path length (diameter) accuracy vs training steps"""

    if (
        "path_length_accuracies" not in dynamics_data
        or not dynamics_data["path_length_accuracies"]
    ):
        print("No path length accuracy data found, skipping heatmap creation")
        return None

    steps = np.array(dynamics_data["steps"])
    path_length_data = np.array(
        dynamics_data["path_length_accuracies"]
    )  # Shape: (num_steps, max_diameter+1)
    max_diameter = dynamics_data.get("max_diameter", path_length_data.shape[1] - 1)

    # ===== CREATE HEATMAP FIGURE =====
    fig1, ax1 = plt.subplots(1, 1, figsize=(12, 8))
    fig1.patch.set_facecolor("white")

    # Transpose data for heatmap (diameters on y-axis, steps on x-axis)
    # Include all diameters starting from 0
    heatmap_data = path_length_data[
        :, 0 : max_diameter + 1
    ].T  # Shape: (max_diameter+1, num_steps)

    # Create heatmap
    im = ax1.imshow(
        heatmap_data,
        cmap="magma",
        aspect="auto",
        vmin=0,
        vmax=1,
        interpolation="nearest",
        origin="lower",  # Flip y-axis so 0 is at bottom, max_diameter at top
    )

    # Set axis labels and ticks for heatmap
    ax1.set_xlabel("Training Step", fontsize=14, fontweight="bold")
    ax1.set_ylabel("Path Length", fontsize=14, fontweight="bold")
    ax1.set_title(
        f'Path Length Accuracy Heatmap - {dataset_type.replace("_", " ").title()}',
        fontsize=16,
        fontweight="bold",
        pad=20,
    )

    # Set x-ticks to show actual step numbers (log scale visualization)
    num_ticks = min(len(steps), 10)  # Show at most 10 ticks
    tick_indices = np.linspace(0, len(steps) - 1, num_ticks, dtype=int)
    ax1.set_xticks(tick_indices)
    ax1.set_xticklabels([f"{steps[i]:.0f}" for i in tick_indices])

    # Set y-ticks to show diameter values (0 to max_diameter)
    diameter_range = np.arange(0, max_diameter + 1)
    ax1.set_yticks(np.arange(len(diameter_range)))
    ax1.set_yticklabels(diameter_range)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
    cbar.set_label("Accuracy", fontsize=12, fontweight="bold")
    cbar.ax.tick_params(labelsize=10)

    # Style the heatmap
    ax1.tick_params(labelsize=12)
    for label in ax1.get_xticklabels() + ax1.get_yticklabels():
        label.set_fontweight("bold")

    # Add grid for better readability
    ax1.set_xticks(np.arange(-0.5, len(steps), 1), minor=True)
    ax1.set_yticks(np.arange(-0.5, len(diameter_range), 1), minor=True)
    ax1.grid(which="minor", color="white", linestyle="-", linewidth=0.5, alpha=0.7)

    plt.tight_layout()

    # Save the heatmap
    heatmap_save_path = os.path.join(save_dir, "diameter_dynamics_heatmap.pdf")
    plt.savefig(
        heatmap_save_path,
        dpi=300,
        bbox_inches="tight",
        facecolor="white",
        edgecolor="none",
    )
    plt.close(fig1)

    # ===== CREATE RELIABLE THRESHOLD LINE PLOT =====
    fig2, ax2 = plt.subplots(1, 1, figsize=(12, 8))
    fig2.patch.set_facecolor("white")

    # Compute the first path length k where all j <= k have accuracy >= 0.99
    reliable_thresholds = []
    threshold_accuracy = 0.99

    for step_idx in range(len(steps)):
        step_accuracies = path_length_data[step_idx, 0 : max_diameter + 1]

        # Find the first k where all j <= k have accuracy >= threshold
        reliable_k = -1  # -1 means no reliable threshold found
        for k in range(max_diameter + 1):
            # Check if all path lengths from 0 to k have accuracy >= threshold
            if (
                np.all(step_accuracies[1 : k + 1] >= threshold_accuracy)
                and step_accuracies[0] > 0.1  # Ensure path length 0 accuracy is > 0.1
            ):
                reliable_k = k
            else:
                break  # Once we find a path length below threshold, stop

        reliable_thresholds.append(reliable_k)

    reliable_thresholds = np.array(reliable_thresholds)

    # Plot the reliable threshold progression
    # Handle cases where no reliable threshold is found (-1 values)
    mask = reliable_thresholds >= 0

    if np.any(mask):
        ax2.plot(
            steps[mask],
            reliable_thresholds[mask],
            color="#FF6B35",
            linewidth=4.0,
            marker="o",
            markersize=7,
            alpha=0.95,
            label=f"Max Perfect Path Length (Acc≥{threshold_accuracy:.2f})",
        )

        # Fill area below the curve to emphasize the reliable region
        ax2.fill_between(
            steps[mask],
            0,
            reliable_thresholds[mask],
            color="#FF6B35",
            alpha=0.2,
            label="Perfect Prediction Region",
        )

    # Mark points where no reliable threshold exists
    no_reliable_mask = ~mask
    if np.any(no_reliable_mask):
        ax2.scatter(
            steps[no_reliable_mask],
            np.zeros(np.sum(no_reliable_mask)),
            color="#D32F2F",
            marker="x",
            s=100,
            alpha=0.7,
            label="No Reliable Threshold",
            zorder=5,
        )

    # Add theoretical capacity line if num_layers information is available
    if "num_layers" in dynamics_data and dynamics_data["num_layers"] is not None:
        theoretical_capacity = 3 ** dynamics_data["num_layers"]
        ax2.axhline(
            y=theoretical_capacity,
            color="#1E3A8A",
            linestyle="--",
            linewidth=2.5,
            alpha=0.9,
            label=f"3^{dynamics_data['num_layers']} (capacity)",
        )

    # Style the threshold plot
    ax2.set_xlabel("Training Step", fontsize=14, fontweight="bold")
    ax2.set_ylabel("Maximum Perfect Path Length", fontsize=14, fontweight="bold")
    ax2.set_title(
        "Progression of Perfect Path Length Prediction",
        fontsize=20,
        fontweight="bold",
        pad=20,
    )
    ax2.set_xscale("log")  # Log scale for x-axis to match other plots
    ax2.grid(True, alpha=0.3, linestyle="--")
    ax2.set_facecolor("#f8fafc")

    # Set y-axis limits and ticks
    y_max = max_diameter + 0.5
    ax2.set_ylim(-0.5, y_max)
    ax2.set_yticks(np.arange(0, max_diameter + 1, max(1, max_diameter // 10)))

    # Add legend
    ax2.legend(frameon=True, fancybox=True, shadow=True, fontsize=12, loc="upper right")

    # Style the line plot
    ax2.tick_params(labelsize=12)
    for label in ax2.get_xticklabels() + ax2.get_yticklabels():
        label.set_fontweight("bold")

    plt.tight_layout()

    # Save the line plot
    line_save_path = os.path.join(save_dir, "diameter_dynamics_line.pdf")
    plt.savefig(
        line_save_path,
        dpi=300,
        bbox_inches="tight",
        facecolor="white",
        edgecolor="none",
    )
    plt.close(fig2)

    print(f"Diameter dynamics heatmap saved to: {heatmap_save_path}")
    print(f"Diameter dynamics line plot saved to: {line_save_path}")

    return heatmap_save_path, line_save_path


def run_iclr_analysis(model, config, args, device=None):
    """Run the ICLR-specific analysis"""
    print("Running ICLR analysis...")
    print(f"Plot mode: {args.iclr_plot_mode}")

    # Check what to plot based on the flag
    plot_behavior = args.iclr_plot_mode in ["all", "behavior"]
    plot_graphs = args.iclr_plot_mode in ["all", "graphs"]

    if plot_behavior:
        # Run training dynamics analysis
        print("\n" + "=" * 60)
        print("TRAINING DYNAMICS ANALYSIS")
        print("=" * 60)

        # Check if --dataset is specified for behavior analysis
        if args.dataset is None:
            raise ValueError(
                "ICLR behavior analysis requires --dataset to be specified"
            )

        dynamics_data = analyze_training_dynamics(config, args)

    if not plot_graphs:
        print("Skipping graph analysis as per --iclr_plot_mode setting")
        return

    print("\n" + "=" * 60)
    print("GRAPH ANALYSIS SETUP")
    print("=" * 60)

    # For graph analysis, always use Erdos-Renyi graphs regardless of --dataset
    print(
        "Using Erdos-Renyi graphs for graph analysis (override --dataset for this part)"
    )
    graph_dataset_type = "erdos_renyi"

    # Use the original dataset info for reference but override for actual analysis
    if args.dataset is not None:
        original_dataset = (
            args.dataset[0] if isinstance(args.dataset, list) else args.dataset
        )
        print(f"Original --dataset setting: {original_dataset}")
    print(f"Graph analysis dataset: {graph_dataset_type}")
    print(f"Number of evaluation examples: {args.num_eval_examples}")

    # Create Erdos-Renyi dataset for graph analysis
    params = get_dataset_params(config, graph_dataset_type, args.eval_edge_prob)
    dataset = create_dataset(
        graph_dataset_type,
        num_samples=args.num_eval_examples,  # Use the specified number of examples
        num_nodes=config["num_nodes"],
        **params,
    )

    dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
    criterion = torch.nn.BCEWithLogitsLoss()

    # Create visualization directory
    vis_dir = os.path.join(args.ckpt_path, "iclr_analysis")
    os.makedirs(vis_dir, exist_ok=True)

    print("\n" + "=" * 60)
    print("WORST GRAPH ANALYSIS")
    print("=" * 60)

    # First, try to find a graph with both false positives and false negatives
    print("Looking for a graph with both false positives and false negatives...")
    both_errors_graph = find_graph_with_both_errors(
        model, dataloader, criterion, device=device
    )

    if both_errors_graph is not None:
        graph_data, pred_data, error_info = both_errors_graph
        adj_matrix, connectivity_matrix = graph_data

        print(f"Found graph with both error types!")
        print(f"  False positives: {error_info['false_positives']}")
        print(f"  False negatives: {error_info['false_negatives']}")
        print(f"  Loss: {error_info['loss']:.4f}")
        print(f"  Error balance: {error_info['error_balance']:.3f}")

        # Save this special graph as graph.pdf
        save_path = os.path.join(vis_dir, "graph.pdf")
        permutation = visualize_iclr_graph(
            adj_matrix, connectivity_matrix, pred_data, save_path
        )
        print(f"Graph with both error types saved to: {save_path}")

        # Compute metrics for this graph
        print(
            "Computing equivariant consistency metrics for the balanced error graph..."
        )
        num_permutations = 4096

        if device is None:
            device = next(model.model.parameters()).device
        adj_cuda = adj_matrix.float().to(device)
        conn_cuda = connectivity_matrix.float().to(device)
        pred_cuda = pred_data.float().to(device)

        value_metrics = compute_permutation_equivariant_metrics(
            pred_cuda,
            conn_cuda,
            adj_cuda,
            num_permutations=num_permutations,
            type="value",
            model=model,
            verbose=args.verbose,
        )

        # Save metrics for the special graph
        special_metrics = {
            "graph_type": "both_errors",
            "error_info": error_info,
            "value_consistency": value_metrics,
            "num_permutations": num_permutations,
        }

        special_metrics_path = os.path.join(vis_dir, "graph_metrics.json")
        import json

        with open(special_metrics_path, "w") as f:
            json.dump(special_metrics, f, indent=2)
        print(f"Special graph metrics saved to: {special_metrics_path}")

        print("\n" + "=" * 60)
        print("BALANCED ERROR GRAPH METRICS")
        print("=" * 60)
        print(f"False Positives: {error_info['false_positives']}")
        print(f"False Negatives: {error_info['false_negatives']}")
        print(
            f"Value Frobenius cosine similarity: {value_metrics['perm_frob_cosine_similarity']:.6f}"
        )
        print("=" * 60)

    else:
        print("No graph found with both false positives and false negatives.")
        print("Falling back to worst performing graph analysis...")

    # Find the worst performing graphs
    print("Finding top 5 worst performing graphs...")
    worst_graphs = find_worst_performing_graphs(
        model, dataloader, criterion, top_k=5, device=device
    )

    if not worst_graphs:
        print("No graphs found for analysis!")
        return

    print(f"Found {len(worst_graphs)} worst performing graphs")
    for i, (loss, _, _) in enumerate(worst_graphs):
        print(f"  Graph {i+1}: Loss = {loss:.4f}")

    # Visualize all worst performing graphs
    all_metrics = []
    for i, (worst_loss, worst_graph_data, worst_pred) in enumerate(worst_graphs):
        adj_matrix, connectivity_matrix = worst_graph_data

        # Visualize the graph
        save_path = os.path.join(vis_dir, f"graph_{i+1:03d}.pdf")
        permutation = visualize_iclr_graph(
            adj_matrix, connectivity_matrix, worst_pred, save_path
        )
        print(f"Graph {i+1} visualization saved to: {save_path}")

        # Compute equivariant consistency metrics for this graph
        print(f"Computing equivariant consistency metrics for graph {i+1}...")
        num_permutations = 16  # Large number as requested

        # Convert to cuda tensors for the metric computation
        if device is None:
            device = next(model.model.parameters()).device
        adj_cuda = adj_matrix.float().to(device)
        conn_cuda = connectivity_matrix.float().to(device)
        pred_cuda = worst_pred.float().to(device)

        # Compute value consistency only
        value_metrics = compute_permutation_equivariant_metrics(
            pred_cuda,
            conn_cuda,
            adj_cuda,
            num_permutations=num_permutations,
            type="value",
            model=model,
            verbose=args.verbose,
        )

        # Store metrics for this graph
        graph_metrics = {
            "graph_id": i + 1,
            "loss": worst_loss,
            "value_consistency": value_metrics,
            "num_permutations": num_permutations,
        }
        all_metrics.append(graph_metrics)

        print(f"\nGraph {i+1} Results:")
        print(f"  Loss: {worst_loss:.4f}")
        print(
            f"  Value Frobenius cosine similarity: {value_metrics['perm_frob_cosine_similarity']:.6f}"
        )

    # Print summary of all graphs
    print("\n" + "=" * 80)
    print("SUMMARY OF TOP 5 WORST PERFORMING GRAPHS")
    print("=" * 80)
    for i, metrics in enumerate(all_metrics):
        print(f"Graph {i+1}:")
        print(f"  Loss: {metrics['loss']:.4f}")
        print(
            f"  Value consistency (Frobenius cosine): {metrics['value_consistency']['perm_frob_cosine_similarity']:.6f}"
        )
        print()
    print("=" * 80)

    # Save all metrics to file
    metrics_path = os.path.join(vis_dir, "all_graphs_metrics.json")
    import json

    with open(metrics_path, "w") as f:
        json.dump(all_metrics, f, indent=2)
    print(f"All metrics saved to: {metrics_path}")


def evaluate_model(model, dataloader, criterion, acc_threshold=0.0, device=None):
    """Evaluate model on given dataloader with precision/recall metrics"""
    model.model.eval()
    total_loss = 0
    correct_preds = 0
    total_preds = 0
    completely_correct_pred = 0
    total_graphs = 0

    # Edge-level precision/recall tracking
    total_true_positives = 0
    total_false_positives = 0
    total_false_negatives = 0
    total_true_negatives = 0

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in tqdm(dataloader, desc="Evaluating"):
            if device is None:
                device = next(model.model.parameters()).device
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)

            pred_connectivity = model.forward(adj_matrix)
            loss = criterion(pred_connectivity, connectivity_matrix)

            total_loss += loss.item()
            pred = (pred_connectivity > acc_threshold).float()
            correct_preds += ((pred == connectivity_matrix)).sum().item()
            total_preds += connectivity_matrix.numel()

            # Count graphs that are completely correct
            for pred_graph, ans_graph in zip(pred, connectivity_matrix):
                if (pred_graph == ans_graph).all():
                    completely_correct_pred += 1
                total_graphs += 1

                # Compute edge-level precision/recall metrics
                # Flatten to 1D for edge-level analysis
                pred_flat = pred_graph.view(-1)
                ans_flat = ans_graph.view(-1)

                # Compute confusion matrix elements
                true_positives = ((pred_flat == 1) & (ans_flat == 1)).sum().item()
                false_positives = ((pred_flat == 1) & (ans_flat == 0)).sum().item()
                false_negatives = ((pred_flat == 0) & (ans_flat == 1)).sum().item()
                true_negatives = ((pred_flat == 0) & (ans_flat == 0)).sum().item()

                total_true_positives += true_positives
                total_false_positives += false_positives
                total_false_negatives += false_negatives
                total_true_negatives += true_negatives

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0
    avg_accuracy = correct_preds / total_preds if total_preds > 0 else 0
    avg_all_correct = completely_correct_pred / total_graphs if total_graphs > 0 else 0

    # Compute precision, recall, and F1 score
    precision = (
        total_true_positives / (total_true_positives + total_false_positives)
        if (total_true_positives + total_false_positives) > 0
        else 0.0
    )
    recall = (
        total_true_positives / (total_true_positives + total_false_negatives)
        if (total_true_positives + total_false_negatives) > 0
        else 0.0
    )
    f1_score = (
        2 * (precision * recall) / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )

    return avg_loss, avg_accuracy, avg_all_correct, precision, recall, f1_score


def visualize_model_weights(model, save_dir, model_type):
    """Visualize model weights based on model type"""
    if model_type == "roberta":
        # RoBERTa weight visualization
        for layer_idx, block in enumerate(model.model._backbone.encoder.layer):
            query_weight = block.attention.self.query.weight.detach().cpu().numpy()
            key_weight = block.attention.self.key.weight.detach().cpu().numpy()
            qk_dot = np.dot(query_weight, key_weight.T)

            fig, ax = plt.subplots(1, 3, figsize=(15, 5))
            ax[0].imshow(query_weight, cmap="viridis")
            ax[0].set_title(f"Layer {layer_idx} Query Weight Matrix")
            ax[1].imshow(key_weight, cmap="viridis")
            ax[1].set_title(f"Layer {layer_idx} Key Weight Matrix")
            ax[2].imshow(qk_dot, cmap="RdBu_r")
            ax[2].set_title(f"Layer {layer_idx} QK Dot Product")

            save_path = os.path.join(save_dir, f"layer_{layer_idx}_weights.png")
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
            plt.close(fig)  # Close the figure to free memory

        # Output layer weights
        W = model.model._read_out.weight.detach().cpu().numpy()
        fig = plt.figure()
        plt.imshow(W, cmap="viridis")
        plt.title("Output Layer Weight Matrix")
        plt.colorbar()
        save_path = os.path.join(save_dir, "output_layer_weights.png")
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close(fig)  # Close the figure to free memory

    elif model_type == "looped_transformer":
        # Looped transformer weight visualization
        # Visualize the single layer that gets looped

        # Attention weights (query, key, value)
        attn_layer = model.model.layer.attention.self
        query_weight = attn_layer.query.weight.detach().cpu().numpy()
        key_weight = attn_layer.key.weight.detach().cpu().numpy()
        value_weight = attn_layer.value.weight.detach().cpu().numpy()
        qk_dot = np.dot(query_weight, key_weight.T)

        fig, ax = plt.subplots(2, 2, figsize=(12, 10))
        im1 = ax[0, 0].imshow(query_weight, cmap="viridis")
        ax[0, 0].set_title("Query Weight Matrix")
        plt.colorbar(im1, ax=ax[0, 0])

        im2 = ax[0, 1].imshow(key_weight, cmap="viridis")
        ax[0, 1].set_title("Key Weight Matrix")
        plt.colorbar(im2, ax=ax[0, 1])

        im3 = ax[1, 0].imshow(value_weight, cmap="viridis")
        ax[1, 0].set_title("Value Weight Matrix")
        plt.colorbar(im3, ax=ax[1, 0])

        im4 = ax[1, 1].imshow(qk_dot, cmap="RdBu_r")
        ax[1, 1].set_title("QK Dot Product")
        plt.colorbar(im4, ax=ax[1, 1])

        save_path = os.path.join(save_dir, "attention_weights.png")
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close(fig)

        # FFN weights
        ffn_layer = model.model.layer.intermediate
        dense_weight = ffn_layer.dense.weight.detach().cpu().numpy()

        fig = plt.figure(figsize=(10, 6))
        plt.imshow(dense_weight, cmap="viridis", aspect="auto")
        plt.title("Feed-Forward Network Dense Layer Weight Matrix")
        plt.colorbar()
        save_path = os.path.join(save_dir, "ffn_weights.png")
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close(fig)

        # Read-in and read-out weights
        if hasattr(model.model, "read_in") and model.model.read_in is not None:
            read_in_weight = model.model.read_in.weight.detach().cpu().numpy()
            fig = plt.figure(figsize=(8, 6))
            plt.imshow(read_in_weight, cmap="viridis")
            plt.title("Read-in Layer Weight Matrix")
            plt.colorbar()
            save_path = os.path.join(save_dir, "read_in_weights.png")
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
            plt.close(fig)

        read_out_weight = model.model.read_out.weight.detach().cpu().numpy()
        fig = plt.figure(figsize=(8, 6))
        plt.imshow(read_out_weight, cmap="viridis")
        plt.title("Read-out Layer Weight Matrix")
        plt.colorbar()
        save_path = os.path.join(save_dir, "read_out_weights.png")
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close(fig)

    elif model_type == "disentangled_transformer":
        # Disentangled transformer weight visualization
        # Visualize the attention matrices A_i for each layer

        for layer_idx, A_layer in enumerate(model.model.A):
            A_weight = A_layer.detach().cpu().numpy()  # Shape: [n_head, d, d]
            n_heads = A_weight.shape[0]

            # Create subplot for each head in this layer
            fig, axes = plt.subplots(1, n_heads, figsize=(5 * n_heads, 4))
            if n_heads == 1:
                axes = [axes]  # Make it iterable

            for head_idx in range(n_heads):
                im = axes[head_idx].imshow(A_weight[head_idx], cmap="RdBu_r")
                axes[head_idx].set_title(f"Layer {layer_idx} Head {head_idx}")
                plt.colorbar(im, ax=axes[head_idx])
                axes[head_idx].axis("off")

            plt.suptitle(f"Disentangled Attention Matrices - Layer {layer_idx}")
            save_path = os.path.join(
                save_dir, f"layer_{layer_idx}_attention_matrices.png"
            )
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
            plt.close(fig)

        # Visualize the output/readout layer
        if hasattr(model.model, "W") and model.model.W is not None:
            W_weight = model.model.W.weight.detach().cpu().numpy()
            fig = plt.figure(figsize=(10, 6))
            plt.imshow(W_weight, cmap="viridis", aspect="auto")
            plt.title("Output/Readout Layer Weight Matrix")
            plt.colorbar()
            save_path = os.path.join(save_dir, "output_weights.png")
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
            plt.close(fig)

        # Create a summary visualization showing all layers
        num_layers = len(model.model.A)
        fig, axes = plt.subplots(1, num_layers, figsize=(4 * num_layers, 4))
        if num_layers == 1:
            axes = [axes]

        for layer_idx, A_layer in enumerate(model.model.A):
            A_weight = A_layer.detach().cpu().numpy()
            # Average across heads for summary view
            avg_weight = np.mean(A_weight, axis=0)
            im = axes[layer_idx].imshow(avg_weight, cmap="RdBu_r")
            axes[layer_idx].set_title(f"Layer {layer_idx}\n(avg across heads)")
            plt.colorbar(im, ax=axes[layer_idx])
            axes[layer_idx].axis("off")

        plt.suptitle("Disentangled Transformer - All Layers (Head-Averaged)")
        save_path = os.path.join(save_dir, "all_layers_summary.png")
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close(fig)


def main():
    args = parse_arguments()
    config = load_config(args.ckpt_path, args.config_path)

    print(f"Loaded config: {config}")

    # Create model
    print(f"Creating {config['model_type']} model...")

    # Prepare model parameters based on type
    model_params = {
        "num_nodes": config["num_nodes"],
        "hidden_size": config["hidden_size"],
        "num_layers": config["num_layers"],
        "num_attention_heads": config["num_attention_heads"],
    }

    if config["model_type"] == "roberta":
        model_params.update(
            {
                "roberta_type": config["roberta_type"],
                "layer_norm_type": config["layer_norm_type"],
                "attention_only": config.get("roberta_attention_only", False),
            }
        )
        # Validate attention_only constraints
        if (
            config.get("roberta_attention_only", False)
            and config["layer_norm_type"] != "pre"
        ):
            raise ValueError(
                "roberta_attention_only=True only supports layer_norm_type='pre'"
            )
    elif config["model_type"] == "looped_transformer":
        model_params.update(
            {
                "read_in_method": config["read_in_method"],
                "layer_norm_type": config["layer_norm_type"],
                "tie_qk": config.get("tie_qk", False),
            }
        )
    elif config["model_type"] == "disentangled_transformer":
        model_params = {
            "num_nodes": config["num_nodes"],
            "heads": config["heads"],
            "init_type": config["init_type"],
            "readout_type": config["readout_type"],
        }

    model = create_model(config["model_type"], **model_params)
    print(f"Model created: {type(model.model).__name__}")

    # Load checkpoint
    checkpoint_file = find_latest_checkpoint(args.ckpt_path, args.ckpt_id)
    print(f"Loading model from {checkpoint_file}...")
    state_dict = torch.load(checkpoint_file, map_location="cpu")["model_state_dict"]
    model.model.load_state_dict(state_dict)
    model.model.eval()
    device = get_eval_device(config["model_type"])  # select device per rule
    model.model.to(device)
    print(f"Model loaded from {checkpoint_file}")

    # If ICLR-only mode, run special analysis and exit
    if args.iclr_only:
        run_iclr_analysis(model, config, args, device=device)
        return

    # Continue with regular evaluation...
    # Determine datasets to evaluate on
    if args.dataset is not None:
        eval_dataset_types = (
            args.dataset if isinstance(args.dataset, list) else [args.dataset]
        )
    else:
        # Use datasets from config
        if "eval_dataset" in config and config["eval_dataset"] is not None:
            eval_dataset_types = (
                config["eval_dataset"]
                if isinstance(config["eval_dataset"], list)
                else [config["eval_dataset"]]
            )
        else:
            eval_dataset_types = (
                config["dataset"]
                if isinstance(config["dataset"], list)
                else [config["dataset"]]
            )

    print(f"Evaluating on datasets: {eval_dataset_types}")

    # Create evaluation datasets
    eval_datasets = create_eval_datasets(
        config, eval_dataset_types, args.num_eval_examples, args.eval_edge_prob
    )
    eval_dataloaders = {
        name: DataLoader(
            dataset, batch_size=args.batch_size, shuffle=False, drop_last=True
        )
        for name, dataset in eval_datasets.items()
    }

    # Create weights visualization directory
    weights_dir = os.path.join(args.ckpt_path, "weights_visualization")
    os.makedirs(weights_dir, exist_ok=True)
    visualize_model_weights(model, weights_dir, config["model_type"])

    # Evaluation
    criterion = torch.nn.BCEWithLogitsLoss()
    acc_threshold = 0.0

    results = {}
    print("\nEvaluating on all datasets...")
    for name, dataloader in eval_dataloaders.items():
        print(f"\nEvaluating on {name}...")
        loss, accuracy, all_correct, precision, recall, f1_score = evaluate_model(
            model, dataloader, criterion, acc_threshold, device=device
        )
        results[name] = {
            "loss": loss,
            "accuracy": accuracy,
            "all_correct_accuracy": all_correct,
            "edge_precision": precision,
            "edge_recall": recall,
            "edge_f1_score": f1_score,
        }
        print(
            f"{name} - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}, All Correct: {all_correct:.4f}, "
            f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1_score:.4f}"
        )

    # Save results
    results_path = os.path.join(args.ckpt_path, "eval_results.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to {results_path}")

    # Visualization of hidden states
    if args.num_vis_examples > 0:
        print(f"\nVisualizing hidden states for {args.num_vis_examples} examples...")
        vis_dir = os.path.join(args.ckpt_path, "hidden_states_visualization")
        os.makedirs(vis_dir, exist_ok=True)

        for dataset_name, dataset in eval_datasets.items():
            print(f"Visualizing {dataset_name}...")
            dataset_vis_dir = os.path.join(vis_dir, dataset_name)
            os.makedirs(dataset_vis_dir, exist_ok=True)

            dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
            examples_per_dataset = min(
                args.num_vis_examples // len(eval_datasets), args.num_vis_examples
            )

            for batch_idx, (adj_matrix, _) in enumerate(dataloader):
                if batch_idx >= examples_per_dataset:
                    break

                with torch.no_grad():
                    hidden_states = model.get_hidden_states(
                        adj_matrix.float().to(device)
                    )

                plot_hidden_states(
                    adj_matrix,
                    hidden_states,
                    example_idx=batch_idx,
                    save_dir=dataset_vis_dir,
                    config=config,
                    model=model,
                    device=device,
                )

                # Clean up memory after each example
                del hidden_states, adj_matrix

                # Periodically clear matplotlib cache to prevent memory buildup
                if batch_idx % 5 == 0:  # Every 5 examples
                    plt.clf()  # Clear current figure
                    plt.cla()  # Clear current axis

            # Clean up after each dataset
            plt.close("all")  # Close all figures
            if torch.cuda.is_available() and str(device).startswith("cuda"):
                torch.cuda.empty_cache()

    print("\nEvaluation completed!")

    # Final cleanup
    plt.close("all")
    if torch.cuda.is_available() and str(device).startswith("cuda"):
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
