"""
Comprehensive Dataset Analyzer for GraGR Research
================================================

This script provides comprehensive analysis of all datasets used in the GraGR research,
including detailed statistics, visualizations, and comparisons.
"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.datasets import Planetoid, WikiCS, WebKB, TUDataset
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, degree, to_undirected
import networkx as nx
import warnings
warnings.filterwarnings('ignore')

class ComprehensiveDatasetAnalyzer:
    """Comprehensive analyzer for all datasets in GraGR research."""
    
    def __init__(self):
        self.datasets = {}
        self.dataset_stats = {}
        self.results = []
        
    def load_all_datasets(self):
        """Load all datasets for comprehensive analysis."""
        print("Loading all datasets for comprehensive analysis...")
        
        # Citation networks
        self._load_citation_networks()
        
        # WebKB datasets
        self._load_webkb_datasets()
        
        # Structural graphs
        self._load_structural_graphs()
        
        # Molecular datasets
        self._load_molecular_datasets()
        
        # TU datasets
        self._load_tu_datasets()
        
        print(f"✓ Loaded {len(self.datasets)} datasets successfully")
        
    def _load_citation_networks(self):
        """Load citation network datasets."""
        print("  → Loading citation networks...")
        
        try:
            cora_dataset = Planetoid(root='./data', name='Cora')
            self.datasets['Cora'] = cora_dataset[0]
            print("    ✓ Cora loaded")
        except Exception as e:
            print(f"    ✗ Error loading Cora: {e}")
            
        try:
            citeseer_dataset = Planetoid(root='./data', name='CiteSeer')
            self.datasets['CiteSeer'] = citeseer_dataset[0]
            print("    ✓ CiteSeer loaded")
        except Exception as e:
            print(f"    ✗ Error loading CiteSeer: {e}")
            
        try:
            pubmed_dataset = Planetoid(root='./data', name='PubMed')
            self.datasets['PubMed'] = pubmed_dataset[0]
            print("    ✓ PubMed loaded")
        except Exception as e:
            print(f"    ✗ Error loading PubMed: {e}")
    
    def _load_webkb_datasets(self):
        """Load WebKB datasets."""
        print("  → Loading WebKB datasets...")
        
        webkb_datasets = ['Texas', 'Cornell', 'Wisconsin']
        
        for name in webkb_datasets:
            try:
                dataset = WebKB(root='./data', name=name)
                data = dataset[0]
                
                # Fix WebKB datasets masks
                if hasattr(data, 'train_mask') and data.train_mask.dim() > 1:
                    data.train_mask = data.train_mask[:, 0]
                if hasattr(data, 'val_mask') and data.val_mask.dim() > 1:
                    data.val_mask = data.val_mask[:, 0]
                if not hasattr(data, 'test_mask') or data.test_mask.dim() > 1:
                    test_indices = ~(data.train_mask | data.val_mask)
                    data.test_mask = test_indices
                
                # Ensure masks are boolean
                data.train_mask = data.train_mask.bool()
                if hasattr(data, 'val_mask'):
                    data.val_mask = data.val_mask.bool()
                if hasattr(data, 'test_mask'):
                    data.test_mask = data.test_mask.bool()
                
                self.datasets[name] = data
                print(f"    ✓ {name} loaded")
            except Exception as e:
                print(f"    ✗ Error loading {name}: {e}")
    
    def _load_structural_graphs(self):
        """Load structural graph datasets."""
        print("  → Loading structural graphs...")
        
        try:
            wikics_dataset = WikiCS(root='./data')
            data = wikics_dataset[0]
            
            # Fix WikiCS masks
            if hasattr(data, 'train_mask') and data.train_mask.dim() > 1:
                data.train_mask = data.train_mask[:, 0]
                data.val_mask = data.val_mask[:, 0]
                
            if not hasattr(data, 'test_mask') or data.test_mask.dim() > 1:
                test_indices = ~(data.train_mask | data.val_mask)
                data.test_mask = test_indices
                
            data.train_mask = data.train_mask.bool()
            data.val_mask = data.val_mask.bool()
            data.test_mask = data.test_mask.bool()
            
            self.datasets['WikiCS'] = data
            print("    ✓ WikiCS loaded")
        except Exception as e:
            print(f"    ✗ Error loading WikiCS: {e}")
    
    def _load_molecular_datasets(self):
        """Load molecular datasets."""
        print("  → Loading molecular datasets...")
        
        try:
            from ogb.graphproppred import PygGraphPropPredDataset
            molhiv_dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root='./data')
            self.datasets['OGB-MolHIV'] = molhiv_dataset
            print("    ✓ OGB-MolHIV loaded")
        except Exception as e:
            print(f"    ✗ Error loading OGB-MolHIV: {e}")
    
    def _load_tu_datasets(self):
        """Load TU datasets (PROTEINS and MUTAG)."""
        print("  → Loading TU datasets...")
        
        try:
            proteins_dataset = TUDataset(root='./data', name='PROTEINS')
            self.datasets['TU-PROTEINS'] = proteins_dataset
            print(f"    ✓ TU-PROTEINS loaded: {len(proteins_dataset)} graphs")
        except Exception as e:
            print(f"    ✗ Error loading TU-PROTEINS: {e}")
            
        try:
            mutag_dataset = TUDataset(root='./data', name='MUTAG')
            self.datasets['TU-MUTAG'] = mutag_dataset
            print(f"    ✓ TU-MUTAG loaded: {len(mutag_dataset)} graphs")
        except Exception as e:
            print(f"    ✗ Error loading TU-MUTAG: {e}")
    
    def analyze_single_graph_dataset(self, name, data):
        """Analyze a single graph dataset."""
        stats = {
            'Dataset': name,
            'Task_Level': 'Node',
            'Task_Type': 'Classification',
            'Training_Type': 'Transductive',
            'Category': self._get_dataset_category(name)
        }
        
        # Basic statistics
        stats['Num_Graphs'] = 1
        stats['Num_Nodes'] = data.num_nodes
        stats['Num_Edges'] = data.num_edges
        stats['Num_Features'] = data.num_node_features
        stats['Num_Labels'] = len(torch.unique(data.y)) if hasattr(data, 'y') else 0
        
        # Split sizes
        if hasattr(data, 'train_mask'):
            stats['Train_Size'] = data.train_mask.sum().item()
        else:
            stats['Train_Size'] = 'N/A'
            
        if hasattr(data, 'val_mask'):
            stats['Val_Size'] = data.val_mask.sum().item()
        else:
            stats['Val_Size'] = 'N/A'
            
        if hasattr(data, 'test_mask'):
            stats['Test_Size'] = data.test_mask.sum().item()
        else:
            stats['Test_Size'] = 'N/A'
        
        # Graph properties
        if data.num_nodes > 0:
            max_edges = data.num_nodes * (data.num_nodes - 1) // 2
            stats['Density'] = data.num_edges / max_edges if max_edges > 0 else 0
            stats['Avg_Degree'] = (2 * data.num_edges) / data.num_nodes
        else:
            stats['Density'] = 0
            stats['Avg_Degree'] = 0
        
        # Heterophily (for node classification)
        if hasattr(data, 'y') and data.y is not None:
            stats['Heterophily'] = self._calculate_heterophily(data)
        else:
            stats['Heterophily'] = 'N/A'
        
        return stats
    
    def analyze_multi_graph_dataset(self, name, dataset):
        """Analyze a multi-graph dataset."""
        stats = {
            'Dataset': name,
            'Task_Level': 'Graph',
            'Task_Type': 'Classification',
            'Training_Type': 'Inductive',
            'Category': self._get_dataset_category(name)
        }
        
        # Basic statistics
        stats['Num_Graphs'] = len(dataset)
        
        # Aggregate statistics across all graphs
        num_nodes_list = []
        num_edges_list = []
        num_features_list = []
        labels_list = []
        
        for data in dataset:
            num_nodes_list.append(data.num_nodes)
            num_edges_list.append(data.num_edges)
            num_features_list.append(data.num_node_features)
            if hasattr(data, 'y') and data.y is not None:
                labels_list.append(data.y.item())
        
        stats['Num_Nodes'] = f"{np.mean(num_nodes_list):.1f}±{np.std(num_nodes_list):.1f}"
        stats['Num_Edges'] = f"{np.mean(num_edges_list):.1f}±{np.std(num_edges_list):.1f}"
        stats['Num_Features'] = num_features_list[0] if num_features_list else 0
        stats['Num_Labels'] = len(set(labels_list)) if labels_list else 0
        
        # Split sizes (approximate for multi-graph datasets)
        if name == 'OGB-MolHIV':
            stats['Train_Size'] = 32901
            stats['Val_Size'] = 4113
            stats['Test_Size'] = 4113
        else:
            # For TU datasets, use standard split
            total_graphs = len(dataset)
            train_size = int(0.8 * total_graphs)
            val_size = int(0.1 * total_graphs)
            test_size = total_graphs - train_size - val_size
            stats['Train_Size'] = train_size
            stats['Val_Size'] = val_size
            stats['Test_Size'] = test_size
        
        # Graph properties (average)
        avg_nodes = np.mean(num_nodes_list)
        avg_edges = np.mean(num_edges_list)
        if avg_nodes > 0:
            max_edges = avg_nodes * (avg_nodes - 1) // 2
            stats['Density'] = avg_edges / max_edges if max_edges > 0 else 0
            stats['Avg_Degree'] = (2 * avg_edges) / avg_nodes
        else:
            stats['Density'] = 0
            stats['Avg_Degree'] = 0
        
        stats['Heterophily'] = 'N/A'  # Not applicable for graph-level tasks
        
        return stats
    
    def _get_dataset_category(self, name):
        """Get dataset category."""
        if name in ['Cora', 'CiteSeer', 'PubMed']:
            return 'Citation Networks'
        elif name in ['Texas', 'Cornell', 'Wisconsin']:
            return 'WebKB Datasets'
        elif name == 'WikiCS':
            return 'Structural Graphs'
        elif name == 'OGB-MolHIV':
            return 'Molecular Datasets'
        elif name in ['TU-PROTEINS', 'TU-MUTAG']:
            return 'TU Datasets'
        else:
            return 'Other'
    
    def _calculate_heterophily(self, data):
        """Calculate heterophily for a graph."""
        if not hasattr(data, 'y') or data.y is None:
            return 'N/A'
        
        edge_index = data.edge_index
        y = data.y
        
        if edge_index.size(1) == 0:
            return 0.0
        
        # Calculate edge homophily
        edge_homophily = (y[edge_index[0]] == y[edge_index[1]]).float().mean().item()
        heterophily = 1 - edge_homophily
        
        return round(heterophily, 3)
    
    def analyze_all_datasets(self):
        """Analyze all loaded datasets."""
        print("\nAnalyzing all datasets...")
        
        for name, data in self.datasets.items():
            print(f"  → Analyzing {name}...")
            
            # Check if it's a multi-graph dataset
            if (isinstance(data, list) or 
                hasattr(data, '__len__') and not hasattr(data, 'num_nodes') or
                name == 'OGB-MolHIV'):  # OGB-MolHIV is a dataset object
                # Multi-graph dataset
                stats = self.analyze_multi_graph_dataset(name, data)
            else:
                # Single graph dataset
                stats = self.analyze_single_graph_dataset(name, data)
            
            self.dataset_stats[name] = stats
            self.results.append(stats)
            print(f"    ✓ {name} analyzed")
    
    def generate_detailed_analysis(self):
        """Generate detailed analysis for TU datasets."""
        print("\nGenerating detailed analysis for TU datasets...")
        
        tu_analysis = {}
        
        for name in ['TU-PROTEINS', 'TU-MUTAG']:
            if name in self.datasets:
                print(f"  → Detailed analysis for {name}...")
                dataset = self.datasets[name]
                
                # Graph size distribution
                graph_sizes = [data.num_nodes for data in dataset]
                edge_counts = [data.num_edges for data in dataset]
                
                # Label distribution
                labels = [data.y.item() for data in dataset]
                label_counts = {}
                for label in labels:
                    label_counts[label] = label_counts.get(label, 0) + 1
                
                # Feature analysis
                feature_dims = [data.num_node_features for data in dataset]
                
                # Edge attribute analysis (if available)
                has_edge_attrs = any(hasattr(data, 'edge_attr') and data.edge_attr is not None for data in dataset)
                
                tu_analysis[name] = {
                    'total_graphs': len(dataset),
                    'graph_sizes': {
                        'mean': np.mean(graph_sizes),
                        'std': np.std(graph_sizes),
                        'min': np.min(graph_sizes),
                        'max': np.max(graph_sizes)
                    },
                    'edge_counts': {
                        'mean': np.mean(edge_counts),
                        'std': np.std(edge_counts),
                        'min': np.min(edge_counts),
                        'max': np.max(edge_counts)
                    },
                    'label_distribution': label_counts,
                    'feature_dimension': feature_dims[0] if feature_dims else 0,
                    'has_edge_attributes': has_edge_attrs
                }
                
                print(f"    ✓ {name} detailed analysis complete")
        
        return tu_analysis
    
    def create_dataset_table(self):
        """Create comprehensive dataset table."""
        print("\nCreating comprehensive dataset table...")
        
        # Convert results to DataFrame
        df = pd.DataFrame(self.results)
        
        # Reorder columns for better readability
        column_order = [
            'Dataset', 'Num_Graphs', 'Num_Nodes', 'Num_Edges', 'Num_Features', 
            'Num_Labels', 'Task_Level', 'Task_Type', 'Training_Type', 
            'Train_Size', 'Val_Size', 'Test_Size', 'Density', 'Avg_Degree', 
            'Heterophily', 'Category'
        ]
        
        df = df[column_order]
        
        # Save to CSV
        df.to_csv('Research_Tables_and_Results/comprehensive_dataset_statistics.csv', index=False)
        
        # Create markdown table
        markdown_table = df.to_markdown(index=False)
        
        with open('Research_Tables_and_Results/comprehensive_dataset_table.md', 'w') as f:
            f.write("# Comprehensive Dataset Statistics\n\n")
            f.write(markdown_table)
        
        print("    ✓ Dataset table created")
        return df
    
    def create_visualizations(self):
        """Create comprehensive visualizations."""
        print("\nCreating dataset visualizations...")
        
        # Set style
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
        
        # 1. Dataset size comparison
        self._create_size_comparison_plot()
        
        # 2. TU datasets detailed analysis
        self._create_tu_datasets_analysis()
        
        # 3. Dataset categories overview
        self._create_categories_overview()
        
        print("    ✓ Visualizations created")
    
    def _create_size_comparison_plot(self):
        """Create dataset size comparison plot."""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('Dataset Size Comparison', fontsize=16, fontweight='bold')
        
        # Extract data for plotting
        datasets = []
        nodes = []
        edges = []
        features = []
        labels = []
        
        for stats in self.results:
            datasets.append(stats['Dataset'])
            if isinstance(stats['Num_Nodes'], str):
                # Handle multi-graph datasets
                nodes.append(float(stats['Num_Nodes'].split('±')[0]))
            else:
                nodes.append(stats['Num_Nodes'])
            
            if isinstance(stats['Num_Edges'], str):
                edges.append(float(stats['Num_Edges'].split('±')[0]))
            else:
                edges.append(stats['Num_Edges'])
            
            features.append(stats['Num_Features'])
            labels.append(stats['Num_Labels'])
        
        # Node count
        axes[0, 0].bar(datasets, nodes, color='skyblue', alpha=0.7)
        axes[0, 0].set_title('Number of Nodes')
        axes[0, 0].set_ylabel('Nodes')
        axes[0, 0].tick_params(axis='x', rotation=45)
        
        # Edge count
        axes[0, 1].bar(datasets, edges, color='lightcoral', alpha=0.7)
        axes[0, 1].set_title('Number of Edges')
        axes[0, 1].set_ylabel('Edges')
        axes[0, 1].tick_params(axis='x', rotation=45)
        
        # Feature count
        axes[1, 0].bar(datasets, features, color='lightgreen', alpha=0.7)
        axes[1, 0].set_title('Number of Features')
        axes[1, 0].set_ylabel('Features')
        axes[1, 0].tick_params(axis='x', rotation=45)
        
        # Label count
        axes[1, 1].bar(datasets, labels, color='gold', alpha=0.7)
        axes[1, 1].set_title('Number of Labels')
        axes[1, 1].set_ylabel('Labels')
        axes[1, 1].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.savefig('Research_Tables_and_Results/dataset_size_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_tu_datasets_analysis(self):
        """Create detailed analysis for TU datasets."""
        tu_datasets = ['TU-PROTEINS', 'TU-MUTAG']
        available_tu = [name for name in tu_datasets if name in self.datasets]
        
        if not available_tu:
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('TU Datasets Detailed Analysis', fontsize=16, fontweight='bold')
        
        for i, name in enumerate(available_tu):
            dataset = self.datasets[name]
            
            # Graph size distribution
            graph_sizes = [data.num_nodes for data in dataset]
            edge_counts = [data.num_edges for data in dataset]
            
            # Label distribution
            labels = [data.y.item() for data in dataset]
            unique_labels, label_counts = np.unique(labels, return_counts=True)
            
            # Plot graph size distribution
            axes[0, i].hist(graph_sizes, bins=20, alpha=0.7, color='skyblue')
            axes[0, i].set_title(f'{name} - Graph Size Distribution')
            axes[0, i].set_xlabel('Number of Nodes')
            axes[0, i].set_ylabel('Frequency')
            
            # Plot edge count distribution
            axes[1, i].hist(edge_counts, bins=20, alpha=0.7, color='lightcoral')
            axes[1, i].set_title(f'{name} - Edge Count Distribution')
            axes[1, i].set_xlabel('Number of Edges')
            axes[1, i].set_ylabel('Frequency')
        
        plt.tight_layout()
        plt.savefig('Research_Tables_and_Results/tu_datasets_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_categories_overview(self):
        """Create dataset categories overview."""
        # Count datasets by category
        category_counts = {}
        for stats in self.results:
            category = stats['Category']
            category_counts[category] = category_counts.get(category, 0) + 1
        
        # Create pie chart
        plt.figure(figsize=(10, 8))
        plt.pie(category_counts.values(), labels=category_counts.keys(), autopct='%1.1f%%', startangle=90)
        plt.title('Dataset Distribution by Category', fontsize=16, fontweight='bold')
        plt.axis('equal')
        plt.tight_layout()
        plt.savefig('Research_Tables_and_Results/dataset_categories_overview.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def print_summary(self):
        """Print comprehensive summary."""
        print("\n" + "="*80)
        print("COMPREHENSIVE DATASET ANALYSIS SUMMARY")
        print("="*80)
        
        print(f"\nTotal datasets analyzed: {len(self.results)}")
        
        # Category breakdown
        categories = {}
        for stats in self.results:
            category = stats['Category']
            categories[category] = categories.get(category, 0) + 1
        
        print("\nDataset categories:")
        for category, count in categories.items():
            print(f"  {category}: {count} datasets")
        
        # Task level breakdown
        task_levels = {}
        for stats in self.results:
            level = stats['Task_Level']
            task_levels[level] = task_levels.get(level, 0) + 1
        
        print("\nTask levels:")
        for level, count in task_levels.items():
            print(f"  {level}: {count} datasets")
        
        # Size ranges
        print("\nDataset size ranges:")
        for stats in self.results:
            name = stats['Dataset']
            if isinstance(stats['Num_Nodes'], str):
                nodes = stats['Num_Nodes']
            else:
                nodes = f"{stats['Num_Nodes']:,}"
            print(f"  {name}: {nodes} nodes")
        
        print("\n" + "="*80)
        print("ANALYSIS COMPLETED")
        print("="*80)
        print("\nGenerated files:")
        print("📊 comprehensive_dataset_statistics.csv - Complete dataset statistics")
        print("📊 comprehensive_dataset_table.md - Formatted markdown table")
        print("📊 dataset_size_comparison.png - Size comparison visualization")
        print("📊 tu_datasets_analysis.png - TU datasets detailed analysis")
        print("📊 dataset_categories_overview.png - Categories overview")

def main():
    """Main function to run comprehensive dataset analysis."""
    analyzer = ComprehensiveDatasetAnalyzer()
    
    # Load all datasets
    analyzer.load_all_datasets()
    
    if not analyzer.datasets:
        print("No datasets available! Exiting.")
        return
    
    # Analyze all datasets
    analyzer.analyze_all_datasets()
    
    # Generate detailed analysis for TU datasets
    tu_analysis = analyzer.generate_detailed_analysis()
    
    # Create comprehensive table
    df = analyzer.create_dataset_table()
    
    # Create visualizations
    analyzer.create_visualizations()
    
    # Print summary
    analyzer.print_summary()
    
    # Print detailed TU analysis
    if tu_analysis:
        print("\n" + "="*60)
        print("TU DATASETS DETAILED ANALYSIS")
        print("="*60)
        
        for name, analysis in tu_analysis.items():
            print(f"\n{name}:")
            print(f"  Total graphs: {analysis['total_graphs']}")
            print(f"  Graph sizes: {analysis['graph_sizes']['mean']:.1f}±{analysis['graph_sizes']['std']:.1f} nodes")
            print(f"  Edge counts: {analysis['edge_counts']['mean']:.1f}±{analysis['edge_counts']['std']:.1f} edges")
            print(f"  Feature dimension: {analysis['feature_dimension']}")
            print(f"  Has edge attributes: {analysis['has_edge_attributes']}")
            print(f"  Label distribution: {analysis['label_distribution']}")

if __name__ == "__main__":
    main()