import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from scipy import sparse
import seaborn as sns

# Add the current directory to sys.path to ensure imports work correctly
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# Import functions from visualization script
from visualize_gaussian_process_on_cora import (
    load_cora_data, 
    calculate_graph_laplacian,
    subsample_graph,
    compute_edge_covariance
)

def plot_covariance_matrices(laplacian, nu_values, kappa=1.0, save_path=None):
    """
    Plot covariance matrices for different values of nu
    
    Args:
        laplacian: Graph Laplacian matrix
        nu_values: List of nu values to compare
        kappa: Scaling parameter
        save_path: Path to save the figure
    """
    n_matrices = len(nu_values)
    fig, axes = plt.subplots(1, n_matrices, figsize=(n_matrices * 6, 5))
    
    if n_matrices == 1:
        axes = [axes]
        
    # Compute eigendecomposition once
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    max_val = 0
    min_val = float('inf')
    matrices = []
    
    # First pass to determine common color scale
    for i, nu in enumerate(nu_values):
        # Compute covariance matrix: (2*nu/kappa^2 + Lambda)^(-nu)
        covariance = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
        covariance_np = covariance.numpy()
        matrices.append(covariance_np)
        
        max_val = max(max_val, np.max(covariance_np))
        min_val = min(min_val, np.min(covariance_np))
    
    # Second pass to plot with common color scale
    for i, (nu, cov_matrix) in enumerate(zip(nu_values, matrices)):
        im = sns.heatmap(
            cov_matrix, 
            ax=axes[i], 
            cmap='viridis', 
            vmin=min_val, 
            vmax=max_val
        )
        axes[i].set_title(f'Covariance Matrix (nu = {nu})')
        
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
    return fig

def plot_covariance_decay(laplacian, nu_values, kappa=1.0, reference_node=0, top_n=20, node_labels=None, save_path=None):
    """
    Plot how covariance decays with graph distance for different values of nu
    
    Args:
        laplacian: Graph Laplacian matrix
        nu_values: List of nu values to compare
        kappa: Scaling parameter
        reference_node: Index of node to use as reference
        top_n: Number of highest covariance nodes to highlight
        node_labels: Optional node labels for coloring
        save_path: Path to save the figure
    """
    # Compute eigendecomposition once
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # Create figure
    fig, axes = plt.subplots(1, len(nu_values), figsize=(len(nu_values) * 7, 6))
    
    if len(nu_values) == 1:
        axes = [axes]
    
    # Compute shortest path distances for all nodes from reference node
    # First, convert Laplacian back to adjacency matrix (approximately)
    n = laplacian.shape[0]
    adj = torch.eye(n) - laplacian
    
    # Create NetworkX graph from adjacency matrix
    G = nx.from_numpy_array(adj.numpy())
    
    # Compute shortest paths from reference node
    try:
        path_lengths = nx.single_source_shortest_path_length(G, reference_node)
        distances = np.zeros(n)
        for node, dist in path_lengths.items():
            distances[node] = dist
    except nx.NetworkXError:
        # If reference node is isolated, use another node
        connected_nodes = list(G.nodes())
        for potential_ref in connected_nodes:
            try:
                path_lengths = nx.single_source_shortest_path_length(G, potential_ref)
                distances = np.zeros(n)
                for node, dist in path_lengths.items():
                    distances[node] = dist
                reference_node = potential_ref
                break
            except nx.NetworkXError:
                continue
    
    # Position nodes using spring layout
    pos = nx.spring_layout(G, seed=42)
    
    for i, nu in enumerate(nu_values):
        # Compute covariance matrix: (2*nu/kappa^2 + Lambda)^(-nu)
        covariance = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
        covariance_np = covariance.numpy()
        
        # Get covariance values from reference node to all other nodes
        ref_covariance = covariance_np[reference_node, :]
        
        # Sort nodes by covariance value
        sorted_indices = np.argsort(-ref_covariance)
        top_indices = sorted_indices[:top_n]
        
        # Draw edges with colors based on covariance
        edges = list(G.edges())
        edge_weights = {}
        for u, v in edges:
            edge_weights[(u, v)] = abs(covariance_np[u, v])
        
        edge_colors = [edge_weights.get(edge, edge_weights.get((edge[1], edge[0]), 0.1)) for edge in edges]
        
        # Normalize edge colors
        if edge_colors:
            edge_colors = np.array(edge_colors)
            edge_colors = (edge_colors - edge_colors.min()) / (edge_colors.max() - edge_colors.min() + 1e-10)
            
            # Draw edges with covariance-based colors
            nx.draw_networkx_edges(
                G, pos, 
                edgelist=edges,
                width=1.0,
                edge_color=edge_colors,
                edge_cmap=plt.cm.Blues,
                alpha=0.4,
                ax=axes[i]
            )
        else:
            nx.draw_networkx_edges(G, pos, alpha=0.2, ax=axes[i])
        
        # Draw reference node 
        nx.draw_networkx_nodes(
            G, pos, 
            nodelist=[reference_node], 
            node_color='red',
            node_size=300, 
            ax=axes[i]
        )
        
        # Draw other nodes with colors
        if node_labels is not None:
            # Use node labels for categorical coloring
            unique_labels = np.unique(node_labels)
            cmap_categorical = plt.cm.tab10
            
            # Draw nodes with categorical colors based on labels
            for l, label in enumerate(unique_labels):
                nodelist = [node for node in G.nodes() if node != reference_node and node_labels[node] == label]
                if nodelist:
                    nx.draw_networkx_nodes(
                        G, pos, 
                        nodelist=nodelist,
                        node_color=[cmap_categorical(l % 10)] * len(nodelist),
                        node_size=100,
                        alpha=0.9,
                        edgecolors='black',
                        linewidths=0.5,
                        ax=axes[i],
                        label=f'Class {int(label)}' if i == 0 else None  # Only add legend for first plot
                    )
            
            # Add legend only for the first plot to avoid duplication
            if i == 0:
                axes[i].legend(scatterpoints=1, title='Node Classes', loc='upper right', fontsize=8)
                
        else:
            # Use covariance values for coloring other nodes
            other_nodes = [n for n in G.nodes() if n != reference_node]
            node_colors = [ref_covariance[n] for n in other_nodes]
            
            vmin = np.min(ref_covariance)
            vmax = np.max(ref_covariance)
            
            nodes = nx.draw_networkx_nodes(
                G, pos, 
                nodelist=other_nodes,
                node_color=node_colors, 
                cmap='viridis',
                node_size=100,
                vmin=vmin,
                vmax=vmax,
                ax=axes[i]
            )
            
            # Add colorbar
            cbar = plt.colorbar(nodes, ax=axes[i], shrink=0.75)
            cbar.set_label('Covariance value')
        
        # Draw edges from reference node to top correlated nodes
        highlight_edges = [(reference_node, target) for target in top_indices if target != reference_node and G.has_edge(reference_node, target)]
        
        nx.draw_networkx_edges(
            G, pos,
            edgelist=highlight_edges,
            width=2,
            edge_color='red',
            ax=axes[i]
        )
        
        # Draw labels for reference node and top correlated nodes
        labels = {reference_node: 'Ref'}
        for j, target in enumerate(top_indices):
            if target != reference_node:
                labels[target] = f'{j+1}'
        
        nx.draw_networkx_labels(
            G, pos,
            labels=labels,
            font_size=10,
            font_color='black',
            ax=axes[i]
        )
        
        # Set title
        axes[i].set_title(f'Covariance Structure (nu = {nu})')
        axes[i].set_axis_off()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
    return fig

def plot_covariance_vs_distance(laplacian, nu_values, kappa=1.0, save_path=None):
    """
    Plot how covariance decays with graph distance 
    
    Args:
        laplacian: Graph Laplacian matrix
        nu_values: List of nu values to compare
        kappa: Scaling parameter
        save_path: Path to save the figure
    """
    # Compute eigendecomposition once
    eigvals, eigvecs = torch.linalg.eigh(laplacian)
    
    # First, convert Laplacian back to adjacency matrix (approximately)
    n = laplacian.shape[0]
    adj = torch.eye(n) - laplacian
    
    # Create NetworkX graph from adjacency matrix
    G = nx.from_numpy_array(adj.numpy())
    
    # Compute all-pairs shortest paths
    try:
        distances = dict(nx.all_pairs_shortest_path_length(G))
    except nx.NetworkXError:
        # Find largest connected component
        largest_cc = max(nx.connected_components(G), key=len)
        G_connected = G.subgraph(largest_cc).copy()
        distances = dict(nx.all_pairs_shortest_path_length(G_connected))
    
    # Create figure
    plt.figure(figsize=(10, 6))
    
    # Define distance bins
    max_distance = max(max(d.values()) for d in distances.values())
    distance_bins = list(range(max_distance + 1))
    
    for nu in nu_values:
        # Compute covariance matrix: (2*nu/kappa^2 + Lambda)^(-nu)
        covariance = eigvecs @ torch.diag(torch.pow(2*nu/kappa**2 + eigvals, -nu)) @ eigvecs.T
        covariance_np = covariance.numpy()
        
        # Collect (distance, covariance) pairs
        distance_cov_pairs = []
        
        for i in distances:
            for j in distances[i]:
                if i != j:  # Exclude self-covariances
                    distance_cov_pairs.append((distances[i][j], covariance_np[i, j]))
        
        # Group by distance
        grouped_by_distance = {}
        for dist, cov in distance_cov_pairs:
            if dist not in grouped_by_distance:
                grouped_by_distance[dist] = []
            grouped_by_distance[dist].append(cov)
        
        # Compute mean and std for each distance
        mean_covs = []
        std_covs = []
        
        for dist in distance_bins:
            if dist in grouped_by_distance:
                mean_covs.append(np.mean(grouped_by_distance[dist]))
                std_covs.append(np.std(grouped_by_distance[dist]))
            else:
                mean_covs.append(0)
                std_covs.append(0)
        
        # Plot
        plt.errorbar(
            distance_bins, 
            mean_covs, 
            yerr=std_covs,
            label=f'nu = {nu}',
            capsize=3,
            marker='o',
            markersize=8,
            linewidth=2
        )
    
    plt.xlabel('Graph Distance', fontsize=14)
    plt.ylabel('Average Covariance', fontsize=14)
    plt.title('Covariance Decay with Graph Distance', fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    plt.xticks(distance_bins)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return plt.gcf()

def main():
    # Create results directory
    os.makedirs("results", exist_ok=True)
    
    # Load Cora dataset
    print("Loading Cora dataset...")
    data = load_cora_data()
    print(f"Cora dataset has {data.num_nodes} nodes and {data.num_edges} edges")
    
    # Subsample the graph
    n_samples = 200  # Number of nodes to include in the subgraph
    print(f"Extracting a subgraph with {n_samples} nodes...")
    subgraph, node_map = subsample_graph(data, n_samples=n_samples)
    print(f"Subgraph has {subgraph.number_of_nodes()} nodes and {subgraph.number_of_edges()} edges")
    
    # Extract node labels for the subgraph
    original_labels = data.y
    subgraph_labels = torch.tensor([original_labels[node_map[i]].item() for i in range(subgraph.number_of_nodes())])
    
    # Convert subgraph to edge_index format for Laplacian computation
    edge_index = torch.tensor(list(subgraph.edges())).T
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)  # Make it bidirectional
    
    # Calculate graph Laplacian of the subgraph
    print("Computing graph Laplacian...")
    laplacian = calculate_graph_laplacian(edge_index, num_nodes=subgraph.number_of_nodes())
    
    # Set parameters
    kappa = 1.0
    nu_values = [0.01, 100]  # Two different values of nu
    
    # Plot covariance matrices
    print("Plotting covariance matrices...")
    save_path = "results/covariance_matrices_comparison.png"
    plot_covariance_matrices(laplacian, nu_values, kappa, save_path=save_path)
    print(f"Plot saved to {save_path}")
    
    # Plot covariance decay from a reference node with node class labels
    print("Plotting covariance structure with node labels...")
    save_path = "results/covariance_structure_labeled_comparison.png"
    plot_covariance_decay(laplacian, nu_values, kappa, reference_node=0, node_labels=subgraph_labels, save_path=save_path)
    print(f"Plot saved to {save_path}")
    
    # Plot covariance decay from a reference node without node labels (original style)
    print("Plotting covariance structure with covariance coloring...")
    save_path = "results/covariance_structure_comparison.png"
    plot_covariance_decay(laplacian, nu_values, kappa, reference_node=0, node_labels=None, save_path=save_path)
    print(f"Plot saved to {save_path}")
    
    # Plot covariance vs. distance relationship
    print("Plotting covariance vs. distance relationship...")
    save_path = "results/covariance_vs_distance.png"
    plot_covariance_vs_distance(laplacian, nu_values, kappa, save_path=save_path)
    print(f"Plot saved to {save_path}")

if __name__ == "__main__":
    main() 