#!/usr/bin/env python3
"""
Utility functions for transformer-graph training and evaluation.

This module contains helper functions for:
- Path length computation and accuracy metrics
- Dataset parameter handling
- Visualization utilities
- Loss functions
"""

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import os


def compute_path_lengths_batch(adj_matrices):
    """
    Efficiently compute shortest path lengths for a batch of adjacency matrices.
    Returns a tensor of shape (batch_size, num_nodes, num_nodes) where entry (b, i, j)
    contains the shortest path length from node i to node j in graph b.
    Disconnected nodes have path length set to num_nodes (larger than max possible path).

    Note: Handles adjacency matrices that may have self-loops (diagonal entries).
    """
    batch_size, num_nodes, _ = adj_matrices.shape
    device = adj_matrices.device

    # Remove self-loops for path computation (diagonal should not affect paths)
    adj_no_diag = adj_matrices.clone()
    diag_indices = torch.arange(num_nodes, device=device)
    adj_no_diag[:, diag_indices, diag_indices] = 0

    # Initialize path lengths with infinity (represented as num_nodes)
    path_lengths = torch.full(
        (batch_size, num_nodes, num_nodes),
        num_nodes,
        dtype=torch.long,
        device=device,
    )

    # Set direct connections (path length 1)
    path_lengths[adj_no_diag == 1] = 1

    # Set diagonal to 0 (distance from node to itself)
    path_lengths[:, diag_indices, diag_indices] = 0

    # Floyd-Warshall algorithm for all-pairs shortest paths
    for k in range(num_nodes):
        path_lengths = torch.min(
            path_lengths, path_lengths[:, :, k : k + 1] + path_lengths[:, k : k + 1, :]
        )

    return path_lengths


def compute_accuracy_for_path_length(
    pred_binary, true_connectivity, path_lengths, target_path_length
):
    """
    Compute accuracy for node pairs with a specific path length.

    Args:
        pred_binary: Binary predictions (batch_size, num_nodes, num_nodes)
        true_connectivity: Ground truth connectivity (batch_size, num_nodes, num_nodes)
        path_lengths: Path lengths matrix (batch_size, num_nodes, num_nodes)
        target_path_length: The specific path length to compute accuracy for

    Returns:
        accuracy: Mean accuracy for this path length
        count: Number of node pairs with this path length
    """
    # Find all node pairs with the target path length
    mask = path_lengths == target_path_length

    if mask.sum() == 0:
        return 0.0, 0.0

    # Get predictions and ground truth for these pairs
    pred_subset = pred_binary[mask]
    true_subset = true_connectivity[mask]

    # Compute accuracy
    correct = (pred_subset == true_subset).float()
    accuracy = correct.mean()
    count = mask.sum().float()

    return accuracy, count


def compute_path_length_accuracy(
    pred_connectivity,
    true_connectivity,
    adj_matrices,
    max_path_length,
    acc_threshold=0.0,
):
    """
    Compute accuracy grouped by path length between nodes.

    Note: Path length 0 represents disconnected nodes (no path between them).
    Path lengths 1 through max_path_length represent connected nodes.

    Args:
        pred_connectivity: Model predictions (batch_size, num_nodes, num_nodes)
        true_connectivity: Ground truth connectivity (batch_size, num_nodes, num_nodes)
        adj_matrices: Adjacency matrices (batch_size, num_nodes, num_nodes)
        max_path_length: Maximum path length to consider
        acc_threshold: Threshold for converting predictions to binary (default: 0.0)

    Returns:
        accuracy_by_path_length: Tensor of accuracies for path lengths 0 to max_path_length
        count_by_path_length: Tensor of counts for path lengths 0 to max_path_length
    """
    device = pred_connectivity.device
    num_nodes = adj_matrices.shape[1]
    path_lengths = compute_path_lengths_batch(adj_matrices)

    # Convert predictions to binary using the same threshold as main evaluation
    pred_binary = (pred_connectivity > acc_threshold).float()

    # Initialize results (include path_length 0 for disconnected nodes)
    accuracy_by_path_length = torch.zeros(max_path_length + 1, device=device)
    count_by_path_length = torch.zeros(max_path_length + 1, device=device)

    # Compute accuracy for disconnected nodes (path length 0)
    # Disconnected nodes have path_length == num_nodes (infinity)
    disconnected_mask = path_lengths == num_nodes
    # Exclude diagonal (node to itself)
    diag_indices = torch.arange(num_nodes, device=device)
    disconnected_mask[:, diag_indices, diag_indices] = False

    if disconnected_mask.sum() > 0:
        pred_subset = pred_binary[disconnected_mask]
        true_subset = true_connectivity[disconnected_mask]
        correct = (pred_subset == true_subset).float()
        accuracy_by_path_length[0] = correct.mean()
        count_by_path_length[0] = disconnected_mask.sum().float()

    # Compute accuracy for each connected path length (1 to max_path_length)
    for path_len in range(1, max_path_length + 1):
        accuracy, count = compute_accuracy_for_path_length(
            pred_binary, true_connectivity, path_lengths, path_len
        )
        accuracy_by_path_length[path_len] = accuracy  # Store at index path_len
        count_by_path_length[path_len] = count

    return accuracy_by_path_length, count_by_path_length


def aggregate_path_length_accuracies(accuracies_list, counts_list):
    """
    Aggregate path length accuracies over multiple batches using weighted averaging.

    Args:
        accuracies_list: List of accuracy tensors from different batches
        counts_list: List of count tensors from different batches

    Returns:
        aggregated_accuracies: Weighted average accuracies
        total_counts: Total counts for each path length
    """
    if not accuracies_list:
        return None, None

    max_path_length_plus_one = accuracies_list[0].shape[0]  # Now includes path length 0
    aggregated_accuracies = torch.zeros(max_path_length_plus_one)
    total_counts = torch.zeros(max_path_length_plus_one)

    # Weighted sum
    for acc, counts in zip(accuracies_list, counts_list):
        aggregated_accuracies += acc * counts
        total_counts += counts

    # Avoid division by zero
    mask = total_counts > 0
    aggregated_accuracies[mask] /= total_counts[mask]

    return aggregated_accuracies, total_counts


def create_path_length_wandb_log(accuracies, total_counts, prefix, **extra_fields):
    """
    Create wandb log dictionary for path length accuracies.

    Note: Assumes accuracies and total_counts correspond to path lengths 0, 1, 2, 3, ...
    where path length 0 represents disconnected nodes.

    Args:
        accuracies: Tensor of accuracies for path lengths 0 to N
        total_counts: Tensor of counts for path lengths 0 to N
        prefix: Prefix for wandb keys (e.g., "path_length_accuracy", "epoch_path_length_accuracy")
        **extra_fields: Additional fields to include in the log

    Returns:
        Dictionary ready for wandb.log()
    """
    log_dict = {}

    # Add path length accuracies (starting from path length 0, index 0 = path length 0)
    max_path_length_plus_one = accuracies.shape[0]
    for i in range(max_path_length_plus_one):
        path_len = i  # path_len starts from 0 (disconnected nodes)
        if total_counts[i] > 0:
            log_dict[f"{prefix}/length_{path_len}"] = accuracies[i].item()

    # Add extra fields
    for key, value in extra_fields.items():
        log_dict[f"{prefix}/{key}"] = value

    return log_dict


class FocalLoss(nn.Module):
    """
    Focal Loss implementation for binary classification.

    Focal Loss is designed to address class imbalance and focus learning on hard examples.
    Formula: FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)

    Args:
        alpha (float): Weighting factor for rare class (typically 0.25)
        gamma (float): Focusing parameter (typically 2.0)
        preds_are_probs (bool): Whether inputs are probabilities (True) or logits (False)
        reduction (str): Specifies the reduction to apply to the output
    """

    def __init__(self, alpha=0.25, gamma=2.0, preds_are_probs=False, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.preds_are_probs = preds_are_probs
        self.reduction = reduction

    def forward(self, inputs, targets):
        if not self.preds_are_probs:
            # Convert logits to probabilities
            probs = torch.sigmoid(inputs)
            # Calculate binary cross entropy from logits
            bce_loss = F.binary_cross_entropy_with_logits(
                inputs, targets, reduction="none"
            )
        else:
            # Inputs are already probabilities
            probs = inputs
            # Calculate binary cross entropy from probabilities
            # Add small epsilon to avoid log(0)
            eps = 1e-8
            probs_clamped = torch.clamp(probs, eps, 1 - eps)
            bce_loss = F.binary_cross_entropy(probs_clamped, targets, reduction="none")

        # Calculate p_t
        p_t = probs * targets + (1 - probs) * (1 - targets)

        # Calculate alpha_t
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)

        # Calculate focal weight
        focal_weight = alpha_t * (1 - p_t) ** self.gamma

        # Calculate focal loss
        focal_loss = focal_weight * bce_loss

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        else:
            return focal_loss


def get_dataset_params(args, dataset_type):
    """Get dataset-specific parameters based on dataset type"""
    params = {}

    if dataset_type == "erdos_renyi":
        params.update(
            {
                "p": args.fixed_p,
                "sample_p": args.sample_p,
                "p_range": args.p_range,
                "restrict_diam": getattr(args, "restrict_diam", None),
            }
        )

    elif dataset_type == "two_chains":
        params["k"] = getattr(
            args, "k", None
        )  # Let the TwoChainsGenerator calculate k if None
        params["add_isolated_nodes"] = getattr(args, "add_isolated_nodes", False)

    elif dataset_type == "two_trees":
        params["k"] = getattr(
            args, "k", None
        )  # Let the TwoTreesGenerator calculate k if None
        params["add_isolated_nodes"] = getattr(args, "add_isolated_nodes", False)

    elif dataset_type == "two_stars":
        params["k"] = getattr(
            args, "k", None
        )  # Let the TwoStarsGenerator calculate k if None
        params["add_isolated_nodes"] = getattr(args, "add_isolated_nodes", False)

    elif dataset_type == "sbm":
        params.update(
            {
                "p_intra": getattr(args, "p_intra", 0.4),
                "p_inter": getattr(args, "p_inter", 0.05),
                "num_communities": getattr(args, "num_communities", 4),
            }
        )

    elif dataset_type == "erdos_renyi_two_graphs":
        params.update(
            {
                "p": args.fixed_p,
                "connect_prob": getattr(args, "connect_prob", 0.5),
                "min_component_size_frac": getattr(
                    args, "min_component_size_frac", 0.3
                ),
            }
        )

    elif dataset_type == "erdos_renyi_medium":
        params.update(
            {
                "p": args.fixed_p,
                "sample_p": args.sample_p,
                "p_range": args.p_range,
                "connect_prob": getattr(args, "connect_prob", 0.0),
                "min_component_size_frac": getattr(
                    args, "min_component_size_frac", 0.33
                ),
            }
        )

    elif dataset_type == "erdos_renyi_hard":
        params.update(
            {
                "p": args.fixed_p,
                "sample_p": args.sample_p,
                "p_range": args.p_range,
                "connect_prob": getattr(args, "connect_prob", 0.5),
            }
        )

    elif dataset_type == "tree_forest":
        params.update(
            {
                "min_tree_size": getattr(args, "min_tree_size", 3),
                "max_tree_size": getattr(args, "max_tree_size", None),
            }
        )

    elif dataset_type == "star_forest":
        # Star forest doesn't need additional parameters
        pass

    elif dataset_type == "two_cliques":
        params.update(
            {
                "connect_prob": getattr(args, "connect_prob", 0.0),
                "size_variation": getattr(args, "size_variation", 0.1),
            }
        )

    elif dataset_type == "one_circle":
        # One circle doesn't need additional parameters beyond num_nodes
        pass

    elif dataset_type == "two_degree_3_chains":
        params["k"] = getattr(
            args, "k", None
        )  # Let the TwoDegree3ChainsGenerator calculate k if None
        params["add_isolated_nodes"] = getattr(args, "add_isolated_nodes", False)

    return params


def visualize_graph_samples(
    dataset, dataset_name, split_name, saved_models_dir, num_samples=5
):
    """
    Visualize and save sample graphs from a dataset using NetworkX.

    Args:
        dataset: The dataset to sample from
        dataset_name: Name of the dataset (e.g., 'erdos_renyi', 'two_chains')
        split_name: 'train' or 'eval'
        saved_models_dir: Directory to save visualization
        num_samples: Number of samples to visualize
    """
    # Create visualization directory
    viz_dir = os.path.join(saved_models_dir, "data_samples", split_name, dataset_name)
    os.makedirs(viz_dir, exist_ok=True)

    # Sample from dataset
    indices = np.random.choice(
        len(dataset), min(num_samples, len(dataset)), replace=False
    )

    for i, idx in enumerate(indices):
        adj_matrix, connectivity_matrix = dataset[idx]

        # Convert to numpy if tensor
        if torch.is_tensor(adj_matrix):
            adj_matrix = adj_matrix.numpy()
        if torch.is_tensor(connectivity_matrix):
            connectivity_matrix = connectivity_matrix.numpy()

        num_nodes = adj_matrix.shape[0]

        # Create figure with subplots: graph visualization, adjacency matrix, connectivity matrix
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

        adj_matrix_for_nx = adj_matrix - np.eye(
            num_nodes
        )  # Remove self-loops for visualization
        # 1. Graph visualization using NetworkX
        G = nx.from_numpy_array(adj_matrix_for_nx)

        # Use different layouts based on graph structure and connectivity to minimize edge crossings
        def get_optimal_layout(graph, num_nodes):
            """Choose the best layout algorithm to minimize edge crossings"""

            # Try two-pi (twopi) layout first - excellent for minimizing edge crossings
            try:
                if graph.number_of_edges() > 0:  # Only for graphs with edges
                    return nx.nx_agraph.graphviz_layout(graph, prog="circo")
            except:
                print("Graphviz not available, falling back to other layouts.")
                pass

            # For very small graphs, try circular layout first
            if num_nodes <= 6:
                try:
                    # Try circular layout for small graphs
                    return nx.circular_layout(graph)
                except:
                    pass

            # For tree-like structures, use tree layout
            if nx.is_tree(graph):
                try:
                    # Find a good root node (node with highest degree or centrality)
                    if graph.number_of_nodes() > 0:
                        degrees = dict(graph.degree())
                        root = max(degrees, key=degrees.get) if degrees else 0
                        return nx.nx_agraph.graphviz_layout(
                            graph, prog="dot", root=root
                        )
                except:
                    # Fallback to spring layout if graphviz not available
                    pass

            # For planar graphs, try to use planar layout
            if num_nodes <= 20:
                try:
                    if nx.is_planar(graph):
                        pos = nx.planar_layout(graph)
                        return pos
                except:
                    pass

            # For small complete or near-complete graphs, use circular layout
            density = nx.density(graph)
            if num_nodes <= 12 and density > 0.6:
                try:
                    return nx.circular_layout(graph)
                except:
                    pass

            # For bipartite graphs, use bipartite layout
            try:
                if nx.is_bipartite(graph):
                    top_nodes = set()
                    bottom_nodes = set()
                    for node in graph.nodes():
                        if node % 2 == 0:
                            top_nodes.add(node)
                        else:
                            bottom_nodes.add(node)
                    return nx.bipartite_layout(graph, top_nodes)
            except:
                pass

            # Default: Enhanced spring layout with better parameters
            if num_nodes <= 10:
                return nx.spring_layout(
                    graph, seed=42, k=2.0 / np.sqrt(num_nodes), iterations=100
                )
            else:
                return nx.spring_layout(
                    graph, seed=42, k=1.0 / np.sqrt(num_nodes), iterations=50
                )

        pos = get_optimal_layout(G, num_nodes)

        # Draw nodes
        nx.draw_networkx_nodes(
            G, pos, ax=ax1, node_color="lightblue", node_size=300, alpha=0.8
        )

        # Draw edges
        nx.draw_networkx_edges(G, pos, ax=ax1, edge_color="gray", width=1, alpha=0.6)

        # Draw node labels
        nx.draw_networkx_labels(G, pos, ax=ax1, font_size=8, font_weight="bold")

        ax1.set_title(f"Graph Structure\n{dataset_name} Sample {i+1}")
        # ax1.set_aspect("equal")
        ax1.axis("off")

        # 2. Plot adjacency matrix
        im2 = ax2.imshow(adj_matrix, cmap="Blues", vmin=0, vmax=1)
        ax2.set_title(f"Adjacency Matrix\n{dataset_name} Sample {i+1}")
        ax2.set_xlabel("Node")
        ax2.set_ylabel("Node")

        # Add grid
        ax2.set_xticks(np.arange(-0.5, num_nodes, 1), minor=True)
        ax2.set_yticks(np.arange(-0.5, num_nodes, 1), minor=True)
        ax2.grid(which="minor", color="gray", linestyle="-", linewidth=0.5, alpha=0.3)

        # Add colorbar
        plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)

        # 3. Plot connectivity matrix
        im3 = ax3.imshow(connectivity_matrix, cmap="Reds", vmin=0, vmax=1)
        ax3.set_title(f"Connectivity Matrix\n{dataset_name} Sample {i+1}")
        ax3.set_xlabel("Node")
        ax3.set_ylabel("Node")

        # Add grid
        ax3.set_xticks(np.arange(-0.5, num_nodes, 1), minor=True)
        ax3.set_yticks(np.arange(-0.5, num_nodes, 1), minor=True)
        ax3.grid(which="minor", color="gray", linestyle="-", linewidth=0.5, alpha=0.3)

        # Add colorbar
        plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)

        # Add text annotations for small matrices
        if num_nodes <= 8:
            for r in range(num_nodes):
                for c in range(num_nodes):
                    ax2.text(
                        c,
                        r,
                        f"{adj_matrix[r,c]:.0f}",
                        ha="center",
                        va="center",
                        fontsize=8,
                    )
                    ax3.text(
                        c,
                        r,
                        f"{connectivity_matrix[r,c]:.0f}",
                        ha="center",
                        va="center",
                        fontsize=8,
                    )

        plt.tight_layout()

        # Save figure
        save_path = os.path.join(viz_dir, f"{dataset_name}_{i+1:03d}.png")
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close()

    print(f"Saved {len(indices)} {dataset_name} {split_name} samples to {viz_dir}")


def visualize_dataset_samples(datasets_dict, split_name, saved_models_dir):
    """
    Visualize samples from multiple datasets.

    Args:
        datasets_dict: Dictionary of dataset_name -> dataset
        split_name: 'train' or 'eval'
        saved_models_dir: Directory to save visualizations
    """
    print(f"\n>>> Visualizing {split_name} dataset samples...")

    for dataset_name, dataset in datasets_dict.items():
        # Extract clean dataset name (remove 'eval_' prefix if present)
        clean_name = (
            dataset_name.replace("eval_", "")
            if dataset_name.startswith("eval_")
            else dataset_name
        )

        visualize_graph_samples(dataset, clean_name, split_name, saved_models_dir)


def compute_permutation_equivariant_metrics(
    pred_connectivity,
    true_connectivity,
    adj_matrices,
    num_permutations=10,
    type="value",
    model=None,
    verbose=False,
):
    """
    Compute value-level permutation equivariance consistency using cosine similarity
    at every hidden layer (not just the final layer).

    Note:
    - Support-level metrics and L1 distances are intentionally removed.
    - This function now requires `model` to be provided and uses
      `model.get_hidden_states()` to obtain per-layer outputs.

    Returns a dict with:
      - perm_frob_cosine_similarity: average cosine similarity across permutations,
        batches, and layers (primary scalar)
      - per_layer_cosine: list[float] of average cosine similarities per layer
    """
    if model is None:
        raise ValueError("model is required to compute per-layer equivariance metrics")

    device = adj_matrices.device
    batch_size, num_nodes, _ = adj_matrices.shape

    # Helper to convert logits->probs if needed
    def to_probs(x: torch.Tensor) -> torch.Tensor:
        # Use scalar checks to avoid Tensor-in-boolean-context issues
        try:
            x_max = x.detach().amax().item()
            x_min = x.detach().amin().item()
        except Exception:
            # Fallback: treat as logits
            x_max, x_min = float("inf"), float("-inf")
        return torch.sigmoid(x) if (x_max > 1.0 or x_min < 0.0) else x

    # Helper to coerce hidden states to tensors on the right device
    def to_device_tensor(x: "torch.Tensor | any") -> torch.Tensor:
        # torch.as_tensor will convert numpy arrays to tensors without copying when possible
        t = torch.as_tensor(x, device=device)
        # Ensure floating point for similarity calculations
        if not torch.is_floating_point(t):
            t = t.float()
        return t

    # Helper: if model has a learned unembedding/readout (e.g., RoBERTa),
    # apply it to hidden states when the last dim equals hidden_size.
    def maybe_apply_readout(layer_list):
        processed = []
        has_roberta_readout = hasattr(getattr(model, "model", model), "_read_out")
        roberta_mod = getattr(model, "model", model)
        hidden_sz = getattr(roberta_mod, "hidden_size", None)
        readout = getattr(roberta_mod, "_read_out", None)
        for tens in layer_list:
            t = to_device_tensor(tens)
            if (
                has_roberta_readout
                and hidden_sz is not None
                and t.dim() >= 3
                and t.shape[-1] == hidden_sz
            ):
                # Expect shape (B, N, hidden_size) -> apply readout on last dim
                t = readout(t)
            processed.append(t)
        return processed

    # Original per-layer predictions (list[L] of (B, N, N) after readout for RoBERTa)
    with torch.no_grad():
        raw_layers = model.get_hidden_states(adj_matrices)
        # Some models may return a single tensor; normalize to list
        if isinstance(raw_layers, torch.Tensor):
            raw_layers = [raw_layers]
        # Apply readout if applicable (e.g., RoBERTa), then convert to probabilities/logits
        orig_layers = [to_probs(t) for t in maybe_apply_readout(raw_layers)]

    num_layers = len(orig_layers)

    # Accumulators for cosine similarities per layer
    # We'll accumulate sums and counts to compute means robustly
    layer_cos_sum = torch.zeros(num_layers, device=device, dtype=torch.float64)
    layer_cos_count = torch.zeros(num_layers, device=device, dtype=torch.float64)

    eps = 1e-12

    # Generate permutations per batch item for each permutation trial
    for perm_idx in range(num_permutations):
        perms = torch.stack(
            [torch.randperm(num_nodes, device=device) for _ in range(batch_size)]
        )

        # Build permuted adjacency batch
        adj_perm_list = []
        for b in range(batch_size):
            p = perms[b]
            adj_perm_list.append(adj_matrices[b][p][:, p])
        adj_perm_batch = torch.stack(adj_perm_list)

        # Forward on permuted graphs to get per-layer predictions, then unpermute
        with torch.no_grad():
            perm_raw = model.get_hidden_states(adj_perm_batch)
            if isinstance(perm_raw, torch.Tensor):
                perm_raw = [perm_raw]
            # Apply readout if applicable (e.g., RoBERTa)
            perm_layers = [to_probs(t) for t in maybe_apply_readout(perm_raw)]

        # Unpermute predictions back to original ordering per sample (stay on same device)
        # Support tensors with extra dims (e.g., heads, features); permute all axes with size == num_nodes
        unperm_layers = []
        for L in perm_layers:  # L shape: (B, ...)
            unperm_list = []
            for b in range(batch_size):
                p = perms[b]
                inv_p = torch.argsort(p)
                x = L[b]
                if inv_p.device != x.device:
                    inv_p = inv_p.to(x.device)
                # Identify axes corresponding to node indices
                axes = [ax for ax, sz in enumerate(x.shape) if sz == num_nodes]
                y = x
                for ax in axes:
                    y = y.index_select(ax, inv_p)
                unperm_list.append(y)
            unperm_layers.append(torch.stack(unperm_list))

        # Compute cosine similarity per layer between orig_layers and unperm_layers
        for li in range(num_layers):
            A = orig_layers[li]
            B = unperm_layers[li]
            # Ensure same shape
            if A.shape != B.shape:
                # Try to coerce by selecting node axes as above
                # Fallback: skip this layer if shapes still mismatch
                try:
                    A = A.reshape(batch_size, -1)
                    B = B.reshape(batch_size, -1)
                except Exception:
                    continue
            else:
                A = A.reshape(batch_size, -1)
                B = B.reshape(batch_size, -1)
            dot = (A * B).sum(dim=1)
            denom = A.norm(dim=1) * B.norm(dim=1) + eps
            cos = (dot / denom).clamp(-1.0, 1.0)  # (B,)
            layer_cos_sum[li] += cos.double().sum()
            layer_cos_count[li] += cos.numel()

    # Mean per-layer cosine
    per_layer_cosine = (
        (layer_cos_sum / (layer_cos_count + eps)).detach().cpu().numpy().tolist()
    )
    # Aggregate across layers for the primary scalar
    avg_cosine = float(np.mean(per_layer_cosine)) if len(per_layer_cosine) > 0 else 0.0

    return {
        "perm_frob_cosine_similarity": avg_cosine,
        "per_layer_cosine": per_layer_cosine,
    }


def evaluate_model_with_perm_metrics(
    model, dataloader, criterion, acc_threshold=0.0, device=None, num_permutations=5
):
    """
    Enhanced evaluation function that includes permutation equivariant metrics.

    Args:
        model: The model to evaluate
        dataloader: DataLoader for evaluation data
        criterion: Loss criterion
        acc_threshold: Threshold for binary predictions
        device: Device to run evaluation on
        num_permutations: Number of permutations to test for equivariance

    Returns:
        dict containing standard metrics plus permutation equivariant metrics
    """
    if dataloader is None:
        return None

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.model.eval()

    # Standard metrics
    total_loss, correct_preds, total_preds, all_correct, graph_count = 0, 0, 0, 0, 0

    # Permutation equivariant metrics accumulators
    perm_metrics_list = []

    with torch.no_grad():
        for adj_matrix, connectivity_matrix in dataloader:
            adj_matrix = adj_matrix.float().to(device)
            connectivity_matrix = connectivity_matrix.float().to(device)
            pred_connectivity = model.forward(adj_matrix)

            # Standard evaluation
            loss = criterion(pred_connectivity, connectivity_matrix)
            total_loss += loss.item()
            pred = (pred_connectivity > acc_threshold).float()
            correct_preds += ((pred == connectivity_matrix)).sum().item()
            total_preds += connectivity_matrix.numel()

            for pred_graph, ans_graph in zip(pred, connectivity_matrix):
                if (pred_graph == ans_graph).all():
                    all_correct += 1
                graph_count += 1

            # Compute permutation equivariant metrics for this batch
            perm_metrics = compute_permutation_equivariant_metrics(
                pred_connectivity,
                connectivity_matrix,
                adj_matrix,
                num_permutations,
                model=model,
            )
            perm_metrics_list.append(perm_metrics)

    # Aggregate standard metrics
    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0
    avg_accuracy = correct_preds / total_preds if total_preds > 0 else 0
    avg_all_correct = all_correct / graph_count if graph_count > 0 else 0

    # Aggregate permutation equivariant metrics
    if perm_metrics_list:
        aggregated_perm_metrics = {}
        for key in perm_metrics_list[0].keys():
            # Collect values while robustly filtering any entries containing NaNs
            collected = []
            for m in perm_metrics_list:
                v = m[key]
                a = np.asarray(v, dtype=float)
                # Skip if any NaNs present
                if np.any(np.isnan(a)):
                    continue
                collected.append(a)

            if not collected:
                # Return sensible defaults when everything was filtered out
                aggregated_perm_metrics[key] = (
                    0.0 if np.isscalar(perm_metrics_list[0][key]) else []
                )
                continue

            # If scalar values, average to a single float; if arrays/lists, average elementwise
            if np.ndim(collected[0]) == 0:
                aggregated_perm_metrics[key] = float(np.mean(collected))
            else:
                arr = np.stack(collected, axis=0)
                avg = np.mean(arr, axis=0)
                # Convert back to list for JSON/printing friendliness
                aggregated_perm_metrics[key] = avg.tolist()
    else:
        aggregated_perm_metrics = {}

    return {
        "loss": avg_loss,
        "accuracy": avg_accuracy,
        "all_correct": avg_all_correct,
        **aggregated_perm_metrics,
    }
