import os
import random
import torch
import numpy as np
import networkx as nx
from torch_geometric.data import Data
from typing import List, Dict
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix, add_self_loops, remove_self_loops
import scipy.sparse as sp
import numpy as np
from torch_geometric.data import Data
import copy
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import time
import torch
import psutil
import numpy as np
from torch_geometric.data import Data
from typing import List, Dict, Tuple
import gc
import contextlib
import torch
torch.autograd.set_detect_anomaly(True)

def generate_synthetic_anomaly_graphs(
    num_graphs: int,
    num_nodes: int,
    num_edges: int,
    feature_dim: int,
    anomaly_ratio_node: float,
    anomaly_ratio_edge: float,
    anomaly_ratio_graph: float,
    dataset_name: str,
    save_dir: str = "./synthetic_graph_datasets"
) -> List[Data]:
    """
    Generate synthetic graph-level anomaly detection dataset based on parameters.
    Saves dataset as a PyTorch .pt file.
    """
    # Reproducibility
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    os.makedirs(save_dir, exist_ok=True)
    graph_list = []

    for gid in range(num_graphs):
        G = nx.gnm_random_graph(num_nodes, num_edges, seed=gid)
        features = np.random.normal(0.0, 1.0, size=(num_nodes, feature_dim))

        # Node anomalies
        num_anomalous_nodes = int(anomaly_ratio_node * num_nodes)
        anomalous_nodes = np.random.choice(num_nodes, num_anomalous_nodes, replace=False)
        features[anomalous_nodes] += np.random.normal(5.0, 1.0, size=(num_anomalous_nodes, feature_dim))
        node_labels = np.zeros(num_nodes, dtype=int)
        node_labels[anomalous_nodes] = 1

        # Edges
        edge_index = np.array(G.edges).T
        num_edges_graph = edge_index.shape[1]
        num_anomalous_edges = int(anomaly_ratio_edge * num_edges_graph)
        anomalous_edges = np.random.choice(num_edges_graph, num_anomalous_edges, replace=False)
        edge_labels = np.zeros(num_edges_graph, dtype=int)
        edge_labels[anomalous_edges] = 1

        # Graph-level label
        graph_label = int(random.random() < anomaly_ratio_graph)

        # PyG object
        data = Data(
            x=torch.tensor(features, dtype=torch.float),
            edge_index=torch.tensor(edge_index, dtype=torch.long),
            y=torch.tensor([graph_label], dtype=torch.long),
            node_label=torch.tensor(node_labels, dtype=torch.long),
            edge_label=torch.tensor(edge_labels, dtype=torch.long),
            gid=torch.tensor([gid], dtype=torch.long)
        )
        graph_list.append(data)

    # Save to file
    save_path = os.path.join(save_dir, f"{dataset_name}.pt")
    torch.save(graph_list, save_path)
    print(f"✅ Saved {dataset_name} with {num_graphs} graphs to: {save_path}")
    return graph_list


def get_prdigy_dataset_configs() -> Dict[str, Dict]:
    """Returns synthetic settings inspired by PRoDIGY dataset stats."""
    return {
        "BM-MS":   {"num_graphs": 700, "num_nodes": 14, "num_edges": 43, "feature_dim": 1, "anomaly_ratio_node": 0.3199, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.1429},
        "BM-MN":   {"num_graphs": 700, "num_nodes": 18, "num_edges": 57, "feature_dim": 1, "anomaly_ratio_node": 0.4891, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.1429},
        "BM-MT":   {"num_graphs": 700, "num_nodes": 17, "num_edges": 45, "feature_dim": 1, "anomaly_ratio_node": 0.3449, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.1429},
        "MUTAG":   {"num_graphs": 2951, "num_nodes": 30, "num_edges": 61, "feature_dim": 14, "anomaly_ratio_node": 0.0481, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.3440},
        "MNIST0":  {"num_graphs": 1000, "num_nodes": 70, "num_edges": 90, "feature_dim": 5, "anomaly_ratio_node": 0.3546, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.0986},
        "MNIST1":  {"num_graphs": 1000, "num_nodes": 70, "num_edges": 90, "feature_dim": 5, "anomaly_ratio_node": 0.3546, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.1125},
        "T-Group": {"num_graphs": 1000, "num_nodes": 300, "num_edges": 1200, "feature_dim": 10, "anomaly_ratio_node": 0.0064, "anomaly_ratio_edge": 0.0, "anomaly_ratio_graph": 0.0426},
    }


def compute_normalized_laplacian(edge_index, num_nodes):
    """Compute normalized Laplacian matrix: L = I - D^{-1/2} A D^{-1/2}"""
    A = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocoo()
    A = A + sp.eye(num_nodes)  # Add self-loops
    D_inv_sqrt = sp.diags(1.0 / np.sqrt(A.sum(1).A1))
    L = sp.eye(num_nodes) - D_inv_sqrt @ A @ D_inv_sqrt
    return torch.tensor(L.toarray(), dtype=torch.float32)

def node_level_augmentation(data: Data, gamma=0.2, sigma=0.1) -> Data:
    """Node-level noise: Gaussian + Laplacian injection"""
    data = copy.deepcopy(data)
    L = compute_normalized_laplacian(data.edge_index, data.num_nodes)
    noise = gamma * L @ data.x + sigma * torch.randn_like(data.x)
    data.x = data.x + noise
    return data

def edge_level_augmentation(data: Data, perturb_ratio=0.1) -> Data:
    """Edge-level noise: random edge addition/removal"""
    data = copy.deepcopy(data)
    edge_index, _ = remove_self_loops(data.edge_index)
    num_edges = edge_index.size(1)
    num_nodes = data.num_nodes
    num_perturb = int(perturb_ratio * num_edges)

    # Randomly remove edges
    mask = torch.ones(num_edges, dtype=torch.bool)
    remove_idx = torch.randperm(num_edges)[:num_perturb]
    mask[remove_idx] = False
    edge_index = edge_index[:, mask]

    # Randomly add edges
    added = 0
    while added < num_perturb:
        src = torch.randint(0, num_nodes, (1,))
        dst = torch.randint(0, num_nodes, (1,))
        if not ((edge_index[0] == src) & (edge_index[1] == dst)).any():
            edge_index = torch.cat([edge_index, torch.stack([src, dst], dim=0)], dim=1)
            added += 1

    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
    data.edge_index = edge_index
    return data

def graph_level_augmentation(data: Data, gamma=0.2, sigma=0.1) -> Data:
    """Graph-level noise: add global Laplacian + Gaussian perturbation (on pooled feature)"""
    data = copy.deepcopy(data)
    L = compute_normalized_laplacian(data.edge_index, data.num_nodes)
    H = data.x.mean(dim=0, keepdim=True)  # global pooled feature
    global_noise = gamma * (L @ data.x).mean(dim=0, keepdim=True) + sigma * torch.randn_like(H)
    data.x += global_noise.repeat(data.num_nodes, 1)
    return data

class GCNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, output_dim=64, num_layers=2):
        super(GCNEncoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        self.layers.append(GCNConv(input_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.layers.append(GCNConv(hidden_dim, hidden_dim))
        self.layers.append(GCNConv(hidden_dim, output_dim))

    def forward(self, data: Data):
        x, edge_index = data.x, data.edge_index
        for layer in self.layers[:-1]:
            x = layer(x, edge_index)
            x = torch.relu(x)
        x = self.layers[-1](x, edge_index)
        return x  # Node embeddings

    def encode_graph(self, data: Data):
        x = self.forward(data)
        out = global_mean_pool(x, data.batch) if hasattr(data, 'batch') else x.mean(dim=0, keepdim=True)
        return F.normalize(out, p=2, dim=-1)

class TeacherWithCheckpoints:
    def __init__(self, encoder: GCNEncoder):
        self.encoder = encoder
        self.checkpoints = {}  # t -> (state_dict, embedding)

    def save_checkpoint(self, t: int, embedding: torch.Tensor):
        # Save model weights and clean teacher embedding at step t
        self.checkpoints[t] = {
            'state_dict': copy.deepcopy(self.encoder.state_dict()),
            'embedding': embedding.detach().clone()
        }

    def get_embedding(self, data: Data, checkpoint: int = None) -> torch.Tensor:
        if checkpoint is not None and checkpoint in self.checkpoints:
            # Create a temporary copy of the encoder for this checkpoint
            temp_encoder = copy.deepcopy(self.encoder)
            temp_encoder.load_state_dict(self.checkpoints[checkpoint]['state_dict'])
            with torch.no_grad():  # No need to track gradients here
                return temp_encoder.encode_graph(data)
        return self.encoder.encode_graph(data)

    def get_best_checkpoint(self, student_emb: torch.Tensor, level: str) -> int:
        """Returns the checkpoint with highest cosine similarity to student"""
        best_t, best_sim = None, -float('inf')
        for t, ckpt in self.checkpoints.items():
            teacher_emb = ckpt['embedding']
            sim = cosine_similarity(student_emb, teacher_emb)
            if sim > best_sim:
                best_t, best_sim = t, sim
        return best_t

class StudentGCN(nn.Module):
    def __init__(self, input_dim, output_dim=64):
        super(StudentGCN, self).__init__()
        self.encoder = GCNEncoder(input_dim, hidden_dim=output_dim, output_dim=output_dim, num_layers=1)

    def forward(self, data):
        return self.encoder.forward(data)

    def encode_graph(self, data):
        return self.encoder.encode_graph(data)

def cosine_similarity(a: torch.Tensor, b: torch.Tensor, eps=1e-8):
    a = a / (a.norm(dim=-1, keepdim=True) + eps)
    b = b / (b.norm(dim=-1, keepdim=True) + eps)
    return (a * b).sum(dim=-1)


def bidirectional_reverse_contrastive_loss(
    H_S_k: torch.Tensor,
    H_C_k: torch.Tensor,
    H_N_k: torch.Tensor,
    alpha_k: float = 1.0,
    beta: float = 0.5,  # Teacher regularization strength parameter
    epsilon: float = 1e-5
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes bidirectional reverse contrastive loss for one structural level.
    
    Args:
        H_S_k: Student embedding
        H_C_k: Clean teacher embedding
        H_N_k: Noisy teacher embedding
        alpha_k: Weight for this structural level
        beta: Teacher regularization strength
        epsilon: Small constant for numerical stability
        
    Returns:
        Tuple of (student_loss, teacher_loss)
    """
    sim_student_clean = cosine_similarity(H_S_k, H_C_k)
    sim_student_noisy = cosine_similarity(H_S_k, H_N_k)
    sim_clean_noisy = cosine_similarity(H_C_k, H_N_k)
    
    # Student contrastive alignment
    numerator = 1 - sim_student_clean
    denominator = 1 - sim_student_noisy + epsilon
    student_loss = alpha_k * (numerator / denominator)
    
    # Teacher regularization (push clean teacher away from noisy views)
    teacher_regularization = alpha_k * beta * (sim_clean_noisy / (sim_student_clean + epsilon))
    
    return student_loss, teacher_regularization


class GCNDecoder(nn.Module):
    """Decoder to reconstruct clean teacher embeddings from student embeddings"""
    def __init__(self, input_dim, output_dim):
        super(GCNDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, input_dim*2),
            nn.ReLU(),
            nn.Linear(input_dim*2, output_dim)
        )
    
    def forward(self, x):
        return self.decoder(x)


class LearningWeights(nn.Module):
    """Learnable weights for multi-level anomaly detection"""
    def __init__(self, num_levels=3):
        super(LearningWeights, self).__init__()
        # Initialize weights with equal probabilities
        self.weights = nn.Parameter(torch.ones(num_levels) / num_levels)
        
    def forward(self):
        # Apply softmax to ensure weights sum to 1
        return torch.softmax(self.weights, dim=0)


def train_student_on_graph(original, teacher, student_model, optimizer):
    """
    Legacy training function for backward compatibility.
    This is a simplified version of train_recodistill that doesn't use bidirectional
    loss or decoder. It's kept for backward compatibility with old code.
    """
    # Create perturbed views
    gn_view = node_level_augmentation(copy.deepcopy(original))
    ge_view = edge_level_augmentation(copy.deepcopy(original))
    gg_view = graph_level_augmentation(copy.deepcopy(original))
    
    # Teacher embeddings (detached)
    with torch.no_grad():
        emb_gn = teacher.encoder.encode_graph(gn_view).detach()
        emb_ge = teacher.encoder.encode_graph(ge_view).detach()
        emb_gg = teacher.encoder.encode_graph(gg_view).detach()
    
    # Student embeddings
    emb_gn_s = student_model.encode_graph(gn_view)
    emb_ge_s = student_model.encode_graph(ge_view)
    emb_gg_s = student_model.encode_graph(gg_view)
    
    # Clean teacher embeddings
    with torch.no_grad():
        emb_clean = teacher.encoder.encode_graph(original).detach()
    
    # Simple reverse contrastive loss (without bidirectional component)
    sim_gn_clean = cosine_similarity(emb_gn_s, emb_clean)
    sim_gn_noisy = cosine_similarity(emb_gn_s, emb_gn)
    
    sim_ge_clean = cosine_similarity(emb_ge_s, emb_clean)
    sim_ge_noisy = cosine_similarity(emb_ge_s, emb_ge)
    
    sim_gg_clean = cosine_similarity(emb_gg_s, emb_clean)
    sim_gg_noisy = cosine_similarity(emb_gg_s, emb_gg)
    
    # Compute loss terms
    loss_n = (1 - sim_gn_clean) / (1 - sim_gn_noisy + 1e-5)
    loss_e = (1 - sim_ge_clean) / (1 - sim_ge_noisy + 1e-5)
    loss_g = (1 - sim_gg_clean) / (1 - sim_gg_noisy + 1e-5)
    
    # Compute total loss
    total_loss = loss_n + loss_e + loss_g
    
    # Update model
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return total_loss.item()

def train_recodistill(original, teacher, student_model, decoder, weights_model, 
                      optimizer_student, optimizer_teacher, optimizer_weights,
                      lambda_recon=0.1, update_teacher=True):
    """
    Trains the ReCoDistill framework with bidirectional reverse contrastive loss
    and progressive checkpoint-based distillation.
    """
    # Create multi-level augmented views
    gn_view = node_level_augmentation(copy.deepcopy(original))
    ge_view = edge_level_augmentation(copy.deepcopy(original))
    gg_view = graph_level_augmentation(copy.deepcopy(original))
    
    # STEP 1: Student + Decoder update
    # --------------------------------
    # Get learnable level weights
    alpha = weights_model()  # [alpha_N, alpha_E, alpha_G]
    
    # Teacher embeddings - with no_grad to avoid adding to computation graph
    with torch.no_grad():
        emb_gn = teacher.encoder.encode_graph(gn_view).detach()
        emb_ge = teacher.encoder.encode_graph(ge_view).detach()
        emb_gg = teacher.encoder.encode_graph(gg_view).detach()
        
        # Get best teacher checkpoints for progressive distillation
        best_ckpt_n = teacher.get_best_checkpoint(student_model.encode_graph(gn_view).detach(), level='N')
        best_ckpt_e = teacher.get_best_checkpoint(student_model.encode_graph(ge_view).detach(), level='E')
        best_ckpt_g = teacher.get_best_checkpoint(student_model.encode_graph(gg_view).detach(), level='G')
        
        # Clean teacher embeddings from best checkpoint
        emb_n_clean = teacher.get_embedding(original, checkpoint=best_ckpt_n).detach()
        emb_e_clean = teacher.get_embedding(original, checkpoint=best_ckpt_e).detach()
        emb_g_clean = teacher.get_embedding(original, checkpoint=best_ckpt_g).detach()

    # Student embeddings (fresh computation)
    emb_gn_s = student_model.encode_graph(gn_view)
    emb_ge_s = student_model.encode_graph(ge_view)
    emb_gg_s = student_model.encode_graph(gg_view)
    
    # Student losses - use only the student part
    L_N_student = (1 - cosine_similarity(emb_gn_s, emb_n_clean)) / (1 - cosine_similarity(emb_gn_s, emb_gn) + 1e-5) * alpha[0]
    L_E_student = (1 - cosine_similarity(emb_ge_s, emb_e_clean)) / (1 - cosine_similarity(emb_ge_s, emb_ge) + 1e-5) * alpha[1]
    L_G_student = (1 - cosine_similarity(emb_gg_s, emb_g_clean)) / (1 - cosine_similarity(emb_gg_s, emb_gg) + 1e-5) * alpha[2]
    
    # Decoder-based reconstruction loss
    decoded_n = decoder(emb_gn_s)
    decoded_e = decoder(emb_ge_s)
    decoded_g = decoder(emb_gg_s)
    
    recon_loss_n = alpha[0] * torch.nn.functional.mse_loss(decoded_n, emb_n_clean)
    recon_loss_e = alpha[1] * torch.nn.functional.mse_loss(decoded_e, emb_e_clean)
    recon_loss_g = alpha[2] * torch.nn.functional.mse_loss(decoded_g, emb_g_clean)
    recon_loss = recon_loss_n + recon_loss_e + recon_loss_g
    
    # Total student loss
    student_loss = L_N_student + L_E_student + L_G_student + lambda_recon * recon_loss
    
    # Update student and decoder
    optimizer_student.zero_grad()
    student_loss.backward(retain_graph=update_teacher)
    optimizer_student.step()
    
    # STEP 2: Only update teacher if required (completely separate computation)
    # ------------------------------------------------------------------------
    teacher_loss_value = 0.0
    if update_teacher:
        # Get fresh teacher embeddings
        emb_n_clean_teacher = teacher.get_embedding(original, checkpoint=best_ckpt_n)
        emb_gn_teacher = teacher.encoder.encode_graph(gn_view)
        
        emb_e_clean_teacher = teacher.get_embedding(original, checkpoint=best_ckpt_e)
        emb_ge_teacher = teacher.encoder.encode_graph(ge_view)
        
        emb_g_clean_teacher = teacher.get_embedding(original, checkpoint=best_ckpt_g)
        emb_gg_teacher = teacher.encoder.encode_graph(gg_view)
        
        # Get student embeddings (detached from computation graph)
        with torch.no_grad():
            emb_gn_s_detached = student_model.encode_graph(gn_view).detach()
            emb_ge_s_detached = student_model.encode_graph(ge_view).detach()
            emb_gg_s_detached = student_model.encode_graph(gg_view).detach()
        
        # Teacher losses (using formula from bidirectional_reverse_contrastive_loss)
        beta = 0.5  # Teacher regularization strength parameter
        L_N_teacher = alpha[0] * beta * (cosine_similarity(emb_n_clean_teacher, emb_gn_teacher) / 
                                        (cosine_similarity(emb_gn_s_detached, emb_n_clean_teacher) + 1e-5))
        L_E_teacher = alpha[1] * beta * (cosine_similarity(emb_e_clean_teacher, emb_ge_teacher) / 
                                        (cosine_similarity(emb_ge_s_detached, emb_e_clean_teacher) + 1e-5))
        L_G_teacher = alpha[2] * beta * (cosine_similarity(emb_g_clean_teacher, emb_gg_teacher) / 
                                        (cosine_similarity(emb_gg_s_detached, emb_g_clean_teacher) + 1e-5))
        
        teacher_loss = L_N_teacher + L_E_teacher + L_G_teacher
        teacher_loss_value = teacher_loss.item()
        
        optimizer_teacher.zero_grad()
        teacher_loss.backward()
        optimizer_teacher.step()
    
    # STEP 3: Weight optimization (completely separate computation)
    # ------------------------------------------------------------
    # Update learnable weights separately - create a fresh computation
    with torch.no_grad():
        emb_gn_s_fresh = student_model.encode_graph(gn_view).detach()
        emb_ge_s_fresh = student_model.encode_graph(ge_view).detach()
        emb_gg_s_fresh = student_model.encode_graph(gg_view).detach()
        
        emb_gn_fresh = teacher.encoder.encode_graph(gn_view).detach()
        emb_ge_fresh = teacher.encoder.encode_graph(ge_view).detach()
        emb_gg_fresh = teacher.encoder.encode_graph(gg_view).detach()
        
        # Get best teacher checkpoints again
        best_ckpt_n_fresh = teacher.get_best_checkpoint(emb_gn_s_fresh, level='N')
        best_ckpt_e_fresh = teacher.get_best_checkpoint(emb_ge_s_fresh, level='E')
        best_ckpt_g_fresh = teacher.get_best_checkpoint(emb_gg_s_fresh, level='G')
        
        # Clean teacher embeddings from best checkpoint
        emb_n_clean_fresh = teacher.get_embedding(original, checkpoint=best_ckpt_n_fresh).detach()
        emb_e_clean_fresh = teacher.get_embedding(original, checkpoint=best_ckpt_e_fresh).detach()
        emb_g_clean_fresh = teacher.get_embedding(original, checkpoint=best_ckpt_g_fresh).detach()
    
    # Get fresh weights for optimization
    fresh_alpha = weights_model()
    
    # Calculate losses for weight optimization
    L_N_w = (1 - cosine_similarity(emb_gn_s_fresh, emb_n_clean_fresh)) / (1 - cosine_similarity(emb_gn_s_fresh, emb_gn_fresh) + 1e-5)
    L_E_w = (1 - cosine_similarity(emb_ge_s_fresh, emb_e_clean_fresh)) / (1 - cosine_similarity(emb_ge_s_fresh, emb_ge_fresh) + 1e-5)
    L_G_w = (1 - cosine_similarity(emb_gg_s_fresh, emb_g_clean_fresh)) / (1 - cosine_similarity(emb_gg_s_fresh, emb_gg_fresh) + 1e-5)
    
    weight_loss = fresh_alpha[0] * L_N_w + fresh_alpha[1] * L_E_w + fresh_alpha[2] * L_G_w
    
    optimizer_weights.zero_grad()
    weight_loss.backward()
    optimizer_weights.step()

    return {
        'student_loss': student_loss.item(),
        'teacher_loss': teacher_loss_value,
        'recon_loss': recon_loss.item(),
        'level_weights': alpha.detach().cpu().numpy()
    }

def train_recodistill_pipeline(
    graph_list, 
    input_dim, 
    num_epochs=50,
    teacher_update_freq=5,
    checkpoint_epochs=[10, 20, 30],
    lambda_recon=0.1,
    beta=0.5,
    seed=42
):
    """
    Complete training pipeline for ReCoDistill.
    
    Args:
        graph_list: List of graphs for training
        input_dim: Feature dimension
        num_epochs: Total number of epochs
        teacher_update_freq: Frequency to update teacher (in epochs)
        checkpoint_epochs: When to save teacher checkpoints
        lambda_recon: Weight for reconstruction loss
        beta: Teacher regularization strength
        seed: Random seed
        
    Returns:
        Trained models and training metrics
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Split data into train/test
    train_size = int(0.8 * len(graph_list))
    train_graphs = graph_list[:train_size]
    test_graphs = graph_list[train_size:]
    
    # Initialize models
    teacher_model = GCNEncoder(input_dim=input_dim, num_layers=3)
    student_model = StudentGCN(input_dim=input_dim)
    decoder = GCNDecoder(input_dim=64, output_dim=64)  # Assuming embedding dim is 64
    weights_model = LearningWeights(num_levels=3)  # Node, edge, graph levels
    teacher = TeacherWithCheckpoints(teacher_model)
    
    # Initialize optimizers
    optimizer_student = torch.optim.Adam(
        list(student_model.parameters()) + list(decoder.parameters()), 
        lr=0.01
    )
    optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=0.005)
    optimizer_weights = torch.optim.Adam(weights_model.parameters(), lr=0.001)
    
    # Training metrics
    metrics = {
        'student_losses': [],
        'teacher_losses': [],
        'recon_losses': [],
        'level_weights': []
    }
    
    # Pre-train teacher on clean data to initialize checkpoints
    print("🔄 Pre-training teacher model...")
    for epoch in range(1, 11):  # Pre-training epochs
        total_loss = 0
        for graph in train_graphs[:10]:  # Use subset for efficiency
            # Simple pre-training with reconstruction objective
            encoded = teacher_model.encode_graph(graph)
            total_loss += torch.nn.functional.mse_loss(encoded, encoded)  # Identity objective
        
        avg_loss = total_loss / min(10, len(train_graphs))
        if epoch % 5 == 0:
            print(f"Teacher pre-training epoch {epoch:02d} | Loss: {avg_loss:.4f}")
            
    # Save initial checkpoints at different semantic levels
    print("💾 Saving teacher checkpoints for progressive distillation...")
    for t, graph_idx in enumerate(range(min(3, len(train_graphs)))):
        graph = train_graphs[graph_idx]
        emb_clean = teacher_model.encode_graph(graph)
        teacher.save_checkpoint(t, emb_clean)
        print(f"Saved checkpoint {t} using graph {graph_idx}")
    
    # Main training loop - multi-phase curriculum
    print("🚀 Starting ReCoDistill training...")
    
    # Define curriculum phases
    phases = [
        {"name": "Node-level", "epochs": (1, 15), "focus_alpha": 0},
        {"name": "Edge-level", "epochs": (16, 35), "focus_alpha": 1},
        {"name": "Graph-level", "epochs": (36, 50), "focus_alpha": 2}
    ]
    
    current_phase = 0
    for epoch in range(1, num_epochs + 1):
        # Check if we need to move to next curriculum phase
        if current_phase < len(phases) - 1 and epoch > phases[current_phase]["epochs"][1]:
            current_phase += 1
            print(f"\n📚 Moving to {phases[current_phase]['name']} phase")
        
        # Determine if we should update teacher in this epoch
        update_teacher = (epoch % teacher_update_freq == 0) and (epoch < num_epochs/2)
        
        # Select a batch of graphs for this epoch
        batch_size = min(5, len(train_graphs))
        batch_graphs = random.sample(train_graphs, batch_size)
        
        total_metrics = {
            'student_loss': 0,
            'teacher_loss': 0,
            'recon_loss': 0,
            'level_weights': np.zeros(3)
        }
        
        # Train on each graph in batch
        for graph in batch_graphs:
            batch_metrics = train_recodistill(
                graph, teacher, student_model, decoder, weights_model,
                optimizer_student, optimizer_teacher, optimizer_weights,
                lambda_recon=lambda_recon, update_teacher=update_teacher
            )
            
            # Accumulate metrics
            for k, v in batch_metrics.items():
                if k == 'level_weights':
                    total_metrics[k] += v
                else:
                    total_metrics[k] += v / batch_size
        
        total_metrics['level_weights'] /= batch_size
        
        # Save metrics
        metrics['student_losses'].append(total_metrics['student_loss'])
        metrics['teacher_losses'].append(total_metrics['teacher_loss'])
        metrics['recon_losses'].append(total_metrics['recon_loss'])
        metrics['level_weights'].append(total_metrics['level_weights'])
        
        # Save checkpoint at specified epochs
        if epoch in checkpoint_epochs:
            for i, graph in enumerate(train_graphs[:3]):
                emb_clean = teacher_model.encode_graph(graph)
                teacher.save_checkpoint(epoch + i, emb_clean)
            print(f"💾 Saved checkpoints at epoch {epoch}")
        
        # Print progress
        if epoch % 5 == 0 or epoch == 1:
            phase_name = phases[current_phase]["name"]
            weights = total_metrics['level_weights']
            print(f"Epoch {epoch:02d} ({phase_name}) | " 
                  f"Student Loss: {total_metrics['student_loss']:.4f} | "
                  f"Teacher Loss: {total_metrics['teacher_loss']:.4f} | "
                  f"Recon Loss: {total_metrics['recon_loss']:.4f} | "
                  f"Weights: N={weights[0]:.2f}, E={weights[1]:.2f}, G={weights[2]:.2f}")
    
    print("✅ Training complete!")
    
    # Test the trained model
    results, test_metrics = test_anomaly_detection(
        test_graphs, teacher, student_model, 
        decoder=decoder, weights_model=weights_model
    )
    
    return {
        'teacher': teacher,
        'student': student_model,
        'decoder': decoder,
        'weights_model': weights_model,
        'training_metrics': metrics,
        'test_metrics': test_metrics,
        'test_results': results
    }


def detect_anomalies_recodistill(
   graph_list: List[Data],
   teacher,
   student_model,
   decoder=None,
   weights_model=None,
   threshold: float = None,
   return_scores: bool = False
):
   """
   Implements ReCoDistill's anomaly detection scoring mechanism.
   
   Args:
       graph_list: List of PyG Data objects to analyze
       teacher: Trained TeacherWithCheckpoints model
       student_model: Trained StudentGCN model
       decoder: Trained decoder model
       weights_model: Model with learned weights (if None, use equal weights)
       threshold: Anomaly threshold (if None, chooses based on ground truth)
       return_scores: If True, returns raw anomaly scores instead of binary labels
       
   Returns:
       Dictionary with node, edge, and graph anomaly scores/predictions
   """
   results = {
       'node_scores': [], 'node_preds': [], 'node_labels': [],
       'edge_scores': [], 'edge_preds': [], 'edge_labels': [],
       'graph_scores': [], 'graph_preds': [], 'graph_labels': []
   }
   
   # Get alpha weights (either from model or equal weighting)
   if weights_model is not None:
       alpha = weights_model().detach().cpu().numpy()
   else:
       alpha = np.array([1/3, 1/3, 1/3])  # Equal weights
   
   with torch.no_grad():
       for graph in graph_list:
           # Get anomaly ground truth labels
           node_labels = graph.node_label.cpu().numpy()
           edge_labels = graph.edge_label.cpu().numpy()
           graph_label = graph.y.item()
           
           # Get multi-level embeddings from student and teacher
           gn_view = node_level_augmentation(copy.deepcopy(graph))
           ge_view = edge_level_augmentation(copy.deepcopy(graph))
           gg_view = graph_level_augmentation(copy.deepcopy(graph))
           
           # Student embeddings
           emb_gn_s = student_model.encode_graph(gn_view)
           emb_ge_s = student_model.encode_graph(ge_view)
           emb_gg_s = student_model.encode_graph(gg_view)
           
           # Find best teacher checkpoints (progressive distillation)
           best_ckpt_n = teacher.get_best_checkpoint(emb_gn_s, level='N')
           best_ckpt_e = teacher.get_best_checkpoint(emb_ge_s, level='E')
           best_ckpt_g = teacher.get_best_checkpoint(emb_gg_s, level='G')
           
           # Clean teacher embeddings from best checkpoints
           emb_n_clean = teacher.get_embedding(graph, checkpoint=best_ckpt_n)
           emb_e_clean = teacher.get_embedding(graph, checkpoint=best_ckpt_e)
           emb_g_clean = teacher.get_embedding(graph, checkpoint=best_ckpt_g)
           
           # Calculate scale-aware anomaly scores using equation (8) from the paper
           # s(x) = ∑_{k∈{N,E,G}} α_k(1 - sim(H_S^(k), H_C^(t_k)))
           node_sim = cosine_similarity(emb_gn_s, emb_n_clean)
           edge_sim = cosine_similarity(emb_ge_s, emb_e_clean)
           graph_sim = cosine_similarity(emb_gg_s, emb_g_clean)
           
           # Use decoder reconstructions for additional verification if decoder is provided
           if decoder is not None:
               decoded_n = decoder(emb_gn_s)
               decoded_e = decoder(emb_ge_s)
               decoded_g = decoder(emb_gg_s)
               
               recon_n_sim = cosine_similarity(decoded_n, emb_n_clean)
               recon_e_sim = cosine_similarity(decoded_e, emb_e_clean)
               recon_g_sim = cosine_similarity(decoded_g, emb_g_clean)
               
               # Combined similarity scores (original + reconstruction)
               combined_n_sim = (node_sim + recon_n_sim) / 2
               combined_e_sim = (edge_sim + recon_e_sim) / 2
               combined_g_sim = (graph_sim + recon_g_sim) / 2
           else:
               combined_n_sim = node_sim
               combined_e_sim = edge_sim
               combined_g_sim = graph_sim
           
           # Final anomaly scores
           graph_score = alpha[0] * (1 - combined_n_sim) + \
                         alpha[1] * (1 - combined_e_sim) + \
                         alpha[2] * (1 - combined_g_sim)
           
           # Node and edge level scoring with decoder verification
           node_scores = []
           for node_idx in range(graph.num_nodes):
               # Get student and teacher node embeddings
               node_emb_student = student_model.forward(graph)[node_idx]
               node_emb_teacher = teacher.encoder.forward(graph)[node_idx]
               
               # Similarity scores
               direct_sim = cosine_similarity(
                   node_emb_student.unsqueeze(0), 
                   node_emb_teacher.unsqueeze(0)
               )
               
               # Decoder verification if available
               if decoder is not None:
                   node_decoded = decoder(node_emb_student.unsqueeze(0)).squeeze(0)
                   recon_sim = cosine_similarity(
                       node_decoded.unsqueeze(0),
                       node_emb_teacher.unsqueeze(0)
                   )
                   combined_sim = (direct_sim + recon_sim) / 2
               else:
                   combined_sim = direct_sim
               
               node_score = alpha[0] * (1 - combined_sim)
               node_scores.append(node_score.item())
           
           # Edge level scores
           edge_scores = []
           for edge_idx in range(graph.edge_index.shape[1]):
               src, dst = graph.edge_index[:, edge_idx]
               
               # Average node embeddings for edge representation
               src_emb_student = student_model.forward(graph)[src]
               dst_emb_student = student_model.forward(graph)[dst]
               edge_emb_student = (src_emb_student + dst_emb_student) / 2
               
               src_emb_teacher = teacher.encoder.forward(graph)[src]
               dst_emb_teacher = teacher.encoder.forward(graph)[dst]
               edge_emb_teacher = (src_emb_teacher + dst_emb_teacher) / 2
               
               # Direct similarity
               direct_sim = cosine_similarity(
                   edge_emb_student.unsqueeze(0),
                   edge_emb_teacher.unsqueeze(0)
               )
               
               # Decoder verification if available
               if decoder is not None:
                   edge_decoded = decoder(edge_emb_student.unsqueeze(0)).squeeze(0)
                   recon_sim = cosine_similarity(
                       edge_decoded.unsqueeze(0),
                       edge_emb_teacher.unsqueeze(0)
                   )
                   combined_sim = (direct_sim + recon_sim) / 2
               else:
                   combined_sim = direct_sim
               
               edge_score = alpha[1] * (1 - combined_sim)
               edge_scores.append(edge_score.item())
           
           # Final graph score - combine direct graph score with node/edge aggregation
           # s(G) = (1/|V| + |E|) * (sum_{v ∈ V} s(v) + sum_{e ∈ E} s(e)) - Equation (9)
           aggregated_score = (sum(node_scores) + sum(edge_scores)) / (len(node_scores) + len(edge_scores))
           final_graph_score = (graph_score.item() + aggregated_score) / 2
           
           # Determine thresholds or use provided
           if threshold is None:
               node_threshold = np.percentile(node_scores, 80)
               edge_threshold = np.percentile(edge_scores, 80)
               graph_threshold = 0.5
           else:
               node_threshold = edge_threshold = graph_threshold = threshold
           
           # Make binary predictions
           node_preds = [1 if s > node_threshold else 0 for s in node_scores]
           edge_preds = [1 if s > edge_threshold else 0 for s in edge_scores]
           graph_pred = 1 if final_graph_score > graph_threshold else 0
           
           # Store results
           results['node_scores'].extend(node_scores)
           results['node_preds'].extend(node_preds)
           results['node_labels'].extend(node_labels.tolist())
           
           results['edge_scores'].extend(edge_scores)
           results['edge_preds'].extend(edge_preds)
           results['edge_labels'].extend(edge_labels.tolist())
           
           results['graph_scores'].append(final_graph_score)
           results['graph_preds'].append(graph_pred)
           results['graph_labels'].append(graph_label)
   
   if return_scores:
       return results
   else:
       return {
           'node_preds': results['node_preds'],
           'edge_preds': results['edge_preds'],
           'graph_preds': results['graph_preds']
       }


def evaluate_anomaly_detection(results):
   """
   Calculate AUROC, AUPRC, and Macro F1 scores for node, edge and graph-level anomaly detection.
   
   Args:
       results: Dictionary with prediction results from detect_anomalies(return_scores=True)
       
   Returns:
       Dictionary with evaluation metrics
   """
   from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

   metrics = {}
   
   # Node-level metrics
   if len(results['node_scores']) > 0:
       node_labels = np.array(results['node_labels'])
       node_scores = np.array(results['node_scores'])
       node_preds = np.array(results['node_preds'])
       
       # Check if we have both classes for proper evaluation
       if len(np.unique(node_labels)) > 1:
           metrics['node_auroc'] = roc_auc_score(node_labels, node_scores)
           metrics['node_auprc'] = average_precision_score(node_labels, node_scores)
           metrics['node_f1'] = f1_score(node_labels, node_preds, average='macro')
   
   # Edge-level metrics
   if len(results['edge_scores']) > 0:
       edge_labels = np.array(results['edge_labels'])
       edge_scores = np.array(results['edge_scores'])
       edge_preds = np.array(results['edge_preds'])
       
       if len(np.unique(edge_labels)) > 1:
           metrics['edge_auroc'] = roc_auc_score(edge_labels, edge_scores)
           metrics['edge_auprc'] = average_precision_score(edge_labels, edge_scores)
           metrics['edge_f1'] = f1_score(edge_labels, edge_preds, average='macro')
   
   # Graph-level metrics
   if len(results['graph_scores']) > 0:
       graph_labels = np.array(results['graph_labels'])
       graph_scores = np.array(results['graph_scores'])
       graph_preds = np.array(results['graph_preds'])
       
       if len(np.unique(graph_labels)) > 1:
           metrics['graph_auroc'] = roc_auc_score(graph_labels, graph_scores)
           metrics['graph_auprc'] = average_precision_score(graph_labels, graph_scores)
           metrics['graph_f1'] = f1_score(graph_labels, graph_preds, average='macro')
   
   return metrics


def test_anomaly_detection(graph_list, teacher, student_model, decoder=None, weights_model=None):
   """
   Test anomaly detection on a set of graphs and print evaluation metrics.
   
   Args:
       graph_list: List of PyG Data objects to analyze
       teacher: Trained TeacherWithCheckpoints model
       student_model: Trained StudentGCN model
       decoder: Trained decoder model (optional)
       weights_model: Model with learned weights (optional)
   """
   print("🔍 Testing ReCoDistill anomaly detection...")
   
   # Split data into train/test
   train_size = int(0.8 * len(graph_list))
   train_graphs = graph_list[:train_size]
   test_graphs = graph_list[train_size:]
   
   # Detect anomalies
   results = detect_anomalies_recodistill(
       test_graphs,
       teacher,
       student_model,
       decoder=decoder,
       weights_model=weights_model,
       threshold=None,  # Auto-determine threshold
       return_scores=True
   )
   
   # Evaluate
   metrics = evaluate_anomaly_detection(results)
   
   print("\n📊 Evaluation Results:")
   # Node-level metrics
   if 'node_auroc' in metrics:
       print(f"Node-level AUROC: {metrics['node_auroc']:.4f}")
       print(f"Node-level AUPRC: {metrics['node_auprc']:.4f}")
       print(f"Node-level Macro F1: {metrics['node_f1']:.4f}")
   
   # Edge-level metrics
   if 'edge_auroc' in metrics:
       print(f"Edge-level AUROC: {metrics['edge_auroc']:.4f}")
       print(f"Edge-level AUPRC: {metrics['edge_auprc']:.4f}")
       print(f"Edge-level Macro F1: {metrics['edge_f1']:.4f}")
   
   # Graph-level metrics
   if 'graph_auroc' in metrics:
       print(f"Graph-level AUROC: {metrics['graph_auroc']:.4f}")
       print(f"Graph-level AUPRC: {metrics['graph_auprc']:.4f}")
       print(f"Graph-level Macro F1: {metrics['graph_f1']:.4f}")
   
   return results, metrics


def train_zero_shot_recodistill(
   normal_graphs,
   input_dim,
   num_epochs=20,
   seed=42
):
   """
   Train ReCoDistill model on normal samples for zero-shot anomaly detection.
   
   Args:
       normal_graphs: List of normal graphs for training
       input_dim: Feature dimension
       num_epochs: Number of training epochs
       seed: Random seed
       
   Returns:
       Trained models (teacher, student, decoder, weights_model)
   """
   random.seed(seed)
   np.random.seed(seed)
   torch.manual_seed(seed)
   
   if len(normal_graphs) == 0:
       raise ValueError("No normal graphs provided for training")
   
   # Initialize models
   teacher_model = GCNEncoder(input_dim=input_dim, num_layers=3)
   student_model = StudentGCN(input_dim=input_dim)
   teacher = TeacherWithCheckpoints(teacher_model)
   decoder = GCNDecoder(input_dim=64, output_dim=64)
   weights_model = LearningWeights(num_levels=3)
   
   # Initialize optimizers
   optimizer_student = torch.optim.Adam(
       list(student_model.parameters()) + list(decoder.parameters()),
       lr=0.01
   )
   optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=0.005)
   optimizer_weights = torch.optim.Adam(weights_model.parameters(), lr=0.001)
   
   # Save initial teacher checkpoints
   for t in range(3):
       sample_graph = normal_graphs[min(t, len(normal_graphs)-1)]
       emb_clean = teacher_model.encode_graph(sample_graph)
       teacher.save_checkpoint(t, emb_clean)
   
   # Training loop
   for epoch in range(1, num_epochs + 1):
       total_loss = 0
       
       # Train on each normal graph
       for graph in normal_graphs:
           # Use the bidirectional training
           metrics = train_recodistill(
               graph, teacher, student_model, decoder, weights_model,
               optimizer_student, optimizer_teacher, optimizer_weights,
               lambda_recon=0.1, update_teacher=(epoch < num_epochs/2)
           )
           total_loss += metrics['student_loss']
       
       avg_loss = total_loss / len(normal_graphs)
       if epoch % 5 == 0:
           print(f"Epoch {epoch:02d} | Avg. Loss: {avg_loss:.4f}")
   
   return {
       'teacher': teacher,
       'student': student_model,
       'decoder': decoder,
       'weights_model': weights_model
   }


def evaluate_zero_shot(
   train_graphs: List[Data],
   test_graphs: List[Data],
   input_dim: int,
   num_epochs: int = 20,
   seed: int = 42
):
   """
   Trains ReCoDistill on normal samples only and evaluates on unseen anomalies (zero-shot).
   
   Args:
       train_graphs: List of normal graphs for training (no anomalies)
       test_graphs: List of test graphs containing both normal and anomalous samples
       input_dim: Input dimension for the models
       num_epochs: Number of training epochs
       seed: Random seed for reproducibility
       
   Returns:
       Dictionary with evaluation metrics
   """
   # Set random seeds for reproducibility
   random.seed(seed)
   np.random.seed(seed)
   torch.manual_seed(seed)
   
   print(f"🔄 Running zero-shot evaluation (seed: {seed})...")
   print(f"Training on {len(train_graphs)} normal graphs, testing on {len(test_graphs)} graphs")
   
   # Filter training data to ensure we only use normal samples
   normal_train_graphs = [g for g in train_graphs if g.y.item() == 0]
   print(f"Using {len(normal_train_graphs)} confirmed normal graphs for training")
   
   if len(normal_train_graphs) == 0:
       print("⚠️ No normal graphs found in training set! Using first graph and ignoring its label.")
       normal_train_graphs = [train_graphs[0]]
   
   # Train models using the bidirectional approach
   trained_models = train_zero_shot_recodistill(
       normal_train_graphs,
       input_dim=input_dim,
       num_epochs=num_epochs,
       seed=seed
   )
   
   teacher = trained_models['teacher']
   student_model = trained_models['student']
   decoder = trained_models['decoder']
   weights_model = trained_models['weights_model']
   
   # Evaluate on the test set (which contains unseen anomalies)
   print("🔍 Zero-shot evaluation on test set...")
   results = detect_anomalies_recodistill(
       test_graphs,
       teacher,
       student_model,
       decoder=decoder,
       weights_model=weights_model,
       threshold=None,
       return_scores=True
   )
   
   # Calculate metrics
   metrics = evaluate_anomaly_detection(results)
   
   print("\n📊 Zero-Shot Evaluation Results:")
   # Node-level metrics
   if 'node_auroc' in metrics:
       print(f"Node-level AUROC: {metrics['node_auroc']:.4f}")
       print(f"Node-level AUPRC: {metrics['node_auprc']:.4f}")
   
   # Edge-level metrics
   if 'edge_auroc' in metrics:
       print(f"Edge-level AUROC: {metrics['edge_auroc']:.4f}")
       print(f"Edge-level AUPRC: {metrics['edge_auprc']:.4f}")
   
   # Graph-level metrics
   if 'graph_auroc' in metrics:
       print(f"Graph-level AUROC: {metrics['graph_auroc']:.4f}")
       print(f"Graph-level AUPRC: {metrics['graph_auprc']:.4f}")
   
   return metrics


def zero_shot_experiment(datasets_config, num_trials=5):
   """
   Run zero-shot experiments across multiple datasets and trials.
   
   Args:
       datasets_config: Dictionary with dataset configurations
       num_trials: Number of trials to run with different seeds
   
   Returns:
       Dictionary with aggregated results
   """
   all_seeds = [42, 123, 456, 789, 101][:num_trials]
   all_results = {}
   
   for name, params in datasets_config.items():
       print(f"\n{'='*50}")
       print(f"🔧 Dataset: {name} (Zero-Shot Evaluation)")
       print(f"{'='*50}")
       
       # Initialize result containers for this dataset
       dataset_results = {
           'node_auroc': [], 'node_auprc': [], 'node_f1': [],
           'edge_auroc': [], 'edge_auprc': [], 'edge_f1': [],
           'graph_auroc': [], 'graph_auprc': [], 'graph_f1': []
       }
       
       for trial, seed in enumerate(all_seeds):
           print(f"\n🔄 Running zero-shot trial {trial+1}/{num_trials}...")
           
           # Generate dataset with the current seed
           graph_list = generate_synthetic_anomaly_graphs(dataset_name=f"{name}_zero_shot_trial{trial+1}", **params)
           
           # Important for zero-shot: Separate normal graphs for training
           normal_graphs = [g for g in graph_list if g.y.item() == 0]
           # Ensure that we have enough normal graphs
           if len(normal_graphs) < 10:
               print(f"⚠️ Only {len(normal_graphs)} normal graphs available. Using all for training.")
               train_size = len(normal_graphs)
           else:
               train_size = min(50, int(0.5 * len(normal_graphs)))
           
           train_graphs = normal_graphs[:train_size]
           
           # Test set includes both normal and anomalous graphs not used in training
           all_test_graphs = [g for g in graph_list if g not in train_graphs]
           test_size = min(len(all_test_graphs), 200)  # Limit test set size for efficiency
           test_graphs = all_test_graphs[:test_size]
           
           print(f"Training on {len(train_graphs)} normal graphs, testing on {len(test_graphs)} graphs")
           
           # Get input dimension
           input_dim = graph_list[0].x.shape[1]
           
           # Run zero-shot evaluation
           metrics = evaluate_zero_shot(
               train_graphs=train_graphs,
               test_graphs=test_graphs,
               input_dim=input_dim,
               num_epochs=20,
               seed=seed
           )
           
           # Save metrics from this trial
           for metric_name, value in metrics.items():
               if metric_name in dataset_results:
                   dataset_results[metric_name].append(value)
       
       # Calculate mean and std for each metric across trials
       mean_results = {}
       std_results = {}
       
       for metric_name, values in dataset_results.items():
           if values:  # Check if we have values for this metric
               mean_results[metric_name] = np.mean(values)
               std_results[metric_name] = np.std(values)
       
       # Store aggregated results for this dataset
       all_results[name] = {
           'mean': mean_results,
           'std': std_results,
           'trials': dataset_results
       }
       
       # Print dataset summary
       print(f"\n📊 Zero-Shot Summary for {name} (averaged over {num_trials} trials):")
       print(f"Node-level AUROC: {mean_results.get('node_auroc', float('nan')):.4f} ± {std_results.get('node_auroc', float('nan')):.4f}")
       print(f"Edge-level AUROC: {mean_results.get('edge_auroc', float('nan')):.4f} ± {std_results.get('edge_auroc', float('nan')):.4f}")
       print(f"Graph-level AUROC: {mean_results.get('graph_auroc', float('nan')):.4f} ± {std_results.get('graph_auroc', float('nan')):.4f}")
   
   # Print grand summary of zero-shot results
   print("\n" + "="*80)
   print("📊 SUMMARY OF ZERO-SHOT RESULTS ACROSS ALL DATASETS")
   print("="*80)
   print(f"{'Dataset':<10} | {'Node AUROC':<20} | {'Edge AUROC':<20} | {'Graph AUROC':<20}")
   print("-"*80)
   
   for name, results in all_results.items():
       mean_metrics = results['mean']
       std_metrics = results['std']
       
       node_result = f"{mean_metrics.get('node_auroc', float('nan')):.4f} ± {std_metrics.get('node_auroc', float('nan')):.4f}"
       edge_result = f"{mean_metrics.get('edge_auroc', float('nan')):.4f} ± {std_metrics.get('edge_auroc', float('nan')):.4f}"
       graph_result = f"{mean_metrics.get('graph_auroc', float('nan')):.4f} ± {std_metrics.get('graph_auroc', float('nan')):.4f}"
       
       print(f"{name:<10} | {node_result:<20} | {edge_result:<20} | {graph_result:<20}")
   
   print("="*80)
   
   # Save zero-shot results
   import json
   with open("prodigy_zero_shot_results.json", "w") as f:
       # Convert NumPy values to native Python types
       serializable_results = {}
       for dataset, results in all_results.items():
           serializable_results[dataset] = {
               'mean': {k: float(v) for k, v in results['mean'].items()},
               'std': {k: float(v) for k, v in results['std'].items()},
               'trials': {k: [float(val) for val in v] for k, v in results['trials'].items() if v}
           }
       json.dump(serializable_results, f, indent=2)
   
   print("\n✅ Zero-shot results saved to prodigy_zero_shot_results.json")
   
   return all_results


def evaluate_model_performance(
   graph_list: List[Data],
   input_dim: int,
   num_epochs: int = 20,
   batch_size: int = 32,
   seed: int = 42
) -> Dict[str, Dict[str, float]]:
   """
   Evaluates performance metrics including training time, inference time, and memory usage
   for both teacher and student models in the ReCoDistill framework.
   
   Args:
       graph_list: List of graphs for evaluation
       input_dim: Input dimension for the models
       num_epochs: Number of training epochs
       batch_size: Batch size for inference (to simulate batched prediction)
       seed: Random seed for reproducibility
       
   Returns:
       Dictionary with performance metrics for teacher and student models
   """
   # Set random seeds for reproducibility
   torch.manual_seed(seed)
   np.random.seed(seed)
   
   # Ensure clean memory state
   gc.collect()
   torch.cuda.empty_cache() if torch.cuda.is_available() else None
   
   # Select device
   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   print(f"Using device: {device}")
   
   # Sample graph for testing
   test_graph = graph_list[0].to(device)
   
   # Performance metrics dictionary
   metrics = {
       'teacher': {},
       'student': {},
       'comparison': {}
   }
   
   # ----- Teacher Model Evaluation -----
   print("\n🔍 Evaluating Teacher Model Performance...")
   
   # Initialize teacher model
   teacher_model = GCNEncoder(input_dim=input_dim, num_layers=3).to(device)
   teacher = TeacherWithCheckpoints(teacher_model)
   
   # Measure peak memory before training
   memory_before = psutil.Process().memory_info().rss / (1024 * 1024)  # MB
   
   # Measure training time
   train_start = time.time()
   
   # Save clean teacher checkpoints (progressive distillation)
   for t in [0, 1, 2]:
       emb_clean = teacher_model.encode_graph(test_graph)
       teacher.save_checkpoint(t, emb_clean)
   
   train_end = time.time()
   teacher_train_time = train_end - train_start
   
   # Measure peak memory after training
   memory_after = psutil.Process().memory_info().rss / (1024 * 1024)  # MB
   teacher_memory = memory_after - memory_before
   
   # Measure inference time - single sample
   torch.cuda.synchronize() if torch.cuda.is_available() else None
   inference_start = time.time()
   with torch.no_grad():
       _ = teacher_model.encode_graph(test_graph)
   torch.cuda.synchronize() if torch.cuda.is_available() else None
   inference_end = time.time()
   teacher_inference_time_single = (inference_end - inference_start) * 1000  # ms
   
   # Measure inference time - batch processing simulation (if we have enough graphs)
   if len(graph_list) >= batch_size:
       batch_graphs = graph_list[:batch_size]
       
       torch.cuda.synchronize() if torch.cuda.is_available() else None
       batch_start = time.time()
       
       with torch.no_grad():
           for graph in batch_graphs:
               graph = graph.to(device)
               _ = teacher_model.encode_graph(graph)
               
       torch.cuda.synchronize() if torch.cuda.is_available() else None
       batch_end = time.time()
       teacher_inference_time_batch = (batch_end - batch_start) * 1000 / batch_size  # ms per sample
   else:
       teacher_inference_time_batch = teacher_inference_time_single
   
   # Count parameters
   teacher_params = sum(p.numel() for p in teacher_model.parameters())
   
   # Store teacher metrics
   metrics['teacher'] = {
       'train_time_seconds': teacher_train_time,
       'inference_time_ms_single': teacher_inference_time_single,
       'inference_time_ms_batch': teacher_inference_time_batch,
       'memory_usage_mb': teacher_memory,
       'parameter_count': teacher_params
   }
   
   # Clear memory
   del teacher_model, teacher
   gc.collect()
   torch.cuda.empty_cache() if torch.cuda.is_available() else None
   
   # ----- Student Model Evaluation -----
   print("\n🔍 Evaluating Student Model Performance...")
   
   # Initialize teacher and student models
   teacher_model = GCNEncoder(input_dim=input_dim, num_layers=3).to(device)
   teacher = TeacherWithCheckpoints(teacher_model)
   student_model = StudentGCN(input_dim=input_dim).to(device)
   decoder = GCNDecoder(input_dim=64, output_dim=64).to(device)
   weights_model = LearningWeights(num_levels=3).to(device)
   
   # Setup teacher with checkpoints for student training
   for t in [0, 1, 2]:
       emb_clean = teacher_model.encode_graph(test_graph)
       teacher.save_checkpoint(t, emb_clean)
   
   # Optimizers
   optimizer_student = torch.optim.Adam(
       list(student_model.parameters()) + list(decoder.parameters()), 
       lr=0.01
   )
   optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=0.005)
   optimizer_weights = torch.optim.Adam(weights_model.parameters(), lr=0.001)
   
   # Measure peak memory before training
   memory_before = psutil.Process().memory_info().rss / (1024 * 1024)  # MB
   
   # Measure training time
   train_start = time.time()
   
   # Train student model
   for epoch in range(1, num_epochs + 1):
       train_recodistill(
           test_graph, teacher, student_model, decoder, weights_model,
           optimizer_student, optimizer_teacher, optimizer_weights,
           lambda_recon=0.1, update_teacher=(epoch < num_epochs/2)
       )
   
   train_end = time.time()
   student_train_time = train_end - train_start
   
   # Measure peak memory after training
   memory_after = psutil.Process().memory_info().rss / (1024 * 1024)  # MB
   student_memory = memory_after - memory_before
   
   # Measure inference time - single sample
   torch.cuda.synchronize() if torch.cuda.is_available() else None
   inference_start = time.time()
   with torch.no_grad():
       _ = student_model.encode_graph(test_graph)
   torch.cuda.synchronize() if torch.cuda.is_available() else None
   inference_end = time.time()
   student_inference_time_single = (inference_end - inference_start) * 1000  # ms
   
   # Measure inference time - batch processing simulation
   if len(graph_list) >= batch_size:
       batch_graphs = graph_list[:batch_size]
       
       torch.cuda.synchronize() if torch.cuda.is_available() else None
       batch_start = time.time()
       
       with torch.no_grad():
           for graph in batch_graphs:
               graph = graph.to(device)
               _ = student_model.encode_graph(graph)
               
       torch.cuda.synchronize() if torch.cuda.is_available() else None
       batch_end = time.time()
       student_inference_time_batch = (batch_end - batch_start) * 1000 / batch_size  # ms per sample
   else:
       student_inference_time_batch = student_inference_time_single
   
   # Count parameters
   student_params = sum(p.numel() for p in student_model.parameters()) + sum(p.numel() for p in decoder.parameters())
   
   # Store student metrics
   metrics['student'] = {
       'train_time_seconds': student_train_time,
       'inference_time_ms_single': student_inference_time_single,
       'inference_time_ms_batch': student_inference_time_batch,
       'memory_usage_mb': student_memory,
       'parameter_count': student_params
   }
   
   # ----- Comparison Metrics -----
   metrics['comparison'] = {
       'speed_up_train': teacher_train_time / student_train_time if student_train_time > 0 else float('inf'),
       'speed_up_inference': teacher_inference_time_single / student_inference_time_single,
       'memory_reduction': teacher_memory / student_memory if student_memory > 0 else float('inf'),
       'parameter_reduction': teacher_params / student_params
   }
   
   # Print summary
   print("\n📊 Model Performance Comparison:")
   print(f"{'Metric':<25} | {'Teacher':<15} | {'Student':<15} | {'Improvement':<15}")
   print("-" * 75)
   
   print(f"{'Training time (s)':<25} | {teacher_train_time:<15.4f} | {student_train_time:<15.4f} | {metrics['comparison']['speed_up_train']:.2f}x")
   print(f"{'Inference time (ms)':<25} | {teacher_inference_time_single:<15.4f} | {student_inference_time_single:<15.4f} | {metrics['comparison']['speed_up_inference']:.2f}x")
   print(f"{'Memory usage (MB)':<25} | {teacher_memory:<15.4f} | {student_memory:<15.4f} | {metrics['comparison']['memory_reduction']:.2f}x")
   print(f"{'Parameter count':<25} | {teacher_params:<15,d} | {student_params:<15,d} | {metrics['comparison']['parameter_reduction']:.2f}x")
   
   return metrics


def run_performance_evaluation(graph_list, input_dim, trials=3):
   """
   Run multiple trials of performance evaluation and report average metrics.
   
   Args:
       graph_list: List of graphs for evaluation
       input_dim: Input dimension for the models
       trials: Number of evaluation trials
       
   Returns:
       Dictionary with averaged performance metrics
   """
   all_metrics = []
   
   for trial in range(trials):
       print(f"\n{'='*50}")
       print(f"Running performance evaluation trial {trial+1}/{trials}")
       print(f"{'='*50}")
       
       metrics = evaluate_model_performance(
           graph_list=graph_list,
           input_dim=input_dim,
           seed=42+trial
       )
       
       all_metrics.append(metrics)
   
   # Average the metrics across trials
   avg_metrics = {
       'teacher': {},
       'student': {},
       'comparison': {}
   }
   
   # Average teacher metrics
   for key in all_metrics[0]['teacher'].keys():
       avg_metrics['teacher'][key] = sum(m['teacher'][key] for m in all_metrics) / trials
   
   # Average student metrics
   for key in all_metrics[0]['student'].keys():
       avg_metrics['student'][key] = sum(m['student'][key] for m in all_metrics) / trials
   
   # Average comparison metrics
   for key in all_metrics[0]['comparison'].keys():
       avg_metrics['comparison'][key] = sum(m['comparison'][key] for m in all_metrics) / trials
   
   # Print final averaged results
   print("\n" + "="*80)
   print("📊 FINAL PERFORMANCE METRICS (Averaged over multiple trials)")
   print("="*80)
   
   print(f"{'Metric':<25} | {'Teacher':<15} | {'Student':<15} | {'Improvement':<15}")
   print("-" * 75)
   
   print(f"{'Training time (s)':<25} | {avg_metrics['teacher']['train_time_seconds']:<15.4f} | {avg_metrics['student']['train_time_seconds']:<15.4f} | {avg_metrics['comparison']['speed_up_train']:.2f}x")
   print(f"{'Inference time (ms)':<25} | {avg_metrics['teacher']['inference_time_ms_single']:<15.4f} | {avg_metrics['student']['inference_time_ms_single']:<15.4f} | {avg_metrics['comparison']['speed_up_inference']:.2f}x")
   print(f"{'Memory usage (MB)':<25} | {avg_metrics['teacher']['memory_usage_mb']:<15.4f} | {avg_metrics['student']['memory_usage_mb']:<15.4f} | {avg_metrics['comparison']['memory_reduction']:.2f}x")
   
   # Fix: Use .0f format instead of d for parameter counts (which are floats)
   print(f"{'Parameter count':<25} | {avg_metrics['teacher']['parameter_count']:<15,.0f} | {avg_metrics['student']['parameter_count']:<15,.0f} | {avg_metrics['comparison']['parameter_reduction']:.2f}x")
   
   # Save the performance metrics to a file
   import json
   with open("prodigy_performance_metrics.json", "w") as f:
       json.dump(avg_metrics, f, indent=2)
   
   print("\n✅ Performance metrics saved to prodigy_performance_metrics.json")
   
   return avg_metrics


def main_with_comprehensive_evaluation():
   configs = get_prdigy_dataset_configs()
   comprehensive_results = {
       "standard_evaluation": {},
       "zero_shot_evaluation": {},
       "performance_metrics": {},
       "comparative_analysis": {}
   }
   
   # Run standard evaluation
   print("\n🔄 Running standard evaluation...")
   standard_results = main()
   comprehensive_results["standard_evaluation"] = standard_results
   
   # Run zero-shot evaluation
   print("\n🔄 Running zero-shot evaluation...")
   zero_shot_results = zero_shot_experiment(configs)
   comprehensive_results["zero_shot_evaluation"] = zero_shot_results
   
   # Run performance evaluation on a representative dataset
   print("\n🔄 Running performance evaluation...")
   dataset_name = "MUTAG"  # Representative dataset
   params = configs[dataset_name]
   graph_list = generate_synthetic_anomaly_graphs(dataset_name=dataset_name, **params)
   input_dim = graph_list[0].x.shape[1]
   performance_metrics = run_performance_evaluation(graph_list=graph_list, input_dim=input_dim, trials=3)
   comprehensive_results["performance_metrics"] = performance_metrics
   
   # Compile comparative analysis
   print("\n" + "="*80)
   print("📊 COMPARISON: STANDARD VS. ZERO-SHOT GRAPH-LEVEL AUROC")
   print("="*80)
   print(f"{'Dataset':<10} | {'Standard AUROC':<20} | {'Zero-Shot AUROC':<20} | {'Difference':<10}")
   print("-"*80)
   
   comparative_analysis = {}
   for name in configs.keys():
       if name in standard_results and name in zero_shot_results:
           std_auroc = standard_results[name]['mean'].get('graph_auroc', float('nan'))
           zs_auroc = zero_shot_results[name]['mean'].get('graph_auroc', float('nan'))
           difference = std_auroc - zs_auroc
           
           comparative_analysis[name] = {
               "standard_auroc": float(std_auroc),
               "zero_shot_auroc": float(zs_auroc),
               "difference": float(difference),
               "relative_difference_percent": float(difference / std_auroc * 100) if std_auroc != 0 else float('nan')
           }
           
           print(f"{name:<10} | {std_auroc:.4f} | {zs_auroc:.4f} | {difference:.4f}")
   
   comprehensive_results["comparative_analysis"] = comparative_analysis
   print("="*80)
   
   # Save all results to a single file
   import json
   with open("prodigy_comprehensive_results.json", "w") as f:
       # Handle NumPy types for JSON serialization
       def convert_to_serializable(obj):
           if isinstance(obj, dict):
               return {k: convert_to_serializable(v) for k, v in obj.items()}
           elif isinstance(obj, list):
               return [convert_to_serializable(item) for item in obj]
           elif isinstance(obj, np.ndarray):
               return obj.tolist()
           elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64,
                                 np.uint8, np.uint16, np.uint32, np.uint64)):
               return int(obj)
           elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
               return float(obj)
           else:
               return obj
       
       serializable_results = convert_to_serializable(comprehensive_results)
       json.dump(serializable_results, f, indent=2)
   
   print("\n✅ Comprehensive results saved to prodigy_comprehensive_results.json")
   
   return comprehensive_results


def main():
   configs = get_prdigy_dataset_configs()
   num_trials = 5  # Number of trials for each dataset
   all_seeds = [42, 123, 456, 789, 101]  # Different seeds for each trial
   
   # Dictionary to store results across all datasets and trials
   all_results = {}

   for name, params in configs.items():
       print(f"\n{'='*50}")
       print(f"🔧 Dataset: {name}")
       print(f"{'='*50}")
       
       # Initialize result containers for this dataset
       dataset_results = {
           'node_auroc': [], 'node_auprc': [], 'node_f1': [],
           'edge_auroc': [], 'edge_auprc': [], 'edge_f1': [],
           'graph_auroc': [], 'graph_auprc': [], 'graph_f1': []
       }
       
       for trial in range(num_trials):
           print(f"\n🔄 Running trial {trial+1}/{num_trials}...")
           
           # Set seed for this trial
           seed = all_seeds[trial]
           random.seed(seed)
           np.random.seed(seed)
           torch.manual_seed(seed)
           
           # Generate dataset with the current seed
           graph_list = generate_synthetic_anomaly_graphs(dataset_name=f"{name}_trial{trial+1}", **params)
           
           # Split data for training and testing
           train_size = min(50, int(0.2 * len(graph_list)))
           train_graphs = graph_list[:train_size]
           test_graphs = graph_list[train_size:]
           
           print(f"Dataset split: {train_size} graphs for training, {len(test_graphs)} graphs for testing")

           # Use the first graph for student-teacher training
           original = train_graphs[0]
           input_dim = original.x.shape[1]

           # Initialize models for bidirectional ReCoDistill
           teacher_model = GCNEncoder(input_dim=input_dim, num_layers=3)
           student_model = StudentGCN(input_dim=input_dim)
           decoder = GCNDecoder(input_dim=64, output_dim=64)
           weights_model = LearningWeights(num_levels=3)
           teacher = TeacherWithCheckpoints(teacher_model)
           
           # Initialize optimizers
           optimizer_student = torch.optim.Adam(
               list(student_model.parameters()) + list(decoder.parameters()), 
               lr=0.01
           )
           optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=0.005)
           optimizer_weights = torch.optim.Adam(weights_model.parameters(), lr=0.001)

           # Save clean teacher checkpoints (progressive distillation)
           for t in [0, 1, 2]:
               emb_clean = teacher_model.encode_graph(original)
               teacher.save_checkpoint(t, emb_clean)

           print("🚀 Training with bidirectional reverse contrastive loss...")
           for epoch in range(1, 21):
               metrics = train_recodistill(
                   original, teacher, student_model, decoder, weights_model,
                   optimizer_student, optimizer_teacher, optimizer_weights,
                   lambda_recon=0.1, update_teacher=(epoch < 10)
               )
               if epoch % 5 == 0:  # Print only every 5 epochs to reduce output
                   print(f"Epoch {epoch:02d} | Student Loss: {metrics['student_loss']:.4f} | Teacher Loss: {metrics['teacher_loss']:.4f}")

           # Run anomaly detection on test set
           print(f"🔍 Running anomaly detection on test set...")
           results, metrics = test_anomaly_detection(
               test_graphs, teacher, student_model, 
               decoder=decoder, weights_model=weights_model
           )
           
           # Save metrics from this trial
           for metric_name, value in metrics.items():
               if metric_name in dataset_results:
                   dataset_results[metric_name].append(value)
           
           # Save the trained models for this trial (optional)
           os.makedirs("./models", exist_ok=True)
           torch.save({
               'teacher': teacher_model.state_dict(),
               'student': student_model.state_dict(),
               'decoder': decoder.state_dict(),
               'weights_model': weights_model.state_dict(),
               'teacher_checkpoints': teacher.checkpoints
           }, f"./models/recodistill_{name}_trial{trial+1}_models.pt")
       
       # Calculate mean and std for each metric across trials
       mean_results = {}
       std_results = {}
       
       for metric_name, values in dataset_results.items():
           if values:  # Check if we have values for this metric
               mean_results[metric_name] = np.mean(values)
               std_results[metric_name] = np.std(values)
       
       # Store aggregated results for this dataset
       all_results[name] = {
           'mean': mean_results,
           'std': std_results,
           'trials': dataset_results
       }
       
       # Print dataset summary
       print(f"\n📊 Summary for {name} (averaged over {num_trials} trials):")
       print(f"Node-level AUROC: {mean_results.get('node_auroc', float('nan')):.4f} ± {std_results.get('node_auroc', float('nan')):.4f}")
       print(f"Edge-level AUROC: {mean_results.get('edge_auroc', float('nan')):.4f} ± {std_results.get('edge_auroc', float('nan')):.4f}")
       print(f"Graph-level AUROC: {mean_results.get('graph_auroc', float('nan')):.4f} ± {std_results.get('graph_auroc', float('nan')):.4f}")
   
   # Print grand summary across all datasets
   print("\n" + "="*80)
   print("📊 SUMMARY OF RESULTS ACROSS ALL DATASETS")
   print("="*80)
   print(f"{'Dataset':<10} | {'Node AUROC':<20} | {'Edge AUROC':<20} | {'Graph AUROC':<20}")
   print("-"*80)
   
   for name, results in all_results.items():
       mean_metrics = results['mean']
       std_metrics = results['std']
       
       node_result = f"{mean_metrics.get('node_auroc', float('nan')):.4f} ± {std_metrics.get('node_auroc', float('nan')):.4f}"
       edge_result = f"{mean_metrics.get('edge_auroc', float('nan')):.4f} ± {std_metrics.get('edge_auroc', float('nan')):.4f}"
       graph_result = f"{mean_metrics.get('graph_auroc', float('nan')):.4f} ± {std_metrics.get('graph_auroc', float('nan')):.4f}"
       
       print(f"{name:<10} | {node_result:<20} | {edge_result:<20} | {graph_result:<20}")
   
   print("="*80)
   
   # Write results to file
   import json
   with open("recodistill_results_summary.json", "w") as f:
       # Convert NumPy values to native Python types for JSON serialization
       serializable_results = {}
       for dataset, results in all_results.items():
           serializable_results[dataset] = {
               'mean': {k: float(v) for k, v in results['mean'].items()},
               'std': {k: float(v) for k, v in results['std'].items()},
               'trials': {k: [float(val) for val in v] for k, v in results['trials'].items() if v}
           }
       json.dump(serializable_results, f, indent=2)
   
   print("\n✅ Results saved to recodistill_results_summary.json")
   return all_results


def main_with_zero_shot():
   configs = get_prdigy_dataset_configs()
   
   # First run standard evaluation
   print("\n🔄 Running standard evaluation...")
   standard_results = main()
   
   # Then run zero-shot evaluation
   print("\n🔄 Running zero-shot evaluation...")
   zero_shot_results = zero_shot_experiment(configs)
   
   # Compare results
   print("\n" + "="*80)
   print("📊 COMPARISON: STANDARD VS. ZERO-SHOT GRAPH-LEVEL AUROC")
   print("="*80)
   print(f"{'Dataset':<10} | {'Standard AUROC':<20} | {'Zero-Shot AUROC':<20}")
   print("-"*80)
   
   for name in configs.keys():
       if name in standard_results and name in zero_shot_results:
           std_auroc = standard_results[name]['mean'].get('graph_auroc', float('nan'))
           zs_auroc = zero_shot_results[name]['mean'].get('graph_auroc', float('nan'))
           
           print(f"{name:<10} | {std_auroc:.4f} | {zs_auroc:.4f}")
   
   print("="*80)


if __name__ == "__main__":
   import os
   os.makedirs("./models", exist_ok=True)
   
   main_with_comprehensive_evaluation()