import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import floyd_warshall, shortest_path
import networkx as nx
from torch_geometric.utils import to_networkx, to_dense_adj
from torch_geometric.data import Data



def construct_kernel_matrix(distances,alpha=1.0):
    return torch.exp(-distances/alpha)


def energy_function(logits):
    """Compute energy values from logits using -log(sum(exp(logits))"""
    return -np.log(np.sum(np.exp(logits), axis=1))

def energy_function_torch(logits):
    """Compute energy values from logits using -log(sum(exp(logits))"""
    return -torch.log(torch.sum(torch.exp(logits), dim=1))

def entropy_function_torch(logits):
    """Compute entropy values from logits using -sum(p * log(p))"""
    probs = torch.softmax(logits, dim=1)
    log_probs = torch.log(probs + 1e-10)  # Add small epsilon to avoid log(0)
    return -torch.sum(probs * log_probs, dim=1)

def log_density_function(logits):
    """Convert logits to log density estimates"""
    return -energy_function(logits)

def log_density_function_torch(logits):
    """Convert logits to log density estimates"""
    return -energy_function_torch(logits)

def scale_datasets_preserving_relative_size(dataset1, dataset2):
    """
    Scale two datasets to [0,1] while preserving their relative sizes
    """
    # Find global min and max
    global_min = min(np.min(dataset1), np.min(dataset2))
    global_max = max(np.max(dataset1), np.max(dataset2))
    
    # Scale both datasets using the same min and max
    scaled_dataset1 = (dataset1 - global_min) / (global_max - global_min)
    scaled_dataset2 = (dataset2 - global_min) / (global_max - global_min)
    
    return scaled_dataset1, scaled_dataset2

def compute_distances_for_densities(x_points, density1, density2):
    """Compute Wasserstein distance between two density distributions"""
    from scipy.stats import wasserstein_distance
    
    # Filter out invalid values
    valid_mask = ~(np.isnan(density1) | np.isinf(density1) | np.isnan(density2) | np.isinf(density2))
    if not np.any(valid_mask):
        return float('inf')
    
    return wasserstein_distance(density1[valid_mask], density2[valid_mask])




################################################################################
# calculate pairwise shortest distances
################################################################################
def calculate_pairwise_shortest_distances(edge_index, num_nodes=None, directed=False, method='floyd_warshall'):
    """
    Calculate pairwise shortest distances for all nodes in a graph.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        directed: Whether the graph is directed (default: False)
        method: Algorithm to use ('floyd_warshall', 'dijkstra', 'bfs', or 'networkx')
                'floyd_warshall': Best for dense graphs, O(V^3)
                'dijkstra': Good for sparse graphs, O(V^2 log V)
                'bfs': Fastest for unweighted graphs, O(V*(V+E))
                'networkx': Uses NetworkX library, good for convenience
    
    Returns:
        distances: Tensor of shape [num_nodes, num_nodes] containing shortest path distances
    """
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    if method == 'networkx':
        # Convert to NetworkX graph
        G = to_networkx(Data(edge_index=edge_index, num_nodes=num_nodes), to_undirected=not directed)
        
        # Calculate shortest paths
        length = dict(nx.all_pairs_shortest_path_length(G))
        
        # Convert to matrix form
        dist_matrix = np.full((num_nodes, num_nodes), float('inf'))
        for i in range(num_nodes):
            if i in length:
                for j in range(num_nodes):
                    if j in length[i]:
                        dist_matrix[i, j] = length[i][j]
    else:
        # Convert edge_index to adjacency matrix
        edge_index_np = edge_index.cpu().numpy()
        rows, cols = edge_index_np
        
        # Create weights (all 1s for unweighted graph)
        data = np.ones(len(rows))
        
        # Create sparse adjacency matrix
        adj_matrix = csr_matrix((data, (rows, cols)), shape=(num_nodes, num_nodes))
        
        # Make undirected if needed
        if not directed:
            adj_matrix = adj_matrix.maximum(adj_matrix.T)
            
        # Calculate shortest paths using the specified method
        if method == 'floyd_warshall':
            # Floyd-Warshall algorithm - efficient for dense graphs
            dist_matrix = floyd_warshall(csgraph=adj_matrix, directed=directed)
        elif method == 'dijkstra':
            # Dijkstra's algorithm - efficient for sparse graphs
            dist_matrix = shortest_path(csgraph=adj_matrix, method='D', directed=directed)
        elif method == 'bfs':
            # Breadth-first search - fastest for unweighted graphs
            dist_matrix = shortest_path(csgraph=adj_matrix, method='FW', directed=directed, unweighted=True)
        else:
            raise ValueError(f"Unknown method: {method}. Choose from 'floyd_warshall', 'dijkstra', 'bfs', or 'networkx'.")

    # Convert to PyTorch tensor
    distances = torch.from_numpy(dist_matrix).float()
    
    # Replace infinities with a large value
    distances[torch.isinf(distances)] = num_nodes  # or another suitable large value
    
    return distances

def calculate_pairwise_shortest_distances_batched(edge_index, num_nodes=None, batch_size=1000, device='cpu'):
    """
    Memory-efficient implementation for large graphs, processing nodes in batches.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph
        batch_size: Number of source nodes to process at once
        device: Device to use for computation ('cuda' or 'cpu')
    
    Returns:
        distances: Tensor of shape [num_nodes, num_nodes] containing shortest path distances
    """
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Move edge_index to the desired device
    edge_index = edge_index.to(device)
    
    # Create adjacency matrix
    adj = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0].to(device)
    
    # Initialize distance matrix
    distances = torch.full((num_nodes, num_nodes), float('inf'), device=device)
    
    # Set diagonal to 0 (distance to self)
    distances.fill_diagonal_(0)
    
    # Set direct connections from adjacency matrix
    distances[adj > 0] = 1
    
    # Process in batches to save memory
    for k in range(0, num_nodes, batch_size):
        # Get batch of intermediate nodes
        k_batch = min(k + batch_size, num_nodes)
        
        # Floyd-Warshall for this batch
        for i in range(num_nodes):
            # Calculate i -> k -> j
            i_to_k = distances[i, k:k_batch].unsqueeze(1)  # [batch_size, 1]
            k_to_j = distances[k:k_batch, :]  # [batch_size, num_nodes]
            
            # Calculate new distances through k
            new_dists = i_to_k + k_to_j  # [batch_size, num_nodes]
            
            # Update distances if a shorter path is found
            distances[i] = torch.minimum(distances[i], new_dists.min(dim=0)[0])
    
    # Replace infinities with a large value
    distances[torch.isinf(distances)] = num_nodes
    
    return distances

def calculate_graph_statistics(edge_index, num_nodes=None):
    """
    Calculate various graph statistics.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
    
    Returns:
        dict: Dictionary containing various graph statistics
    """
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Convert to NetworkX graph for easy analysis
    G = to_networkx(Data(edge_index=edge_index, num_nodes=num_nodes), to_undirected=True)
    
    # Calculate basic statistics
    stats = {
        'num_nodes': num_nodes,
        'num_edges': G.number_of_edges(),
        'avg_degree': sum(dict(G.degree()).values()) / num_nodes,
        'density': nx.density(G),
        'clustering_coefficient': nx.average_clustering(G),
    }
    
    # Calculate connected components
    components = list(nx.connected_components(G))
    stats['num_components'] = len(components)
    
    # Calculate largest component statistics
    largest_cc = max(components, key=len)
    largest_cc_graph = G.subgraph(largest_cc)
    
    stats['largest_component_size'] = len(largest_cc)
    stats['largest_component_ratio'] = len(largest_cc) / num_nodes
    
    # Calculate diameter (only for the largest component to avoid inf)
    if len(largest_cc) > 1:
        stats['diameter'] = nx.diameter(largest_cc_graph)
    else:
        stats['diameter'] = 0
    
    # Calculate more advanced statistics if the graph is not too large
    if num_nodes <= 10000:  # Avoid expensive computations for large graphs
        try:
            stats['avg_shortest_path'] = nx.average_shortest_path_length(largest_cc_graph)
        except:
            stats['avg_shortest_path'] = float('nan')  # Graph might be disconnected
            
        # Calculate centrality measures for a sample of nodes to speed up computation
        if num_nodes <= 1000:
            stats['avg_betweenness_centrality'] = np.mean(list(nx.betweenness_centrality(largest_cc_graph).values()))
            stats['avg_closeness_centrality'] = np.mean(list(nx.closeness_centrality(largest_cc_graph).values()))
        else:
            # Sample 1000 nodes for large graphs
            sampled_nodes = np.random.choice(list(largest_cc), min(1000, len(largest_cc)), replace=False)
            stats['avg_betweenness_centrality'] = np.mean(list(nx.betweenness_centrality(largest_cc_graph, k=sampled_nodes).values()))
            stats['avg_closeness_centrality'] = np.mean(list(nx.closeness_centrality(largest_cc_graph, u=list(sampled_nodes)).values()))
    
    return stats

def visualize_graph_distances(distances, title="Pairwise Shortest Distances", save_path=None):
    """
    Visualize the distribution of pairwise shortest distances in a graph.
    
    Args:
        distances: Tensor of shape [num_nodes, num_nodes] containing shortest path distances
        title: Title for the plot
        save_path: Path to save the figure
    """
    # Convert to numpy if it's a tensor
    if isinstance(distances, torch.Tensor):
        distances_np = distances.cpu().numpy()
    else:
        distances_np = distances
    
    # Flatten the distance matrix, excluding self-distances (diagonals)
    flat_distances = distances_np[~np.eye(distances_np.shape[0], dtype=bool)]
    
    # Remove infinity values
    flat_distances = flat_distances[~np.isinf(flat_distances)]
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Histogram of distances
    ax1.hist(flat_distances, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.set_xlabel('Shortest Path Distance')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Pairwise Shortest Distances')
    ax1.grid(alpha=0.3)
    
    # Heatmap of distance matrix
    im = ax2.imshow(distances_np, cmap='viridis', interpolation='nearest')
    ax2.set_title('Shortest Path Distance Matrix')
    plt.colorbar(im, ax=ax2, label='Distance')
    
    # Overall title
    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig


################################################################################
# calculate homophily matrix
################################################################################
def calculate_node_homophily(edge_index, labels, directed=False):
    """
    Calculate the homophily for each node in the graph.
    Homophily is defined as the fraction of neighbors that have the same label as the node.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        labels: PyTorch tensor of shape [num_nodes] containing node labels
        directed: Whether the graph is directed (default: False)
    
    Returns:
        homophily: Tensor of shape [num_nodes] with homophily scores for each node
    """
    num_nodes = labels.size(0)
    
    # Make sure labels are on CPU for indexing
    labels_np = labels.cpu()
    
    # Initialize homophily scores
    homophily_scores = torch.zeros(num_nodes, dtype=torch.float32)
    
    # Get source and target nodes
    src, dst = edge_index[0], edge_index[1]
    
    # Create an efficient mapping of nodes to their neighbors
    neighbors = {}
    for i in range(edge_index.size(1)):
        s, d = src[i].item(), dst[i].item()
        if s not in neighbors:
            neighbors[s] = []
        neighbors[s].append(d)
        
        # For undirected graphs, add the reverse edge
        if not directed:
            if d not in neighbors:
                neighbors[d] = []
            neighbors[d].append(s)
    
    # Calculate homophily for each node
    for node in range(num_nodes):
        if node in neighbors:
            node_neighbors = neighbors[node]
            if len(node_neighbors) > 0:
                # Calculate fraction of neighbors with same label
                node_label = labels_np[node].item()
                same_label_count = sum(1 for neighbor in node_neighbors if labels_np[neighbor].item() == node_label)
                homophily_scores[node] = same_label_count / len(node_neighbors)
            # If a node has no neighbors, leave its homophily as 0
    
    return homophily_scores

def calculate_node_homophily_vectorized(edge_index, labels, directed=False):
    """
    Vectorized implementation to calculate the homophily for each node in the graph.
    This implementation is more memory-efficient for large graphs.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        labels: PyTorch tensor of shape [num_nodes] containing node labels
        directed: Whether the graph is directed (default: False)
    
    Returns:
        homophily: Tensor of shape [num_nodes] with homophily scores for each node
    """
    # Get device of input tensors
    device = edge_index.device
    
    # Get source and target nodes
    src, dst = edge_index
    
    # Get the labels of destination nodes
    dst_labels = labels[dst]
    
    # Create a tensor for counting neighbors and same-label neighbors for each node
    num_nodes = labels.size(0)
    neighbor_counts = torch.zeros(num_nodes, device=device)
    same_label_counts = torch.zeros(num_nodes, device=device)
    
    # Count neighbors for each source node
    src_unique, src_counts = src.unique(return_counts=True)
    neighbor_counts.scatter_add_(0, src_unique, src_counts.float())
    
    # Create a mask for neighbors with the same label
    same_label_mask = labels[src] == dst_labels
    
    # Add 1 to same_label_counts for each neighbor with same label
    same_label_counts.scatter_add_(0, src[same_label_mask], 
                                  torch.ones(same_label_mask.sum(), device=device))
    
    # For undirected graphs, do the same with reversed edges
    if not directed:
        # Count neighbors for each destination node
        dst_unique, dst_counts = dst.unique(return_counts=True)
        neighbor_counts.scatter_add_(0, dst_unique, dst_counts.float())
        
        # Create a mask for neighbors with the same label (reversed edges)
        same_label_mask_rev = labels[dst] == labels[src]
        
        # Add 1 to same_label_counts for each neighbor with same label
        same_label_counts.scatter_add_(0, dst[same_label_mask_rev], 
                                      torch.ones(same_label_mask_rev.sum(), device=device))
    
    # Calculate homophily scores (prevent division by zero for isolated nodes)
    homophily_scores = torch.zeros_like(neighbor_counts)
    mask = neighbor_counts > 0
    homophily_scores[mask] = same_label_counts[mask] / neighbor_counts[mask]
    
    return homophily_scores

def calculate_graph_homophily(edge_index, labels, directed=False):
    """
    Calculate the overall graph homophily.
    This measure averages the homophily across all edges, not nodes.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        labels: PyTorch tensor of shape [num_nodes] containing node labels
        directed: Whether the graph is directed (default: False)
    
    Returns:
        homophily: Float representing the graph-level homophily
    """
    # Get source and target nodes
    src, dst = edge_index
    
    # Calculate the fraction of edges connecting nodes with the same label
    same_label_mask = labels[src] == labels[dst]
    graph_homophily = same_label_mask.float().mean().item()
    
    return graph_homophily

def visualize_node_homophily(edge_index, labels, node_pos=None, title="Node Homophily", save_path=None):
    """
    Visualize the homophily of each node in the graph.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        labels: PyTorch tensor of shape [num_nodes] containing node labels
        node_pos: Optional dictionary mapping node indices to positions
        title: Title for the plot
        save_path: Path to save the figure
    
    Returns:
        fig: The matplotlib figure
    """
    # Calculate node homophily
    homophily_scores = calculate_node_homophily_vectorized(edge_index, labels)
    
    # Create graph
    num_nodes = labels.size(0)
    G = to_networkx(Data(edge_index=edge_index, num_nodes=num_nodes), to_undirected=True)
    
    # Create node attributes
    node_labels = {i: labels[i].item() for i in range(num_nodes)}
    node_homophily = {i: homophily_scores[i].item() for i in range(num_nodes)}
    
    # Set node attributes
    nx.set_node_attributes(G, node_labels, 'label')
    nx.set_node_attributes(G, node_homophily, 'homophily')
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
    
    # Get positions
    if node_pos is None:
        # Use spring layout for positioning
        node_pos = nx.spring_layout(G, seed=42)
    
    # Plot graph colored by label
    unique_labels = sorted(set(node_labels.values()))
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
    label_to_color = {label: colors[i] for i, label in enumerate(unique_labels)}
    node_colors_by_label = [label_to_color[node_labels[node]] for node in G.nodes()]
    
    nx.draw_networkx(G, pos=node_pos, node_color=node_colors_by_label, 
                    with_labels=False, node_size=50, edge_color='gray', 
                    alpha=0.7, ax=ax1)
    
    # Add legend for labels
    label_patches = [plt.Line2D([0], [0], marker='o', color='w', 
                              markerfacecolor=label_to_color[label], 
                              markersize=10, label=f'Class {label}') 
                    for label in unique_labels]
    ax1.legend(handles=label_patches, loc='upper right')
    ax1.set_title('Graph colored by node labels')
    
    # Plot graph colored by homophily
    node_colors_by_homophily = [node_homophily[node] for node in G.nodes()]
    nodes = nx.draw_networkx_nodes(G, pos=node_pos, node_color=node_colors_by_homophily,
                                 cmap='viridis', node_size=50, alpha=0.8, ax=ax2)
    nx.draw_networkx_edges(G, pos=node_pos, edge_color='gray', alpha=0.3, ax=ax2)
    
    # Add colorbar
    cbar = plt.colorbar(nodes, ax=ax2)
    cbar.set_label('Homophily Score')
    
    # Add title
    ax2.set_title('Graph colored by node homophily')
    fig.suptitle(title, fontsize=16)
    
    # Calculate overall graph statistics
    graph_homophily = calculate_graph_homophily(edge_index, labels)
    avg_node_homophily = homophily_scores.mean().item()
    
    # Add text with homophily statistics
    stats_text = (f"Graph Homophily: {graph_homophily:.4f}\n"
                 f"Avg Node Homophily: {avg_node_homophily:.4f}")
    fig.text(0.5, 0.01, stats_text, ha='center', fontsize=12, 
            bbox=dict(facecolor='white', alpha=0.5))
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig



################################################################################
# Laplacian matrix and diffusion calculations
################################################################################
def calculate_graph_laplacian(edge_index, num_nodes=None, normalized=True, return_sparse=False):
    """
    Calculate the graph Laplacian matrix.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        normalized: Whether to return the normalized Laplacian (default: True)
        return_sparse: Whether to return a sparse matrix (default: False)
    
    Returns:
        laplacian: Graph Laplacian matrix (PyTorch tensor or SciPy sparse matrix)
        degree_matrix: Degree matrix of the graph (optional, if return_sparse=True)
        adjacency_matrix: Adjacency matrix of the graph (optional, if return_sparse=True)
    """
    import torch
    import scipy.sparse as sp
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Convert edge_index to sparse adjacency matrix
    edge_index_np = edge_index.cpu().numpy()
    rows, cols = edge_index_np
    data = np.ones(len(rows))
    
    # Create sparse adjacency matrix
    adj_matrix = sp.csr_matrix((data, (rows, cols)), shape=(num_nodes, num_nodes))
    
    # Make the graph undirected (symmetric adjacency)
    adj_matrix = adj_matrix.maximum(adj_matrix.T)
    
    # Compute degree matrix
    degrees = np.array(adj_matrix.sum(axis=1)).flatten()
    degree_matrix = sp.diags(degrees)
    
    if normalized:
        # Compute D^(-1/2)
        with np.errstate(divide='ignore'):
            d_inv_sqrt = np.power(degrees + 1e-12, -0.5)
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
        
        # Normalized Laplacian: I - D^(-1/2) A D^(-1/2)
        normalized_adj = d_inv_sqrt_mat @ adj_matrix @ d_inv_sqrt_mat
        identity = sp.eye(num_nodes)
        laplacian = identity - normalized_adj
    else:
        # Standard Laplacian: D - A
        laplacian = degree_matrix - adj_matrix
    
    if return_sparse:
        return laplacian, degree_matrix, adj_matrix
    else:
        # Convert to dense PyTorch tensor
        laplacian_dense = torch.tensor(laplacian.todense(), dtype=torch.float32)
        return laplacian_dense

def calculate_diffusion_distances_alpha(edge_index, num_nodes=None, alpha=1.0, t=1.0, normalized=True):
    """
    Calculate diffusion distances between all pairs of nodes with custom diffusion rate alpha.
    
    The diffusion distance with custom rate alpha is defined via the matrix exponential:
    exp(-t * alpha * L), where L is the graph Laplacian and alpha controls the diffusion speed.
    Alpha can be a scalar or a matrix of size |V|×|V| for node-pair specific diffusion rates.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        alpha: Diffusion rate parameter (scalar or matrix of shape [num_nodes, num_nodes])
        t: Diffusion time parameter
        normalized: Whether to use the normalized Laplacian (default: True)
    
    Returns:
        distances: Tensor of shape [num_nodes, num_nodes] containing diffusion distances
    """
    import torch
    import scipy.sparse as sp
    from scipy.sparse.linalg import expm
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Get the Laplacian (sparse format for efficiency)
    laplacian, _, _ = calculate_graph_laplacian(edge_index, num_nodes, normalized, return_sparse=True)
    
    # Handle different types of alpha
    if isinstance(alpha, (int, float)):
        # Scalar alpha: Apply uniformly
        scaled_laplacian = -t * alpha * laplacian
    else:
        # Matrix alpha: Element-wise multiplication with the Laplacian
        # Convert alpha to numpy array if it's a tensor
        if isinstance(alpha, torch.Tensor):
            alpha_np = alpha.cpu().numpy()
        else:
            alpha_np = alpha
            
        # Convert Laplacian to dense for element-wise operation
        laplacian_dense = laplacian.todense()
        
        # Scale the Laplacian with alpha
        scaled_laplacian = -t * (alpha_np * laplacian_dense)
        
        # Convert back to sparse for expm
        scaled_laplacian = sp.csr_matrix(scaled_laplacian)
    
    # Calculate heat kernel: exp(-t * alpha * L)
    heat_kernel = expm(scaled_laplacian)
    
    # Convert to dense tensor for distance calculation
    heat_kernel_dense = torch.tensor(heat_kernel.todense(), dtype=torch.float32)
    
    # Calculate diffusion distances
    # d_t(i,j)^2 = ||P_t(i,:) - P_t(j,:)||^2
    diff_distances = torch.zeros((num_nodes, num_nodes), dtype=torch.float32)
    
    # Vectorized implementation for efficiency
    for i in range(num_nodes):
        # Calculate squared differences to all other nodes at once
        diff = heat_kernel_dense[i] - heat_kernel_dense
        squared_diff = torch.sum(diff**2, dim=1)
        diff_distances[i] = squared_diff
    
    # Take square root to get actual distances
    diff_distances = torch.sqrt(diff_distances)
    
    return diff_distances

def calculate_diffusion_kernel(edge_index, num_nodes=None, alpha=1.0, t=1.0, normalized=True):
    """
    Calculate the diffusion kernel (heat kernel) for a graph.
    
    The diffusion kernel is defined as exp(-t * alpha * L), where L is the graph Laplacian.
    This represents how heat/information diffuses through the graph over time.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        alpha: Diffusion rate parameter (scalar or matrix)
        t: Diffusion time parameter
        normalized: Whether to use the normalized Laplacian (default: True)
    
    Returns:
        heat_kernel: Diffusion kernel matrix as a PyTorch tensor
    """
    import torch
    import scipy.sparse as sp
    from scipy.sparse.linalg import expm
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Get the Laplacian (sparse format for efficiency)
    laplacian, _, _ = calculate_graph_laplacian(edge_index, num_nodes, normalized, return_sparse=True)
    
    # Handle different types of alpha
    if isinstance(alpha, (int, float)):
        # Scalar alpha: Apply uniformly
        scaled_laplacian = -t * alpha * laplacian
    else:
        # Matrix alpha: Element-wise multiplication with the Laplacian
        # Convert alpha to numpy array if it's a tensor
        if isinstance(alpha, torch.Tensor):
            alpha_np = alpha.cpu().numpy()
        else:
            alpha_np = alpha
            
        # Convert Laplacian to dense for element-wise operation
        laplacian_dense = laplacian.todense()
        
        # Scale the Laplacian with alpha
        scaled_laplacian = -t * (alpha_np * laplacian_dense)
        
        # Convert back to sparse for expm
        scaled_laplacian = sp.csr_matrix(scaled_laplacian)
    
    # Calculate heat kernel: exp(-t * alpha * L)
    heat_kernel = expm(scaled_laplacian)
    
    # Convert to dense tensor
    heat_kernel_dense = torch.tensor(heat_kernel.todense(), dtype=torch.float32)
    
    return heat_kernel_dense

def visualize_diffusion_process(edge_index, seed_nodes=None, num_nodes=None, time_points=None, 
                               alpha=1.0, normalized=True, node_pos=None, cmap='viridis', save_path=None):
    """
    Visualize the diffusion process starting from seed nodes over time.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        seed_nodes: List of seed nodes to start the diffusion from
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        time_points: List of time points to visualize (default: [0.1, 0.5, 1.0, 2.0, 5.0])
        alpha: Diffusion rate parameter (scalar or matrix)
        normalized: Whether to use the normalized Laplacian (default: True)
        node_pos: Optional dictionary mapping node indices to positions for consistent layout
        cmap: Colormap to use for visualization
        save_path: Path to save the visualization
        
    Returns:
        fig: The matplotlib figure
    """
    import torch
    import matplotlib.pyplot as plt
    import networkx as nx
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Default time points
    if time_points is None:
        time_points = [0.1, 0.5, 1.0, 2.0, 5.0]
    
    # Default seed nodes
    if seed_nodes is None:
        seed_nodes = [0]  # Use first node as default
        
    # Create one-hot encoding for seed nodes
    initial_state = torch.zeros(num_nodes)
    initial_state[seed_nodes] = 1.0
    
    # Convert to NetworkX graph for visualization
    G = to_networkx(Data(edge_index=edge_index, num_nodes=num_nodes), to_undirected=True)
    
    # Compute node positions if not provided
    if node_pos is None:
        node_pos = nx.spring_layout(G, seed=42)
    
    # Create a figure with subplots for each time point
    n_plots = len(time_points) + 1  # +1 for initial state
    n_cols = min(3, n_plots)
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows))
    if n_rows * n_cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    # Plot initial state
    nx.draw_networkx_edges(G, pos=node_pos, alpha=0.3, ax=axes[0])
    nodes = nx.draw_networkx_nodes(G, pos=node_pos, node_color=initial_state, cmap=cmap, 
                                 node_size=100, ax=axes[0])
    plt.colorbar(nodes, ax=axes[0])
    axes[0].set_title('Initial State (t=0)')
    axes[0].axis('off')
    
    # Compute diffusion at each time point
    for i, t in enumerate(time_points):
        # Compute heat kernel at time t
        heat_kernel = calculate_diffusion_kernel(edge_index, num_nodes, alpha, t, normalized)
        
        # Compute state at time t
        state_t = heat_kernel @ initial_state
        
        # Plot state at time t
        nx.draw_networkx_edges(G, pos=node_pos, alpha=0.3, ax=axes[i+1])
        nodes = nx.draw_networkx_nodes(G, pos=node_pos, node_color=state_t, cmap=cmap, 
                                     node_size=100, ax=axes[i+1])
        plt.colorbar(nodes, ax=axes[i+1])
        axes[i+1].set_title(f'Diffusion at t={t}')
        axes[i+1].axis('off')
    
    # Hide any unused subplots
    for i in range(n_plots, len(axes)):
        axes[i].axis('off')
    
    # Set overall title
    plt.suptitle(f'Diffusion Process from Seed Node(s) {seed_nodes}', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

def create_custom_alpha_matrix(edge_index, num_nodes=None, node_features=None, feature_similarity='cosine', sigma=1.0):
    """
    Create a custom alpha matrix for the diffusion process based on node features.
    
    This allows for non-uniform diffusion rates between nodes, where similar nodes
    have faster diffusion rates between them.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        node_features: Node feature matrix of shape [num_nodes, feature_dim]
        feature_similarity: Similarity measure to use ('cosine', 'euclidean', 'custom')
        sigma: Bandwidth parameter for RBF kernel if using 'euclidean' similarity
        
    Returns:
        alpha_matrix: Custom alpha matrix for diffusion
    """
    import torch
    from sklearn.metrics.pairwise import cosine_similarity, rbf_kernel
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # If no features provided, use constant alpha
    if node_features is None:
        return 1.0
    
    # Convert to numpy for sklearn functions
    features_np = node_features.detach().cpu().numpy()
    
    # Compute pairwise similarities
    if feature_similarity == 'cosine':
        similarity_matrix = cosine_similarity(features_np)
    elif feature_similarity == 'euclidean':
        # RBF kernel: exp(-gamma * ||x-y||^2)
        similarity_matrix = rbf_kernel(features_np, gamma=1.0/(2.0*sigma**2))
    elif feature_similarity == 'custom':
        # Custom similarity: can be implemented based on specific needs
        # Example: use inverse of feature distance
        pairwise_distances = torch.cdist(node_features, node_features, p=2).cpu().numpy()
        similarity_matrix = 1.0 / (1.0 + pairwise_distances)
    else:
        raise ValueError(f"Unknown similarity measure: {feature_similarity}")
    
    # Convert similarity to alpha (higher similarity -> faster diffusion)
    alpha_matrix = similarity_matrix
    
    # Ensure alpha is always positive
    alpha_matrix = np.maximum(alpha_matrix, 0.001)
    
    return alpha_matrix

################################################################################
# calculate resistance distance matrix
################################################################################
def calculate_resistance_distances(edge_index, num_nodes=None, method='pinv', return_laplacian=False):
    """
    Calculate resistance distances (effective resistances) between all pairs of nodes.
    
    The resistance distance between nodes i and j is defined as the effective
    resistance between these nodes when the graph is viewed as an electrical network
    with unit resistances on each edge.
    
    Mathematically, it equals r_ij = L⁺_ii + L⁺_jj - 2*L⁺_ij, where L⁺ is the
    Moore-Penrose pseudoinverse of the graph Laplacian.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        method: Method to compute the pseudoinverse ('pinv' or 'eigen')
        return_laplacian: Whether to also return the Laplacian and its pseudoinverse
    
    Returns:
        resistance_distances: Tensor of shape [num_nodes, num_nodes] containing resistance distances
        (optional) laplacian: The graph Laplacian matrix
        (optional) laplacian_pinv: The pseudoinverse of the graph Laplacian
    """
    import torch
    import scipy.sparse as sp
    import numpy as np
    from scipy.linalg import pinv
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Get the Laplacian (reusing the function we already implemented)
    laplacian, degree_matrix, adj_matrix = calculate_graph_laplacian(
        edge_index, num_nodes, normalized=False, return_sparse=True)
    
    # Convert to dense format for pseudoinverse calculation
    laplacian_dense = laplacian.todense()
    
    if method == 'pinv':
        # Compute pseudoinverse using scipy's pinv
        laplacian_pinv = pinv(laplacian_dense)
    elif method == 'eigen':
        # Compute pseudoinverse using eigendecomposition
        # This can be more numerically stable for some graphs
        eigenvalues, eigenvectors = np.linalg.eigh(laplacian_dense)
        
        # Zero out eigenvalues close to zero (corresponding to the null space)
        nonzero_mask = eigenvalues > 1e-10
        
        # Construct pseudoinverse using non-zero eigenvalues
        pinv_diag = np.zeros_like(eigenvalues)
        pinv_diag[nonzero_mask] = 1.0 / eigenvalues[nonzero_mask]
        
        # V * inv(D) * V^T
        laplacian_pinv = eigenvectors @ np.diag(pinv_diag) @ eigenvectors.T
    else:
        raise ValueError(f"Unknown method: {method}. Choose from 'pinv' or 'eigen'.")
    
    # Calculate resistance distances using the formula:
    # r_ij = L⁺_ii + L⁺_jj - 2*L⁺_ij
    resistance_distances = np.zeros((num_nodes, num_nodes))
    
    # Vectorized computation for efficiency
    diag_pinv = np.diag(laplacian_pinv)
    for i in range(num_nodes):
        for j in range(i, num_nodes):
            r_ij = diag_pinv[i] + diag_pinv[j] - 2 * laplacian_pinv[i, j]
            # Ensure non-negative distance (can happen due to numerical issues)
            resistance_distances[i, j] = max(0, r_ij)
            resistance_distances[j, i] = resistance_distances[i, j]  # Symmetric
    
    # Convert to PyTorch tensor
    resistance_distances_tensor = torch.tensor(resistance_distances, dtype=torch.float32)
    
    if return_laplacian:
        laplacian_tensor = torch.tensor(laplacian_dense, dtype=torch.float32)
        laplacian_pinv_tensor = torch.tensor(laplacian_pinv, dtype=torch.float32)
        return resistance_distances_tensor, laplacian_tensor, laplacian_pinv_tensor
    else:
        return resistance_distances_tensor

def calculate_resistance_distances_batched(edge_index, num_nodes=None, batch_size=1000, device='cpu'):
    """
    Memory-efficient implementation of resistance distance calculation for large graphs.
    
    This function calculates resistance distances in batches to reduce memory usage
    for large graphs. It computes the Laplacian pseudoinverse once and then
    calculates distances in batches.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        batch_size: Number of source nodes to process at once
        device: Device to use for computation ('cuda' or 'cpu')
    
    Returns:
        resistance_distances: Tensor of shape [num_nodes, num_nodes] containing resistance distances
    """
    import torch
    import scipy.sparse as sp
    import numpy as np
    from scipy.linalg import pinv
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Get the Laplacian (reusing the function we already implemented)
    laplacian, _, _ = calculate_graph_laplacian(
        edge_index, num_nodes, normalized=False, return_sparse=True)
    
    # Convert to dense format for pseudoinverse calculation
    laplacian_dense = laplacian.todense()
    
    # Compute pseudoinverse
    laplacian_pinv = pinv(laplacian_dense)
    
    # Convert to torch tensor and move to device
    laplacian_pinv_tensor = torch.tensor(laplacian_pinv, dtype=torch.float32).to(device)
    
    # Initialize resistance distance matrix
    resistance_distances = torch.zeros((num_nodes, num_nodes), dtype=torch.float32, device=device)
    
    # Get diagonal elements of pseudoinverse
    diag_pinv = torch.diag(laplacian_pinv_tensor)
    
    # Calculate resistance distances in batches
    for i in range(0, num_nodes, batch_size):
        # Get batch of source nodes
        batch_end = min(i + batch_size, num_nodes)
        batch_size_actual = batch_end - i
        
        # Create diagonal terms for this batch (L⁺_ii)
        diag_terms_i = diag_pinv[i:batch_end].unsqueeze(1).expand(batch_size_actual, num_nodes)
        
        # Add diagonal terms for all nodes (L⁺_jj)
        diag_terms_j = diag_pinv.unsqueeze(0).expand(batch_size_actual, num_nodes)
        
        # Subtract 2 * L⁺_ij for all pairs in this batch
        cross_terms = 2 * laplacian_pinv_tensor[i:batch_end, :]
        
        # Compute resistance distances: r_ij = L⁺_ii + L⁺_jj - 2*L⁺_ij
        batch_distances = diag_terms_i + diag_terms_j - cross_terms
        
        # Ensure non-negative distances (can happen due to numerical issues)
        batch_distances = torch.clamp(batch_distances, min=0.0)
        
        # Store in the result matrix
        resistance_distances[i:batch_end, :] = batch_distances
    
    return resistance_distances

def calculate_effective_resistance(edge_index, i, j, num_nodes=None, method='pinv'):
    """
    Calculate the effective resistance between two specific nodes i and j.
    
    This is more efficient than computing all pairwise resistances when only
    specific node pairs are of interest.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        i: Index of first node
        j: Index of second node
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        method: Method to compute the pseudoinverse ('pinv' or 'eigen')
        
    Returns:
        resistance: Effective resistance between nodes i and j
    """
    import torch
    import numpy as np
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Get the Laplacian and its pseudoinverse
    resistance_distances, _, laplacian_pinv = calculate_resistance_distances(
        edge_index, num_nodes, method=method, return_laplacian=True)
    
    # Extract the resistance between nodes i and j
    resistance = resistance_distances[i, j].item()
    
    return resistance

def visualize_resistance_distances(resistance_distances, node_labels=None, 
                                  title="Resistance Distances", save_path=None):
    """
    Visualize resistance distances between nodes.
    
    Args:
        resistance_distances: Tensor or matrix of resistance distances
        node_labels: Optional tensor of node labels for grouping
        title: Title for the plot
        save_path: Path to save the figure
    
    Returns:
        fig: The matplotlib figure
    """
    # Convert to numpy if it's a tensor
    if isinstance(resistance_distances, torch.Tensor):
        distances_np = resistance_distances.cpu().numpy()
    else:
        distances_np = resistance_distances
    
    # Create figure with subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    # Heatmap of resistance distances
    im = ax1.imshow(distances_np, cmap='viridis', interpolation='nearest')
    ax1.set_title('Resistance Distance Matrix')
    plt.colorbar(im, ax=ax1)
    
    # Histogram of distances
    # Flatten the upper triangular part (excluding diagonal)
    mask = np.triu_indices_from(distances_np, k=1)
    flat_distances = distances_np[mask]
    
    ax2.hist(flat_distances, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    ax2.set_xlabel('Resistance Distance')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Pairwise Resistance Distances')
    ax2.grid(alpha=0.3)
    
    # If node labels are provided, show grouped statistics
    if node_labels is not None:
        labels_np = node_labels.cpu().numpy() if isinstance(node_labels, torch.Tensor) else node_labels
        unique_labels = np.unique(labels_np)
        
        # Create a third subplot for intra/inter-class distances
        fig.set_size_inches(20, 7)
        ax3 = fig.add_subplot(1, 3, 3)
        
        intra_class_distances = []
        inter_class_distances = []
        
        # Collect intra-class and inter-class distances
        for i in range(len(distances_np)):
            for j in range(i+1, len(distances_np)):
                dist = distances_np[i, j]
                if labels_np[i] == labels_np[j]:
                    intra_class_distances.append(dist)
                else:
                    inter_class_distances.append(dist)
        
        # Plot histograms for intra/inter-class distances
        ax3.hist(intra_class_distances, bins=20, alpha=0.5, label='Same Class', color='blue')
        ax3.hist(inter_class_distances, bins=20, alpha=0.5, label='Different Classes', color='red')
        ax3.set_xlabel('Resistance Distance')
        ax3.set_ylabel('Frequency')
        ax3.set_title('Intra vs. Inter-Class Resistance Distances')
        ax3.legend()
        ax3.grid(alpha=0.3)
    
    # Set overall title
    fig.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

def compare_resistance_metrics(edge_index, node_features=None, labels=None, num_nodes=None):
    """
    Compare resistance distances with other distance metrics and node properties.
    
    This function is useful for analyzing how resistance distances relate to
    other graph properties like shortest paths, node features, and class labels.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        node_features: Optional tensor of node features
        labels: Optional tensor of node labels
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        
    Returns:
        fig: The matplotlib figure
        metrics_dict: Dictionary containing the computed metrics
    """
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.stats import pearsonr, spearmanr
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Calculate resistance distances
    resistance_distances = calculate_resistance_distances(edge_index, num_nodes)
    
    # Calculate shortest path distances
    shortest_distances = calculate_pairwise_shortest_distances(edge_index, num_nodes)
    
    # Create a dictionary to store metrics
    metrics_dict = {
        'resistance_distances': resistance_distances,
        'shortest_distances': shortest_distances
    }
    
    # Create a figure for visualization
    fig, axes = plt.subplots(2, 2, figsize=(18, 14))
    
    # Plot 1: Resistance vs Shortest Path
    # Flatten matrices (excluding diagonals)
    mask = np.triu_indices(num_nodes, k=1)
    res_flat = resistance_distances.cpu().numpy()[mask]
    short_flat = shortest_distances.cpu().numpy()[mask]
    
    # Scatter plot
    axes[0, 0].scatter(short_flat, res_flat, alpha=0.5, s=10)
    axes[0, 0].set_xlabel('Shortest Path Distance')
    axes[0, 0].set_ylabel('Resistance Distance')
    axes[0, 0].set_title('Resistance vs Shortest Path Distance')
    
    # Calculate correlation
    if len(res_flat) > 1:  # Need at least 2 points for correlation
        pearson, _ = pearsonr(short_flat, res_flat)
        spearman, _ = spearmanr(short_flat, res_flat)
        axes[0, 0].text(0.05, 0.95, f'Pearson r: {pearson:.3f}\nSpearman ρ: {spearman:.3f}',
                    transform=axes[0, 0].transAxes, va='top',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Plot 2: Resistance Distance Heatmap
    im = axes[0, 1].imshow(resistance_distances.cpu().numpy(), cmap='viridis')
    axes[0, 1].set_title('Resistance Distance Matrix')
    plt.colorbar(im, ax=axes[0, 1])
    
    # Plot 3: If labels are provided, show class-based analysis
    if labels is not None:
        labels_np = labels.cpu().numpy() if isinstance(labels, torch.Tensor) else labels
        
        # Collect intra-class and inter-class resistance distances
        intra_res = []
        inter_res = []
        
        for i in range(num_nodes):
            for j in range(i+1, num_nodes):
                if labels_np[i] == labels_np[j]:
                    intra_res.append(resistance_distances[i, j].item())
                else:
                    inter_res.append(resistance_distances[i, j].item())
        
        # Box plot of intra vs inter class distances
        box_data = [intra_res, inter_res]
        axes[1, 0].boxplot(box_data, labels=['Same Class', 'Different Classes'])
        axes[1, 0].set_ylabel('Resistance Distance')
        axes[1, 0].set_title('Intra vs Inter-Class Resistance Distances')
        
        # Add statistical values
        for i, data in enumerate(box_data):
            y = np.mean(data)
            axes[1, 0].text(i+1, y, f'Mean: {y:.3f}', ha='center', va='bottom',
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Plot 4: If node features are provided, analyze feature similarity vs resistance
    if node_features is not None:
        # Calculate pairwise feature similarities (using cosine similarity)
        features_np = node_features.cpu().numpy() if isinstance(node_features, torch.Tensor) else node_features
        
        from sklearn.metrics.pairwise import cosine_similarity
        feature_sim = cosine_similarity(features_np)
        
        # Flatten for plotting
        feat_sim_flat = feature_sim[mask]
        
        # Scatter plot
        axes[1, 1].scatter(feat_sim_flat, res_flat, alpha=0.5, s=10)
        axes[1, 1].set_xlabel('Feature Cosine Similarity')
        axes[1, 1].set_ylabel('Resistance Distance')
        axes[1, 1].set_title('Resistance Distance vs Feature Similarity')
        
        # Calculate correlation
        if len(res_flat) > 1:
            pearson, _ = pearsonr(feat_sim_flat, res_flat)
            spearman, _ = spearmanr(feat_sim_flat, res_flat)
            axes[1, 1].text(0.05, 0.95, f'Pearson r: {pearson:.3f}\nSpearman ρ: {spearman:.3f}',
                        transform=axes[1, 1].transAxes, va='top',
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
        
        # Store in metrics dictionary
        metrics_dict['feature_similarity'] = torch.tensor(feature_sim)
    
    plt.tight_layout()
    
    return fig, metrics_dict

def calculate_edge_homophily_matrix(node_homophily, edge_index=None):
    """
    Calculate a matrix of edge homophily values, where each entry (i,j) represents 
    the average homophily between nodes i and j.
    
    Args:
        node_homophily: Tensor of shape [num_nodes] containing node homophily values
        edge_index: Optional tensor of shape [2, num_edges] to mask non-connected nodes
                   If provided, only connected nodes will have non-zero entries
    
    Returns:
        edge_homophily_matrix: Tensor of shape [num_nodes, num_nodes] where each entry (i,j)
                               represents the average homophily of nodes i and j
    """
    import torch
    
    # Get number of nodes
    num_nodes = len(node_homophily)
    
    # Convert to tensor if not already
    if not isinstance(node_homophily, torch.Tensor):
        node_homophily = torch.tensor(node_homophily, dtype=torch.float32)
    
    # Create matrices where each row is a copy of node_homophily
    homophily_i = node_homophily.view(-1, 1).expand(num_nodes, num_nodes)
    homophily_j = node_homophily.view(1, -1).expand(num_nodes, num_nodes)
    
    # Calculate edge homophily as average of node homophilies
    edge_homophily_matrix = (homophily_i + homophily_j) / 2.0
    
    # If edge_index is provided, mask out entries for non-connected nodes
    if edge_index is not None:
        # Create an adjacency matrix mask (including self-loops)
        adj_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, 
                              device=node_homophily.device)
        
        # Set entries for connected nodes to True
        src, dst = edge_index
        adj_mask[src, dst] = True
        
        # Add self-loops (set diagonal to True)
        adj_mask.fill_diagonal_(True)
        
        # Apply mask - set non-connected entries to zero
        edge_homophily_matrix = edge_homophily_matrix * adj_mask.float()
    
    return edge_homophily_matrix

def visualize_edge_homophily(edge_homophily_matrix, edge_index=None, node_labels=None, title="Edge Homophily", save_path=None):
    """
    Visualize the edge homophily matrix.
    
    Args:
        edge_homophily_matrix: Tensor of shape [num_nodes, num_nodes] of edge homophily values
        edge_index: Optional tensor of shape [2, num_edges] for graph structure
        node_labels: Optional tensor of node labels for class-based analysis
        title: Title for the plot
        save_path: Path to save the figure
    
    Returns:
        fig: The matplotlib figure
    """
    import torch
    import matplotlib.pyplot as plt
    import numpy as np
    import networkx as nx
    from torch_geometric.data import Data
    
    # Convert to numpy if it's a tensor
    if isinstance(edge_homophily_matrix, torch.Tensor):
        homophily_np = edge_homophily_matrix.cpu().numpy()
    else:
        homophily_np = edge_homophily_matrix
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    
    # Plot 1: Heatmap of edge homophily matrix
    im = axes[0].imshow(homophily_np, cmap='viridis', interpolation='nearest')
    axes[0].set_title('Edge Homophily Matrix')
    plt.colorbar(im, ax=axes[0])
    
    # Plot 2: Graph with edge colors based on homophily (if edge_index is provided)
    if edge_index is not None:
        # Create graph
        num_nodes = homophily_np.shape[0]
        G = nx.Graph()
        G.add_nodes_from(range(num_nodes))
        
        # Extract edges and their homophily values
        src, dst = edge_index
        edge_list = []
        edge_weights = []
        
        for i in range(len(src)):
            s, d = src[i].item(), dst[i].item()
            if s != d:  # Skip self-loops
                homophily_val = homophily_np[s, d]
                edge_list.append((s, d))
                edge_weights.append(homophily_val)
        
        # Add edges to graph
        G.add_edges_from(edge_list)
        
        # Get node positions using spring layout
        pos = nx.spring_layout(G, seed=42)
        
        # Draw nodes
        if node_labels is not None:
            # Color nodes by label
            labels_np = node_labels.cpu().numpy() if isinstance(node_labels, torch.Tensor) else node_labels
            unique_labels = np.unique(labels_np)
            cmap = plt.cm.get_cmap('tab10', len(unique_labels))
            node_colors = [cmap(labels_np[node]) for node in G.nodes()]
            
            # Create legend handles
            legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cmap(i), 
                                      markersize=10, label=f'Class {label}') 
                            for i, label in enumerate(unique_labels)]
        else:
            # Default node color
            node_colors = 'skyblue'
            legend_handles = []
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=100, ax=axes[1])
        
        # Draw edges with color based on homophily
        edges = nx.draw_networkx_edges(
            G, pos, 
            edge_color=edge_weights,
            edge_cmap=plt.cm.viridis,
            edge_vmin=min(edge_weights),
            edge_vmax=max(edge_weights),
            width=2,
            ax=axes[1]
        )
        
        # Add colorbar for edges
        sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, 
                                 norm=plt.Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=axes[1])
        cbar.set_label('Edge Homophily')
        
        # Add legend if node labels were provided
        if node_labels is not None:
            axes[1].legend(handles=legend_handles, loc='upper right')
        
        axes[1].set_title('Graph with Edge Homophily')
        axes[1].axis('off')
    else:
        # If no edge_index, plot histogram of edge homophily values
        # Flatten the matrix, excluding zeros if they represent non-edges
        flat_homophily = homophily_np.flatten()
        
        # If many zeros, they might be from masking - exclude them
        if (flat_homophily == 0).sum() > 0.5 * len(flat_homophily):
            flat_homophily = flat_homophily[flat_homophily > 0]
        
        # Plot histogram
        axes[1].hist(flat_homophily, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
        axes[1].set_xlabel('Edge Homophily')
        axes[1].set_ylabel('Frequency')
        axes[1].set_title('Distribution of Edge Homophily Values')
        axes[1].grid(alpha=0.3)
    
    # Set overall title
    fig.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    # Save if path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

def create_alpha_from_homophily(node_homophily, edge_index=None, mode='direct', invert=False):
    """
    Create a custom alpha matrix for diffusion based on node homophily values.
    
    This allows for diffusion rates to be influenced by node homophily, where
    similar nodes (in terms of homophily) have faster or slower diffusion between them.
    
    Args:
        node_homophily: Tensor of shape [num_nodes] containing node homophily values
        edge_index: Optional tensor of shape [2, num_edges] to mask non-connected nodes
        mode: How to use homophily for alpha ('direct', 'similarity', 'difference')
            - 'direct': Use average of homophily_i and homophily_j as alpha
            - 'similarity': Use similarity of homophily values as alpha
            - 'difference': Use inverse of difference in homophily as alpha
        invert: Whether to invert the relationship (higher homophily -> slower diffusion)
        
    Returns:
        alpha_matrix: Custom alpha matrix for diffusion based on homophily
    """
    import torch
    
    # Get number of nodes
    num_nodes = len(node_homophily)
    
    # Convert to tensor if not already
    if not isinstance(node_homophily, torch.Tensor):
        node_homophily = torch.tensor(node_homophily, dtype=torch.float32)
    
    # Create matrices where each row/column is the homophily of a node
    homophily_i = node_homophily.view(-1, 1).expand(num_nodes, num_nodes)
    homophily_j = node_homophily.view(1, -1).expand(num_nodes, num_nodes)
    
    if mode == 'direct':
        # Use the average of node homophilies directly as alpha
        alpha_matrix = (homophily_i + homophily_j) / 2.0
    
    elif mode == 'similarity':
        # Use similarity in homophily values
        max_diff = torch.max(node_homophily) - torch.min(node_homophily)
        if max_diff > 0:
            alpha_matrix = 1.0 - torch.abs(homophily_i - homophily_j) / max_diff
        else:
            # All nodes have the same homophily - set all alpha to 1.0
            alpha_matrix = torch.ones((num_nodes, num_nodes), dtype=torch.float32)
    
    elif mode == 'difference':
        # Use inverse of difference in homophily values
        # Calculate inverse difference (adding small epsilon to avoid division by zero)
        alpha_matrix = 1.0 / (torch.abs(homophily_i - homophily_j) + 1e-5)
        
        # Cap values to avoid extremely large alphas
        alpha_matrix = torch.clamp(alpha_matrix, 0.1, 100.0)
    
    else:
        raise ValueError(f"Unknown mode: {mode}. Choose from 'direct', 'similarity', or 'difference'.")
    
    # Invert relationship if requested
    if invert:
        # For values in [0,1], simply use 1-value
        if torch.all((alpha_matrix >= 0) & (alpha_matrix <= 1)):
            alpha_matrix = 1.0 - alpha_matrix
        else:
            # For other ranges, normalize to [0,1] first, then invert, then scale back
            alpha_min = alpha_matrix.min()
            alpha_max = alpha_matrix.max()
            if alpha_max > alpha_min:
                alpha_matrix = alpha_max - (alpha_matrix - alpha_min)
            # If all values are the same, inversion doesn't change anything
    
    # If edge_index is provided, mask out entries for non-connected nodes
    if edge_index is not None:
        # Create an adjacency matrix mask (including self-loops)
        adj_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, 
                              device=node_homophily.device)
        
        # Set entries for connected nodes to True
        src, dst = edge_index
        adj_mask[src, dst] = True
        
        # Add self-loops (set diagonal to True)
        adj_mask.fill_diagonal_(True)
        
        # Apply mask - set non-connected entries to 1.0 (neutral value)
        neutral_alpha = 1.0  # use 1.0 as the neutral value (no scaling effect)
        masked_alpha = torch.ones_like(alpha_matrix) * neutral_alpha
        masked_alpha[adj_mask] = alpha_matrix[adj_mask]
        alpha_matrix = masked_alpha
    
    return alpha_matrix



#### Matern covariance

def calculate_matern_covariance(laplacian, nu, kappa, polynomial_type="rational", order=15):
    try:
        # Try the eigendecomposition method first
        eigvals, eigvecs = torch.linalg.eigh(laplacian)
        transformed_eigvals = (2 * nu / kappa**2 + eigvals) ** (-nu)
        return eigvecs @ torch.diag(transformed_eigvals) @ eigvecs.T
    except RuntimeError:
        # Fallback to polynomial approximation
        if polynomial_type == "chebyshev":
            return chebyshev_matern_approx(laplacian, nu, kappa, order)
        else:  # Default to rational approximation
            return rational_matern_approx(laplacian, nu, kappa, order)

def calculate_exponential_covariance(laplacian, kappa):
    try:
        # Try the eigendecomposition method first
        eigvals, eigvecs = torch.linalg.eigh(laplacian)
        transformed_eigvals = torch.exp((-kappa ** 2 /2) * eigvals)
        return eigvecs @ torch.diag(transformed_eigvals) @ eigvecs.T
    except RuntimeError:
        # Fallback to Taylor series approximation
        return taylor_exp_approx(laplacian, kappa)

def compute_chebyshev_coeff(alpha, k):
    """
    Compute coefficients for Chebyshev approximation with careful handling of
    numerical issues that can arise with gamma functions for negative values.
    
    Args:
        alpha: Power parameter (can be negative)
        k: Order of Chebyshev polynomial
        
    Returns:
        Coefficient value
    """
    import math
    
    # Prevent math domain errors by checking input values
    if k == 0:
        try:
            # Direct calculation when safe
            return 2**alpha * math.gamma(alpha + 1/2) / (math.sqrt(math.pi) * math.gamma(alpha + 1))
        except ValueError:
            # For negative alpha where gamma might fail, use log-space calculation
            log_coeff = alpha * math.log(2) + math.lgamma(alpha + 1/2) - (0.5 * math.log(math.pi) + math.lgamma(alpha + 1))
            return math.exp(log_coeff)
    else:
        try:
            # Direct calculation for k>0
            num = (-1)**k * math.gamma(alpha + 1/2) * math.gamma(alpha - k + 1)
            denom = math.sqrt(math.pi) * math.factorial(k) * math.gamma(alpha + 1) * (1 - 2*k)
            return 2**alpha * num / denom
        except ValueError:
            # If gamma(alpha - k + 1) causes domain error
            if alpha - k + 1 <= 0 and (alpha - k + 1) == int(alpha - k + 1):
                # If alpha-k+1 is a non-positive integer, the coefficient should be 0
                return 0.0
            else:
                # Use log-space calculation to handle large values and prevent overflow
                # Calculate sign separately since we can't take log of negative numbers
                sign = (-1)**k
                
                # Use log for magnitudes only
                log_num = math.lgamma(alpha + 1/2) + math.lgamma(alpha - k + 1)
                
                # Use lgamma instead of log(factorial) for large k to prevent overflow
                if k > 20:  # For large k, use lgamma(k+1) instead of log(k!)
                    log_factorial_k = math.lgamma(k + 1)
                else:
                    log_factorial_k = math.log(math.factorial(k))
                
                # Protect against division by zero when 1-2k = 0
                if abs(1 - 2*k) < 1e-10:
                    # When k is very close to 0.5, return 0 as the coefficient
                    return 0.0
                    
                log_denom = 0.5 * math.log(math.pi) + log_factorial_k + math.lgamma(alpha + 1) + math.log(abs(1 - 2*k))
                log_coeff = alpha * math.log(2) + log_num - log_denom
                
                # Account for sign from (1-2k) separately (we already handled (-1)^k above)
                sign *= (1 if (1-2*k) > 0 else -1)
                return sign * math.exp(log_coeff)

def chebyshev_matern_approx(L, nu, kappa, order=20):
    """
    Approximate Matérn kernel using Chebyshev polynomial approximation
    
    Args:
        L: Graph Laplacian matrix
        nu: Smoothness parameter
        kappa: Scaling parameter
        order: Order of Chebyshev approximation
        
    Returns:
        K: Approximated Matérn kernel matrix
    """
    import math
    n = L.shape[0]
    I = torch.eye(n, device=L.device)
    
    # Scale L to have eigenvalues in [-1, 1]
    # We use a simple estimate based on Gershgorin circle theorem
    row_sums = torch.sum(torch.abs(L), dim=1)
    max_eigenval = torch.max(row_sums)
    scaled_L = 2 * L / max_eigenval - I
    
    # Set up the operator A = (2*nu/(kappa^2))*I + L
    term = 2 * nu / (kappa**2)
    scaled_A = (2/max_eigenval) * term * I + scaled_L
    
    # Approximate (A)^{-nu} using truncated Chebyshev series
    # Initialize with identity and first Chebyshev polynomial
    T_prev = I
    T_curr = scaled_A
    
    # For negative power, use a direct polynomial approximation
    # Initialize result with appropriate coefficient
    c0 = compute_chebyshev_coeff(-nu, 0)
    result = c0 * T_prev
    
    c1 = compute_chebyshev_coeff(-nu, 1)
    result += c1 * T_curr
    
    # Recurrence relation for higher order terms
    for i in range(2, order):
        T_next = 2 * scaled_A @ T_curr - T_prev
        T_prev = T_curr
        T_curr = T_next
        
        ci = compute_chebyshev_coeff(-nu, i)
        result += ci * T_curr
    
    # Final scaling to account for domain transformation
    scaling_factor = (max_eigenval / 2) ** nu
    K = scaling_factor * result
    
    # Add jitter to ensure positive definiteness
    jitter = 1e-6 * torch.eye(n, device=L.device)
    K = K + jitter
    
    return K

def rational_matern_approx(L, nu, kappa, order=15):
    """
    Approximate Matérn kernel using rational approximation that ensures positive definiteness
    
    Args:
        L: Graph Laplacian matrix
        nu: Smoothness parameter
        kappa: Scaling parameter
        order: Order of approximation
        
    Returns:
        K: Approximated Matérn kernel matrix
    """
    import math
    n = L.shape[0]
    I = torch.eye(n, device=L.device)
    
    # Set up the operator A = (2*nu/(kappa^2))*I + L
    A = (2 * nu / (kappa**2)) * I + L
    
    # Approximate (A)^{-nu} using a rational approximation approach
    # First compute an approximation of A^{-1} via Neumann series
    
    # Estimate largest eigenvalue for scaling
    row_sums = torch.sum(torch.abs(L), dim=1)
    spectral_radius = torch.max(row_sums)
    
    # Scale A to ensure convergence of the series
    scaling = 1.0 / (spectral_radius + 2*nu/(kappa**2))
    scaled_A = scaling * A
    
    # Compute approximate inverse using Neumann series: 
    # (I - (I - scaled_A))^{-1} = I + (I-scaled_A) + (I-scaled_A)^2 + ...
    B = I - scaled_A
    A_inv_approx = I.clone()
    B_power = I.clone()
    
    for i in range(1, order):  # 15 terms should be sufficient
        B_power = B_power @ B
        A_inv_approx += B_power
    
    # Undo scaling
    A_inv_approx = scaling * A_inv_approx
    
    # For nu ≤ 1, we can directly use A_inv_approx^nu
    if nu <= 1:
        # Try eigendecomposition of A_inv_approx safely
        try:
            eigvals, eigvecs = torch.linalg.eigh(A_inv_approx)
            # Ensure all eigenvalues are positive
            eigvals = torch.clamp(eigvals, min=1e-10)
            # Compute power
            K = eigvecs @ torch.diag(eigvals.pow(nu)) @ eigvecs.T
        except RuntimeError:
            # If eigendecomposition fails, use a different approach
            # For nu close to 1, A_inv_approx is already close to A^{-nu}
            jitter = 1e-6 * torch.eye(n, device=L.device)
            K = A_inv_approx + jitter
            if nu < 0.9:
                # For smaller nu, use weighted combination
                alpha = nu  # Weight factor
                K = alpha * A_inv_approx + (1 - alpha) * I
    else:
        # For nu > 1, use a recursive approach
        # Compute A^{-1} first, then recursively compute powers
        integer_part = int(nu)
        fractional_part = nu - integer_part
        
        # Compute A^{-integer_part}
        K = A_inv_approx.clone()
        for _ in range(1, integer_part):
            K = K @ A_inv_approx
        
        # If there's a fractional part, compute it
        if fractional_part > 0:
            try:
                # Try eigendecomposition on the current K
                eigvals, eigvecs = torch.linalg.eigh(K)
                # Ensure all eigenvalues are positive
                eigvals = torch.clamp(eigvals, min=1e-10)
                # Apply fractional power
                K = eigvecs @ torch.diag(eigvals.pow(fractional_part)) @ eigvecs.T
            except RuntimeError:
                # If eigendecomposition fails, approximate the fractional part
                alpha = fractional_part  # Weight factor
                K_frac = alpha * A_inv_approx + (1 - alpha) * I
                K = K @ K_frac
    
    # Add small jitter to diagonal to ensure positive definiteness
    jitter = 1e-6 * torch.eye(n, device=L.device)
    K = K + jitter
    
    return K

def taylor_exp_approx(L, kappa, order=10):
    """
    Approximate exponential kernel exp(-kappa^2/2 * L) using Taylor series
    
    Args:
        L: Graph Laplacian matrix
        kappa: Scaling parameter
        order: Order of Taylor approximation
        
    Returns:
        K: Approximated exponential kernel matrix
    """
    n = L.shape[0]
    I = torch.eye(n, device=L.device)
    
    # Compute exp(-kappa^2/2 * L) using Taylor series
    scaled_L = -kappa**2/2 * L
    K = I.clone()  # First term: I
    L_power = I.clone()
    factorial = 1.0
    
    for i in range(1, order):
        factorial *= i
        L_power = L_power @ scaled_L  # L^i
        K += L_power / factorial
    
    return K


