import os
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from scipy import sparse
from sklearn.preprocessing import normalize

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

def load_cora_data(data_dir="./data"):
    """Load the Cora dataset"""
    transform = T.NormalizeFeatures()
    dataset = Planetoid(root=f'{data_dir}/Planetoid', 
                        name='Cora', 
                        transform=transform)
    data = dataset[0]
    return data

def calculate_graph_laplacian(edge_index, num_nodes=None, normalized=True):
    """
    Calculate the graph Laplacian matrix.
    
    Args:
        edge_index: PyTorch tensor of shape [2, num_edges] containing edges
        num_nodes: Number of nodes in the graph (optional, inferred if not provided)
        normalized: Whether to return the normalized Laplacian (default: True)
    
    Returns:
        laplacian: Graph Laplacian matrix (dense PyTorch tensor)
    """
    import scipy.sparse as sp
    
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = edge_index.max().item() + 1
    
    # Convert edge_index to sparse adjacency matrix
    edge_index_np = edge_index.cpu().numpy()
    rows, cols = edge_index_np
    data = np.ones(len(rows))
    
    # Create sparse adjacency matrix
    adj_matrix = sp.csr_matrix((data, (rows, cols)), shape=(num_nodes, num_nodes))
    
    # Make the graph undirected (symmetric adjacency)
    adj_matrix = adj_matrix.maximum(adj_matrix.T)
    
    # Compute degree matrix
    degrees = np.array(adj_matrix.sum(axis=1)).flatten()
    degree_matrix = sp.diags(degrees)
    
    if normalized:
        # Compute D^(-1/2)
        with np.errstate(divide='ignore'):
            d_inv_sqrt = np.power(degrees + 1e-12, -0.5)
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
        
        # Normalized Laplacian: I - D^(-1/2) A D^(-1/2)
        normalized_adj = d_inv_sqrt_mat @ adj_matrix @ d_inv_sqrt_mat
        identity = sp.eye(num_nodes)
        laplacian = identity - normalized_adj
    else:
        # Standard Laplacian: D - A
        laplacian = degree_matrix - adj_matrix
    
    # Convert to dense PyTorch tensor
    laplacian_dense = torch.tensor(laplacian.todense(), dtype=torch.float32)
    return laplacian_dense

def sample_gaussian_process(laplacian, nu, kappa=1.0):
    """
    Sample from a Gaussian process with covariance (2*nu/kappa^2 + Lambda)^(-nu)
    
    Args:
        laplacian: Graph Laplacian matrix
        nu: Smoothness parameter
        kappa: Scaling parameter
    
    Returns:
        sample: A sample from the Gaussian process
    """
    # Calculate the operator Phi(Lambda) = (2*nu/kappa^2 + Lambda)^(-nu/2)
    # The covariance is Phi(Lambda)^(-2) = (2*nu/kappa^2 + Lambda)^(-nu)
    
    # Eigendecomposition of Laplacian
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # Compute the operator (2*nu/kappa^2 + Lambda)^(-nu/2)
    operator = torch.pow(2*nu/kappa**2 + eigvals, -nu/2)
    
    # Generate standard normal samples
    z = torch.randn(laplacian.shape[0])
    
    # Apply the operator to get samples from the desired Gaussian process
    sample = eigvecs @ (operator.unsqueeze(1) * z.unsqueeze(0)).T
    sample = sample.squeeze()
    
    return sample

def subsample_graph(data, n_samples=200):
    """
    Extract a connected subgraph by sampling from the original graph
    
    Args:
        data: PyG data object
        n_samples: Number of nodes to include in the subgraph
    
    Returns:
        subgraph: NetworkX graph object of the subgraph
        node_map: Mapping from new node indices to original indices
    """
    # Convert to NetworkX graph
    G = nx.Graph()
    edge_index = data.edge_index.numpy()
    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i])
    
    # Find largest connected component
    largest_cc = max(nx.connected_components(G), key=len)
    
    if len(largest_cc) <= n_samples:
        subgraph_nodes = list(largest_cc)
    else:
        # Start with a random node from the largest connected component
        start_node = np.random.choice(list(largest_cc))
        # Use BFS to grow the subgraph
        subgraph_nodes = list(nx.bfs_tree(G, start_node, depth_limit=100))[:n_samples]
    
    # Create the subgraph
    subgraph = G.subgraph(subgraph_nodes).copy()
    
    # Relabel nodes to be consecutive integers
    subgraph = nx.convert_node_labels_to_integers(subgraph)
    
    # Create mapping from new indices to original indices
    old_to_new = {old: new for new, old in enumerate(subgraph_nodes)}
    node_map = {new: old for new, old in enumerate(subgraph_nodes)}
    
    return subgraph, node_map

def plot_gaussian_field_on_graph(G, values, title, edge_weights=None, node_labels=None, cmap='coolwarm', 
                                 node_size=100, edge_width=1.0, save_path=None):
    """
    Plot a scalar field on a graph
    
    Args:
        G: NetworkX graph
        values: Node values (e.g., Gaussian process sample)
        title: Plot title
        edge_weights: Optional edge weights for visualizing covariance
        node_labels: Optional labels for nodes for color coding
        cmap: Colormap for node values
        node_size: Size of nodes
        edge_width: Width of edges
        save_path: Path to save the figure
    """
    # Create figure and axes
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Compute positions using spring layout
    pos = nx.spring_layout(G, seed=42)
    
    # Draw edges with colors based on covariance
    if edge_weights is not None:
        edges = list(G.edges())
        edge_colors = [edge_weights.get(edge, edge_weights.get((edge[1], edge[0]), 0.1)) for edge in edges]
        
        # Normalize edge colors to be between 0 and 1
        if len(edge_colors) > 0:
            edge_colors = np.array(edge_colors)
            edge_colors = (edge_colors - edge_colors.min()) / (edge_colors.max() - edge_colors.min() + 1e-10)
            
            # Add a separate scatter plot for the edge color legend
            edge_color_mappable = plt.cm.ScalarMappable(cmap=plt.cm.Blues)
            edge_color_mappable.set_array(edge_colors)
            
            # Draw edges with color based on covariance
            edges_drawn = nx.draw_networkx_edges(
                G, pos, 
                edgelist=edges,
                width=edge_width,
                edge_color=edge_colors,
                edge_cmap=plt.cm.Blues,
                alpha=0.7,
                ax=ax
            )
            
            # Add edge colorbar
            edge_cbar = plt.colorbar(edge_color_mappable, ax=ax, label='Edge Covariance', shrink=0.75, pad=0.05)
            edge_cbar.ax.set_ylabel('Edge Covariance', fontsize=12)
    else:
        nx.draw_networkx_edges(G, pos, width=edge_width, alpha=0.3, ax=ax)
    
    # Draw nodes based on labels if provided, otherwise use Gaussian process values
    if node_labels is not None:
        # Convert tensor to numpy array if needed
        if torch.is_tensor(node_labels):
            labels_np = node_labels.numpy()
        else:
            labels_np = np.array(node_labels)
        
        # Ensure labels are appropriate size
        if len(labels_np) != G.number_of_nodes():
            if len(labels_np) > G.number_of_nodes():
                labels_np = labels_np[:G.number_of_nodes()]
            else:
                labels_np = np.pad(labels_np, (0, G.number_of_nodes() - len(labels_np)))
        
        # Get unique labels for categorical coloring
        unique_labels = np.unique(labels_np)
        cmap_categorical = plt.cm.tab10
        
        # Draw nodes with categorical colors based on labels
        for i, label in enumerate(unique_labels):
            nodelist = [node for node, node_label in enumerate(labels_np) if node_label == label]
            if nodelist:
                nx.draw_networkx_nodes(
                    G, pos,
                    nodelist=nodelist,
                    node_color=[cmap_categorical(i % 10)] * len(nodelist),
                    node_size=node_size,
                    alpha=0.9,
                    edgecolors='black',
                    linewidths=0.5,
                    label=f'Class {int(label)}',
                    ax=ax
                )
        
        # Add legend for node classes
        ax.legend(scatterpoints=1, title='Node Classes', fontsize=10)
        
    else:
        # Use Gaussian process values for node colors
        # Convert tensor to numpy array if needed
        if torch.is_tensor(values):
            values_np = values.numpy()
        else:
            values_np = np.array(values)
        
        # Check if the shape of values matches the number of nodes
        if values_np.ndim > 1:
            # If multi-dimensional, flatten it to 1D
            values_np = values_np.flatten()
        
        # Ensure length matches
        if len(values_np) != G.number_of_nodes():
            # Either truncate or pad with zeros
            if len(values_np) > G.number_of_nodes():
                values_np = values_np[:G.number_of_nodes()]
            else:
                values_np = np.pad(values_np, (0, G.number_of_nodes() - len(values_np)))
        
        # Draw nodes colored by Gaussian process values
        nodes = nx.draw_networkx_nodes(
            G, pos,
            node_color=values_np,
            cmap=cmap,
            node_size=node_size,
            alpha=0.9,
            edgecolors='black',
            linewidths=0.5,
            ax=ax
        )
        
        # Add a colorbar for node values
        node_cbar = plt.colorbar(nodes, ax=ax, label='Gaussian Process Value', shrink=0.75, pad=0.05)
        node_cbar.ax.set_ylabel('Gaussian Process Value', fontsize=12)
    
    ax.set_title(title, fontsize=16)
    ax.set_axis_off()
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def compute_edge_covariance(G, cov_matrix):
    """
    Compute edge weights based on covariance between nodes
    
    Args:
        G: NetworkX graph
        cov_matrix: Covariance matrix
    
    Returns:
        edge_weights: Dictionary mapping edges to covariance values
    """
    edge_weights = {}
    for u, v in G.edges():
        edge_weights[(u, v)] = abs(cov_matrix[u, v].item())
    return edge_weights

def main():
    # Create results directory
    os.makedirs("results", exist_ok=True)
    
    # Load Cora dataset
    print("Loading Cora dataset...")
    data = load_cora_data()
    print(f"Cora dataset has {data.num_nodes} nodes and {data.num_edges} edges")
    
    # Subsample the graph
    n_samples = 200  # Number of nodes to include in the subgraph
    print(f"Extracting a subgraph with {n_samples} nodes...")
    subgraph, node_map = subsample_graph(data, n_samples=n_samples)
    print(f"Subgraph has {subgraph.number_of_nodes()} nodes and {subgraph.number_of_edges()} edges")
    
    # Extract node labels for the subgraph
    original_labels = data.y
    subgraph_labels = torch.tensor([original_labels[node_map[i]].item() for i in range(subgraph.number_of_nodes())])
    
    # Convert subgraph to edge_index format for Laplacian computation
    edge_index = torch.tensor(list(subgraph.edges())).T
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)  # Make it bidirectional
    
    # Calculate graph Laplacian of the subgraph
    print("Computing graph Laplacian...")
    laplacian = calculate_graph_laplacian(edge_index, num_nodes=subgraph.number_of_nodes())
    
    # Set parameters
    kappa = 1.0
    nu_values = [0.01, 100]  # Two different values of nu
    
    for nu in nu_values:
        print(f"Processing nu = {nu}...")
        
        # Compute the precision matrix: (2*nu/kappa^2 + Lambda)^(nu)
        eigvals, eigvecs = torch.linalg.eigh(laplacian)
        precision = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, nu)) @ eigvecs.T
        
        # Compute the covariance matrix: (2*nu/kappa^2 + Lambda)^(-nu)
        covariance = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
        
        # Sample from the Gaussian process
        sample = sample_gaussian_process(laplacian, nu, kappa)
        
        # Compute edge weights based on covariance
        edge_weights = compute_edge_covariance(subgraph, covariance)
        
        # Plot the Gaussian random field with nodes colored by class label
        title = f"Gaussian Process on Cora Graph (nu = {nu})"
        save_path = f"results/cora_gaussian_process_nu_{nu}.png"
        
        plot_gaussian_field_on_graph(
            subgraph, 
            sample, 
            title, 
            edge_weights=edge_weights,
            node_labels=subgraph_labels,
            cmap='coolwarm',
            node_size=100,
            save_path=save_path
        )
        
        # Also create a version showing just the Gaussian Process values (without labels)
        title = f"Gaussian Process Values on Cora Graph (nu = {nu})"
        save_path = f"results/cora_gaussian_process_values_nu_{nu}.png"
        
        plot_gaussian_field_on_graph(
            subgraph, 
            sample, 
            title, 
            edge_weights=edge_weights,
            node_labels=None,  # No labels, color by Gaussian process values
            cmap='coolwarm',
            node_size=100,
            save_path=save_path
        )
        
        print(f"Plots saved to results/cora_gaussian_process_nu_{nu}.png and results/cora_gaussian_process_values_nu_{nu}.png")

if __name__ == "__main__":
    main() 