import os
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

def generate_homophilic_graph(n_per_class=30, avg_degree=5, homophily_ratio=0.8):
    """Generate a homophilic graph where nodes prefer to connect to the same class"""
    n_nodes = n_per_class * 2
    
    # Create empty graph
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    
    # Assign labels: first n_per_class nodes are class 0, rest are class 1
    labels = np.zeros(n_nodes)
    labels[n_per_class:] = 1
    
    # Calculate number of edges needed for desired average degree
    n_edges = int(n_nodes * avg_degree / 2)
    
    # Add edges while maintaining homophily ratio
    edges_added = 0
    max_attempts = n_edges * 10  # Safeguard against infinite loops
    attempts = 0
    
    while edges_added < n_edges and attempts < max_attempts:
        # Pick a random node
        node_i = np.random.randint(0, n_nodes)
        label_i = labels[node_i]
        
        # Decide if this should be homophilic (same class) or heterophilic (different class) connection
        same_class = np.random.random() < homophily_ratio
        
        # Potential nodes to connect to
        if same_class:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] == label_i]
        else:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] != label_i]
        
        if potential_nodes:
            node_j = np.random.choice(potential_nodes)
            
            # Add edge if it doesn't already exist
            if not G.has_edge(node_i, node_j):
                G.add_edge(node_i, node_j)
                edges_added += 1
        
        attempts += 1
    
    return G, labels

def generate_heterophilic_graph(n_per_class=30, avg_degree=5, heterophily_ratio=0.8):
    """Generate a heterophilic graph where nodes prefer to connect to different classes"""
    n_nodes = n_per_class * 2
    
    # Create empty graph
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    
    # Assign labels: first n_per_class nodes are class 0, rest are class 1
    labels = np.zeros(n_nodes)
    labels[n_per_class:] = 1
    
    # Calculate number of edges needed for desired average degree
    n_edges = int(n_nodes * avg_degree / 2)
    
    # Add edges while maintaining heterophily ratio
    edges_added = 0
    max_attempts = n_edges * 10  # Safeguard against infinite loops
    attempts = 0
    
    while edges_added < n_edges and attempts < max_attempts:
        # Pick a random node
        node_i = np.random.randint(0, n_nodes)
        label_i = labels[node_i]
        
        # Decide if this should be heterophilic (different class) or homophilic (same class) connection
        different_class = np.random.random() < heterophily_ratio
        
        # Potential nodes to connect to
        if different_class:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] != label_i]
        else:
            potential_nodes = [j for j in range(n_nodes) if j != node_i and labels[j] == label_i]
        
        if potential_nodes:
            node_j = np.random.choice(potential_nodes)
            
            # Add edge if it doesn't already exist
            if not G.has_edge(node_i, node_j):
                G.add_edge(node_i, node_j)
                edges_added += 1
        
        attempts += 1
    
    return G, labels

def calculate_homophily_score(G, labels):
    """Calculate the homophily score of a graph (ratio of same-class edges)"""
    same_class_edges = 0
    total_edges = 0
    
    for u, v in G.edges():
        if labels[u] == labels[v]:
            same_class_edges += 1
        total_edges += 1
    
    return same_class_edges / total_edges if total_edges > 0 else 0

def calculate_laplacian(G):
    """Calculate the normalized graph Laplacian"""
    # Convert to adjacency matrix
    A = nx.to_numpy_array(G)
    
    # Compute degree matrix
    D = np.diag(np.sum(A, axis=1))
    
    # Compute D^(-1/2)
    D_inv_sqrt = np.linalg.inv(np.sqrt(D + 1e-10 * np.eye(D.shape[0])))
    
    # Normalized Laplacian: I - D^(-1/2) A D^(-1/2)
    L = np.eye(A.shape[0]) - D_inv_sqrt @ A @ D_inv_sqrt
    
    return torch.tensor(L, dtype=torch.float32)

def compute_covariance(laplacian, nu, kappa=1.0):
    """Compute the covariance matrix (2*nu/kappa^2 + Lambda)^(-nu)"""
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    cov = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
    return cov.numpy()

def sample_gaussian_process(laplacian, nu, kappa=1.0):
    """Sample from the Gaussian process"""
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # Compute the operator (2*nu/kappa^2 + Lambda)^(-nu/2)
    operator = torch.pow(2*nu/kappa**2 + eigvals, -nu/2)
    
    # Generate standard normal samples
    z = torch.randn(laplacian.shape[0])
    
    # Apply the operator to get samples from the desired Gaussian process
    sample = eigvecs @ (operator.unsqueeze(1) * z.unsqueeze(0)).T
    sample = sample.squeeze().numpy()
    
    return sample

def plot_original_graphs(homo_G, homo_labels, hetero_G, hetero_labels, homo_score, hetero_score):
    """Plot the original homophilic and heterophilic graphs"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    # Common layout algorithm parameters
    layout_params = dict(k=0.3, iterations=50, seed=42)
    
    # Get positions for nodes
    homo_pos = nx.spring_layout(homo_G, **layout_params)
    hetero_pos = nx.spring_layout(hetero_G, **layout_params)
    
    # Plot homophilic graph
    class0_nodes = [i for i, label in enumerate(homo_labels) if label == 0]
    class1_nodes = [i for i, label in enumerate(homo_labels) if label == 1]
    
    nx.draw_networkx_edges(homo_G, homo_pos, ax=ax1, alpha=0.5, width=1.0)
    nx.draw_networkx_nodes(homo_G, homo_pos, nodelist=class0_nodes, ax=ax1, 
                         node_color='blue', node_size=80, alpha=0.8, label='Class 0')
    nx.draw_networkx_nodes(homo_G, homo_pos, nodelist=class1_nodes, ax=ax1, 
                         node_color='red', node_size=80, alpha=0.8, label='Class 1')
    
    ax1.set_title(f"Homophilic Graph (Homophily = {homo_score:.2f})", fontsize=14)
    ax1.legend(fontsize=12)
    ax1.set_axis_off()
    
    # Plot heterophilic graph
    class0_nodes = [i for i, label in enumerate(hetero_labels) if label == 0]
    class1_nodes = [i for i, label in enumerate(hetero_labels) if label == 1]
    
    nx.draw_networkx_edges(hetero_G, hetero_pos, ax=ax2, alpha=0.5, width=1.0)
    nx.draw_networkx_nodes(hetero_G, hetero_pos, nodelist=class0_nodes, ax=ax2, 
                         node_color='blue', node_size=80, alpha=0.8, label='Class 0')
    nx.draw_networkx_nodes(hetero_G, hetero_pos, nodelist=class1_nodes, ax=ax2, 
                         node_color='red', node_size=80, alpha=0.8, label='Class 1')
    
    ax2.set_title(f"Heterophilic Graph (Homophily = {hetero_score:.2f})", fontsize=14)
    ax2.legend(fontsize=12)
    ax2.set_axis_off()
    
    plt.tight_layout()
    plt.savefig('results/toy_graphs_original.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return homo_pos, hetero_pos

def plot_covariance_matrices(homo_cov, hetero_cov, nu):
    """Plot the covariance matrices"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    # Set common scale
    vmin = min(homo_cov.min(), hetero_cov.min())
    vmax = max(homo_cov.max(), hetero_cov.max())
    
    # Plot homophilic covariance
    im1 = sns.heatmap(homo_cov, ax=ax1, cmap='viridis', square=True, vmin=vmin, vmax=vmax)
    ax1.set_title(f"Homophilic Graph Covariance (nu = {nu})", fontsize=14)
    
    # Plot heterophilic covariance
    im2 = sns.heatmap(hetero_cov, ax=ax2, cmap='viridis', square=True, vmin=vmin, vmax=vmax)
    ax2.set_title(f"Heterophilic Graph Covariance (nu = {nu})", fontsize=14)
    
    plt.tight_layout()
    plt.savefig(f'results/toy_graphs_covariance_nu_{nu}.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_gaussian_processes(homo_G, homo_labels, homo_pos, homo_sample, homo_cov, 
                           hetero_G, hetero_labels, hetero_pos, hetero_sample, hetero_cov, nu):
    """Plot the Gaussian process samples on the graphs"""
    # Create figures with 2 rows (top: nodes by class, bottom: nodes by GP)
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))
    
    # Common settings
    node_size = 100
    
    # Process edge weights for visualization
    def compute_edge_weights(G, cov_matrix):
        weights = {}
        for u, v in G.edges():
            weights[(u, v)] = abs(cov_matrix[u, v])
        return weights
    
    homo_edge_weights = compute_edge_weights(homo_G, homo_cov)
    hetero_edge_weights = compute_edge_weights(hetero_G, hetero_cov)
    
    # Normalize edge weights
    def normalize_weights(weights):
        if not weights:
            return weights
        values = list(weights.values())
        min_val, max_val = min(values), max(values)
        scale = max_val - min_val if max_val > min_val else 1.0
        return {k: (v - min_val) / scale for k, v in weights.items()}
    
    homo_edge_weights_norm = normalize_weights(homo_edge_weights)
    hetero_edge_weights_norm = normalize_weights(hetero_edge_weights)
    
    # Row 1: Nodes colored by class, edges by covariance
    # Homophilic graph (top left)
    ax = axes[0, 0]
    
    # Draw edges with covariance coloring
    for (u, v), weight in homo_edge_weights_norm.items():
        ax.plot([homo_pos[u][0], homo_pos[v][0]], [homo_pos[u][1], homo_pos[v][1]], 
                color=plt.cm.Blues(weight), alpha=0.7, linewidth=1.5)
    
    # Draw nodes by class
    class0_nodes = [i for i, label in enumerate(homo_labels) if label == 0]
    class1_nodes = [i for i, label in enumerate(homo_labels) if label == 1]
    
    for node in class0_nodes:
        ax.scatter(homo_pos[node][0], homo_pos[node][1], c='blue', s=node_size, 
                  alpha=0.8, edgecolors='black', linewidths=0.5)
    
    for node in class1_nodes:
        ax.scatter(homo_pos[node][0], homo_pos[node][1], c='red', s=node_size, 
                  alpha=0.8, edgecolors='black', linewidths=0.5)
    
    ax.set_title(f"Homophilic Graph (nu = {nu})\nNodes by Class", fontsize=14)
    ax.set_axis_off()
    
    # Add custom legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Class 0'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Class 1')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    # Heterophilic graph (top right)
    ax = axes[0, 1]
    
    # Draw edges with covariance coloring
    for (u, v), weight in hetero_edge_weights_norm.items():
        ax.plot([hetero_pos[u][0], hetero_pos[v][0]], [hetero_pos[u][1], hetero_pos[v][1]], 
                color=plt.cm.Blues(weight), alpha=0.7, linewidth=1.5)
    
    # Draw nodes by class
    class0_nodes = [i for i, label in enumerate(hetero_labels) if label == 0]
    class1_nodes = [i for i, label in enumerate(hetero_labels) if label == 1]
    
    for node in class0_nodes:
        ax.scatter(hetero_pos[node][0], hetero_pos[node][1], c='blue', s=node_size, 
                  alpha=0.8, edgecolors='black', linewidths=0.5)
    
    for node in class1_nodes:
        ax.scatter(hetero_pos[node][0], hetero_pos[node][1], c='red', s=node_size, 
                  alpha=0.8, edgecolors='black', linewidths=0.5)
    
    ax.set_title(f"Heterophilic Graph (nu = {nu})\nNodes by Class", fontsize=14)
    ax.set_axis_off()
    ax.legend(handles=legend_elements, loc='upper right')
    
    # Add edge colorbar
    sm_edge = ScalarMappable(cmap=plt.cm.Blues)
    sm_edge.set_array([])
    cbar_edge = plt.colorbar(sm_edge, ax=axes[0, :], location='bottom', 
                           shrink=0.6, pad=0.05, label='Edge Covariance')
    
    # Row 2: Nodes colored by Gaussian process values, edges by covariance
    # Homophilic graph (bottom left)
    ax = axes[1, 0]
    
    # Draw edges with covariance coloring
    for (u, v), weight in homo_edge_weights_norm.items():
        ax.plot([homo_pos[u][0], homo_pos[v][0]], [homo_pos[u][1], homo_pos[v][1]], 
                color=plt.cm.Blues(weight), alpha=0.7, linewidth=1.5)
    
    # Normalize GP values for coloring
    homo_norm = Normalize(vmin=homo_sample.min(), vmax=homo_sample.max())
    
    # Draw nodes colored by GP values
    for node in range(len(homo_sample)):
        ax.scatter(homo_pos[node][0], homo_pos[node][1], 
                  c=plt.cm.coolwarm(homo_norm(homo_sample[node])), 
                  s=node_size, alpha=0.8, edgecolors='black', linewidths=0.5)
    
    ax.set_title(f"Homophilic Graph (nu = {nu})\nNodes by Gaussian Process", fontsize=14)
    ax.set_axis_off()
    
    # Heterophilic graph (bottom right)
    ax = axes[1, 1]
    
    # Draw edges with covariance coloring
    for (u, v), weight in hetero_edge_weights_norm.items():
        ax.plot([hetero_pos[u][0], hetero_pos[v][0]], [hetero_pos[u][1], hetero_pos[v][1]], 
                color=plt.cm.Blues(weight), alpha=0.7, linewidth=1.5)
    
    # Normalize GP values for coloring
    hetero_norm = Normalize(vmin=hetero_sample.min(), vmax=hetero_sample.max())
    
    # Draw nodes colored by GP values
    for node in range(len(hetero_sample)):
        ax.scatter(hetero_pos[node][0], hetero_pos[node][1], 
                  c=plt.cm.coolwarm(hetero_norm(hetero_sample[node])), 
                  s=node_size, alpha=0.8, edgecolors='black', linewidths=0.5)
    
    ax.set_title(f"Heterophilic Graph (nu = {nu})\nNodes by Gaussian Process", fontsize=14)
    ax.set_axis_off()
    
    # Add GP value colorbar
    sm_gp = ScalarMappable(cmap=plt.cm.coolwarm, norm=Normalize(
        vmin=min(homo_sample.min(), hetero_sample.min()),
        vmax=max(homo_sample.max(), hetero_sample.max())
    ))
    sm_gp.set_array([])
    cbar_gp = plt.colorbar(sm_gp, ax=axes[1, :], location='bottom', 
                          shrink=0.6, pad=0.05, label='Gaussian Process Value')
    
    plt.tight_layout()
    plt.savefig(f'results/toy_graphs_gaussian_nu_{nu}.png', dpi=300, bbox_inches='tight')
    plt.show()

def main():
    # Create results directory
    os.makedirs("results", exist_ok=True)
    
    # Parameters
    n_per_class = 30  # 30 nodes per class = 60 nodes total
    avg_degree = 5
    homophily_ratio = 0.8
    heterophily_ratio = 0.8
    nu_values = [0.01, 100]  # Two different values of nu
    kappa = 1.0
    
    # 1. Generate homophilic graph
    print("Generating homophilic graph...")
    homo_G, homo_labels = generate_homophilic_graph(
        n_per_class=n_per_class, 
        avg_degree=avg_degree, 
        homophily_ratio=homophily_ratio
    )
    
    homo_score = calculate_homophily_score(homo_G, homo_labels)
    print(f"Homophilic graph: {homo_G.number_of_nodes()} nodes, {homo_G.number_of_edges()} edges")
    print(f"Homophily score: {homo_score:.2f} (ratio of same-class edges)")
    
    # 2. Generate heterophilic graph
    print("\nGenerating heterophilic graph...")
    hetero_G, hetero_labels = generate_heterophilic_graph(
        n_per_class=n_per_class, 
        avg_degree=avg_degree, 
        heterophily_ratio=heterophily_ratio
    )
    
    hetero_score = calculate_homophily_score(hetero_G, hetero_labels)
    print(f"Heterophilic graph: {hetero_G.number_of_nodes()} nodes, {hetero_G.number_of_edges()} edges")
    print(f"Homophily score: {hetero_score:.2f} (ratio of same-class edges)")
    
    # 3. Plot original graphs
    print("\nPlotting original graphs...")
    homo_pos, hetero_pos = plot_original_graphs(
        homo_G, homo_labels, hetero_G, hetero_labels, homo_score, hetero_score
    )
    
    # 4. Compute graph Laplacians
    print("\nComputing graph Laplacians...")
    homo_laplacian = calculate_laplacian(homo_G)
    hetero_laplacian = calculate_laplacian(hetero_G)
    
    # 5. Process each value of nu
    for nu in nu_values:
        print(f"\nProcessing nu = {nu}...")
        
        # Compute covariance matrices
        homo_cov = compute_covariance(homo_laplacian, nu, kappa)
        hetero_cov = compute_covariance(hetero_laplacian, nu, kappa)
        
        # Plot covariance matrices
        plot_covariance_matrices(homo_cov, hetero_cov, nu)
        
        # Sample from Gaussian processes
        homo_sample = sample_gaussian_process(homo_laplacian, nu, kappa)
        hetero_sample = sample_gaussian_process(hetero_laplacian, nu, kappa)
        
        # Visualize Gaussian processes on graphs
        plot_gaussian_processes(
            homo_G, homo_labels, homo_pos, homo_sample, homo_cov,
            hetero_G, hetero_labels, hetero_pos, hetero_sample, hetero_cov,
            nu
        )

if __name__ == "__main__":
    main() 