import os
import pandas as pd
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import time
import math
import matplotlib.pyplot as plt

#==================================================================================================
# Data Processing Functions
#==================================================================================================
def process_data(folder_path):
    ground_truth_ratings = None
    file_columns = {
        'user_movies.xlsx': ['userID', 'movieID', 'rating'],
        'movie_directors.xlsx': ['movieID', 'directorID'],
        'movie_actors.xlsx': ['movieID', 'actorID'],
        'movie_genres.xlsx': ['movieID', 'genreID']
    }

    unique_values = {col: set() for columns in file_columns.values() for col in columns}
    genre_mapping = {}

    for file_name, columns in file_columns.items():
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            df = pd.read_excel(file_path, usecols=columns)
            for column in columns:
                if column in unique_values and column != 'rating':
                    unique_values[column].update(df[column].unique())
            if file_name == 'movie_genres.xlsx':
                genre_mapping = df.set_index('movieID')['genreID'].to_dict()
            elif 'rating' in columns:
                if ground_truth_ratings is None:
                    ground_truth_ratings = df
                else:
                    ground_truth_ratings = pd.concat([ground_truth_ratings, df], ignore_index=True)
        else:
            print(f"File not found: {file_name}")

    if ground_truth_ratings is not None and genre_mapping:
        ground_truth_ratings['genreID'] = ground_truth_ratings['movieID'].map(genre_mapping)

        # Encode genreID to integer labels for classification
        label_encoder = LabelEncoder()
        ground_truth_ratings['genreID_encoded'] = label_encoder.fit_transform(ground_truth_ratings['genreID'])

    return ground_truth_ratings, label_encoder

def create_heterogeneous_graph(folder_path):
    G = nx.Graph()
    node_counts = {'userID': 0, 'movieID': 0, 'directorID': 0, 'actorID': 0}
    node_attributes = {}
    edge_weights = {}

    file_columns = {
        'user_movies.xlsx': ['userID', 'movieID', 'rating'],
        'movie_directors.xlsx': ['movieID', 'directorID'],
        'movie_actors.xlsx': ['movieID', 'actorID']
    }

    for file_name, columns in file_columns.items():
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            df = pd.read_excel(file_path, usecols=columns)

            if 'userID' in columns:
                for _, row in df.iterrows():
                    user_node = f"userID:{row['userID']}"
                    movie_node = f"movieID:{row['movieID']}"
                    rating = row['rating']

                    if user_node not in G:
                        G.add_node(user_node, type='userID')
                        node_counts['userID'] += 1

                    if movie_node not in G:
                        G.add_node(movie_node, type='movieID')
                        node_counts['movieID'] += 1

                    G.add_edge(user_node, movie_node, weight=rating)

            if 'directorID' in columns:
                for _, row in df.iterrows():
                    movie_node = f"movieID:{row['movieID']}"
                    director_node = f"directorID:{row['directorID']}"

                    if movie_node not in G:
                        G.add_node(movie_node, type='movieID')
                        node_counts['movieID'] += 1

                    if director_node not in G:
                        G.add_node(director_node, type='directorID')
                        node_counts['directorID'] += 1

                    G.add_edge(movie_node, director_node)

            if 'actorID' in columns:
                for _, row in df.iterrows():
                    movie_node = f"movieID:{row['movieID']}"
                    actor_node = f"actorID:{row['actorID']}"

                    if movie_node not in G:
                        G.add_node(movie_node, type='movieID')
                        node_counts['movieID'] += 1

                    if actor_node not in G:
                        G.add_node(actor_node, type='actorID')
                        node_counts['actorID'] += 1

                    G.add_edge(movie_node, actor_node)
    return G

#==================================================================================================
# Hypergraph Generation Functions
#==================================================================================================
def hypergraph_MU(folder_path):
    hyper_MU = {}
    att_MU = {}
    relationship_counts = {}
    edge_weights = {}

    file_columns = {
        'user_movies.xlsx': ['userID', 'movieID', 'rating'],
    }

    for file_name, columns in file_columns.items():
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            df = pd.read_excel(file_path, usecols=columns)

            for _, row in df.iterrows():
                movie_node = f"movieID:{row['movieID']}"
                user_node = f"userID:{str(row['userID'])}"
                rating = row['rating']

                if movie_node not in hyper_MU:
                    hyper_MU[movie_node] = []

                if user_node not in hyper_MU:
                    hyper_MU[user_node] = []

                hyper_MU[movie_node].append(user_node)

                att_MU[user_node] = {'type': 'userID'}
                att_MU[movie_node] = {'type': 'movieID'}

                edge_weights[(movie_node, user_node)] = rating

                relationship = 'userID-movieID'
                relationship_counts[relationship] = relationship_counts.get(relationship, {'nodes': 0, 'edges': 0})
                relationship_counts[relationship]['nodes'] += 2
                relationship_counts[relationship]['edges'] += 1

    hyper_MU = {k: v for k, v in hyper_MU.items() if v}
    
    num_edges = sum(len(nodes) for nodes in hyper_MU.values())

    print("Hypergraph information of MU:")
    print("Number of hyperedges of MU (nodes):", len(hyper_MU))
    print("Number of edges of MU:", num_edges)

    return hyper_MU, att_MU

def hypergraph_MD(folder_path):
    hyper_MD = {}
    att_MD = {}
    relationship_counts_MD = {}
    
    file_columns = {
        'movie_directors.xlsx': ['movieID', 'directorID'],
    }

    for file_name, columns in file_columns.items():
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            df = pd.read_excel(file_path, usecols=columns)

            for _, row in df.iterrows():
                movie_node = f"movieID:{row['movieID']}"
                director_node = f"directorID:{str(row['directorID'])}"

                if movie_node not in hyper_MD:
                    hyper_MD[movie_node] = []

                if director_node not in hyper_MD:
                    hyper_MD[director_node] = []

                hyper_MD[movie_node].append(director_node)

                att_MD[director_node] = {'type': 'directorID'}
                att_MD[movie_node] = {'type': 'movieID'}

                relationship = 'directorID-movieID'
                relationship_counts_MD[relationship] = relationship_counts_MD.get(relationship, {'nodes': 0, 'edges': 0})
                relationship_counts_MD[relationship]['nodes'] += 2
                relationship_counts_MD[relationship]['edges'] += 1

    hyper_MD = {k: v for k, v in hyper_MD.items() if v}

    num_edges = sum(len(nodes) for nodes in hyper_MD.values())

    print("Hypergraph information of MD:")
    print("Number of hyperedges of MD (nodes):", len(hyper_MD))
    print("Number of edges of MD:", num_edges)

    return hyper_MD, att_MD

def hypergraph_MA(folder_path):
    hyper_MA = {}
    att_MA = {}
    relationship_counts_MA = {}
    
    file_columns = {
        'movie_actors.xlsx': ['movieID', 'actorID'],
    }

    for file_name, columns in file_columns.items():
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            df = pd.read_excel(file_path, usecols=columns)

            for _, row in df.iterrows():
                movie_node = f"movieID:{row['movieID']}"
                actor_node = f"actorID:{str(row['actorID'])}"

                if movie_node not in hyper_MA:
                    hyper_MA[movie_node] = []

                if actor_node not in hyper_MA:
                    hyper_MA[actor_node] = []

                hyper_MA[movie_node].append(actor_node)

                att_MA[actor_node] = {'type': 'actorID'}
                att_MA[movie_node] = {'type': 'movieID'}

                relationship = 'actorID-movieID'
                relationship_counts_MA[relationship] = relationship_counts_MA.get(relationship, {'nodes': 0, 'edges': 0})
                relationship_counts_MA[relationship]['nodes'] += 2
                relationship_counts_MA[relationship]['edges'] += 1

    hyper_MA = {k: v for k, v in hyper_MA.items() if v}

    num_edges = sum(len(nodes) for nodes in hyper_MA.values())

    print("Hypergraph information of MA:")
    print("Number of hyperedges of MA (nodes):", len(hyper_MA))
    print("Number of edges of MA:", num_edges)

    return hyper_MA, att_MA

#==================================================================================================
# Incidence Matrix Generation
#==================================================================================================
def generate_incidence_matrices_MU(hyper_MU, att_MU):
    # Extract all unique nodes by type
    movie_nodes = [node for node in att_MU if att_MU[node]['type'] == 'movieID']
    user_nodes = [node for node in att_MU if att_MU[node]['type'] == 'userID']
    
    # Create index maps for efficient lookup
    movie_index_map = {movie: i for i, movie in enumerate(movie_nodes)}
    user_index_map = {user: i for i, user in enumerate(user_nodes)}
    
    # Initialize incidence matrix with correct dimensions
    num_users = len(user_nodes)
    num_movies = len(movie_nodes)
    incidence_matrix_MU = np.zeros((num_users, num_movies), dtype=float)
    
    # Fill in the incidence matrix based on hypergraph connections
    for movie_node, users_connected in hyper_MU.items():
        if movie_node in movie_index_map:
            movie_index = movie_index_map[movie_node]
            for user_node in users_connected:
                if user_node in user_index_map:
                    user_index = user_index_map[user_node]
                    incidence_matrix_MU[user_index, movie_index] = 1
                    
    print("incidence_matrix_MU shape:", incidence_matrix_MU.shape)
    return incidence_matrix_MU, user_index_map, movie_index_map

def generate_incidence_matrices_MD(hyper_MD, att_MD):
    movie_nodes = [node for node in att_MD if att_MD[node]['type'] == 'movieID']
    director_nodes = [node for node in att_MD if att_MD[node]['type'] == 'directorID']
    movie_index_map = {movie: i for i, movie in enumerate(movie_nodes)}
    director_index_map = {director: i for i, director in enumerate(director_nodes)}
    num_movies = len(movie_nodes)
    num_directors = len(director_nodes)
    incidence_matrix_MD = np.zeros((num_directors, num_movies), dtype=float)
    
    for movie_node, directors_connected in hyper_MD.items():
        if movie_node in movie_index_map:
            movie_index = movie_index_map[movie_node]
            for director_node in directors_connected:
                if director_node in director_index_map:
                    director_index = director_index_map[director_node]
                    incidence_matrix_MD[director_index, movie_index] = 1
                    
    print("incidence_matrix_MD shape:", incidence_matrix_MD.shape)
    return incidence_matrix_MD, director_index_map, movie_index_map

def generate_incidence_matrices_MA(hyper_MA, att_MA):
    movie_nodes = [node for node in att_MA if att_MA[node]['type'] == 'movieID']
    actor_nodes = [node for node in att_MA if att_MA[node]['type'] == 'actorID']
    movie_index_map = {movie: i for i, movie in enumerate(movie_nodes)}
    actor_index_map = {actor: i for i, actor in enumerate(actor_nodes)}
    num_movies = len(movie_nodes)
    num_actors = len(actor_nodes)
    incidence_matrix_MA = np.zeros((num_actors, num_movies), dtype=float)
    
    for movie_node, actors_connected in hyper_MA.items():
        if movie_node in movie_index_map:
            movie_index = movie_index_map[movie_node]
            for actor_node in actors_connected:
                if actor_node in actor_index_map:
                    actor_index = actor_index_map[actor_node]
                    incidence_matrix_MA[actor_index, movie_index] = 1
                    
    print("incidence_matrix_MA shape:", incidence_matrix_MA.shape)
    return incidence_matrix_MA, actor_index_map, movie_index_map

#==================================================================================================
# Utility Functions for Matrix Operations
#==================================================================================================
def pad_matrix(matrix, target_rows, target_cols):
    """Pad a matrix to the target dimensions."""
    current_rows, current_cols = matrix.shape
    row_padding = max(0, target_rows - current_rows)
    col_padding = max(0, target_cols - current_cols)
    
    if row_padding > 0 or col_padding > 0:
        if isinstance(matrix, torch.Tensor):
            matrix = F.pad(matrix, (0, col_padding, 0, row_padding), "constant", 0)
        else:
            matrix = F.pad(torch.tensor(matrix, dtype=torch.float32), 
                           (0, col_padding, 0, row_padding), "constant", 0)
    elif not isinstance(matrix, torch.Tensor):
        matrix = torch.tensor(matrix, dtype=torch.float32)
        
    return matrix

def compute_degree_matrices(incidence_matrix):
    """Compute degree matrices for vertices and hyperedges."""
    if not isinstance(incidence_matrix, torch.Tensor):
        incidence_matrix = torch.tensor(incidence_matrix, dtype=torch.float32)
    
    if incidence_matrix.dim() == 3:
        incidence_matrix = incidence_matrix.squeeze(0)
    
    # Compute node and hyperedge degrees
    node_degrees = torch.sum(incidence_matrix, dim=1)
    hyperedge_degrees = torch.sum(incidence_matrix, dim=0)
    
    # Create diagonal degree matrices with proper handling of empty matrices
    D_v = torch.diag(node_degrees) if node_degrees.numel() > 0 else torch.eye(incidence_matrix.size(0))
    D_e = torch.diag(hyperedge_degrees) if hyperedge_degrees.numel() > 0 else torch.eye(incidence_matrix.size(1))
    
    # Ensure matrices are properly dimensioned
    if len(D_v.shape) == 1:
        D_v = torch.diag(D_v)
    if len(D_e.shape) == 1:
        D_e = torch.diag(D_e)
    
    return D_v, D_e

def compute_laplacian_matrix(incidence_matrix):
    """Compute the normalized Laplacian matrix for hypergraph."""
    if not isinstance(incidence_matrix, torch.Tensor):
        incidence_matrix = torch.tensor(incidence_matrix, dtype=torch.float32)
    
    # Get vertex and hyperedge degree matrices
    D_v, D_e = compute_degree_matrices(incidence_matrix)
    
    # Compute D_v^(-1/2) and D_e^(-1) with numerical stability
    D_v_diag = torch.diag(D_v)
    D_e_diag = torch.diag(D_e)
    
    # Add small epsilon to avoid division by zero
    D_v_sqrt_inv = torch.diag(1.0 / torch.sqrt(torch.clamp(D_v_diag, min=1e-12)))
    D_e_inv = torch.diag(1.0 / torch.clamp(D_e_diag, min=1e-12))
    
    # Compute the normalized Laplacian: L = I - D_v^(-1/2) * H * D_e^(-1) * H^T * D_v^(-1/2)
    temp = torch.mm(D_v_sqrt_inv, torch.mm(incidence_matrix, D_e_inv))
    L = torch.eye(D_v.size(0)) - torch.mm(temp, torch.mm(incidence_matrix.t(), D_v_sqrt_inv))
    
    return L

def pad_to_match_shape(matrix, target_shape):
    """Pad a matrix to match the target shape, preserving batch dimensions."""
    if len(matrix.shape) > 2:
        *batch_dims, rows, cols = matrix.shape
    else:
        batch_dims = []
        rows, cols = matrix.shape

    target_rows, target_cols = target_shape

    # Compute how much padding is needed
    row_pad = max(0, target_rows - rows)
    col_pad = max(0, target_cols - cols)

    # Apply padding only if needed
    if row_pad > 0 or col_pad > 0:
        padding = (0, col_pad, 0, row_pad)  # (left, right, top, bottom)
        padded_matrix = F.pad(matrix, padding, "constant", 0)
    else:
        padded_matrix = matrix

    # If there were batch dimensions, reshape back
    if batch_dims:
        padded_matrix = padded_matrix.view(*batch_dims, target_rows, target_cols)

    return padded_matrix

#==================================================================================================
# CuCoDistill Implementation - Improved Version
#==================================================================================================

# 1. Hypergraph Triple Attention (HTA) Teacher Model
class HypergraphTripleAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(HypergraphTripleAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        # Output projection layer
        self.output_layer = nn.Linear(hidden_dim, output_dim) if hidden_dim != output_dim else nn.Identity()
        
        # Transformation matrices for different attention mechanisms
        self.W_n = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_e = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_g = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.W_g_prime = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.W_r = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        
        # Projection matrix with dropout
        self.P_d = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.dropout = nn.Dropout(0.2)
        
        # Learnable importance weights for attention mechanisms
        self.omega_n = nn.Parameter(torch.ones(1))
        self.omega_e = nn.Parameter(torch.ones(1))
        self.omega_g = nn.Parameter(torch.ones(1))
        
        # Additional parameters from the paper
        self.alpha = nn.Parameter(torch.zeros(1))  # Start with 0 to allow learning
        self.lambda_param = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.ones(1))
        
        # Temperature parameters
        self.tau_n = 0.1
        
        # Initialize parameters
        self.reset_parameters()
        
    def reset_parameters(self):
        """Initialize all weight matrices using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.W_n)
        nn.init.xavier_uniform_(self.W_e)
        nn.init.xavier_uniform_(self.W_g)
        nn.init.xavier_uniform_(self.W_g_prime)
        nn.init.xavier_uniform_(self.W_r)
        nn.init.xavier_uniform_(self.P_d)
        
    def forward(self, incidence_matrix, X):
        """
        Forward pass of the Hypergraph Triple Attention.
        
        Args:
            incidence_matrix: Hypergraph incidence matrix [batch_size, num_nodes, num_edges]
            X: Node features [batch_size, num_nodes, feature_dim]
            
        Returns:
            E_new: Updated node embeddings
            A_HTA: Combined attention matrix for knowledge distillation
        """
        # Ensure proper batch dimension
        batch_size = X.size(0) if X.dim() > 2 else 1
        if X.dim() == 2:
            X = X.unsqueeze(0)  # Add batch dimension
            
        if incidence_matrix.dim() == 2:
            incidence_matrix = incidence_matrix.unsqueeze(0)  # Add batch dimension
            
        # Initial embeddings
        E = X  
        
        # Ensure E has the right feature dimension
        if E.size(2) != self.input_dim:
            if E.size(2) < self.input_dim:
                # Pad with zeros
                padding = torch.zeros(E.size(0), E.size(1), self.input_dim - E.size(2), device=E.device)
                E = torch.cat([E, padding], dim=2)
            else:
                # Truncate
                E = E[:, :, :self.input_dim]
        
        # Compute adjacency matrix from incidence matrix
        H = incidence_matrix
        H_T = H.transpose(1, 2)
        A = torch.matmul(H, H_T)  # Adjacency matrix
        
        # 1. Node-Level Attention
        # Apply dropout to projection matrix
        P_d_with_dropout = self.dropout(self.P_d)
        # Project features for attention calculation
        E_tilde = torch.matmul(E, P_d_with_dropout)
        
        # Calculate cosine similarity for node-level attention
        E_norm = F.normalize(E_tilde, p=2, dim=2)
        similarity_n = torch.matmul(E_norm, E_norm.transpose(1, 2))
        
        # Apply temperature scaling
        similarity_n = similarity_n / self.tau_n
        
        # Create mask based on adjacency
        mask = (A > 0).float()
        # Mask attention scores, setting non-connections to large negative
        similarity_n = similarity_n.masked_fill(mask == 0, -1e9)
        # Apply softmax to get final attention weights
        alpha_n = F.softmax(similarity_n, dim=2)
        
        # Node-level update
        E_node = torch.matmul(alpha_n, torch.matmul(E, self.W_n))
        
        # 2. Hyperedge-Aware Attention
        # Calculate R_i (SVD-transformed features)
        R = torch.matmul(E, self.W_r)
        
        # Calculate hyperedge co-membership (using adjacency matrix)
        edge_sum = torch.sum(A, dim=2, keepdim=True) + 1e-10  # Add small epsilon for stability
        HE = A / edge_sum
        
        # Compute edge-aware attention scores (combining feature similarities and structure)
        R_norm = F.normalize(R, p=2, dim=2)
        E_norm = F.normalize(E, p=2, dim=2)
        sim_r = torch.matmul(R_norm, R_norm.transpose(1, 2))
        sim_e = torch.matmul(E_norm, E_norm.transpose(1, 2))
        
        # Combined attention score with learnable weights
        edge_attention = sim_r + self.lambda_param * sim_e + self.beta * HE
        alpha_e = F.softmax(edge_attention, dim=2)
        
        # Edge-aware update
        E_edge = torch.matmul(alpha_e, torch.matmul(E, self.W_e))
        
        # 3. Global-Structure Attention
        # Identity matrix for Laplacian computation
        I = torch.eye(A.size(1), device=A.device).unsqueeze(0).expand(batch_size, -1, -1)
        
        # Calculate degree matrix for normalization
        degree = torch.sum(A, dim=2, keepdim=True)
        D_sqrt_inv = torch.diag_embed(1.0 / torch.sqrt(torch.clamp(degree.squeeze(2), min=1e-12)))
        
        # Normalized adjacency
        A_norm = torch.matmul(torch.matmul(D_sqrt_inv, A), D_sqrt_inv)
        
        # Approximate Laplacian (2I - L approximates 2-hop connectivity)
        L_approx = I - A_norm
        Z = torch.matmul(2*I - L_approx, torch.matmul(E, self.W_g))
        
        # Compute global attention scores
        Z_norm = F.normalize(Z, p=2, dim=2)
        sim_z = torch.matmul(Z_norm, Z_norm.transpose(1, 2))
        sim_z = F.leaky_relu(sim_z)  # Apply non-linearity
        alpha_g = F.softmax(sim_z, dim=2)
        
        # Global update
        E_global = torch.matmul(alpha_g, torch.matmul(Z, self.W_g_prime))
        
        # Unified attention (weighted combination)
        # Ensure proper normalization of importance weights
        omega_cat = torch.cat([self.omega_n, self.omega_e, self.omega_g])
        omega_softmax = F.softmax(omega_cat, dim=0)
        omega_n, omega_e, omega_g = omega_softmax[0], omega_softmax[1], omega_softmax[2]
        
        # Final update rule with gating mechanism
        # Ensure all embeddings have matching dimensions
        E_node = pad_to_match_shape(E_node, E.shape[1:])
        E_edge = pad_to_match_shape(E_edge, E.shape[1:])
        E_global = pad_to_match_shape(E_global, E.shape[1:])
        
        # Combine attention mechanisms with learned weighting
        E_new = E + torch.sigmoid(self.alpha) * (
            omega_n * E_node + 
            omega_e * E_edge + 
            omega_g * E_global
        )
        
        # Project to output dimension if needed
        E_new = self.output_layer(E_new)
        
        # Store attention matrices for knowledge distillation
        alpha_n_pad = pad_to_match_shape(alpha_n, (alpha_n.size(1), alpha_n.size(2)))
        alpha_e_pad = pad_to_match_shape(alpha_e, (alpha_n.size(1), alpha_n.size(2)))
        alpha_g_pad = pad_to_match_shape(alpha_g, (alpha_n.size(1), alpha_n.size(2)))
        
        # Combined attention matrix
        self.A_HTA = (omega_n * alpha_n_pad + omega_e * alpha_e_pad + omega_g * alpha_g_pad).detach().clone()

        return E_new, self.A_HTA

# 2. Lightweight Student Model
class LightweightStudentModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, k=10, low_rank_dim=None):
        super(LightweightStudentModel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.k = k  # Number of top neighbors to consider
        
        # Low-rank factorization for efficiency
        if low_rank_dim is not None:
            self.use_low_rank = True
            self.low_rank_dim = low_rank_dim
            self.C = nn.Parameter(torch.Tensor(input_dim, low_rank_dim))
            self.D = nn.Parameter(torch.Tensor(low_rank_dim, low_rank_dim))
            # Weight matrix is implicitly C * D * C^T
            nn.init.xavier_uniform_(self.C)
            nn.init.xavier_uniform_(self.D)
        else:
            self.use_low_rank = False
            self.W = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
            nn.init.xavier_uniform_(self.W)
            
        # Output projection if needed
        if hidden_dim != output_dim:
            self.output_layer = nn.Linear(hidden_dim, output_dim)
            
    def forward(self, X, teacher_attention=None):
        """
        Forward pass through lightweight student model.
        
        Args:
            X: Node features [batch_size, num_nodes, feature_dim]
            teacher_attention: Teacher's attention matrix for guidance (optional)
            
        Returns:
            E_student: Student embeddings
        """
        batch_size = X.size(0) if X.dim() > 2 else 1
        if X.dim() == 2:
            X = X.unsqueeze(0)  # Add batch dimension
            
        # Initialize student embeddings
        E_student = torch.zeros_like(X)
        
        # Generate attention scores
        if teacher_attention is not None:
            # Get top-K indices for each node based on teacher attention
            _, top_k_indices = torch.topk(
                teacher_attention, 
                min(self.k, teacher_attention.size(-1)), 
                dim=-1
            )
            
            # For each node, aggregate from its top-K neighbors based on teacher attention
            for b in range(batch_size):
                for i in range(X.size(1)):
                    # Get indices of top-K neighbors for node i
                    neighbors = top_k_indices[b, i]
                    
                    # Get embeddings of these neighbors
                    neighbor_embeddings = X[b, neighbors]
                    
                    # Get attention scores for these neighbors (and normalize them)
                    scores = teacher_attention[b, i, neighbors]
                    scores = F.softmax(scores, dim=0)
                    
                    # Compute weighted sum
                    weighted_embedding = torch.sum(scores.unsqueeze(1) * neighbor_embeddings, dim=0)
                    E_student[b, i] = weighted_embedding
        else:
            # If no teacher attention, use a simple self-attention mechanism
            scale_factor = math.sqrt(X.size(-1))
            sim = torch.matmul(X, X.transpose(1, 2)) / scale_factor
            attention = F.softmax(sim, dim=2)
            
            # Get top-K indices
            _, top_k_indices = torch.topk(
                attention, 
                min(self.k, attention.size(-1)), 
                dim=-1
            )
            
            # For each node, aggregate from its top-K neighbors
            for b in range(batch_size):
                for i in range(X.size(1)):
                    # Get indices of top-K neighbors for node i
                    neighbors = top_k_indices[b, i]
                    
                    # Get embeddings of these neighbors
                    neighbor_embeddings = X[b, neighbors]
                    
                    # Get attention scores for these neighbors
                    scores = attention[b, i, neighbors]
                    
                    # Compute weighted sum
                    weighted_embedding = torch.sum(scores.unsqueeze(1) * neighbor_embeddings, dim=0)
                    E_student[b, i] = weighted_embedding
        
        # Apply transformation (either with low-rank factorization or direct)
        if self.use_low_rank:
            # Efficient computation: E_student → C → D → C^T without materializing full W
            E_student = torch.matmul(E_student, self.C)
            E_student = torch.matmul(E_student, self.D)
            E_student = torch.matmul(E_student, self.C.t())
        else:
            E_student = torch.matmul(E_student, self.W)
            
        # Apply output projection if needed
        if hasattr(self, 'output_layer'):
            E_student = self.output_layer(E_student)
            
        return E_student

# 3. Adaptive Knowledge-Guided Edge Dropping (AKED)
class AdaptiveKnowledgeGuidedEdgeDropping:
    def __init__(self, theta_drop=1.0, gamma=-0.5, delta=2.0):
        self.theta_drop = theta_drop  # Controls inverse attention term
        self.gamma = gamma            # Controls direct attention term (negative)
        self.delta = delta            # Controls knowledge disparity term
        
    def compute_drop_probabilities(self, A_HTA, A_Student):
        """
        Compute edge dropping probabilities based on attention matrices.
        
        Args:
            A_HTA: Teacher's attention matrix
            A_Student: Student's attention matrix
            
        Returns:
            P_drop: Matrix of drop probabilities
        """
        # Normalize attention matrices for comparable scale
        A_HTA_norm = F.normalize(A_HTA, p=1, dim=-1)
        A_Student_norm = F.normalize(A_Student, p=1, dim=-1)
        
        # Compute knowledge disparity between teacher and student
        D = torch.abs(A_HTA_norm - A_Student_norm)
        
        # Combine all terms according to the paper's formula
        # 1. Inverse attention term (drop edges with low teacher attention)
        inverse_attention_term = self.theta_drop * (1 - A_HTA_norm)
        # 2. Direct attention term (keep edges with high teacher attention)
        direct_attention_term = self.gamma * A_HTA_norm
        # 3. Knowledge disparity term (keep edges where teacher and student disagree)
        knowledge_disparity_term = self.delta * D
        
        # Combine all terms and apply sigmoid for final probabilities
        P_drop = torch.sigmoid(inverse_attention_term + direct_attention_term - knowledge_disparity_term)
        
        return P_drop
    
    def apply_edge_dropping(self, adj_matrix, P_drop):
        """
        Apply edge dropping to the adjacency matrix.
        
        Args:
            adj_matrix: Original adjacency matrix
            P_drop: Matrix of drop probabilities
            
        Returns:
            adj_matrix_dropped: Adjacency matrix after edge dropping
        """
        # Generate binary mask (1 = keep edge, 0 = drop edge)
        mask = torch.bernoulli(1 - P_drop)
        
        # Apply mask to adjacency matrix
        adj_matrix_dropped = adj_matrix * mask
        
        return adj_matrix_dropped
    
    def generate_augmented_view(self, incidence_matrix, A_HTA, A_Student=None):
        """
        Generate an augmented view of the hypergraph by dropping edges.
        
        Args:
            incidence_matrix: Original incidence matrix
            A_HTA: Teacher's attention matrix
            A_Student: Student's attention matrix (optional)
            
        Returns:
            incidence_matrix_aug: Augmented incidence matrix
        """
        # If student attention not provided, create dummy matrix
        if A_Student is None:
            A_Student = torch.zeros_like(A_HTA)
        
        # Convert incidence matrix to adjacency-like matrix for edge dropping
        H = incidence_matrix
        H_T = H.transpose(1, 2) if H.dim() > 2 else H.t()
        A = torch.matmul(H, H_T)  # Adjacency matrix
        
        # Compute drop probabilities based on attention matrices
        P_drop = self.compute_drop_probabilities(A_HTA, A_Student)
        
        # Apply edge dropping to adjacency matrix
        A_dropped = self.apply_edge_dropping(A, P_drop)
        
        # Instead of trying to fully reconstruct incidence matrix from A_dropped (which is difficult),
        # we approximate by directly dropping edges in the original incidence matrix
        # This preserves hypergraph structure better than trying to decompose A_dropped
        
        # Use row-wise mean of P_drop to determine hyperedge dropping probabilities
        edge_drop_prob = P_drop.mean(dim=-1, keepdim=True)
        H_aug = H * torch.bernoulli(1 - edge_drop_prob)
        
        return H_aug

# 4. Knowledge-Aware Dual-View Learning (KDV)
class KnowledgeAwareDualViewLearning(nn.Module):
    def __init__(self, temperature=0.1, beta=0.5):
        super(KnowledgeAwareDualViewLearning, self).__init__()
        self.temperature = temperature  # Temperature for contrastive learning
        self.beta = beta                # Controls influence of attention similarity
        self.alpha = nn.Parameter(torch.zeros(1))  # Learnable gating parameter
        
    def forward(self, E_clean, E_aug, A_HTA):
        """
        Integrate clean and augmented views with adaptive gating.
        
        Args:
            E_clean: Embeddings from clean view
            E_aug: Embeddings from augmented view
            A_HTA: Teacher's attention matrix
            
        Returns:
            E_integrated: Integrated embeddings
        """
        # Adaptive integration of clean and augmented embeddings
        # Using sigmoid to ensure alpha is in [0, 1]
        alpha_sigmoid = torch.sigmoid(self.alpha)
        E_integrated = alpha_sigmoid * E_clean + (1 - alpha_sigmoid) * E_aug
        
        return E_integrated
    
    def contrastive_loss(self, E_integrated, incidence_matrix=None, A_HTA=None):
        """
        Compute hypergraph-specific contrastive loss with knowledge-aware weighting.
        
        Args:
            E_integrated: Integrated embeddings
            incidence_matrix: Hypergraph incidence matrix (optional)
            A_HTA: Teacher's attention matrix (optional)
            
        Returns:
            loss: Contrastive loss
        """
        batch_size = E_integrated.size(0)
        num_nodes = E_integrated.size(1)
        
        # Compute similarity matrix (cosine similarity with temperature scaling)
        E_norm = F.normalize(E_integrated, p=2, dim=2)
        sim_matrix = torch.matmul(E_norm, E_norm.transpose(1, 2)) / self.temperature
        
        # Define positive pairs based on hypergraph structure
        if incidence_matrix is not None:
            # Nodes sharing hyperedges form positive pairs
            H = incidence_matrix
            H_T = H.transpose(1, 2) if H.dim() > 2 else H.t()
            A = torch.matmul(H, H_T)
            pos_mask = (A > 0).float()
        else:
            # If no structure provided, use self-similarity (diagonal mask)
            pos_mask = torch.eye(num_nodes, device=E_integrated.device)
            pos_mask = pos_mask.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Apply knowledge-aware weighting if A_HTA is provided
        if A_HTA is not None:
            # Calculate attention similarity between nodes
            A_HTA_norm = F.normalize(A_HTA, p=2, dim=2)
            A_similarity = torch.matmul(A_HTA_norm, A_HTA_norm.transpose(1, 2))
            # Weight positive pairs based on attention similarity
            w_ij = 1 + self.beta * A_similarity
        else:
            # Use uniform weighting if no attention matrix provided
            w_ij = torch.ones_like(pos_mask)
        
        # Weight positive pairs by both structure and attention similarity
        weighted_pos_mask = pos_mask * w_ij
        
        # For numerical stability, subtract max from similarity scores
        sim_matrix = sim_matrix - torch.max(sim_matrix, dim=2, keepdim=True)[0]
        
        # Compute exponential similarities for the numerator and denominator
        exp_sim = torch.exp(sim_matrix)
        
        # Compute log denominator (sum over all pairs)
        log_prob_denominator = torch.log(torch.sum(exp_sim, dim=2) + 1e-10)
        
        # Compute log numerator (weighted sum over positive pairs)
        weighted_exp_sim = exp_sim * weighted_pos_mask
        log_prob_numerator = torch.log(torch.sum(weighted_exp_sim, dim=2) + 1e-10)
        
        # Final loss: negative log likelihood
        loss = -torch.mean(log_prob_numerator - log_prob_denominator)
        
        return loss

# 5. Hierarchical Attention Distillation (HAD)
class HierarchicalAttentionDistillation:
    def __init__(self, lambda_embed=1.0, lambda_attn=0.5, lambda_feat=0.2):
        self.lambda_embed = lambda_embed  # Weight for embedding-level distillation
        self.lambda_attn = lambda_attn    # Weight for attention-level distillation
        self.lambda_feat = lambda_feat    # Weight for feature-level distillation
        
    def embedding_level_distillation(self, E_Student, E_Teacher):
        """
        Compute embedding-level distillation loss with adaptive weighting.
        
        Args:
            E_Student: Student embeddings
            E_Teacher: Teacher embeddings
            
        Returns:
            loss: Embedding-level distillation loss
        """
        # Compute distance between embeddings
        embedding_diff = torch.norm(E_Student - E_Teacher, dim=2)
        
        # Adaptive weighting based on difficulty
        # Higher weight for nodes where student differs more from teacher
        weights = torch.sigmoid(embedding_diff)
        
        # Compute weighted MSE loss
        loss = torch.mean(weights * (embedding_diff ** 2))
        
        return loss
    
    def attention_level_distillation(self, A_Student, A_Teacher, top_k=None):
        """
        Compute attention-level distillation loss with KL divergence.
        
        Args:
            A_Student: Student attention matrix
            A_Teacher: Teacher attention matrix
            top_k: Number of top neighbors to consider
            
        Returns:
            loss: Attention-level distillation loss
        """
        batch_size = A_Teacher.size(0)
        num_nodes = A_Teacher.size(1)
        
        if top_k is not None and top_k < num_nodes:
            # Compute KL divergence for only the top-K neighbors (more efficient)
            # Get top-K indices for each node based on teacher attention
            _, top_k_indices = torch.topk(A_Teacher, k=min(top_k, A_Teacher.size(-1)), dim=-1)
            
            # Initialize loss
            loss = 0.0
            
            # Compute KL divergence for each node's top-K neighbors
            for b in range(batch_size):
                for i in range(num_nodes):
                    # Get indices of top-K neighbors for node i
                    neighbors = top_k_indices[b, i]
                    
                    # Get attention distributions for these neighbors
                    teacher_dist = F.softmax(A_Teacher[b, i, neighbors], dim=0)
                    student_dist = F.softmax(A_Student[b, i, neighbors], dim=0)
                    
                    # Compute KL divergence (teacher → student)
                    loss += F.kl_div(
                        student_dist.log(), 
                        teacher_dist, 
                        reduction='sum'
                    )
        else:
            # Use full attention matrices (more expensive)
            teacher_dist = F.softmax(A_Teacher, dim=2)
            student_dist = F.softmax(A_Student, dim=2)
            
            # Compute KL divergence
            loss = F.kl_div(
                student_dist.log(), 
                teacher_dist, 
                reduction='batchmean'
            )
        
        return loss
    
    def feature_level_distillation(self, F_Student, F_Teacher):
        """
        Compute feature-level distillation loss using Frobenius norm.
        
        Args:
            F_Student: Student feature maps
            F_Teacher: Teacher feature maps
            
        Returns:
            loss: Feature-level distillation loss
        """
        # Compute Frobenius norm of the difference
        loss = torch.norm(F_Student - F_Teacher, p='fro') ** 2
        
        return loss
    
    def compute_total_loss(self, E_Student, E_Teacher, A_Student, A_Teacher, F_Student=None, F_Teacher=None, top_k=None):
        """
        Compute total hierarchical distillation loss combining all three levels.
        
        Args:
            E_Student: Student embeddings
            E_Teacher: Teacher embeddings
            A_Student: Student attention matrix
            A_Teacher: Teacher attention matrix
            F_Student: Student feature maps (optional)
            F_Teacher: Teacher feature maps (optional)
            top_k: Number of top neighbors to consider
            
        Returns:
            loss: Total distillation loss
        """
        # Compute embedding-level distillation
        embed_loss = self.embedding_level_distillation(E_Student, E_Teacher)
        
        # Compute attention-level distillation
        attn_loss = self.attention_level_distillation(A_Student, A_Teacher, top_k)
        
        # Compute feature-level distillation if provided
        if F_Student is not None and F_Teacher is not None:
            feat_loss = self.feature_level_distillation(F_Student, F_Teacher)
        else:
            feat_loss = 0.0
        
        # Compute total loss with weighted components
        total_loss = self.lambda_embed * embed_loss + self.lambda_attn * attn_loss
        if F_Student is not None and F_Teacher is not None:
            total_loss += self.lambda_feat * feat_loss
        
        return total_loss

# 6. Integrated Curriculum Distillation (ICD)
class IntegratedCurriculumDistillation:
    def __init__(self, alpha_0=0.9, beta_0=0.1, gamma_alpha=0.5, gamma_beta=0.5, gamma_g=0.5, total_steps=1000):
        # Initial and scaling parameters for curriculum thresholds
        self.alpha_0 = alpha_0            # Initial quantile for contrastive difficulty
        self.beta_0 = beta_0              # Initial quantile for distillation difficulty
        self.gamma_alpha = gamma_alpha    # Rate of change for contrastive curriculum
        self.gamma_beta = gamma_beta      # Rate of change for distillation curriculum
        self.gamma_g = gamma_g            # Rate of change for distillation weights
        self.total_steps = total_steps    # Total number of training steps
        
    def get_difficulty_thresholds(self, current_step, D_cont, D_dist):
        """
        Compute adaptive difficulty thresholds based on current training step.
        
        Args:
            current_step: Current training step
            D_cont: Contrastive difficulty values
            D_dist: Distillation difficulty values
            
        Returns:
            tau_cont: Contrastive threshold
            tau_dist: Distillation threshold
        """
        # Compute time-dependent quantile parameters
        alpha_t = self.alpha_0 * (1 - current_step / self.total_steps) ** self.gamma_alpha
        beta_t = self.beta_0 * (1 + current_step / self.total_steps) ** self.gamma_beta
        
        # Ensure alpha_t and beta_t are in valid range [0, 1]
        alpha_t = min(max(alpha_t, 0.0), 1.0)
        beta_t = min(max(beta_t, 0.0), 1.0)
        
        # Sort difficulty values for quantile computation
        D_cont_sorted, _ = torch.sort(D_cont.view(-1))
        D_dist_sorted, _ = torch.sort(D_dist.view(-1))
        
        # Handle empty tensors gracefully
        if len(D_cont_sorted) == 0:
            tau_cont = 0.0
        else:
            # Compute quantile indices
            alpha_idx = min(int(alpha_t * len(D_cont_sorted)), len(D_cont_sorted) - 1)
            # Get threshold values
            tau_cont = D_cont_sorted[alpha_idx]
        
        if len(D_dist_sorted) == 0:
            tau_dist = 0.0
        else:
            # Compute quantile indices
            beta_idx = min(int(beta_t * len(D_dist_sorted)), len(D_dist_sorted) - 1)
            # Get threshold values
            tau_dist = D_dist_sorted[beta_idx]
        
        return tau_cont, tau_dist
    
    def compute_contrastive_difficulty(self, E_clean, E_aug):
        """
        Compute contrastive learning difficulty as dissimilarity between views.
        
        Args:
            E_clean: Embeddings from clean view
            E_aug: Embeddings from augmented view
            
        Returns:
            D_cont: Contrastive difficulty values
        """
        # Normalize embeddings for cosine similarity
        E_clean_norm = F.normalize(E_clean, p=2, dim=2)
        E_aug_norm = F.normalize(E_aug, p=2, dim=2)
        
        # Compute cosine similarity between clean and augmented embeddings
        similarity = torch.sum(E_clean_norm * E_aug_norm, dim=2)
        
        # Compute difficulty as dissimilarity (1 - cosine similarity)
        D_cont = 1 - similarity
        
        return D_cont
    
    def compute_distillation_difficulty(self, E_Teacher, E_Student):
        """
        Compute distillation difficulty as distance between teacher and student.
        
        Args:
            E_Teacher: Teacher embeddings
            E_Student: Student embeddings
            
        Returns:
            D_dist: Distillation difficulty values
        """
        # Compute L2 distance between teacher and student embeddings
        D_dist = torch.norm(E_Teacher - E_Student, dim=2)
        
        return D_dist
    
    def curriculum_weights(self, current_step, E_Teacher, E_Student, E_clean, E_aug):
        """
        Compute curriculum-based weights for different losses.
        
        Args:
            current_step: Current training step
            E_Teacher: Teacher embeddings
            E_Student: Student embeddings
            E_clean: Embeddings from clean view
            E_aug: Embeddings from augmented view
            
        Returns:
            lambda_LCL: Weight for local contrastive loss
            lambda_HGCL: Weight for hypergraph contrastive loss
            lambda_HAD: Weight for hierarchical attention distillation
            tau_cont: Contrastive threshold
            tau_dist: Distillation threshold
        """
        # Normalize step for smoother transitions
        t_norm = min(current_step / self.total_steps, 1.0)
        
        # Compute difficulty measures
        D_cont = self.compute_contrastive_difficulty(E_clean, E_aug)
        D_dist = self.compute_distillation_difficulty(E_Teacher, E_Student)
        
        # Get adaptive thresholds
        tau_cont, tau_dist = self.get_difficulty_thresholds(current_step, D_cont, D_dist)
        
        # Compute time-dependent weights following curriculum schedule
        # LCL weight: starts high, decreases over time
        lambda_LCL = np.cos(t_norm * np.pi/2) ** 2
        # HGCL weight: peaks in the middle of training
        lambda_HGCL = np.sin(t_norm * np.pi/2) * np.cos(t_norm * np.pi/2)
        # HAD weight: starts low, increases over time
        lambda_HAD = np.sin(t_norm * np.pi/2) ** 2
        
        return lambda_LCL, lambda_HGCL, lambda_HAD, tau_cont, tau_dist
    
    def get_difficulty_masks(self, D_cont, D_dist, tau_cont, tau_dist, current_step):
        """
        Generate masks for curriculum learning.
        
        Args:
            D_cont: Contrastive difficulty values
            D_dist: Distillation difficulty values
            tau_cont: Contrastive threshold
            tau_dist: Distillation threshold
            current_step: Current training step
            
        Returns:
            v_cont: Mask for contrastive learning
            w_dist: Weights for distillation
        """
        # Generate binary mask for contrastive learning
        # Only include examples with difficulty below threshold
        v_cont = (D_cont <= tau_cont).float()
        
        # Generate adaptive weights for distillation
        # Weight increases for difficult examples as training progresses
        g_t = (1 + current_step / self.total_steps) ** self.gamma_g
        w_dist = torch.sigmoid(D_dist * g_t)
        
        return v_cont, w_dist

# 7. Complete CuCoDistill Framework
# 1. Update the CuCoDistill class initialization to match the paper's description
class CuCoDistill(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, k=10, low_rank_dim=None, 
                 dropout=0.2, temperature=0.1, total_steps=1000, device='cpu'):
        super(CuCoDistill, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.current_step = 0
        self.total_steps = total_steps
        self.device = device
        self.dropout = dropout
        
        # Initialize teacher model (HypergraphTripleAttention)
        self.teacher = HypergraphTripleAttention(input_dim, hidden_dim, output_dim).to(device)
        
        # Initialize student model (LightweightStudentModel)
        self.student = LightweightStudentModel(input_dim, hidden_dim, output_dim, 
                                               k=k, low_rank_dim=low_rank_dim).to(device)
        
        # Initialize edge dropping module
        self.aked = AdaptiveKnowledgeGuidedEdgeDropping(theta_drop=1.0, gamma=-0.5, delta=2.0)
        
        # Initialize dual-view learning module
        self.kdv = KnowledgeAwareDualViewLearning(temperature=temperature, beta=0.5).to(device)
        
        # Initialize hierarchical distillation module
        self.had = HierarchicalAttentionDistillation(lambda_embed=1.0, lambda_attn=0.5, lambda_feat=0.2)
        
        # Initialize curriculum distillation module
        self.icd = IntegratedCurriculumDistillation(total_steps=total_steps)
        
        self.classifier = nn.Linear(output_dim, output_dim).to(device)

        # To a non-shared parameter version that avoids inplace operations:
        def clone_param(param):
            """Create a fresh copy of parameter without sharing computational history"""
            return nn.Parameter(param.data.clone())

        # Create a classifier that won't have inplace operations
        self.classifier = nn.Linear(output_dim, output_dim)
        self.classifier.weight = clone_param(self.classifier.weight)
        self.classifier.bias = clone_param(self.classifier.bias)
        self.classifier = self.classifier.to(device)
        
        # Apply dropout to student embeddings before classification for regularization
        self.dropout_layer = nn.Dropout(dropout)
        
        # Initialize attention matrices for first iteration
        self.teacher_attention = None
        self.student_attention = None

# 2. Update the forward method to implement the progressive learning coordination as described in the paper
    def forward(self, incidence_matrix, X, labels=None):
        """
        Forward pass for CuCoDistill with progressive learning coordination.
        
        Args:
            incidence_matrix: Hypergraph incidence matrix
            X: Node features
            labels: Node labels (optional)
            
        Returns:
            student_embeddings: Embeddings from student model
            teacher_embeddings: Embeddings from teacher model
            total_loss: Combined loss
            losses: Dictionary of different loss components
        """
        # Ensure tensors have batch dimension
        if incidence_matrix.dim() == 2:
            incidence_matrix = incidence_matrix.unsqueeze(0)
        if X.dim() == 2:
            X = X.unsqueeze(0)
            
        # Step 1: Create augmented view using AKED
        if hasattr(self, 'teacher_attention') and self.teacher_attention is not None and \
            hasattr(self, 'student_attention') and self.student_attention is not None:
            # Use stored attention matrices after first step
            augmented_view = self.aked.generate_augmented_view(
                incidence_matrix, self.teacher_attention, self.student_attention)
        else:
            # For first step, create a simple random augmentation
            mask = torch.bernoulli(torch.ones_like(incidence_matrix) * 0.8)
            augmented_view = incidence_matrix * mask
            
        # Step 2: Forward pass through teacher model with both views
        teacher_embeddings_clean, teacher_attention_clean = self.teacher(incidence_matrix, X)
        teacher_embeddings_aug, teacher_attention_aug = self.teacher(augmented_view, X)
        
        # Store teacher attention for next iteration
        self.teacher_attention = teacher_attention_clean.detach().clone()
        
        # Step 3: Apply Knowledge-Aware Dual-View Learning
        integrated_teacher_embeddings = self.kdv.forward(
            teacher_embeddings_clean, teacher_embeddings_aug, teacher_attention_clean)
        
        # Step 4: Forward pass through student model (only clean view)
        student_embeddings = self.student(X, teacher_attention_clean)
        
        # Get student attention (approximation)
        student_attention = torch.matmul(
            F.normalize(student_embeddings, p=2, dim=2), 
            F.normalize(student_embeddings, p=2, dim=2).transpose(1, 2)
        )
        
        # Store student attention for next iteration
        self.student_attention = student_attention.detach().clone()
        
        # Step 5: Compute curriculum weights based on current training step
        lambda_LCL, lambda_HGCL, lambda_HAD, tau_cont, tau_dist = self.icd.curriculum_weights(
            self.current_step, integrated_teacher_embeddings, student_embeddings, 
            teacher_embeddings_clean, teacher_embeddings_aug)
        
        # Step 6: Compute difficulties
        D_cont = self.icd.compute_contrastive_difficulty(teacher_embeddings_clean, teacher_embeddings_aug)
        D_dist = self.icd.compute_distillation_difficulty(integrated_teacher_embeddings, student_embeddings)
        
        # Step 7: Get difficulty masks
        v_cont, w_dist = self.icd.get_difficulty_masks(D_cont, D_dist, tau_cont, tau_dist, self.current_step)
        
        # Step 8: Compute contrastive learning losses with curriculum weighting
        # λ_LCL(t) · L_LCL^curr: Local Contrastive Learning with time-dependent weight
        contrastive_loss_LCL = self.kdv.contrastive_loss(
            teacher_embeddings_clean, incidence_matrix, teacher_attention_clean) * lambda_LCL
        
        # λ_HGCL(t) · L_HGCL^curr: Hypergraph Contrastive Learning with time-dependent weight
        # Compute similarity matrix between clean and augmented views
        E_clean_norm = F.normalize(teacher_embeddings_clean, p=2, dim=2)
        E_aug_norm = F.normalize(teacher_embeddings_aug, p=2, dim=2)
        sim_clean_aug = torch.matmul(E_clean_norm, E_aug_norm.transpose(1, 2)) / self.kdv.temperature
        
        # Apply curriculum mask to focus on easier examples first
        contrastive_loss_HGCL = -torch.mean(
            torch.log(torch.exp(sim_clean_aug) + 1e-10) * v_cont
        ) * lambda_HGCL
        
        # Step 9: Compute hierarchical distillation loss with curriculum weighting
        # λ_HAD(t) · L_HAD^curr: Hierarchical Attention Distillation with time-dependent weight
        distillation_loss = self.had.compute_total_loss(
            student_embeddings, 
            integrated_teacher_embeddings.detach(), 
            student_attention, 
            teacher_attention_clean.detach(), 
            top_k=self.student.k
        ) * lambda_HAD
        
        # Additional weighted distillation loss with curriculum-aware weighting
        weighted_distillation_loss = torch.mean(
            w_dist * torch.norm(student_embeddings - integrated_teacher_embeddings.detach(), dim=2) ** 2
        )
        
        # Step 10: Compute classification loss (if labels provided)
        # λ_class · L_class: Classification loss with fixed weight as in the paper
        lambda_class = 1.0  # Fixed hyperparameter for classification loss
        classification_loss = 0.0
        
        # Fixed classification loss calculation
        if labels is not None:
            # Get node embeddings for classification
            node_embeddings = student_embeddings.squeeze(0)  # Remove batch dimension
            
            # Apply dropout before classification
            node_embeddings = self.dropout_layer(node_embeddings)
            
            # Apply classification layer
            logits = self.classifier(node_embeddings)
            
            # Check for single-label case (1D tensor)
            if labels.dim() == 1 or (labels.dim() == 2 and labels.size(1) == 1):
                # Reshape labels if needed
                if labels.dim() > 1:
                    labels = labels.reshape(-1)
                
                # Match sizes
                min_size = min(logits.size(0), labels.size(0))
                if min_size < logits.size(0) or min_size < labels.size(0):
                    logits = logits[:min_size]
                    labels = labels[:min_size]
                    
                classification_loss = F.cross_entropy(logits, labels)
            else:
                # For multi-label classification
                if labels.dim() == 3:
                    labels = labels.squeeze(0)  # Remove batch dimension
                
                # Handle dimension mismatches
                if logits.size(0) != labels.size(0) or logits.size(1) != labels.size(1):
                    min_samples = min(logits.size(0), labels.size(0))
                    min_classes = min(logits.size(1), labels.size(1))
                    classification_loss = F.binary_cross_entropy_with_logits(
                        logits[:min_samples, :min_classes], 
                        labels[:min_samples, :min_classes].float()
                    )
                else:
                    classification_loss = F.binary_cross_entropy_with_logits(logits, labels.float())
        
        # Step 11: Compute L2 regularization loss (λ_reg · ||θ||^2 as in the paper)
        lambda_reg = 0.0001  # Fixed hyperparameter for regularization
        l2_reg = 0.0
        for param in self.parameters():
            l2_reg += torch.norm(param) ** 2
        l2_reg *= lambda_reg
        
        # Step 12: Compute total loss as described in the paper:
        # L_total = λ_LCL(t) · L_LCL^curr + λ_HGCL(t) · L_HGCL^curr + λ_HAD(t) · L_HAD^curr + λ_class · L_class + λ_reg · ||θ||^2
        total_loss = (
            contrastive_loss_LCL +          # λ_LCL(t) · L_LCL^curr 
            contrastive_loss_HGCL +         # λ_HGCL(t) · L_HGCL^curr
            distillation_loss +             # λ_HAD(t) · L_HAD^curr
            weighted_distillation_loss +    # Additional weighted distillation
            l2_reg                          # λ_reg · ||θ||^2
        )
        
        # Add classification loss to total loss if labels provided
        if labels is not None:
            total_loss += lambda_class * classification_loss  # λ_class · L_class
        
        # Store losses for monitoring
        losses = {
            'contrastive_LCL': contrastive_loss_LCL.item(),
            'contrastive_HGCL': contrastive_loss_HGCL.item(),
            'distillation': distillation_loss.item(),
            'weighted_distillation': weighted_distillation_loss.item(),
            'l2_reg': l2_reg.item(),
            'total': total_loss.item()
        }
        
        if labels is not None:
            losses['classification'] = classification_loss.item()
        
        # Increment step counter
        self.current_step += 1
        
        # Return embeddings and losses
        return student_embeddings, integrated_teacher_embeddings, total_loss, losses

# 3. Update the predict method to use the student model and classifier as described in the paper

    def predict(self, incidence_matrix, X):
        """
        Generate predictions using the student model and classification layer.
        
        Args:
            incidence_matrix: Hypergraph incidence matrix
            X: Node features
            
        Returns:
            student_embeddings: Student embeddings
            logits: Classification logits
            predictions: Predicted class indices
        """
        # Ensure tensors have batch dimension
        if incidence_matrix.dim() == 2:
            incidence_matrix = incidence_matrix.unsqueeze(0)
        if X.dim() == 2:
            X = X.unsqueeze(0)
            
        # Debug information
        # print(f"Prediction input - Incidence matrix: {incidence_matrix.shape}, Features: {X.shape}")
        
        # Use the teacher's attention if available, otherwise None
        teacher_attention = getattr(self, 'teacher_attention', None)
        
        # Forward pass through student model
        with torch.no_grad():
            try:
                # Get student embeddings
                student_embeddings = self.student(X, teacher_attention)
                
                # Get node embeddings for classification
                node_embeddings = student_embeddings.squeeze(0)  # Remove batch dimension
                
                # Apply classification layer
                logits = self.classifier(node_embeddings)
                
                # Get predicted classes
                predictions = torch.argmax(logits, dim=1)
                
                # Important fix: reshape predictions to match label shape if needed
                if predictions.dim() != 1:
                    predictions = predictions.reshape(-1)
                    
                return student_embeddings, logits, predictions
            
            except Exception as e:
                print(f"Error during prediction: {str(e)}")
                import traceback
                traceback.print_exc()
                
                # Return placeholder values in case of error
                batch_size = X.size(0)
                num_nodes = X.size(1)
                
                # Placeholder embeddings with same shape as input features
                placeholder_embeddings = torch.zeros_like(X)
                
                # Placeholder logits with num_classes dimensions
                placeholder_logits = torch.zeros((num_nodes, self.output_dim), device=X.device)
                
                # Placeholder predictions (all class 0)
                placeholder_predictions = torch.zeros(num_nodes, dtype=torch.long, device=X.device)
                
                return placeholder_embeddings, placeholder_logits, placeholder_predictions

#==================================================================================================
# Dataset and Training Classes
#==================================================================================================
class HypergraphDataset(Dataset):
    def __init__(self, incidence_matrices, features, labels=None):
        """
        Dataset for hypergraph learning.
        
        Args:
            incidence_matrices: Dictionary of incidence matrices for different relations
            features: Node features
            labels: Node labels (optional)
        """
        self.incidence_matrices = incidence_matrices
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return 1  # Batch processing for the whole graph
        
    def __getitem__(self, idx):
        if self.labels is not None:
            return self.incidence_matrices, self.features, self.labels
        else:
            return self.incidence_matrices, self.features

class HypergraphClassificationTrainer:
    def __init__(self, model, optimizer, device, num_epochs=50, patience=5):
        """
        Trainer for hypergraph node classification.
        
        Args:
            model: CuCoDistill model
            optimizer: PyTorch optimizer
            device: GPU or CPU device
            num_epochs: Maximum number of epochs
            patience: Early stopping patience
        """
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.num_epochs = num_epochs
        self.patience = patience
        self.best_loss = float('inf')
        self.no_improve_count = 0
        self.history = {
            'train_loss': [],
            'val_accuracy': [],
            'loss_components': []
        }
        
    def train(self, dataloader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        loss_components = {}
        
        for batch in dataloader:
            incidence_matrices, features, labels = batch
            
            # Move data to device
            incidence_matrices = {k: v.to(self.device) for k, v in incidence_matrices.items()}
            features = features.to(self.device)
            if labels is not None:
                labels = labels.to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Remove the problematic weight saving/restoration code
            
            # Forward pass with labels for classification loss
            student_embeddings, teacher_embeddings, loss, batch_loss_components = self.model(
                incidence_matrices['movie_user'], features, labels)
            
            # Use retain_graph=True to fix the immediate issue
            loss.backward(retain_graph=True)
            
            # Apply optimizer step
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Accumulate loss components
            for k, v in batch_loss_components.items():
                if k in loss_components:
                    loss_components[k] += v
                else:
                    loss_components[k] = v
        
        # Average loss components over batches
        for k in loss_components:
            loss_components[k] /= len(dataloader)
            
        avg_loss = total_loss / len(dataloader)
        return avg_loss, loss_components, student_embeddings
    
    def evaluate(self, dataloader):
        """
        Evaluate the model with proper handling of size mismatches.
        
        Args:
            dataloader: DataLoader containing validation/test data
            
        Returns:
            predictions: Model predictions
            accuracy: Model accuracy
        """
        self.model.eval()
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for batch in dataloader:
                if len(batch) == 3:
                    incidence_matrices, features, labels = batch
                else:
                    incidence_matrices, features = batch
                    labels = None
                
                # Move data to device
                incidence_matrices = {k: v.to(self.device) for k, v in incidence_matrices.items()}
                features = features.to(self.device)
                if labels is not None:
                    labels = labels.to(self.device)
                try:
                    # Forward pass (prediction mode)
                    _, logits, predictions = self.model.predict(incidence_matrices['movie_user'], features)
                    
                    # Fix: reshape labels if needed to match predictions
                    if labels is not None:
                        if labels.dim() > 1:
                            # Handle multi-dimensional labels
                            labels = labels.reshape(-1)
                        
                        # Print shapes for debugging
                        print(f"Predictions shape: {predictions.shape}, Labels shape: {labels.shape}")
                        
                        # Make sure both have the same shape
                        min_size = min(predictions.size(0), labels.size(0))
                        predictions = predictions[:min_size]
                        labels = labels[:min_size]
                        
                        all_predictions.append(predictions)
                        all_labels.append(labels)
                    else:
                        all_predictions.append(predictions)
                        
                except Exception as e:    
                    print(f"Error during evaluation: {str(e)}")
                    import traceback
                    traceback.print_exc()
                    continue
        
        # Concatenate predictions and labels
        if all_predictions:
            predictions = torch.cat(all_predictions, dim=0)
            if all_labels:
                labels = torch.cat(all_labels, dim=0)
                
                # Double check sizes match before computing accuracy
                if predictions.size(0) != labels.size(0):
                    print(f"WARNING: Size mismatch after concatenation. Predictions: {predictions.size(0)}, Labels: {labels.size(0)}")
                    min_size = min(predictions.size(0), labels.size(0))
                    predictions = predictions[:min_size]
                    labels = labels[:min_size]
                
                # Compute accuracy
                accuracy = (predictions == labels).float().mean().item()
                
                return predictions, accuracy
            else:
                return predictions, None
        else:
            return None, None
    
    def train_with_early_stopping(self, train_dataloader, val_dataloader=None):
        """
        Train model with early stopping and robust error handling.
        
        Args:
            train_dataloader: DataLoader for training data
            val_dataloader: DataLoader for validation data (optional)
            
        Returns:
            best_model: Best model based on validation loss/accuracy
        """
        best_model = None
        
        for epoch in range(self.num_epochs):
            print(f"Epoch {epoch+1}/{self.num_epochs}")
            
            try:
                # Train for one epoch
                train_loss, loss_components, _ = self.train(train_dataloader)
                
                # Record training history
                self.history['train_loss'].append(train_loss)
                self.history['loss_components'].append(loss_components)
                
                print(f"Training loss: {train_loss:.4f}")
                print("Loss components:")
                for k, v in loss_components.items():
                    print(f"  {k}: {v:.4f}")
                
                # Validate if validation data provided
                if val_dataloader:
                    try:
                        # Get predictions and accuracy (with error handling)
                        test_predictions, val_accuracy = self.evaluate(val_dataloader)
                        
                        if val_accuracy is not None:
                            self.history['val_accuracy'].append(val_accuracy)
                            print(f"Validation accuracy: {val_accuracy:.4f}")
                            
                            # Check for improvement
                            if val_accuracy > 1 - self.best_loss:  # Convert to loss (1 - accuracy)
                                self.best_loss = 1 - val_accuracy
                                self.no_improve_count = 0
                                # Save best model
                                best_model = self.model.state_dict().copy()
                                print(f"New best model! (accuracy: {val_accuracy:.4f})")
                            else:
                                self.no_improve_count += 1
                                print(f"No improvement for {self.no_improve_count} epochs")
                                
                            # Early stopping
                            if self.no_improve_count >= self.patience:
                                print(f"Early stopping at epoch {epoch+1}")
                                if best_model is not None:
                                    # Restore best model
                                    self.model.load_state_dict(best_model)
                                break
                        else:
                            # If no accuracy (no labels), use training loss
                            if train_loss < self.best_loss:
                                self.best_loss = train_loss
                                self.no_improve_count = 0
                                # Save best model
                                best_model = self.model.state_dict().copy()
                                print(f"New best model! (loss: {train_loss:.4f})")
                            else:
                                self.no_improve_count += 1
                                print(f"No improvement for {self.no_improve_count} epochs")
                    except Exception as e:
                        print(f"Validation error: {str(e)}. Using training loss instead.")
                        import traceback
                        traceback.print_exc()
                        
                        # Fall back to using training loss for early stopping
                        if train_loss < self.best_loss:
                            self.best_loss = train_loss
                            self.no_improve_count = 0
                            best_model = self.model.state_dict().copy()
                            print(f"New best model based on training loss: {train_loss:.4f}")
                        else:
                            self.no_improve_count += 1
                            print(f"No improvement for {self.no_improve_count} epochs")
                else:
                    # If no validation data, use training loss
                    if train_loss < self.best_loss:
                        self.best_loss = train_loss
                        self.no_improve_count = 0
                        # Save best model
                        best_model = self.model.state_dict().copy()
                        print(f"New best model! (loss: {train_loss:.4f})")
                    else:
                        self.no_improve_count += 1
                        print(f"No improvement for {self.no_improve_count} epochs")
                        
                    # Early stopping
                    if self.no_improve_count >= self.patience:
                        print(f"Early stopping at epoch {epoch+1}")
                        if best_model is not None:
                            # Restore best model
                            self.model.load_state_dict(best_model)
                        break
            except Exception as e:
                print(f"Error during epoch {epoch+1}: {str(e)}")
                import traceback
                traceback.print_exc()
                print("Skipping to next epoch...")
                continue
        
        # Plot training curves
        try:
            self.plot_training_curves()
        except Exception as e:
            print(f"Error plotting training curves: {str(e)}")
        
        # Return the best model
        if best_model is not None:
            self.model.load_state_dict(best_model)
        
        return self.model
    
    def plot_training_curves(self):
        """Plot training curves for loss and accuracy"""
        plt.figure(figsize=(12, 5))
        
        # Plot training loss
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], 'b-', label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.legend()
        
        # Plot validation accuracy if available
        if 'val_accuracy' in self.history and self.history['val_accuracy']:
            plt.subplot(1, 2, 2)
            plt.plot(self.history['val_accuracy'], 'g-', label='Validation Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.title('Validation Accuracy')
            plt.legend()
        
        plt.tight_layout()
        plt.savefig('cucudistill_training_curves.png')
        plt.close()
        
        # Plot loss components
        if self.history['loss_components']:
            plt.figure(figsize=(14, 6))
            loss_types = list(self.history['loss_components'][0].keys())
            
            for loss_type in loss_types:
                values = [epoch_losses[loss_type] for epoch_losses in self.history['loss_components']]
                plt.plot(values, label=loss_type)
                
            plt.xlabel('Epoch')
            plt.ylabel('Loss Value')
            plt.title('Loss Components')
            plt.legend()
            plt.tight_layout()
            plt.savefig('cucudistill_loss_components.png')
            plt.close()

#==================================================================================================
# Main Function: Hypergraph Classification with CuCoDistill
#==================================================================================================
def main():
    print("Starting CuCoDistill Hypergraph Node Classification")
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Set folder path for data
    folder_path = 'C:\\IMDB'  # Update this path to your data location
    
    # Step 1: Process data
    print("Processing data...")
    ground_truth_ratings, label_encoder = process_data(folder_path)
    
    # Step 2: Create hypergraphs for each relation type
    print("Creating hypergraphs...")
    hyper_MU, att_MU = hypergraph_MU(folder_path)
    hyper_MD, att_MD = hypergraph_MD(folder_path)
    hyper_MA, att_MA = hypergraph_MA(folder_path)
    
    # Step 3: Generate incidence matrices
    print("Generating incidence matrices...")
    incidence_matrix_MU, user_idx_map, movie_idx_map_MU = generate_incidence_matrices_MU(hyper_MU, att_MU)
    incidence_matrix_MD, director_idx_map, movie_idx_map_MD = generate_incidence_matrices_MD(hyper_MD, att_MD)
    incidence_matrix_MA, actor_idx_map, movie_idx_map_MA = generate_incidence_matrices_MA(hyper_MA, att_MA)
    
    # Convert to torch tensors
    incidence_matrix_MU = torch.tensor(incidence_matrix_MU, dtype=torch.float32)
    incidence_matrix_MD = torch.tensor(incidence_matrix_MD, dtype=torch.float32)
    incidence_matrix_MA = torch.tensor(incidence_matrix_MA, dtype=torch.float32)
    
    # Package incidence matrices - MOVED UP
    incidence_matrices = {
        'movie_user': incidence_matrix_MU.t(),  # Transpose to match paper's format
        'movie_director': incidence_matrix_MD.t(),
        'movie_actor': incidence_matrix_MA.t()
    }
    
    # Step 4: Collect movie nodes and their genre labels
    print("Preparing node features and labels...")

    # Get movie IDs from movie_idx_map_MU (using this as primary)
    movie_ids = []
    movie_indices = []
    for movie_node, idx in movie_idx_map_MU.items():
        try:
            # Extract movie ID safely
            movie_id = int(float(movie_node.split(':')[1]))
            movie_ids.append(movie_id)
            movie_indices.append(idx)
        except (ValueError, IndexError):
            print(f"Warning: Could not parse movie ID from {movie_node}")

    # Create mapping for genre labels
    movie_id_to_genre = {}
    if 'genreID_encoded' in ground_truth_ratings.columns:
        for _, row in ground_truth_ratings.drop_duplicates('movieID').iterrows():
            movie_id_to_genre[row['movieID']] = row['genreID_encoded']

    # Create labels tensor
    labels = []
    valid_movie_indices = []
    valid_movie_ids = []

    for i, movie_id in enumerate(movie_ids):
        if movie_id in movie_id_to_genre:
            labels.append(movie_id_to_genre[movie_id])
            valid_movie_indices.append(movie_indices[i])
            valid_movie_ids.append(movie_id)

    # Convert to tensor
    if labels:
        labels = torch.tensor(labels, dtype=torch.long)
        print(f"Created labels tensor with shape: {labels.shape}")
    else:
        print("Warning: No labels found! Creating synthetic labels.")
        # Create synthetic labels
        num_classes = 10  # Adjust as needed
        labels = torch.randint(0, num_classes, (len(movie_indices),))
        valid_movie_indices = movie_indices
        valid_movie_ids = movie_ids
        print(f"Created synthetic labels tensor with shape: {labels.shape}")
    
    # Step 5: Create node feature matrix
    # Get dimensions
    num_users = incidence_matrix_MU.shape[0]
    num_movies_MU = incidence_matrix_MU.shape[1]

    # Set hidden dimension
    hidden_dim = 64

    # Initialize features with random values for all nodes
    # For real applications, you would use actual node features if available
    features = torch.randn(num_movies_MU, hidden_dim, dtype=torch.float32)

    # Print feature dimensions
    print(f"Feature dimensions: {features.shape}")

    # Step 6: Split data into train/validation/test sets
    print(f"Number of labeled movies: {len(labels)}")
    print(f"Number of valid movie indices: {len(valid_movie_indices)}")

    # We need to make sure valid_movie_indices is not longer than num_movies_MU
    valid_movie_indices = [idx for idx in valid_movie_indices if idx < num_movies_MU]
    if len(valid_movie_indices) < len(labels):
        # Trim labels to match valid indices
        labels = labels[:len(valid_movie_indices)]

    print(f"After adjustment - labeled movies: {len(labels)}, valid indices: {len(valid_movie_indices)}")

    # Now create splits
    indices = torch.randperm(len(labels))
    train_size = int(0.7 * len(indices))
    val_size = int(0.15 * len(indices))

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]

    # Get corresponding movie indices
    train_movie_indices = [valid_movie_indices[i] for i in train_indices]
    val_movie_indices = [valid_movie_indices[i] for i in val_indices]
    test_movie_indices = [valid_movie_indices[i] for i in test_indices]

    # Create mask tensors for selecting nodes
    train_mask = torch.zeros(num_movies_MU, dtype=torch.bool)
    val_mask = torch.zeros(num_movies_MU, dtype=torch.bool)
    test_mask = torch.zeros(num_movies_MU, dtype=torch.bool)

    # Apply masks safely
    for idx in train_movie_indices:
        if idx < num_movies_MU:
            train_mask[idx] = True
    for idx in val_movie_indices:
        if idx < num_movies_MU:
            val_mask[idx] = True
    for idx in test_movie_indices:
        if idx < num_movies_MU:
            test_mask[idx] = True

    # Create train/val/test labels
    train_labels = labels[train_indices]
    val_labels = labels[val_indices]
    test_labels = labels[test_indices]

    print(f"Train/Val/Test splits: {len(train_labels)}/{len(val_labels)}/{len(test_labels)}")
    print(f"Train mask sum: {train_mask.sum().item()}, Val mask sum: {val_mask.sum().item()}, Test mask sum: {test_mask.sum().item()}")

    # Ensure movie_user incidence matrix matches the number of movies we have labels for
    if incidence_matrices['movie_user'].size(1) > len(valid_movie_indices):
        # Only keep columns for movies we have labels for
        for relation in incidence_matrices:
            # Create a new matrix with only the relevant columns
            new_matrix = torch.zeros((incidence_matrices[relation].size(0), len(valid_movie_indices)), 
                                    dtype=incidence_matrices[relation].dtype,
                                    device=incidence_matrices[relation].device)
            for i, idx in enumerate(valid_movie_indices):
                if idx < incidence_matrices[relation].size(1):
                    new_matrix[:, i] = incidence_matrices[relation][:, idx]
            incidence_matrices[relation] = new_matrix
            
        # Also adjust the features
        new_features = torch.zeros((len(valid_movie_indices), features.size(1)), 
                                dtype=features.dtype, 
                                device=features.device)
        for i, idx in enumerate(valid_movie_indices):
            if idx < features.size(0):
                new_features[i] = features[idx]
        features = new_features
        
        print(f"Adjusted incidence matrices and features to match labeled movies.")
        print(f"New incidence matrix shape: {incidence_matrices['movie_user'].shape}")
        print(f"New features shape: {features.shape}")

    # Output diagnostics
    print(f"Final shapes - Features: {features.shape}, Labels: {labels.shape}, Train labels: {train_labels.shape}")
    print(f"Incidence matrix: {incidence_matrices['movie_user'].shape}")

    if features.size(0) != len(labels):
        print(f"WARNING: Number of nodes in features ({features.size(0)}) doesn't match number of labels ({len(labels)})")
        # Align them
        min_size = min(features.size(0), len(labels))
        features = features[:min_size]
        labels = labels[:min_size]
        # Filter train/val/test labels based on indices, not values
        train_indices = [i for i in train_indices if i < min_size]
        val_indices = [i for i in val_indices if i < min_size]
        test_indices = [i for i in test_indices if i < min_size]
        train_labels = labels[train_indices]
        val_labels = labels[val_indices]
        test_labels = labels[test_indices]
        print(f"After alignment: Features: {features.shape}, Labels: {labels.shape}")
        print(f"Train: {len(train_labels)}, Val: {len(val_labels)}, Test: {len(test_labels)}")

    # Update output_dim based on actual number of classes
    output_dim = int(labels.max().item() + 1)
    print(f"Number of classes: {output_dim}")
    
    # Step 7: Initialize CuCoDistill model
    print("Initializing CuCoDistill model...")
    input_dim = features.shape[1]
    
    model = CuCoDistill(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        k=10,  # Number of neighbors for student model
        low_rank_dim=16,  # Dimension for low-rank factorization
        dropout=0.2,
        temperature=0.1,
        total_steps=1000,
        device=device
    ).to(device)
    
    # Step 8: Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    
    # Step 9: Create dataset and dataloader - REUSE EXISTING INCIDENCE MATRICES INSTEAD OF REDEFINING
    # Create train/val/test datasets
    if train_mask is not None:
        train_dataset = HypergraphDataset(
            incidence_matrices=incidence_matrices,
            features=features,
            labels=train_labels
        )
        
        val_dataset = HypergraphDataset(
            incidence_matrices=incidence_matrices,
            features=features,
            labels=val_labels
        )
        
        test_dataset = HypergraphDataset(
            incidence_matrices=incidence_matrices,
            features=features,
            labels=test_labels
        )
    else:
        # If no labels, create a single dataset
        dataset = HypergraphDataset(
            incidence_matrices=incidence_matrices,
            features=features
        )
        train_dataset = val_dataset = test_dataset = dataset
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Step 10: Train the model
    print("Training CuCoDistill model...")
    trainer = HypergraphClassificationTrainer(
        model=model,
        optimizer=optimizer,
        device=device,
        num_epochs=50,
        patience=5
    )
    
    best_model = trainer.train_with_early_stopping(train_dataloader, val_dataloader)
    
    # Step 11: Evaluate on test set
    print("Evaluating on test set...")
    model.eval()

    with torch.no_grad():
        test_predictions, test_accuracy = trainer.evaluate(test_dataloader)
        
        if test_accuracy is not None:
            print(f"Test accuracy: {test_accuracy:.4f}")
        else:
            print("No labels for test evaluation.")

    # Step 12: Measure efficiency improvements
    print("Measuring efficiency improvements...")

    # Time teacher model (HypergraphTripleAttention)
    teacher_model = model.teacher

    start_time = time.time()
    with torch.no_grad():
        for batch in test_dataloader:
            incidence_matrices_batch, features_batch, _ = batch
            incidence_matrices_batch = {k: v.to(device) for k, v in incidence_matrices_batch.items()}
            features_batch = features_batch.to(device)
            teacher_embeddings, _ = teacher_model(incidence_matrices_batch['movie_user'], features_batch)
    teacher_time = time.time() - start_time

    # Time student model
    student_model = model.student

    start_time = time.time()
    with torch.no_grad():
        for batch in test_dataloader:
            incidence_matrices_batch, features_batch, _ = batch
            incidence_matrices_batch = {k: v.to(device) for k, v in incidence_matrices_batch.items()}
            features_batch = features_batch.to(device)
            student_embeddings = student_model(features_batch, model.teacher_attention)
    student_time = time.time() - start_time

    # Calculate speedup
    speedup = teacher_time / student_time

    print(f"Teacher inference time: {teacher_time:.4f} seconds")
    print(f"Student inference time: {student_time:.4f} seconds")
    print(f"Speedup: {speedup:.2f}x")

    # Step 13: Print confusion matrix and per-class accuracy for classification
    if test_accuracy is not None:
        try:
            labels_list = test_labels.cpu().numpy()
            predictions_list = test_predictions.cpu().numpy()
            
            # Calculate confusion matrix
            cm = confusion_matrix(labels_list, predictions_list)
            print("\nConfusion Matrix:")
            print(cm)
            
            # Get unique classes present in the predictions and labels
            unique_classes = np.unique(np.concatenate([labels_list, predictions_list]))
            unique_classes.sort()  # Ensure classes are in ascending order
            num_unique_classes = len(unique_classes)
            
            # Create basic class names (as fallback)
            class_names = [str(i) for i in unique_classes]
            
            # Try to use label encoder class names if available and matching
            if hasattr(label_encoder, 'classes_'):
                # Map the unique classes to their corresponding names from the encoder
                class_names = []
                for cls in unique_classes:
                    if cls < len(label_encoder.classes_):
                        class_names.append(str(label_encoder.classes_[cls]))
                    else:
                        class_names.append(f"Class {cls}")
            
            # Generate classification report with explicit labels
            report = classification_report(
                labels_list, 
                predictions_list, 
                labels=unique_classes,  # Explicitly specify which labels to include
                target_names=class_names
            )
            print("\nClassification Report:")
            print(report)
            
            # Calculate per-class accuracy using the confusion matrix
            per_class_accuracy = cm.diagonal() / np.maximum(cm.sum(axis=1), 1)  # Avoid div by zero
            print("\nPer-class Accuracy:")
            for i, cls in enumerate(unique_classes):
                if i < len(class_names):
                    print(f"Class {class_names[i]}: {per_class_accuracy[i]:.4f}")
                else:
                    print(f"Class {cls}: {per_class_accuracy[i]:.4f}")
            
            # Plot confusion matrix with proper labels
            plt.figure(figsize=(10, 8))
            plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
            plt.title('Confusion Matrix')
            plt.colorbar()
            
            # Add labels to axes with proper handling
            tick_marks = np.arange(len(unique_classes))
            plt.xticks(tick_marks, class_names, rotation=45, ha='right')
            plt.yticks(tick_marks, class_names)
            
            # Format confusion matrix with text annotations
            thresh = cm.max() / 2.
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    plt.text(j, i, format(cm[i, j], 'd'),
                            horizontalalignment="center",
                            color="white" if cm[i, j] > thresh else "black")
            
            plt.tight_layout()
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
            plt.savefig('confusion_matrix.png')
            plt.close()
            
            # Additional analysis: Precision, Recall, F1 per class
            plt.figure(figsize=(12, 6))
            
            # Extract metrics from classification report
            # We need to parse the report since it's returned as a string
            lines = report.split('\n')
            metrics = []
            
            for line in lines[2:-5]:  # Skip header and average lines
                if not line.strip():
                    continue
                parts = line.strip().split()
                if len(parts) >= 5:  # Ensure we have enough parts
                    class_label = parts[0]
                    precision = float(parts[1])
                    recall = float(parts[2])
                    f1 = float(parts[3])
                    metrics.append((class_label, precision, recall, f1))
            
            if metrics:
                # Create bar chart
                labels = [m[0] for m in metrics]
                precision = [m[1] for m in metrics]
                recall = [m[2] for m in metrics]
                f1 = [m[3] for m in metrics]
                
                x = np.arange(len(labels))
                width = 0.25
                
                fig, ax = plt.subplots(figsize=(14, 8))
                rects1 = ax.bar(x - width, precision, width, label='Precision')
                rects2 = ax.bar(x, recall, width, label='Recall')
                rects3 = ax.bar(x + width, f1, width, label='F1-score')
                
                ax.set_xlabel('Classes')
                ax.set_ylabel('Scores')
                ax.set_title('Precision, Recall, and F1-score by class')
                ax.set_xticks(x)
                ax.set_xticklabels(labels, rotation=45, ha='right')
                ax.legend()
                
                # Add value labels on top of bars
                def autolabel(rects):
                    for rect in rects:
                        height = rect.get_height()
                        ax.annotate(f'{height:.2f}',
                                    xy=(rect.get_x() + rect.get_width() / 2, height),
                                    xytext=(0, 3),  # 3 points vertical offset
                                    textcoords="offset points",
                                    ha='center', va='bottom')
                
                autolabel(rects1)
                autolabel(rects2)
                autolabel(rects3)
                
                fig.tight_layout()
                plt.savefig('precision_recall_f1.png')
                plt.close()
            
        except Exception as e:
            print(f"Error generating classification metrics: {str(e)}")
            import traceback
            traceback.print_exc()

    print("Analysis complete!")

    return model, teacher_time, student_time, speedup

if __name__ == "__main__":
    # Execute main function
    model, teacher_time, student_time, speedup = main()
    
    # Print summary
    print("\n--- CuCoDistill Summary ---")
    print(f"Teacher inference time: {teacher_time:.4f} seconds")
    print(f"Student inference time: {student_time:.4f} seconds")
    print(f"Speedup achieved: {speedup:.2f}x")
    print("CuCoDistill for hypergraph node classification successfully implemented and evaluated!")