import os
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
from scipy import sparse
import seaborn as sns
import datetime

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

def generate_homophilic_graph(n_per_class=30, avg_degree=5, homophily_ratio=0.8):
    """Generate a homophilic graph where nodes prefer to connect to the same class"""
    n_nodes = n_per_class * 2
    
    # Create empty graph
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    
    # Assign labels: first n_per_class nodes are class 0, rest are class 1
    labels = np.zeros(n_nodes)
    labels[n_per_class:] = 1
    
    # Calculate number of edges needed for desired average degree
    n_edges = int(n_nodes * avg_degree / 2)
    
    # Add edges while maintaining homophily ratio
    edges_added = 0
    max_attempts = n_edges * 10  # Safeguard against infinite loops
    attempts = 0
    
    while edges_added < n_edges and attempts < max_attempts:
        # Pick a random node
        node_i = np.random.randint(0, n_nodes)
        label_i = labels[node_i]
        
        # Decide if this should be homophilic (same class) or heterophilic (different class) connection
        same_class = np.random.random() < homophily_ratio
        
        # Potential nodes to connect to
        if same_class:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] == label_i]
        else:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] != label_i]
        
        if potential_nodes:
            node_j = np.random.choice(potential_nodes)
            
            # Add edge if it doesn't already exist
            if not G.has_edge(node_i, node_j):
                G.add_edge(node_i, node_j)
                edges_added += 1
        
        attempts += 1
    
    return G, labels

def generate_heterophilic_graph(n_per_class=30, avg_degree=5, heterophily_ratio=0.8):
    """Generate a heterophilic graph where nodes prefer to connect to different classes"""
    n_nodes = n_per_class * 2
    
    # Create empty graph
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    
    # Assign labels: first n_per_class nodes are class 0, rest are class 1
    labels = np.zeros(n_nodes)
    labels[n_per_class:] = 1
    
    # Calculate number of edges needed for desired average degree
    n_edges = int(n_nodes * avg_degree / 2)
    
    # Add edges while maintaining heterophily ratio
    edges_added = 0
    max_attempts = n_edges * 10  # Safeguard against infinite loops
    attempts = 0
    
    while edges_added < n_edges and attempts < max_attempts:
        # Pick a random node
        node_i = np.random.randint(0, n_nodes)
        label_i = labels[node_i]
        
        # Decide if this should be heterophilic (different class) or homophilic (same class) connection
        different_class = np.random.random() < heterophily_ratio
        
        # Potential nodes to connect to
        if different_class:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] != label_i]
        else:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] == label_i]
        
        if potential_nodes:
            node_j = np.random.choice(potential_nodes)
            
            # Add edge if it doesn't already exist
            if not G.has_edge(node_i, node_j):
                G.add_edge(node_i, node_j)
                edges_added += 1
        
        attempts += 1
    
    return G, labels

def calculate_homophily_score(G, labels):
    """Calculate the homophily score of a graph (ratio of same-class edges)"""
    same_class_edges = 0
    total_edges = 0
    
    for u, v in G.edges():
        if labels[u] == labels[v]:
            same_class_edges += 1
        total_edges += 1
    
    return same_class_edges / total_edges if total_edges > 0 else 0

def calculate_graph_laplacian(G, normalized=True):
    """Calculate the graph Laplacian matrix"""
    import scipy.sparse as sp
    
    n_nodes = G.number_of_nodes()
    
    # Get adjacency matrix
    adj_matrix = nx.adjacency_matrix(G)
    
    # Compute degree matrix
    degrees = np.array(adj_matrix.sum(axis=1)).flatten()
    degree_matrix = sp.diags(degrees)
    
    if normalized:
        # Compute D^(-1/2)
        with np.errstate(divide='ignore'):
            d_inv_sqrt = np.power(degrees + 1e-12, -0.5)
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
        
        # Normalized Laplacian: I - D^(-1/2) A D^(-1/2)
        normalized_adj = d_inv_sqrt_mat @ adj_matrix @ d_inv_sqrt_mat
        identity = sp.eye(n_nodes)
        laplacian = identity - normalized_adj
    else:
        # Standard Laplacian: D - A
        laplacian = degree_matrix - adj_matrix
    
    # Convert to dense tensor
    laplacian_dense = torch.tensor(laplacian.todense(), dtype=torch.float32)
    return laplacian_dense

def sample_gaussian_process(laplacian, nu, kappa=1.0):
    """Sample from a Gaussian process with covariance (2*nu/kappa^2 + Lambda)^(-nu)"""
    # Eigendecomposition of Laplacian
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # Compute the operator (2*nu/kappa^2 + Lambda)^(-nu/2)
    operator = torch.pow(2*nu/kappa**2 + eigvals, -nu/2)
    
    # Generate standard normal samples
    z = torch.randn(laplacian.shape[0])
    
    # Apply the operator to get samples from the desired Gaussian process
    sample = eigvecs @ (operator.unsqueeze(1) * z.unsqueeze(0)).T
    
    # Ensure we return a 1D tensor
    if sample.dim() > 1 and sample.shape[1] == 1:
        sample = sample.squeeze(1)  # Remove singleton dimension if it's a column vector
    elif sample.dim() > 1:
        # This is a matrix, but we need a vector - take the first column
        sample = sample[:, 0]
    
    # Take the square of the sample to ensure non-negative values (variance-like)
    sample = torch.square(sample)
    
    return sample  # This should be 1D tensor of size = num_nodes

def compute_covariance_matrix(laplacian, nu, kappa=1.0):
    """Compute the covariance matrix (2*nu/kappa^2 + Lambda)^(-nu)"""
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    covariance = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
    return covariance

def compute_edge_covariance(G, cov_matrix):
    """Compute edge weights based on covariance between nodes"""
    edge_weights = {}
    for u, v in G.edges():
        edge_weights[(u, v)] = abs(cov_matrix[u, v].item())
    
    # Normalize edge weights to [0, 1]
    if edge_weights:
        values = np.array(list(edge_weights.values()))
        min_val, max_val = values.min(), values.max()
        scale = max_val - min_val if max_val > min_val else 1.0
        
        for edge in edge_weights:
            edge_weights[edge] = (edge_weights[edge] - min_val) / scale
    
    return edge_weights

def plot_graph(G, labels, pos, title, save_path=None):
    """Plot a graph with nodes colored by class"""
    plt.figure(figsize=(10, 8))
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, alpha=0.3, width=1.0)
    
    # Draw nodes colored by class
    class0_nodes = [i for i, label in enumerate(labels) if label == 0]
    class1_nodes = [i for i, label in enumerate(labels) if label == 1]
    
    nx.draw_networkx_nodes(G, pos, nodelist=class0_nodes, node_color=['blue'] * len(class0_nodes), 
                          node_size=100, alpha=0.8, label='Class 0')
    nx.draw_networkx_nodes(G, pos, nodelist=class1_nodes, node_color=['red'] * len(class1_nodes), 
                          node_size=100, alpha=0.8, label='Class 1')
    
    plt.title(title, fontsize=16)
    plt.legend()
    plt.axis('off')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=600, bbox_inches='tight')
    
    # Disable show
    # plt.show()

def plot_covariance_matrix(cov_matrix, title, save_path=None):
    """Plot the covariance matrix"""
    plt.figure(figsize=(10, 8))
    
    sns.heatmap(cov_matrix.numpy(), cmap='viridis', square=True)
    
    plt.title(title, fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=600, bbox_inches='tight')
    
    # Disable show
    # plt.show()

def plot_gaussian_process_result(G, labels, values, edge_cov, pos, title, save_path=None):
    """Plot Gaussian process results with separate figures for clarity"""
    # 1. Plot nodes by class with edge covariance
    fig1, ax1 = plt.subplots(figsize=(12, 10))
    
    # Create edge colors based on covariance
    edges = list(G.edges())
    edge_colors = [edge_cov.get(edge, edge_cov.get((edge[1], edge[0]), 0.1)) for edge in edges]
    
    # Draw edges with covariance coloring
    nx.draw_networkx_edges(G, pos, edgelist=edges, width=2.0, 
                          edge_color=edge_colors, edge_cmap=plt.cm.Blues, alpha=0.7, ax=ax1)
    
    # Draw nodes colored by class
    class0_nodes = [i for i, label in enumerate(labels) if label == 0]
    class1_nodes = [i for i, label in enumerate(labels) if label == 1]

    print(len(class0_nodes), len(class1_nodes))
    
    nx.draw_networkx_nodes(G, pos, nodelist=class0_nodes, node_color=['blue'] * len(class0_nodes), 
                          node_size=150, alpha=0.8, label='Class 0', ax=ax1)
    nx.draw_networkx_nodes(G, pos, nodelist=class1_nodes, node_color=['red'] * len(class1_nodes), 
                          node_size=150, alpha=0.8, label='Class 1', ax=ax1)
    
    # Add colorbar for edge weights
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues)
    sm.set_array([])
    plt.colorbar(sm, ax=ax1, label='Edge Covariance', shrink=0.75)
    
    ax1.set_title(f"{title} - Nodes by Class, Edges by Covariance", fontsize=16)
    ax1.legend(loc='upper right')
    ax1.set_axis_off()
    plt.tight_layout()
    
    if save_path:
        class_save_path = save_path.replace('.png', '_class.png')
        plt.savefig(class_save_path, dpi=600, bbox_inches='tight')
    
    # Disable show
    # plt.show()
    
    # 2. Plot nodes by Gaussian process value with edge covariance
    fig2, ax2 = plt.subplots(figsize=(12, 10))
    
    # Draw edges with covariance coloring
    nx.draw_networkx_edges(G, pos, edgelist=edges, width=2.0, 
                          edge_color=edge_colors, edge_cmap=plt.cm.Blues, alpha=0.7, ax=ax2)
    
    # Draw nodes colored by Gaussian process values
    if torch.is_tensor(values):
        values_np = values.numpy()
    else:
        values_np = np.array(values)
    
    # Add debugging info
    print(f"Number of nodes in G: {G.number_of_nodes()}")
    print(f"Shape of values_np before processing: {values_np.shape}")
    
    # Ensure values_np is a 1D array with the same length as the number of nodes
    if len(values_np.shape) > 1:
        # If it's a matrix, take diagonal or first column
        if values_np.shape[0] == G.number_of_nodes() and values_np.shape[1] == G.number_of_nodes():
            print("Detected covariance matrix instead of sample vector. Using diagonal values.")
            values_np = np.diag(values_np)
        elif values_np.shape[0] == G.number_of_nodes():
            print("Using first column of matrix")
            values_np = values_np[:, 0]
        elif values_np.shape[1] == G.number_of_nodes():
            print("Using first row of matrix")
            values_np = values_np[0, :]
    
    # Final check to ensure we have the right dimensions
    if len(values_np) != G.number_of_nodes():
        print(f"WARNING: values_np length ({len(values_np)}) doesn't match node count ({G.number_of_nodes()})")
        # As a last resort, just create random values
        values_np = np.random.randn(G.number_of_nodes())
    
    print(f"Final shape of values_np: {values_np.shape}")

    # Create a node collection with proper colors
    # Ensure values_np is used correctly - it should be a 1D array matching the number of nodes
    nodes = nx.draw_networkx_nodes(G, pos, node_size=150, 
                                   node_color=values_np.flatten(), 
                                   cmap=plt.cm.viridis, alpha=0.9, 
                                   edgecolors='black', linewidths=0.5, ax=ax2)
    
    # Add colorbars
    plt.colorbar(nodes, ax=ax2, label='Gaussian Process Value', shrink=0.75)
    
    edge_sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues)
    edge_sm.set_array([])
    plt.colorbar(edge_sm, ax=ax2, label='Edge Covariance', shrink=0.75)
    
    ax2.set_title(f"{title} - Nodes by Gaussian Process, Edges by Covariance", fontsize=16)
    ax2.set_axis_off()
    plt.tight_layout()
    
    if save_path:
        gp_save_path = save_path.replace('.png', '_gp.png')
        plt.savefig(gp_save_path, dpi=600, bbox_inches='tight')
    
    # Disable show
    # plt.show()

def plot_heterophilic_varying_nu(hetero_graph, hetero_labels, hetero_pos, hetero_laplacian, kappa=1.0, save_path=None):
    """
    Creates a 2x2 plot showing heterophilic graph with varying nu values and node coloring
    
    Args:
        hetero_graph: Heterophilic graph
        hetero_labels: Node labels
        hetero_pos: Node positions
        hetero_laplacian: Graph Laplacian
        kappa: Scaling parameter
        save_path: Path to save the figure
    """
    # Set larger font sizes for LaTeX document
    plt.rcParams.update({
        'font.size': 30,           # Increased from 16
        'axes.titlesize': 30,      # Increased from 18
        'axes.labelsize': 26,      # Increased from 16
        'xtick.labelsize': 26,     # Increased from 14
        'ytick.labelsize': 26,     # Increased from 14
        'legend.fontsize': 30,     # Increased from 14
        'figure.titlesize': 30     # Increased from 24
    })
    
    # Nu values to test
    nu_values = [0.1, 5.0]
    
    # Create a larger figure with a custom layout
    fig = plt.figure(figsize=(22, 16))
    
    # Create a GridSpec with extra space for legends and colorbar
    gs = plt.GridSpec(3, 4, height_ratios=[1, 1, 0.2], width_ratios=[0.15, 1, 1, 0.15], figure=fig)
    
    # Create axes for the plots - arranged in a 2x2 grid in the center
    ax1 = fig.add_subplot(gs[0, 1])  # Top left: nu=0.1, GP colors
    ax2 = fig.add_subplot(gs[0, 2])  # Top right: nu=0.1, label colors
    ax3 = fig.add_subplot(gs[1, 1])  # Bottom left: nu=5.0, GP colors
    ax4 = fig.add_subplot(gs[1, 2])  # Bottom right: nu=5.0, label colors
    axes = [ax1, ax2, ax3, ax4]
    
    # Create axis for the colorbar on the left side
    cbar_ax = fig.add_subplot(gs[0:2, 0])  # Colorbar spans top two rows
    
    # Create axes for the legends at the bottom
    corr_legend_ax = fig.add_subplot(gs[2, 1])   # Correlation legend under left column
    class_legend_ax = fig.add_subplot(gs[2, 2])  # Class legend under right column
    
    # Turn off axes for legend areas
    class_legend_ax.axis('off')
    corr_legend_ax.axis('off')
    
    # First, compute all samples and covariances
    samples = []
    edge_weights_list = []
    covariance_matrices = []
    
    for nu in nu_values:
        # Sample from Gaussian process
        sample = sample_gaussian_process(hetero_laplacian, nu, kappa)
        
        # Compute covariance
        cov = compute_covariance_matrix(hetero_laplacian, nu, kappa)
        covariance_matrices.append(cov)
        
        # Compute edge covariance
        edge_cov = compute_edge_covariance(hetero_graph, cov)
        
        # Get values as numpy array
        if torch.is_tensor(sample):
            values_np = sample.numpy()
        else:
            values_np = np.array(sample)
        
        samples.append(values_np)
        edge_weights_list.append(edge_cov)
    
    # Determine global min/max for node colors (Gaussian process values)
    all_values = np.concatenate(samples)
    # Ensure non-negative values for variance visualization
    vmin_nodes = 0.0  # Set minimum to 0 for variance
    vmax_nodes = all_values.max()
    node_norm = plt.Normalize(vmin_nodes, vmax_nodes)
    
    # Determine global min/max for edge colors (edge covariance)
    all_edge_weights = []
    for edge_cov in edge_weights_list:
        all_edge_weights.extend(list(edge_cov.values()))
    edge_min = min(all_edge_weights)
    edge_max = max(all_edge_weights)
    
    # Create the legends
    # Node class legend
    class0_patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', 
                             markersize=18, label='Class 0')
    class1_patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', 
                             markersize=18, label='Class 1')
    
    # Add node class legend
    class_legend_ax.legend(handles=[class0_patch, class1_patch], loc='center', 
                         fontsize=30, title="Node Classes", title_fontsize=30)
    
    # High correlation legend
    red_line = plt.Line2D([0], [0], color='red', lw=3.0, linestyle='dashed', 
                         label='High Matern correlation')
    
    # Add high correlation legend
    corr_legend_ax.legend(handles=[red_line], loc='center', fontsize=30)
    
    # Track high correlation edges for all plots
    all_high_corr_edges = []
    
    # Now plot each subplot
    for i, nu_idx in enumerate([0, 0, 1, 1]):  # 0=nu_0.1, 1=nu_5.0
        ax = axes[i]
        is_label_plot = i % 2 == 1  # Every second plot uses label colors
        nu = nu_values[nu_idx]
        
        # Get the precomputed values
        values_np = samples[nu_idx]
        edge_cov = edge_weights_list[nu_idx]
        cov_matrix = covariance_matrices[nu_idx]
        
        # Create edge colors based on covariance
        edges = list(hetero_graph.edges())
        edge_colors = [edge_cov.get(edge, edge_cov.get((edge[1], edge[0]), edge_min)) for edge in edges]
        
        # Normalize edge colors to global min/max
        edge_colors_norm = [(val - edge_min) / (edge_max - edge_min) for val in edge_colors]
        
        # Draw edges with covariance coloring
        nx.draw_networkx_edges(hetero_graph, hetero_pos, edgelist=edges, 
                              width=3.0, edge_color=edge_colors_norm, 
                              edge_cmap=plt.cm.Blues, alpha=0.7, ax=ax)
        
        # Find high correlation edges that aren't in the original graph
        cov_np = cov_matrix.numpy()
        
        # Use a percentile-based threshold
        off_diag_values = []
        n = cov_np.shape[0]
        for r in range(n):
            for c in range(r+1, n):  # Only upper triangle
                if not hetero_graph.has_edge(r, c):  # Only consider non-existing edges
                    off_diag_values.append((r, c, abs(cov_np[r, c])))
        
        # Sort by correlation value
        off_diag_values.sort(key=lambda x: x[2], reverse=True)
        
        # Take top 10 high correlation edges
        top_k = min(10, len(off_diag_values))
        high_corr_edges = []
        for r, c, val in off_diag_values[:top_k]:
            high_corr_edges.append((r, c))
        
        # Store high correlation edges for this subplot
        all_high_corr_edges.append(high_corr_edges)
        
        # Draw the high correlation edges in red
        if high_corr_edges:
            nx.draw_networkx_edges(hetero_graph, hetero_pos, edgelist=high_corr_edges,
                                  width=2.5, edge_color='red', style='dashed',
                                  alpha=0.8, ax=ax)
        
        # Draw nodes with appropriate coloring
        if is_label_plot:
            # Color by node labels
            class0_nodes = [i for i, label in enumerate(hetero_labels) if label == 0]
            class1_nodes = [i for i, label in enumerate(hetero_labels) if label == 1]
            
            nx.draw_networkx_nodes(hetero_graph, hetero_pos, nodelist=class0_nodes,
                                  node_color='blue', node_size=220, alpha=0.9,
                                  edgecolors='black', linewidths=1.0, ax=ax)
            
            nx.draw_networkx_nodes(hetero_graph, hetero_pos, nodelist=class1_nodes,
                                  node_color='red', node_size=220, alpha=0.9,
                                  edgecolors='black', linewidths=1.0, ax=ax)
        else:
            # Color by Gaussian process values
            nodes = nx.draw_networkx_nodes(hetero_graph, hetero_pos, 
                                      node_size=220, 
                                      node_color=[plt.cm.viridis(node_norm(val)) for val in values_np],
                                      alpha=0.9, 
                                      edgecolors='black', linewidths=1.0, ax=ax)
        
        # Add subplot title
        plot_type = "Node Labels" if is_label_plot else "Gaussian Process"
        ax.set_title(f"$\\nu = {nu}$, {plot_type}", fontsize=30)
        ax.set_axis_off()
    
    # Node value colorbar
    node_sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=node_norm)
    node_sm.set_array([])
    node_cbar = plt.colorbar(node_sm, cax=cbar_ax, orientation='vertical')
    node_cbar.set_label('Gaussian Process Value', fontsize=30)
    node_cbar.ax.tick_params(labelsize=26)
    
    # Adjust spacing between subplots
    plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.3, hspace=0.3)
    
    if save_path:
        # Save as PNG with high DPI
        plt.savefig(save_path, dpi=600, bbox_inches='tight')
        
        # Also save as PDF with high quality for LaTeX
        pdf_path = save_path.replace('.png', '.pdf')
        plt.savefig(pdf_path, format='pdf', dpi=600, bbox_inches='tight', 
                   pad_inches=0.1, transparent=False)
        print(f"Saved figures to {save_path} and {pdf_path}")
    
    # Disable show
    # plt.show()

def main():
    # Create results directory
    os.makedirs("results", exist_ok=True)
    
    # Parameters
    n_per_class = 30  # 30 nodes per class = 60 nodes total
    avg_degree = 5
    homophily_ratio = 0.8
    heterophily_ratio = 0.8
    nu_values = [0.01, 100]  # Two different values of nu
    kappa = 1.0
    
    # Add timestamp to plot filename for tracking different versions
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 1. Generate homophilic graph
    print("Generating homophilic graph...")
    homo_graph, homo_labels = generate_homophilic_graph(
        n_per_class=n_per_class, 
        avg_degree=avg_degree, 
        homophily_ratio=homophily_ratio
    )
    
    homo_score = calculate_homophily_score(homo_graph, homo_labels)
    print(f"Homophilic graph: {homo_graph.number_of_nodes()} nodes, {homo_graph.number_of_edges()} edges")
    print(f"Homophily score: {homo_score:.2f} (ratio of same-class edges)")
    
    # 2. Generate heterophilic graph
    print("\nGenerating heterophilic graph...")
    hetero_graph, hetero_labels = generate_heterophilic_graph(
        n_per_class=n_per_class, 
        avg_degree=avg_degree, 
        heterophily_ratio=heterophily_ratio
    )
    
    hetero_score = calculate_homophily_score(hetero_graph, hetero_labels)
    print(f"Heterophilic graph: {hetero_graph.number_of_nodes()} nodes, {hetero_graph.number_of_edges()} edges")
    print(f"Homophily score: {hetero_score:.2f} (ratio of same-class edges)")
    
    # 3. Compute layout positions
    print("\nComputing graph layouts...")
    homo_pos = nx.spring_layout(homo_graph, seed=42)
    hetero_pos = nx.spring_layout(hetero_graph, seed=42)
    
    # 4. Plot original graphs
    print("Plotting original graphs...")
    plot_graph(
        homo_graph,
        homo_labels,
        homo_pos,
        f"Homophilic Graph (Homophily = {homo_score:.2f})",
        save_path=f"results/homophilic_graph_simple.png"
    )
    
    plot_graph(
        hetero_graph,
        hetero_labels,
        hetero_pos,
        f"Heterophilic Graph (Homophily = {hetero_score:.2f})",
        save_path=f"results/heterophilic_graph_simple.png"
    )
    
    # 5. Compute graph Laplacians
    print("\nComputing graph Laplacians...")
    homo_laplacian = calculate_graph_laplacian(homo_graph)
    hetero_laplacian = calculate_graph_laplacian(hetero_graph)
    
    # 6. Visualize results for each nu value
    for nu in nu_values:
        print(f"\nProcessing nu = {nu}...")
        
        # Compute covariance matrices
        homo_cov = compute_covariance_matrix(homo_laplacian, nu, kappa)
        hetero_cov = compute_covariance_matrix(hetero_laplacian, nu, kappa)
        
        # Plot covariance matrices
        plot_covariance_matrix(
            homo_cov,
            f"Homophilic Graph Covariance (nu = {nu})",
            save_path=f"results/homophilic_covariance_simple_nu_{nu}.png"
        )
        
        plot_covariance_matrix(
            hetero_cov,
            f"Heterophilic Graph Covariance (nu = {nu})",
            save_path=f"results/heterophilic_covariance_simple_nu_{nu}.png"
        )
        
        # Sample Gaussian processes
        homo_sample = sample_gaussian_process(homo_laplacian, nu, kappa)
        hetero_sample = sample_gaussian_process(hetero_laplacian, nu, kappa)
        
        # Print some debug info
        print(f"Homo sample shape: {homo_sample.shape if hasattr(homo_sample, 'shape') else 'no shape'}")
        print(f"Hetero sample shape: {hetero_sample.shape if hasattr(hetero_sample, 'shape') else 'no shape'}")
        
        # Compute edge covariances
        homo_edge_cov = compute_edge_covariance(homo_graph, homo_cov)
        hetero_edge_cov = compute_edge_covariance(hetero_graph, hetero_cov)
        
        # Plot Gaussian process results
        plot_gaussian_process_result(
            homo_graph,
            homo_labels,
            homo_sample,
            homo_edge_cov,
            homo_pos,
            f"Homophilic Graph (nu = {nu})",
            save_path=f"results/homophilic_gaussian_simple_nu_{nu}.png"
        )
        
        plot_gaussian_process_result(
            hetero_graph,
            hetero_labels,
            hetero_sample,
            hetero_edge_cov,
            hetero_pos,
            f"Heterophilic Graph (nu = {nu})",
            save_path=f"results/heterophilic_gaussian_simple_nu_{nu}.png"
        )

    # 7. Create a special 3x1 plot for heterophilic graph with nu values 0.1, 1.0, 10.0
    print("\nCreating special plot for heterophilic graph with varying nu values...")
    plot_heterophilic_varying_nu(
        hetero_graph,
        hetero_labels,
        hetero_pos,
        hetero_laplacian,
        kappa=kappa,
        save_path=f"results/heterophilic_gaussian_varying_nu_{timestamp}.png"
    )

if __name__ == "__main__":
    main() 