import os
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
from scipy import sparse
import seaborn as sns

# Import necessary functions from existing scripts
from visualize_gaussian_process_on_cora import (
    calculate_graph_laplacian,
    sample_gaussian_process,
    compute_edge_covariance
)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

def generate_homophilic_graph(n_per_class=30, avg_degree=5, homophily_ratio=0.8):
    """
    Generate a homophilic graph where nodes prefer to connect to the same class
    
    Args:
        n_per_class: Number of nodes per class
        avg_degree: Average degree of nodes
        homophily_ratio: Probability of connecting to same class vs different class
        
    Returns:
        G: NetworkX graph
        labels: Node labels
    """
    n_nodes = n_per_class * 2
    
    # Create empty graph
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    
    # Assign labels: first n_per_class nodes are class 0, rest are class 1
    labels = np.zeros(n_nodes)
    labels[n_per_class:] = 1
    
    # Calculate number of edges needed for desired average degree
    n_edges = int(n_nodes * avg_degree / 2)
    
    # Calculate probability of intra vs inter-class edges
    edges_added = 0
    max_attempts = n_edges * 10  # Safeguard against infinite loops
    attempts = 0
    
    while edges_added < n_edges and attempts < max_attempts:
        # Pick a random node
        node_i = np.random.randint(0, n_nodes)
        label_i = labels[node_i]
        
        # Decide if this should be homophilic (same class) or heterophilic (different class) connection
        same_class = np.random.random() < homophily_ratio
        
        # Potential nodes to connect to
        if same_class:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] == label_i]
        else:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] != label_i]
        
        if potential_nodes:
            node_j = np.random.choice(potential_nodes)
            
            # Add edge if it doesn't already exist
            if not G.has_edge(node_i, node_j):
                G.add_edge(node_i, node_j)
                edges_added += 1
        
        attempts += 1
    
    if edges_added < n_edges:
        print(f"Warning: Only able to add {edges_added} of {n_edges} edges while maintaining homophily ratio")
    
    # Store labels as node attributes
    for i in range(n_nodes):
        G.nodes[i]['label'] = int(labels[i])
    
    return G, labels

def generate_heterophilic_graph(n_per_class=30, avg_degree=5, heterophily_ratio=0.8):
    """
    Generate a heterophilic graph where nodes prefer to connect to different classes
    
    Args:
        n_per_class: Number of nodes per class
        avg_degree: Average degree of nodes
        heterophily_ratio: Probability of connecting to different class vs same class
        
    Returns:
        G: NetworkX graph
        labels: Node labels
    """
    n_nodes = n_per_class * 2
    
    # Create empty graph
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    
    # Assign labels: first n_per_class nodes are class 0, rest are class 1
    labels = np.zeros(n_nodes)
    labels[n_per_class:] = 1
    
    # Calculate number of edges needed for desired average degree
    n_edges = int(n_nodes * avg_degree / 2)
    
    # Calculate probability of intra vs inter-class edges
    edges_added = 0
    max_attempts = n_edges * 10  # Safeguard against infinite loops
    attempts = 0
    
    while edges_added < n_edges and attempts < max_attempts:
        # Pick a random node
        node_i = np.random.randint(0, n_nodes)
        label_i = labels[node_i]
        
        # Decide if this should be heterophilic (different class) or homophilic (same class) connection
        different_class = np.random.random() < heterophily_ratio
        
        # Potential nodes to connect to
        if different_class:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] != label_i]
        else:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] == label_i]
        
        if potential_nodes:
            node_j = np.random.choice(potential_nodes)
            
            # Add edge if it doesn't already exist
            if not G.has_edge(node_i, node_j):
                G.add_edge(node_i, node_j)
                edges_added += 1
        
        attempts += 1
    
    if edges_added < n_edges:
        print(f"Warning: Only able to add {edges_added} of {n_edges} edges while maintaining heterophily ratio")
    
    # Store labels as node attributes
    for i in range(n_nodes):
        G.nodes[i]['label'] = int(labels[i])
    
    return G, labels

def calculate_homophily_score(G, labels):
    """
    Calculate the homophily score of a graph (ratio of same-class edges)
    
    Args:
        G: NetworkX graph
        labels: Node labels
        
    Returns:
        homophily_score: Ratio of same-class edges (0-1)
    """
    same_class_edges = 0
    total_edges = 0
    
    for u, v in G.edges():
        if labels[u] == labels[v]:
            same_class_edges += 1
        total_edges += 1
    
    return same_class_edges / total_edges if total_edges > 0 else 0

def plot_graph_with_labels(G, labels, title, pos=None, node_size=200, save_path=None):
    """
    Plot a graph with nodes colored by their labels
    
    Args:
        G: NetworkX graph
        labels: Node labels
        title: Title for the plot
        pos: Optional node positions
        node_size: Size of nodes
        save_path: Path to save the figure
    """
    if pos is None:
        pos = nx.spring_layout(G, seed=42)
    
    plt.figure(figsize=(10, 8))
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, alpha=0.3, width=1.0)
    
    # Draw nodes colored by class
    for label_value in np.unique(labels):
        nodelist = [i for i, l in enumerate(labels) if l == label_value]
        nx.draw_networkx_nodes(
            G, pos,
            nodelist=nodelist,
            node_color=[f'C{int(label_value)}'],
            node_size=node_size,
            alpha=0.8,
            edgecolors='black',
            linewidths=0.5,
            label=f'Class {int(label_value)}'
        )
    
    plt.title(title, fontsize=16)
    plt.legend()
    plt.axis('off')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    return pos

def sample_gaussian_process(laplacian, nu, kappa=1.0):
    """
    Sample from a Gaussian process with covariance (2*nu/kappa^2 + Lambda)^(-nu)
    
    Args:
        laplacian: Graph Laplacian matrix
        nu: Smoothness parameter
        kappa: Scaling parameter
    
    Returns:
        sample: A sample from the Gaussian process
    """
    # Calculate the operator Phi(Lambda) = (2*nu/kappa^2 + Lambda)^(-nu/2)
    # The covariance is Phi(Lambda)^(-2) = (2*nu/kappa^2 + Lambda)^(-nu)
    
    # Eigendecomposition of Laplacian
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # Compute the operator (2*nu/kappa^2 + Lambda)^(-nu/2)
    operator = torch.pow(2*nu/kappa**2 + eigvals, -nu/2)
    
    # Generate standard normal samples
    z = torch.randn(laplacian.shape[0])
    
    # Apply the operator to get samples from the desired Gaussian process
    sample = eigvecs @ (operator.unsqueeze(1) * z.unsqueeze(0)).T
    sample = sample.squeeze()
    
    return sample

def plot_gaussian_process_on_graph(G, labels, values, title, edge_weights=None, pos=None, 
                                   node_size=200, save_path=None):
    """
    Plot a Gaussian process on a graph with edge weights and node values
    
    Args:
        G: NetworkX graph
        labels: Node labels for class-based coloring
        values: Gaussian process values
        title: Title for the plot
        edge_weights: Optional edge weights (covariance)
        pos: Node positions
        node_size: Size of nodes
        save_path: Path to save the figure
    """
    if pos is None:
        pos = nx.spring_layout(G, seed=42)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # Process the edge weights and create colors
    if edge_weights is not None:
        edges = list(G.edges())
        edge_colors = [edge_weights.get(edge, edge_weights.get((edge[1], edge[0]), 0.1)) for edge in edges]
        
        # Normalize edge colors
        if edge_colors:
            edge_colors = np.array(edge_colors)
            edge_colors = (edge_colors - edge_colors.min()) / (edge_colors.max() - edge_colors.min() + 1e-10)
            
            edge_color_mappable = plt.cm.ScalarMappable(cmap=plt.cm.Blues)
            edge_color_mappable.set_array(edge_colors)
    
    # FIRST SUBPLOT: Nodes colored by class
    # Draw edges
    if edge_weights is not None:
        nx.draw_networkx_edges(
            G, pos, 
            edgelist=edges,
            width=2.0,
            edge_color=edge_colors,
            edge_cmap=plt.cm.Blues,
            alpha=0.7,
            ax=ax1
        )
        
        # Add edge colorbar
        edge_cbar = plt.colorbar(edge_color_mappable, ax=ax1, label='Edge Covariance', 
                                 shrink=0.75, pad=0.05)
        edge_cbar.ax.set_ylabel('Edge Covariance', fontsize=12)
    else:
        nx.draw_networkx_edges(G, pos, alpha=0.3, width=1.0, ax=ax1)
    
    # Draw nodes with class 0
    class0_nodes = [i for i, label in enumerate(labels) if label == 0]
    if class0_nodes:
        nx.draw_networkx_nodes(
            G, pos,
            nodelist=class0_nodes,
            node_color='blue',
            node_size=node_size,
            alpha=0.8,
            edgecolors='black',
            linewidths=0.5,
            label='Class 0',
            ax=ax1
        )
    
    # Draw nodes with class 1
    class1_nodes = [i for i, label in enumerate(labels) if label == 1]
    if class1_nodes:
        nx.draw_networkx_nodes(
            G, pos,
            nodelist=class1_nodes,
            node_color='red',
            node_size=node_size,
            alpha=0.8,
            edgecolors='black',
            linewidths=0.5,
            label='Class 1',
            ax=ax1
        )
    
    ax1.set_title(f"{title} - Nodes by Class", fontsize=16)
    ax1.legend()
    ax1.set_axis_off()
    
    # SECOND SUBPLOT: Nodes colored by Gaussian process values
    # Draw edges
    if edge_weights is not None:
        nx.draw_networkx_edges(
            G, pos, 
            edgelist=edges,
            width=2.0,
            edge_color=edge_colors,
            edge_cmap=plt.cm.Blues,
            alpha=0.7,
            ax=ax2
        )
        
        # Add edge colorbar
        edge_cbar2 = plt.colorbar(edge_color_mappable, ax=ax2, label='Edge Covariance', 
                                  shrink=0.75, pad=0.05)
        edge_cbar2.ax.set_ylabel('Edge Covariance', fontsize=12)
    else:
        nx.draw_networkx_edges(G, pos, alpha=0.3, width=1.0, ax=ax2)
    
    # Ensure values are numpy array
    if torch.is_tensor(values):
        values_np = values.numpy()
    else:
        values_np = np.array(values)
    
    # Create a list of node colors based on Gaussian process values
    vmin = np.min(values_np)
    vmax = np.max(values_np)
    
    # Normalize values to [0, 1]
    norm = plt.Normalize(vmin, vmax)
    normalized_values = norm(values_np)
    
    # Use 'coolwarm' colormap to generate colors
    cmap = plt.cm.coolwarm
    node_colors = [cmap(val) for val in normalized_values]
    
    # Draw nodes colored by Gaussian process values
    nodes = nx.draw_networkx_nodes(
        G, pos,
        node_color=node_colors,
        node_size=node_size,
        alpha=0.8,
        edgecolors='black',
        linewidths=0.5,
        ax=ax2
    )
    
    # Add colorbar for node values
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    node_cbar = plt.colorbar(sm, ax=ax2, label='Gaussian Process Value', 
                           shrink=0.75, pad=0.05)
    node_cbar.ax.set_ylabel('Gaussian Process Value', fontsize=12)
    
    ax2.set_title(f"{title} - Nodes by Gaussian Process", fontsize=16)
    ax2.set_axis_off()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def plot_covariance_matrix(laplacian, nu, kappa=1.0, title="Covariance Matrix", save_path=None):
    """
    Plot the covariance matrix for a graph Laplacian with given parameters
    
    Args:
        laplacian: Graph Laplacian
        nu: Smoothness parameter
        kappa: Scaling parameter
        title: Title for the plot
        save_path: Path to save the figure
    """
    # Eigendecomposition
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # Compute covariance matrix: (2*nu/kappa^2 + Lambda)^(-nu)
    covariance = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
    covariance_np = covariance.numpy()
    
    plt.figure(figsize=(10, 8))
    
    # Plot covariance matrix
    sns.heatmap(
        covariance_np,
        cmap='viridis',
        square=True,
        xticklabels=10,
        yticklabels=10
    )
    
    plt.title(f"{title} (nu = {nu})", fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    return covariance

def main():
    # Create results directory
    os.makedirs("results", exist_ok=True)
    
    # Parameters
    n_per_class = 30  # 30 nodes per class = 60 nodes total
    avg_degree = 5
    homophily_ratio = 0.8
    heterophily_ratio = 0.8
    nu_values = [0.01, 100]  # Two different values of nu
    kappa = 1.0
    
    # 1. Generate homophilic graph
    print("Generating homophilic graph...")
    homo_graph, homo_labels = generate_homophilic_graph(
        n_per_class=n_per_class, 
        avg_degree=avg_degree, 
        homophily_ratio=homophily_ratio
    )
    
    homo_score = calculate_homophily_score(homo_graph, homo_labels)
    print(f"Homophilic graph: {homo_graph.number_of_nodes()} nodes, {homo_graph.number_of_edges()} edges")
    print(f"Homophily score: {homo_score:.2f} (ratio of same-class edges)")
    
    # 2. Generate heterophilic graph
    print("\nGenerating heterophilic graph...")
    hetero_graph, hetero_labels = generate_heterophilic_graph(
        n_per_class=n_per_class, 
        avg_degree=avg_degree, 
        heterophily_ratio=heterophily_ratio
    )
    
    hetero_score = calculate_homophily_score(hetero_graph, hetero_labels)
    print(f"Heterophilic graph: {hetero_graph.number_of_nodes()} nodes, {hetero_graph.number_of_edges()} edges")
    print(f"Homophily score: {hetero_score:.2f} (ratio of same-class edges)")
    
    # 3. Plot the original graphs
    print("\nPlotting original graphs...")
    
    homo_pos = plot_graph_with_labels(
        homo_graph, 
        homo_labels, 
        f"Homophilic Graph (Homophily = {homo_score:.2f})",
        save_path="results/homophilic_graph.png"
    )
    
    hetero_pos = plot_graph_with_labels(
        hetero_graph, 
        hetero_labels, 
        f"Heterophilic Graph (Homophily = {hetero_score:.2f})",
        save_path="results/heterophilic_graph.png"
    )
    
    # 4. Compute graph Laplacians
    print("\nComputing graph Laplacians...")
    
    # Convert NetworkX graphs to edge_index format
    homo_edge_index = torch.tensor(list(homo_graph.edges())).T
    homo_edge_index = torch.cat([homo_edge_index, homo_edge_index.flip(0)], dim=1)
    
    hetero_edge_index = torch.tensor(list(hetero_graph.edges())).T
    hetero_edge_index = torch.cat([hetero_edge_index, hetero_edge_index.flip(0)], dim=1)
    
    # Calculate Laplacians
    homo_laplacian = calculate_graph_laplacian(
        homo_edge_index, 
        num_nodes=homo_graph.number_of_nodes()
    )
    
    hetero_laplacian = calculate_graph_laplacian(
        hetero_edge_index, 
        num_nodes=hetero_graph.number_of_nodes()
    )
    
    # 5. Visualize covariance matrices
    print("\nVisualizing covariance matrices...")
    
    for nu in nu_values:
        # Homophilic graph covariance
        homo_cov = plot_covariance_matrix(
            homo_laplacian, 
            nu, 
            kappa, 
            title=f"Homophilic Graph Covariance",
            save_path=f"results/homophilic_covariance_nu_{nu}.png"
        )
        
        # Heterophilic graph covariance
        hetero_cov = plot_covariance_matrix(
            hetero_laplacian, 
            nu, 
            kappa, 
            title=f"Heterophilic Graph Covariance",
            save_path=f"results/heterophilic_covariance_nu_{nu}.png"
        )
        
        # 6. Sample Gaussian processes
        print(f"\nSampling Gaussian processes with nu = {nu}...")
        
        # Homophilic graph
        homo_sample = sample_gaussian_process(homo_laplacian, nu, kappa)
        homo_edge_weights = compute_edge_covariance(homo_graph, homo_cov)
        
        # Heterophilic graph  
        hetero_sample = sample_gaussian_process(hetero_laplacian, nu, kappa)
        hetero_edge_weights = compute_edge_covariance(hetero_graph, hetero_cov)
        
        # 7. Visualize Gaussian processes
        print(f"Visualizing Gaussian processes with nu = {nu}...")
        
        # Homophilic graph
        plot_gaussian_process_on_graph(
            homo_graph, 
            homo_labels,
            homo_sample, 
            f"Homophilic Graph (nu = {nu})",
            edge_weights=homo_edge_weights,
            pos=homo_pos,
            save_path=f"results/homophilic_gaussian_nu_{nu}.png"
        )
        
        # Heterophilic graph
        plot_gaussian_process_on_graph(
            hetero_graph, 
            hetero_labels,
            hetero_sample, 
            f"Heterophilic Graph (nu = {nu})",
            edge_weights=hetero_edge_weights,
            pos=hetero_pos,
            save_path=f"results/heterophilic_gaussian_nu_{nu}.png"
        )

if __name__ == "__main__":
    main() 