import torch
from torch_geometric.utils import to_dense_adj
import numpy as np

def calculate_class_heterophily(data):
    """
    Calculate heterophily factor for each class in the graph.
    
    Args:
        data: PyG Data object containing edge_index and y (node labels)
        
    Returns:
        class_heterophily: Dictionary mapping class labels to their heterophily scores
        overall_heterophily: The overall heterophily score of the graph
    """
    # Extract edge index and node labels
    edge_index = data.edge_index
    labels = data.y.squeeze()
    
    # Get number of nodes and classes
    num_nodes = data.num_nodes
    if labels.dim() > 1 and labels.shape[1] > 1:
        # For multi-label classification, use argmax
        num_classes = labels.shape[1]
        labels = labels.argmax(dim=1)
    else:
        # For single-label classification
        num_classes = int(labels.max()) + 1
    
    # Convert edge index to adjacency matrix
    adj = to_dense_adj(edge_index)[0]
    
    # Initialize counters for each class
    same_class_edges = torch.zeros(num_classes, device=edge_index.device)
    total_edges = torch.zeros(num_classes, device=edge_index.device)
    
    # For each class, calculate edges
    class_heterophily = {}
    for c in range(num_classes):
        # Get nodes of this class
        class_mask = (labels == c)
        
        if not class_mask.any():
            # Skip empty classes
            class_heterophily[int(c)] = 0.0
            continue
        
        # Get submatrix of adjacency for nodes of this class
        class_nodes = torch.where(class_mask)[0]
        
        # Count edges for these nodes
        for node in class_nodes:
            # Get neighbors
            neighbors = torch.where(adj[node] > 0)[0]
            
            # Count total edges
            node_edges = len(neighbors)
            total_edges[c] += node_edges
            
            # Count edges to same class
            same_class = torch.sum(labels[neighbors] == c).item()
            same_class_edges[c] += same_class
    
    # Calculate heterophily for each class (1 - homophily)
    for c in range(num_classes):
        if total_edges[c] > 0:
            homophily = same_class_edges[c] / total_edges[c]
            class_heterophily[int(c)] = 1.0 - homophily.item()
        else:
            class_heterophily[int(c)] = 0.0
    
    # Calculate overall heterophily
    if torch.sum(total_edges) > 0:
        overall_homophily = torch.sum(same_class_edges) / torch.sum(total_edges)
        overall_heterophily = 1.0 - overall_homophily.item()
    else:
        overall_heterophily = 0.0
    
    return class_heterophily, overall_heterophily


def calculate_class_heterophily_efficient(data):
    """
    Calculate heterophily factor for each class in the graph more efficiently.
    
    Args:
        data: PyG Data object containing edge_index and y (node labels)
        
    Returns:
        class_heterophily: Dictionary mapping class labels to their heterophily scores
        overall_heterophily: The overall heterophily score of the graph
    """
    # Extract edge index and node labels
    edge_index = data.edge_index
    labels = data.y.squeeze()
    
    # Get number of classes
    if labels.dim() > 1 and labels.shape[1] > 1:
        # For multi-label classification, use argmax
        num_classes = labels.shape[1]
        labels = labels.argmax(dim=1)
    else:
        # For single-label classification
        num_classes = int(labels.max()) + 1
    
    # Get source and target nodes for each edge
    src, dst = edge_index[0], edge_index[1]
    
    # Get labels for the source and target nodes
    src_labels = labels[src]
    dst_labels = labels[dst]
    
    # Initialize counters for each class
    same_class_edges = torch.zeros(num_classes, device=edge_index.device)
    total_edges = torch.zeros(num_classes, device=edge_index.device)
    
    # For each class, count edges
    for c in range(num_classes):
        # Identify edges where source node is of class c
        class_edges = (src_labels == c)
        total_edges[c] = torch.sum(class_edges)
        
        # Count how many of these edges connect to the same class
        same_class = torch.sum((src_labels == c) & (dst_labels == c))
        same_class_edges[c] = same_class
    
    # Calculate heterophily for each class (1 - homophily)
    class_heterophily = {}
    for c in range(num_classes):
        if total_edges[c] > 0:
            homophily = same_class_edges[c] / total_edges[c]
            class_heterophily[int(c)] = 1.0 - homophily.item()
        else:
            class_heterophily[int(c)] = 0.0
    
    # Calculate overall heterophily
    if torch.sum(total_edges) > 0:
        overall_homophily = torch.sum(same_class_edges) / torch.sum(total_edges)
        overall_heterophily = 1.0 - overall_homophily.item()
    else:
        overall_heterophily = 0.0
    
    return class_heterophily, overall_heterophily


def analyze_graph_heterophily(data):
    """
    Analyze heterophily properties of a graph dataset and print statistics.
    
    Args:
        data: PyG Data object containing edge_index and y (node labels)
        
    Returns:
        dict: Statistics about the graph's heterophily
    """
    class_heterophily, overall_heterophily = calculate_class_heterophily_efficient(data)
    
    # Calculate statistics
    class_values = list(class_heterophily.values())
    stats = {
        'overall_heterophily': overall_heterophily,
        'class_heterophily': class_heterophily,
        'min_heterophily': min(class_values) if class_values else 0.0,
        'max_heterophily': max(class_values) if class_values else 0.0,
        'mean_heterophily': sum(class_values) / len(class_values) if class_values else 0.0,
        'std_heterophily': np.std(class_values) if class_values else 0.0
    }
    
    # Print results
    print(f"Graph Heterophily Analysis:")
    print(f"Overall heterophily: {overall_heterophily:.4f}")
    print(f"Class heterophily:")
    for c, score in class_heterophily.items():
        print(f"  Class {c}: {score:.4f}")
    print(f"Min class heterophily: {stats['min_heterophily']:.4f}")
    print(f"Max class heterophily: {stats['max_heterophily']:.4f}")
    print(f"Mean class heterophily: {stats['mean_heterophily']:.4f}")
    print(f"Std class heterophily: {stats['std_heterophily']:.4f}")
    
    return stats


def compute_heterophily_matrix(data):
    """
    Compute pairwise heterophily between classes.
    
    Args:
        data: PyG Data object containing edge_index and y (node labels)
        
    Returns:
        torch.Tensor: Matrix where H[i,j] represents the proportion of edges
                    from class i to class j relative to total edges from class i
    """
    # Extract edge index and node labels
    edge_index = data.edge_index
    labels = data.y.squeeze()
    
    # Get number of classes
    if labels.dim() > 1 and labels.shape[1] > 1:
        # For multi-label classification, use argmax
        num_classes = labels.shape[1]
        labels = labels.argmax(dim=1)
    else:
        # For single-label classification
        num_classes = int(labels.max()) + 1
    
    # Get source and target nodes for each edge
    src, dst = edge_index[0], edge_index[1]
    
    # Get labels for the source and target nodes
    src_labels = labels[src]
    dst_labels = labels[dst]
    
    # Initialize edge count matrix
    edge_counts = torch.zeros((num_classes, num_classes), device=edge_index.device)
    
    # Count edges between each pair of classes
    for i in range(len(src)):
        source_class = src_labels[i]
        target_class = dst_labels[i]
        edge_counts[source_class, target_class] += 1
    
    # Normalize by row sums to get proportions
    row_sums = edge_counts.sum(dim=1, keepdim=True)
    row_sums[row_sums == 0] = 1  # Avoid division by zero
    heterophily_matrix = edge_counts / row_sums
    
    return heterophily_matrix


if __name__ == "__main__":
    import os
    import sys
    from torch_geometric.datasets import Planetoid, Amazon, Coauthor, Twitch
    from ogb.nodeproppred import PygNodePropPredDataset
    from torch_geometric.utils import to_undirected
    
    # Set up data directory
    data_dir = "../../data/"
    
    # Create output directory if it doesn't exist
    output_dir = os.path.join(data_dir, "heterophily_analysis")
    os.makedirs(output_dir, exist_ok=True)
    
    # Output file
    output_file = os.path.join(output_dir, "heterophily_statistics.txt")
    
    # Function to write results to file
    def write_to_file(dataset_name, stats, file_handle):
        file_handle.write(f"=== {dataset_name} Dataset ===\n")
        file_handle.write(f"Overall heterophily: {stats['overall_heterophily']:.4f}\n")
        file_handle.write("Class heterophily:\n")
        for c, score in stats['class_heterophily'].items():
            file_handle.write(f"  Class {c}: {score:.4f}\n")
        file_handle.write(f"Min class heterophily: {stats['min_heterophily']:.4f}\n")
        file_handle.write(f"Max class heterophily: {stats['max_heterophily']:.4f}\n")
        file_handle.write(f"Mean class heterophily: {stats['mean_heterophily']:.4f}\n")
        file_handle.write(f"Std class heterophily: {stats['std_heterophily']:.4f}\n")
        file_handle.write("\n")
    
    # Open file for writing
    with open(output_file, 'w') as f:
        f.write("Graph Dataset Heterophily Analysis\n")
        f.write("================================\n\n")
        
        # Process Planetoid datasets (Cora, Citeseer, Pubmed)
        for dataset_name in ['cora', 'citeseer', 'pubmed']:
            try:
                print(f"Processing {dataset_name}...")
                dataset = Planetoid(root=os.path.join(data_dir, 'Planetoid'), 
                                   name=dataset_name, 
                                   split='public')[0]
                stats = analyze_graph_heterophily(dataset)
                write_to_file(dataset_name.capitalize(), stats, f)
            except Exception as e:
                print(f"Error processing {dataset_name}: {e}")
                f.write(f"=== {dataset_name.capitalize()} Dataset ===\n")
                f.write(f"Error processing dataset: {e}\n\n")
        
        # Process Amazon datasets
        for dataset_name in ['Computers', 'Photo']:
            try:
                print(f"Processing Amazon-{dataset_name}...")
                dataset = Amazon(root=os.path.join(data_dir, 'Amazon'), 
                               name=dataset_name)[0]
                stats = analyze_graph_heterophily(dataset)
                write_to_file(f"Amazon-{dataset_name}", stats, f)
            except Exception as e:
                print(f"Error processing Amazon-{dataset_name}: {e}")
                f.write(f"=== Amazon-{dataset_name} Dataset ===\n")
                f.write(f"Error processing dataset: {e}\n\n")
        
        # Process Coauthor datasets
        for dataset_name in ['CS', 'Physics']:
            try:
                print(f"Processing Coauthor-{dataset_name}...")
                dataset = Coauthor(root=os.path.join(data_dir, 'Coauthor'), 
                                 name=dataset_name)[0]
                stats = analyze_graph_heterophily(dataset)
                write_to_file(f"Coauthor-{dataset_name}", stats, f)
            except Exception as e:
                print(f"Error processing Coauthor-{dataset_name}: {e}")
                f.write(f"=== Coauthor-{dataset_name} Dataset ===\n")
                f.write(f"Error processing dataset: {e}\n\n")
        
        # Process OGB datasets
        try:
            print("Processing ogbn-arxiv...")
            dataset = PygNodePropPredDataset(name='ogbn-arxiv', 
                                          root=os.path.join(data_dir, 'ogb'))[0]
            # Convert to undirected graph
            dataset.edge_index = to_undirected(dataset.edge_index)
            stats = analyze_graph_heterophily(dataset)
            write_to_file("OGB-ArXiv", stats, f)
        except Exception as e:
            print(f"Error processing ogbn-arxiv: {e}")
            f.write(f"=== OGB-ArXiv Dataset ===\n")
            f.write(f"Error processing dataset: {e}\n\n")
        
        # Process Twitch datasets
        for region in ['DE', 'EN', 'ES', 'FR', 'PT', 'RU']:
            try:
                print(f"Processing Twitch-{region}...")
                dataset = Twitch(root=os.path.join(data_dir, 'Twitch'), 
                              name=region)[0]
                stats = analyze_graph_heterophily(dataset)
                write_to_file(f"Twitch-{region}", stats, f)
            except Exception as e:
                print(f"Error processing Twitch-{region}: {e}")
                f.write(f"=== Twitch-{region} Dataset ===\n")
                f.write(f"Error processing dataset: {e}\n\n")
        
        # Add heterophily matrices for selected datasets
        f.write("\n\nHeterophily Matrices\n")
        f.write("==================\n\n")
        
        # Add matrices for a few important datasets
        for dataset_info in [
            ('Cora', 'cora', Planetoid, {'root': os.path.join(data_dir, 'Planetoid'), 'name': 'cora', 'split': 'public'}),
            ('Citeseer', 'citeseer', Planetoid, {'root': os.path.join(data_dir, 'Planetoid'), 'name': 'citeseer', 'split': 'public'}),
            ('Pubmed', 'pubmed', Planetoid, {'root': os.path.join(data_dir, 'Planetoid'), 'name': 'pubmed', 'split': 'public'}),
            ('Amazon-Computers', 'Computers', Amazon, {'root': os.path.join(data_dir, 'Amazon'), 'name': 'Computers'}),
            ('Amazon-Photo', 'Photo', Amazon, {'root': os.path.join(data_dir, 'Amazon'), 'name': 'Photo'})
        ]:
            try:
                display_name, name, dataset_class, kwargs = dataset_info
                dataset = dataset_class(**kwargs)[0]
                
                # Compute and write heterophily matrix
                matrix = compute_heterophily_matrix(dataset)
                f.write(f"{display_name} Heterophily Matrix:\n")
                f.write(f"{matrix.cpu().numpy()}\n\n")
            except Exception as e:
                print(f"Error processing heterophily matrix for {display_name}: {e}")
                f.write(f"{display_name} Heterophily Matrix: Error - {e}\n\n")
    
    print(f"Analysis complete! Results saved to {output_file}") 