#!/usr/bin/env python3
"""
Comprehensive graph analysis for epitope prediction representations.

Analyzes and compares different graph construction methods:
- Base: Simple residue-only graphs  
- Simple: Epiformer with uni-relational edges
- GearNet: Epiformer with 7-relation edges
- RAAD: Epiformer with 4-relation edges

Key analyses:
- Connectivity statistics and degree distributions
- Epitope-specific patterns and binding interface characteristics
- Multi-relational edge type comparisons
- Epiformer structure properties
- Predictive power indicators

Output: Comprehensive plots and statistics for graph representation comparison.
"""

import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import networkx as nx
from collections import defaultdict
import argparse
from scipy import stats
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


class GraphAnalyzer:
    """Comprehensive analysis of protein graph representations for epitope prediction"""
    
    def __init__(self, output_dir: str = "graph_analysis_results"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.results = {}
        
    def load_datasets(self, dataset_paths: Dict[str, str]) -> Dict[str, List]:
        """Load all graph datasets"""
        print("Loading datasets...")
        datasets = {}
        
        for graph_type, path in dataset_paths.items():
            if not os.path.exists(path):
                print(f"Warning: {path} not found, skipping {graph_type}")
                continue
                
            print(f"  Loading {graph_type} from {path}")
            try:
                data = torch.load(path)
                datasets[graph_type] = data
                print(f"    Loaded {len(data)} complexes")
            except Exception as e:
                print(f"    Error loading {graph_type}: {e}")
                
        return datasets
    
    def compute_basic_statistics(self, datasets: Dict[str, List]) -> Dict:
        """Compute fundamental graph statistics"""
        print("\nComputing basic graph statistics...")
        stats = defaultdict(dict)
        
        for graph_type, dataset in datasets.items():
            print(f"  Analyzing {graph_type}...")
            
            # Initialize collectors
            node_counts = {'ag': [], 'ab': []}
            edge_counts = {'ag': [], 'ab': []}
            degrees = {'ag': [], 'ab': []}
            epitope_ratios = []
            edge_densities = {'ag': [], 'ab': []}
            
            for complex_data in dataset:
                try:
                    if graph_type == 'base':
                        # Base graphs use HeteroData structure
                        ag_nodes = complex_data['ag_res'].x.shape[0]
                        ab_nodes = complex_data['ab_res'].x.shape[0]
                        
                        node_counts['ag'].append(ag_nodes)
                        node_counts['ab'].append(ab_nodes)
                        
                        # Count edges from HeteroData structure
                        ag_edges = 0
                        ab_edges = 0
                        
                        # Count antigen edges
                        if ('ag_res', 'connects', 'ag_res') in complex_data.edge_types:
                            ag_edge_index = complex_data['ag_res', 'connects', 'ag_res'].edge_index
                            if ag_edge_index is not None:
                                ag_edges = ag_edge_index.shape[1]
                        
                        # Count antibody edges  
                        if ('ab_res', 'connects', 'ab_res') in complex_data.edge_types:
                            ab_edge_index = complex_data['ab_res', 'connects', 'ab_res'].edge_index
                            if ab_edge_index is not None:
                                ab_edges = ab_edge_index.shape[1]
                        
                        edge_counts['ag'].append(ag_edges)
                        edge_counts['ab'].append(ab_edges)
                        
                        # Degree analysis for HeteroData
                        ag_degrees = []
                        ab_degrees = []
                        
                        if ('ag_res', 'connects', 'ag_res') in complex_data.edge_types:
                            ag_edge_index = complex_data['ag_res', 'connects', 'ag_res'].edge_index
                            if ag_edge_index is not None and ag_edge_index.shape[1] > 0:
                                ag_degree_list = torch.bincount(ag_edge_index[0], minlength=ag_nodes)
                                ag_degrees = ag_degree_list.tolist()
                        
                        if ('ab_res', 'connects', 'ab_res') in complex_data.edge_types:
                            ab_edge_index = complex_data['ab_res', 'connects', 'ab_res'].edge_index
                            if ab_edge_index is not None and ab_edge_index.shape[1] > 0:
                                ab_degree_list = torch.bincount(ab_edge_index[0], minlength=ab_nodes)
                                ab_degrees = ab_degree_list.tolist()
                        
                        degrees['ag'].extend(ag_degrees if ag_degrees else [0] * ag_nodes)
                        degrees['ab'].extend(ab_degrees if ab_degrees else [0] * ab_nodes)
                        
                    else:
                        # Epiformer graphs
                        ag_nodes = complex_data['ag_res'].x.shape[0]
                        ab_nodes = complex_data['ab_res'].x.shape[0]
                        
                        node_counts['ag'].append(ag_nodes)
                        node_counts['ab'].append(ab_nodes)
                        
                        # Edge counts
                        ag_edges = self._count_edges(complex_data, 'ag_res')
                        ab_edges = self._count_edges(complex_data, 'ab_res')
                        
                        edge_counts['ag'].append(ag_edges)
                        edge_counts['ab'].append(ab_edges)
                        
                        # Degree analysis
                        ag_degrees = self._compute_degrees(complex_data, 'ag_res')
                        ab_degrees = self._compute_degrees(complex_data, 'ab_res')
                        
                        degrees['ag'].extend(ag_degrees)
                        degrees['ab'].extend(ab_degrees)
                    
                    # Epitope ratio
                    if graph_type == 'base':
                        epitope_labels = complex_data['ag_res'].y
                    else:
                        epitope_labels = complex_data['ag_res'].y
                    
                    epitope_ratio = epitope_labels.float().mean().item()
                    epitope_ratios.append(epitope_ratio)
                    
                    # Edge density
                    for chain in ['ag', 'ab']:
                        n_nodes = node_counts[chain][-1]
                        n_edges = edge_counts[chain][-1]
                        max_edges = n_nodes * (n_nodes - 1) / 2
                        density = n_edges / max_edges if max_edges > 0 else 0
                        edge_densities[chain].append(density)
                        
                except Exception as e:
                    print(f"    Error processing complex: {e}")
                    continue
            
            # Store statistics
            stats[graph_type] = {
                'node_counts': node_counts,
                'edge_counts': edge_counts,
                'degrees': degrees,
                'epitope_ratios': epitope_ratios,
                'edge_densities': edge_densities,
                'n_complexes': len(dataset)
            }
            
        return dict(stats)
    
    def _count_edges(self, complex_data, chain_key):
        """Count edges for a chain in epiformer graphs"""
        total_edges = 0
        
        # Try different edge type patterns
        edge_patterns = [f'{chain_key}', 'connects', 'r0', 'r1', 'r2', 'r3', 'r4', 'r5', 'r6']
        
        for key in complex_data.edge_types:
            if chain_key in str(key):
                try:
                    edge_index = complex_data[key].edge_index
                    if edge_index is not None:
                        total_edges += edge_index.shape[1]
                except:
                    continue
                    
        return total_edges
    
    def _compute_degrees(self, complex_data, chain_key):
        """Compute node degrees for a chain"""
        degrees = []
        
        for key in complex_data.edge_types:
            if chain_key in str(key) and str(key).count(chain_key) == 2:  # Self-loops only
                try:
                    edge_index = complex_data[key].edge_index
                    if edge_index is not None and edge_index.shape[1] > 0:
                        node_degrees = torch.bincount(edge_index[0]).float()
                        degrees.extend(node_degrees.tolist())
                        break  # Take first valid edge set
                except:
                    continue
                    
        return degrees if degrees else [0]  # Return [0] if no edges found
    
    def analyze_epitope_patterns(self, datasets: Dict[str, List]) -> Dict:
        """Analyze epitope-specific connectivity patterns"""
        print("\nAnalyzing epitope patterns...")
        epitope_stats = defaultdict(dict)
        
        for graph_type, dataset in datasets.items():
            print(f"  Analyzing epitope patterns in {graph_type}...")
            
            epitope_degrees = []
            non_epitope_degrees = []
            epitope_clustering = []
            interface_distances = []
            
            for complex_data in dataset:
                try:
                    # Get epitope labels and degrees
                    if graph_type == 'base':
                        epitope_labels = complex_data['ag_res'].y
                        # Compute degrees for base graphs
                        ag_nodes = complex_data['ag_res'].x.shape[0]
                        ag_degrees = torch.zeros(ag_nodes)
                        if ('ag_res', 'connects', 'ag_res') in complex_data.edge_types:
                            ag_edge_index = complex_data['ag_res', 'connects', 'ag_res'].edge_index
                            if ag_edge_index is not None and ag_edge_index.shape[1] > 0:
                                ag_degrees = torch.bincount(ag_edge_index[0], minlength=ag_nodes)
                    else:
                        epitope_labels = complex_data['ag_res'].y
                        ag_degrees = self._compute_node_degrees_from_edges(complex_data, 'ag_res')
                    
                    # Separate epitope vs non-epitope degrees
                    epitope_mask = epitope_labels.bool()
                    
                    if epitope_mask.sum() > 0:
                        epitope_degrees.extend(ag_degrees[epitope_mask].tolist())
                    if (~epitope_mask).sum() > 0:
                        non_epitope_degrees.extend(ag_degrees[~epitope_mask].tolist())
                    
                    # Epitope clustering coefficient
                    epitope_indices = torch.where(epitope_mask)[0]
                    if len(epitope_indices) > 0:
                        clustering = self._compute_epitope_clustering(complex_data, epitope_indices, graph_type)
                        epitope_clustering.append(clustering)
                    
                except Exception as e:
                    continue
            
            epitope_stats[graph_type] = {
                'epitope_degrees': epitope_degrees,
                'non_epitope_degrees': non_epitope_degrees,
                'epitope_clustering': epitope_clustering,
                'interface_distances': interface_distances
            }
            
        return dict(epitope_stats)
    
    def _compute_node_degrees_from_edges(self, complex_data, chain_key):
        """Compute node degrees from edge indices"""
        n_nodes = complex_data[chain_key].x.shape[0]
        degrees = torch.zeros(n_nodes)
        
        for key in complex_data.edge_types:
            if chain_key in str(key) and str(key).count(chain_key) == 2:
                try:
                    edge_index = complex_data[key].edge_index
                    if edge_index is not None and edge_index.shape[1] > 0:
                        node_degrees = torch.bincount(edge_index[0], minlength=n_nodes)
                        degrees += node_degrees[:n_nodes]
                except:
                    continue
                    
        return degrees
    
    def _compute_epitope_clustering(self, complex_data, epitope_indices, graph_type):
        """Compute clustering coefficient for epitope residues"""
        if len(epitope_indices) < 2:
            return 0.0
            
        # Create subgraph of epitope residues
        try:
            if graph_type == 'base':
                # For base graphs, create NetworkX graph from HeteroData
                G = nx.Graph()
                G.add_nodes_from(epitope_indices.tolist())
                
                if ('ag_res', 'connects', 'ag_res') in complex_data.edge_types:
                    edge_index = complex_data['ag_res', 'connects', 'ag_res'].edge_index
                    if edge_index is not None:
                        for i, j in edge_index.T:
                            if i in epitope_indices and j in epitope_indices:
                                G.add_edge(i.item(), j.item())
                        
            else:
                # For epiformer graphs
                G = nx.Graph()
                G.add_nodes_from(epitope_indices.tolist())
                
                # Add edges from all relation types
                for key in complex_data.edge_types:
                    if 'ag_res' in str(key) and str(key).count('ag_res') == 2:
                        try:
                            edge_index = complex_data[key].edge_index
                            if edge_index is not None:
                                for i, j in edge_index.T:
                                    if i in epitope_indices and j in epitope_indices:
                                        G.add_edge(i.item(), j.item())
                        except:
                            continue
            
            # Compute clustering coefficient
            if G.number_of_edges() > 0:
                return nx.average_clustering(G)
            else:
                return 0.0
                
        except Exception as e:
            return 0.0
    
    def analyze_epiformer_levels(self, datasets: Dict[str, List]) -> Dict:
        """Analyze atom, edge, and residue graph levels in epiformer datasets"""
        print("\nAnalyzing epiformer graph levels...")
        epiformer_stats = defaultdict(dict)
        
        for graph_type, dataset in datasets.items():
            if graph_type == 'base':  # Skip base graphs
                continue
                
            print(f"  Analyzing epiformer levels in {graph_type}...")
            
            # Statistics collectors for each level
            level_stats = {
                'atom': {'node_counts': [], 'edge_counts': [], 'degrees': []},
                'residue': {'node_counts': [], 'edge_counts': [], 'degrees': []},
                'edge': {'node_counts': [], 'edge_counts': [], 'degrees': []}
            }
            
            for complex_data in dataset:
                try:
                    # Atom level analysis
                    if 'ag_atom' in complex_data.node_types and 'ab_atom' in complex_data.node_types:
                        ag_atom_nodes = complex_data['ag_atom'].x.shape[0]
                        ab_atom_nodes = complex_data['ab_atom'].x.shape[0]
                        level_stats['atom']['node_counts'].extend([ag_atom_nodes, ab_atom_nodes])
                        
                        # Atom edges
                        ag_atom_edges = self._count_edges_for_node_type(complex_data, 'ag_atom')
                        ab_atom_edges = self._count_edges_for_node_type(complex_data, 'ab_atom')
                        level_stats['atom']['edge_counts'].extend([ag_atom_edges, ab_atom_edges])
                        
                        # Atom degrees
                        ag_atom_degrees = self._compute_degrees_for_node_type(complex_data, 'ag_atom')
                        ab_atom_degrees = self._compute_degrees_for_node_type(complex_data, 'ab_atom')
                        level_stats['atom']['degrees'].extend(ag_atom_degrees + ab_atom_degrees)
                    
                    # Residue level analysis
                    if 'ag_res' in complex_data.node_types and 'ab_res' in complex_data.node_types:
                        ag_res_nodes = complex_data['ag_res'].x.shape[0]
                        ab_res_nodes = complex_data['ab_res'].x.shape[0]
                        level_stats['residue']['node_counts'].extend([ag_res_nodes, ab_res_nodes])
                        
                        # Residue edges
                        ag_res_edges = self._count_edges_for_node_type(complex_data, 'ag_res')
                        ab_res_edges = self._count_edges_for_node_type(complex_data, 'ab_res')
                        level_stats['residue']['edge_counts'].extend([ag_res_edges, ab_res_edges])
                        
                        # Residue degrees
                        ag_res_degrees = self._compute_degrees_for_node_type(complex_data, 'ag_res')
                        ab_res_degrees = self._compute_degrees_for_node_type(complex_data, 'ab_res')
                        level_stats['residue']['degrees'].extend(ag_res_degrees + ab_res_degrees)
                    
                    # Edge level analysis (line graph nodes)
                    if 'ag_edge' in complex_data.node_types and 'ab_edge' in complex_data.node_types:
                        ag_edge_nodes = complex_data['ag_edge'].x.shape[0]
                        ab_edge_nodes = complex_data['ab_edge'].x.shape[0]
                        level_stats['edge']['node_counts'].extend([ag_edge_nodes, ab_edge_nodes])
                        
                        # Edge-edge connections
                        ag_edge_edges = self._count_edges_for_node_type(complex_data, 'ag_edge')
                        ab_edge_edges = self._count_edges_for_node_type(complex_data, 'ab_edge')
                        level_stats['edge']['edge_counts'].extend([ag_edge_edges, ab_edge_edges])
                        
                        # Edge node degrees
                        ag_edge_degrees = self._compute_degrees_for_node_type(complex_data, 'ag_edge')
                        ab_edge_degrees = self._compute_degrees_for_node_type(complex_data, 'ab_edge')
                        level_stats['edge']['degrees'].extend(ag_edge_degrees + ab_edge_degrees)
                
                except Exception as e:
                    continue
            
            epiformer_stats[graph_type] = level_stats
            
        return dict(epiformer_stats)
    
    def _count_edges_for_node_type(self, complex_data, node_type):
        """Count edges involving a specific node type"""
        edge_count = 0
        
        for edge_key in complex_data.edge_types:
            edge_key_str = str(edge_key)
            if node_type in edge_key_str and edge_key_str.count(node_type) == 2:  # Self-connections
                try:
                    edge_index = complex_data[edge_key].edge_index
                    if edge_index is not None:
                        edge_count += edge_index.shape[1]
                except:
                    continue
                    
        return edge_count
    
    def _compute_degrees_for_node_type(self, complex_data, node_type):
        """Compute node degrees for a specific node type"""
        try:
            n_nodes = complex_data[node_type].x.shape[0]
            degrees = torch.zeros(n_nodes)
            
            for edge_key in complex_data.edge_types:
                edge_key_str = str(edge_key)
                if node_type in edge_key_str and edge_key_str.count(node_type) == 2:
                    try:
                        edge_index = complex_data[edge_key].edge_index
                        if edge_index is not None and edge_index.shape[1] > 0:
                            node_degrees = torch.bincount(edge_index[0], minlength=n_nodes)
                            degrees += node_degrees[:n_nodes]
                    except:
                        continue
                        
            return degrees.tolist()
        except:
            return [0]

    def create_epiformer_plots(self, epiformer_stats: Dict):
        """Create plots for epiformer level analysis"""
        print("\nGenerating epiformer level plots...")
        
        # Create epiformer analysis figure
        fig = plt.figure(figsize=(15, 10))
        gs = fig.add_gridspec(2, 3, hspace=0.4, wspace=0.3)
        
        # Plot 1: Node counts across levels
        self._plot_epiformer_node_counts(fig, gs[0, 0], epiformer_stats)
        
        # Plot 2: Edge counts across levels
        self._plot_epiformer_edge_counts(fig, gs[0, 1], epiformer_stats)
        
        # Plot 3: Degree distributions across levels
        self._plot_epiformer_degrees(fig, gs[0, 2], epiformer_stats)
        
        # Plot 4: Level comparison
        self._plot_level_comparison(fig, gs[1, :], epiformer_stats)
        
        plt.suptitle('Epiformer Graph Level Analysis', fontsize=16, fontweight='bold')
        plt.savefig(self.output_dir / 'epiformer_analysis.png', dpi=300, bbox_inches='tight')
        plt.savefig(self.output_dir / 'epiformer_analysis.pdf', bbox_inches='tight')
        print(f"  Saved epiformer analysis to {self.output_dir}")
        
    def _plot_epiformer_node_counts(self, fig, gs, epiformer_stats):
        """Plot node counts for each epiformer level"""
        ax = fig.add_subplot(gs)
        
        levels = ['atom', 'residue', 'edge']
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
        
        data_for_plot = []
        labels = []
        
        for graph_type, stats in epiformer_stats.items():
            for level in levels:
                if level in stats and stats[level]['node_counts']:
                    data_for_plot.append(stats[level]['node_counts'])
                    labels.append(f'{graph_type}\n{level}')
        
        if data_for_plot:
            bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True)
            
            # Color boxes by level
            level_colors = {}
            for i, label in enumerate(labels):
                level = label.split('\n')[1]
                level_colors[i] = colors[levels.index(level)]
            
            for patch, i in zip(bp['boxes'], range(len(bp['boxes']))):
                patch.set_facecolor(level_colors[i])
                patch.set_alpha(0.7)
        
        ax.set_ylabel('Node Count', fontsize=12)
        ax.set_title('Node Counts by Level', fontsize=14, fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        
    def _plot_epiformer_edge_counts(self, fig, gs, epiformer_stats):
        """Plot edge counts for each epiformer level"""
        ax = fig.add_subplot(gs)
        
        levels = ['atom', 'residue', 'edge']
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
        
        data_for_plot = []
        labels = []
        
        for graph_type, stats in epiformer_stats.items():
            for level in levels:
                if level in stats and stats[level]['edge_counts']:
                    data_for_plot.append(stats[level]['edge_counts'])
                    labels.append(f'{graph_type}\n{level}')
        
        if data_for_plot:
            bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True)
            
            # Color boxes by level
            level_colors = {}
            for i, label in enumerate(labels):
                level = label.split('\n')[1]
                level_colors[i] = colors[levels.index(level)]
            
            for patch, i in zip(bp['boxes'], range(len(bp['boxes']))):
                patch.set_facecolor(level_colors[i])
                patch.set_alpha(0.7)
        
        ax.set_ylabel('Edge Count', fontsize=12)
        ax.set_title('Edge Counts by Level', fontsize=14, fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        
    def _plot_epiformer_degrees(self, fig, gs, epiformer_stats):
        """Plot degree distributions for each epiformer level"""
        ax = fig.add_subplot(gs)
        
        levels = ['atom', 'residue', 'edge']
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
        
        for level, color in zip(levels, colors):
            all_degrees = []
            for graph_type, stats in epiformer_stats.items():
                if level in stats and stats[level]['degrees']:
                    all_degrees.extend(stats[level]['degrees'])
            
            if all_degrees:
                hist, bins = np.histogram(all_degrees, bins=30, density=True)
                bin_centers = (bins[:-1] + bins[1:]) / 2
                ax.plot(bin_centers, hist, label=level.capitalize(), 
                       color=color, linewidth=2, alpha=0.8)
        
        ax.set_xlabel('Node Degree', fontsize=12)
        ax.set_ylabel('Probability Density', fontsize=12)
        ax.set_title('Degree Distributions by Level', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
    def _plot_level_comparison(self, fig, gs, epiformer_stats):
        """Compare statistics across epiformer levels"""
        ax = fig.add_subplot(gs)
        
        levels = ['atom', 'residue', 'edge']
        graph_types = list(epiformer_stats.keys())
        
        metrics = ['avg_nodes', 'avg_edges', 'avg_degree']
        x = np.arange(len(levels))
        width = 0.25
        
        for i, metric in enumerate(metrics):
            values = []
            for level in levels:
                level_values = []
                for graph_type in graph_types:
                    if level in epiformer_stats[graph_type]:
                        stats = epiformer_stats[graph_type][level]
                        if metric == 'avg_nodes' and stats['node_counts']:
                            level_values.append(np.mean(stats['node_counts']))
                        elif metric == 'avg_edges' and stats['edge_counts']:
                            level_values.append(np.mean(stats['edge_counts']))
                        elif metric == 'avg_degree' and stats['degrees']:
                            level_values.append(np.mean(stats['degrees']))
                
                values.append(np.mean(level_values) if level_values else 0)
            
            ax.bar(x + i * width, values, width, label=metric.replace('_', ' ').title(), alpha=0.8)
        
        ax.set_xlabel('Epiformer Level', fontsize=12)
        ax.set_ylabel('Average Value', fontsize=12)
        ax.set_title('Statistics Comparison Across Levels', fontsize=14, fontweight='bold')
        ax.set_xticks(x + width)
        ax.set_xticklabels([level.capitalize() for level in levels])
        ax.legend()
        ax.grid(True, alpha=0.3)

    def compare_edge_types(self, datasets: Dict[str, List]) -> Dict:
        """Compare multi-relational edge patterns"""
        print("\nComparing edge types...")
        edge_type_stats = defaultdict(dict)
        
        for graph_type, dataset in datasets.items():
            if graph_type in ['base']:  # Skip base graphs
                continue
                
            print(f"  Analyzing edge types in {graph_type}...")
            
            relation_counts = defaultdict(list)
            relation_densities = defaultdict(list)
            
            for complex_data in dataset:
                try:
                    ag_nodes = complex_data['ag_res'].x.shape[0]
                    ab_nodes = complex_data['ab_res'].x.shape[0]
                    
                    # Count edges by relation type
                    for key in complex_data.edge_types:
                        key_str = str(key)
                        
                        # Extract relation type info
                        if 'r0' in key_str or 'r1' in key_str or 'r2' in key_str or 'r3' in key_str:
                            relation = key_str.split(',')[1].strip().replace("'", "")
                        elif 'connects' in key_str:
                            relation = 'connects'
                        else:
                            continue
                            
                        try:
                            edge_index = complex_data[key].edge_index
                            if edge_index is not None:
                                n_edges = edge_index.shape[1]
                                relation_counts[relation].append(n_edges)
                                
                                # Compute density
                                chain = 'ag' if 'ag_res' in key_str else 'ab'
                                n_nodes = ag_nodes if chain == 'ag' else ab_nodes
                                max_edges = n_nodes * (n_nodes - 1) / 2
                                density = n_edges / max_edges if max_edges > 0 else 0
                                relation_densities[relation].append(density)
                        except:
                            continue
                            
                except Exception as e:
                    continue
            
            edge_type_stats[graph_type] = {
                'relation_counts': dict(relation_counts),
                'relation_densities': dict(relation_densities)
            }
            
        return dict(edge_type_stats)
    
    def create_comprehensive_plots(self, basic_stats: Dict, epitope_stats: Dict, edge_stats: Dict, epiformer_stats: Dict = None):
        """Generate comprehensive visualization plots"""
        print("\nGenerating plots...")
        
        # Set up the plotting
        plt.rcParams['figure.figsize'] = (20, 15)
        
        # Create main comparison figure
        fig = plt.figure(figsize=(20, 15))
        gs = fig.add_gridspec(4, 4, hspace=0.6, wspace=0.3)
        
        # 1. Degree distributions comparison
        self._plot_degree_distributions(fig, gs[0, :2], basic_stats)
        
        # 2. Node/Edge count distributions
        self._plot_count_distributions(fig, gs[0, 2:], basic_stats)
        
        # 3. Epitope vs Non-epitope degree comparison
        self._plot_epitope_degree_comparison(fig, gs[1, :2], epitope_stats)
        
        # 4. Edge density comparison
        self._plot_edge_density_comparison(fig, gs[1, 2:], basic_stats)
        
        # 5. Epitope ratio distributions
        self._plot_epitope_ratios(fig, gs[2, :2], basic_stats)
        
        # 6. Edge type analysis (if available)
        if edge_stats:
            self._plot_edge_type_analysis(fig, gs[2, 2:], edge_stats)
        
        # 7. Save summary statistics as CSV instead of table
        self._save_summary_csv(basic_stats, epitope_stats)
        
        plt.suptitle('Comprehensive Graph Analysis for Epitope Prediction', fontsize=16, fontweight='bold')
        plt.savefig(self.output_dir / 'comprehensive_analysis.png', dpi=300, bbox_inches='tight')
        plt.savefig(self.output_dir / 'comprehensive_analysis.pdf', bbox_inches='tight')
        print(f"  Saved comprehensive analysis to {self.output_dir}")
        
        # Create epiformer analysis plots if data available
        if epiformer_stats:
            self.create_epiformer_plots(epiformer_stats)
            self._save_epiformer_csv(epiformer_stats)
        
    def _plot_degree_distributions(self, fig, gs, basic_stats):
        """Plot degree distributions like GearNet paper"""
        ax = fig.add_subplot(gs)
        
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
        
        for i, (graph_type, stats) in enumerate(basic_stats.items()):
            all_degrees = stats['degrees']['ag'] + stats['degrees']['ab']
            
            if all_degrees:
                # Plot histogram
                hist, bins = np.histogram(all_degrees, bins=50, density=True)
                bin_centers = (bins[:-1] + bins[1:]) / 2
                ax.plot(bin_centers, hist, label=f'{graph_type.capitalize()}', 
                       color=colors[i % len(colors)], linewidth=2, alpha=0.8)
        
        ax.set_xlabel('Node Degree', fontsize=12)
        ax.set_ylabel('Probability Density', fontsize=12)
        ax.set_title('Degree Distributions Comparison', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
    def _plot_count_distributions(self, fig, gs, basic_stats):
        """Plot node and edge count distributions"""
        ax = fig.add_subplot(gs)
        
        graph_types = list(basic_stats.keys())
        ag_nodes = [np.mean(basic_stats[gt]['node_counts']['ag']) for gt in graph_types]
        ab_nodes = [np.mean(basic_stats[gt]['node_counts']['ab']) for gt in graph_types]
        ag_edges = [np.mean(basic_stats[gt]['edge_counts']['ag']) for gt in graph_types]
        ab_edges = [np.mean(basic_stats[gt]['edge_counts']['ab']) for gt in graph_types]
        
        x = np.arange(len(graph_types))
        width = 0.2
        
        ax.bar(x - 1.5*width, ag_nodes, width, label='AG Nodes', alpha=0.8)
        ax.bar(x - 0.5*width, ab_nodes, width, label='AB Nodes', alpha=0.8)
        ax.bar(x + 0.5*width, [e/10 for e in ag_edges], width, label='AG Edges (÷10)', alpha=0.8)
        ax.bar(x + 1.5*width, [e/10 for e in ab_edges], width, label='AB Edges (÷10)', alpha=0.8)
        
        ax.set_xlabel('Graph Type', fontsize=12)
        ax.set_ylabel('Average Count', fontsize=12)
        ax.set_title('Node and Edge Counts', fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([gt.capitalize() for gt in graph_types])
        ax.legend()
        ax.grid(True, alpha=0.3)
        
    def _plot_epitope_degree_comparison(self, fig, gs, epitope_stats):
        """Compare epitope vs non-epitope node degrees"""
        ax = fig.add_subplot(gs)
        
        data_for_plot = []
        labels = []
        
        for graph_type, stats in epitope_stats.items():
            if stats['epitope_degrees'] and stats['non_epitope_degrees']:
                data_for_plot.extend([
                    stats['epitope_degrees'],
                    stats['non_epitope_degrees']
                ])
                labels.extend([f'{graph_type.capitalize()}\nEpitope', f'{graph_type.capitalize()}\nNon-epitope'])
        
        if data_for_plot:
            bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True)
            
            # Color epitope boxes differently
            colors = ['#FF6B6B', '#FFB6B6'] * (len(data_for_plot) // 2)
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        
        ax.set_ylabel('Node Degree', fontsize=12)
        ax.set_title('Epitope vs Non-epitope Degrees', fontsize=14, fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        
    def _plot_edge_density_comparison(self, fig, gs, basic_stats):
        """Plot edge density comparison"""
        ax = fig.add_subplot(gs)
        
        graph_types = list(basic_stats.keys())
        ag_densities = [np.mean(basic_stats[gt]['edge_densities']['ag']) for gt in graph_types]
        ab_densities = [np.mean(basic_stats[gt]['edge_densities']['ab']) for gt in graph_types]
        
        x = np.arange(len(graph_types))
        width = 0.35
        
        ax.bar(x - width/2, ag_densities, width, label='Antigen', alpha=0.8)
        ax.bar(x + width/2, ab_densities, width, label='Antibody', alpha=0.8)
        
        ax.set_xlabel('Graph Type', fontsize=12)
        ax.set_ylabel('Edge Density', fontsize=12)
        ax.set_title('Edge Density Comparison', fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([gt.capitalize() for gt in graph_types])
        ax.legend()
        ax.grid(True, alpha=0.3)
        
    def _plot_epitope_ratios(self, fig, gs, basic_stats):
        """Plot epitope ratio distributions"""
        ax = fig.add_subplot(gs)
        
        for graph_type, stats in basic_stats.items():
            epitope_ratios = stats['epitope_ratios']
            if epitope_ratios:
                ax.hist(epitope_ratios, bins=20, alpha=0.7, label=graph_type.capitalize(), density=True)
        
        ax.set_xlabel('Epitope Ratio', fontsize=12)
        ax.set_ylabel('Density', fontsize=12)
        ax.set_title('Epitope Ratio Distributions', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
    def _plot_edge_type_analysis(self, fig, gs, edge_stats):
        """Plot edge type analysis for multi-relational graphs"""
        ax = fig.add_subplot(gs)
        
        all_relations = set()
        for stats in edge_stats.values():
            all_relations.update(stats['relation_counts'].keys())
        
        all_relations = sorted(list(all_relations))
        
        if all_relations:
            x = np.arange(len(all_relations))
            width = 0.35
            
            for i, (graph_type, stats) in enumerate(edge_stats.items()):
                counts = [np.mean(stats['relation_counts'].get(rel, [0])) for rel in all_relations]
                ax.bar(x + i * width, counts, width, label=graph_type.capitalize(), alpha=0.8)
            
            ax.set_xlabel('Relation Type', fontsize=12)
            ax.set_ylabel('Average Edge Count', fontsize=12)
            ax.set_title('Edge Type Analysis', fontsize=14, fontweight='bold')
            ax.set_xticks(x + width/2)
            ax.set_xticklabels(all_relations)
            ax.legend()
            ax.grid(True, alpha=0.3)
        
    def _save_summary_csv(self, basic_stats, epitope_stats):
        """Save summary statistics as CSV file"""
        print("  Saving summary statistics to CSV...")
        
        # Prepare summary data
        summary_data = []
        
        for graph_type, stats in basic_stats.items():
            avg_nodes = np.mean(stats['node_counts']['ag']) + np.mean(stats['node_counts']['ab'])
            avg_edges = np.mean(stats['edge_counts']['ag']) + np.mean(stats['edge_counts']['ab'])
            avg_degree = np.mean(stats['degrees']['ag'] + stats['degrees']['ab']) if stats['degrees']['ag'] else 0
            avg_density = np.mean(stats['edge_densities']['ag'] + stats['edge_densities']['ab'])
            avg_epitope_ratio = np.mean(stats['epitope_ratios']) if stats['epitope_ratios'] else 0
            
            # Get epitope degree info if available
            epitope_avg_degree = 0
            non_epitope_avg_degree = 0
            if graph_type in epitope_stats:
                if epitope_stats[graph_type]['epitope_degrees']:
                    epitope_avg_degree = np.mean(epitope_stats[graph_type]['epitope_degrees'])
                if epitope_stats[graph_type]['non_epitope_degrees']:
                    non_epitope_avg_degree = np.mean(epitope_stats[graph_type]['non_epitope_degrees'])
            
            summary_data.append({
                'Graph_Type': graph_type,
                'Avg_Total_Nodes': avg_nodes,
                'Avg_AG_Nodes': np.mean(stats['node_counts']['ag']),
                'Avg_AB_Nodes': np.mean(stats['node_counts']['ab']),
                'Avg_Total_Edges': avg_edges,
                'Avg_AG_Edges': np.mean(stats['edge_counts']['ag']),
                'Avg_AB_Edges': np.mean(stats['edge_counts']['ab']),
                'Avg_Degree': avg_degree,
                'Edge_Density': avg_density,
                'Epitope_Ratio': avg_epitope_ratio,
                'Epitope_Avg_Degree': epitope_avg_degree,
                'Non_Epitope_Avg_Degree': non_epitope_avg_degree,
                'N_Complexes': stats['n_complexes']
            })
        
        # Save as CSV
        df = pd.DataFrame(summary_data)
        csv_path = self.output_dir / 'graph_analysis_summary.csv'
        df.to_csv(csv_path, index=False)
        print(f"    Saved summary statistics to {csv_path}")
        
    def _save_epiformer_csv(self, epiformer_stats):
        """Save epiformer level statistics as CSV file"""
        if not epiformer_stats:
            return
            
        print("  Saving epiformer statistics to CSV...")
        
        epiformer_data = []
        
        for graph_type, stats in epiformer_stats.items():
            for level in ['atom', 'residue', 'edge']:
                if level in stats:
                    level_stats = stats[level]
                    
                    epiformer_data.append({
                        'Graph_Type': graph_type,
                        'Level': level,
                        'Avg_Nodes': np.mean(level_stats['node_counts']) if level_stats['node_counts'] else 0,
                        'Avg_Edges': np.mean(level_stats['edge_counts']) if level_stats['edge_counts'] else 0,
                        'Avg_Degree': np.mean(level_stats['degrees']) if level_stats['degrees'] else 0,
                        'Max_Nodes': np.max(level_stats['node_counts']) if level_stats['node_counts'] else 0,
                        'Max_Edges': np.max(level_stats['edge_counts']) if level_stats['edge_counts'] else 0,
                        'Max_Degree': np.max(level_stats['degrees']) if level_stats['degrees'] else 0
                    })
        
        # Save as CSV
        df = pd.DataFrame(epiformer_data)
        csv_path = self.output_dir / 'epiformer_analysis_summary.csv'
        df.to_csv(csv_path, index=False)
        print(f"    Saved epiformer statistics to {csv_path}")
    
    def generate_report(self, basic_stats: Dict, epitope_stats: Dict, edge_stats: Dict):
        """Generate comprehensive analysis report"""
        print("\nGenerating analysis report...")
        
        report_path = self.output_dir / 'graph_analysis_report.md'
        
        with open(report_path, 'w') as f:
            f.write("# Comprehensive Graph Analysis Report\n\n")
            f.write("Analysis of protein graph representations for epitope prediction.\n\n")
            
            f.write("## Summary\n\n")
            f.write(f"- **Graph types analyzed**: {list(basic_stats.keys())}\n")
            f.write(f"- **Total complexes**: {sum(stats['n_complexes'] for stats in basic_stats.values())}\n")
            f.write(f"- **Analysis timestamp**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            f.write("## Key Findings\n\n")
            
            # Basic statistics summary
            f.write("### Connectivity Statistics\n\n")
            for graph_type, stats in basic_stats.items():
                avg_degree = np.mean(stats['degrees']['ag'] + stats['degrees']['ab']) if stats['degrees']['ag'] else 0
                avg_density = np.mean(stats['edge_densities']['ag'] + stats['edge_densities']['ab'])
                f.write(f"- **{graph_type.capitalize()}**: Avg degree = {avg_degree:.2f}, Edge density = {avg_density:.4f}\n")
            
            f.write("\n### Epitope Analysis\n\n")
            for graph_type, stats in epitope_stats.items():
                if stats['epitope_degrees'] and stats['non_epitope_degrees']:
                    epitope_deg = np.mean(stats['epitope_degrees'])
                    non_epitope_deg = np.mean(stats['non_epitope_degrees'])
                    f.write(f"- **{graph_type.capitalize()}**: Epitope avg degree = {epitope_deg:.2f}, Non-epitope = {non_epitope_deg:.2f}\n")
            
            f.write(f"\n## Detailed Results\n\n")
            f.write("See generated plots for comprehensive visualizations:\n")
            f.write("- `comprehensive_analysis.png`: Main comparison plots\n")
            f.write("- `comprehensive_analysis.pdf`: High-quality PDF version\n\n")
            
            f.write("## Conclusions\n\n")
            f.write("1. **Graph representation comparison**: [Analysis shows differences in connectivity patterns]\n")
            f.write("2. **Epitope-specific insights**: [Epitope residues show different degree patterns]\n")
            f.write("3. **Recommendations**: [Based on observed patterns for epitope prediction]\n\n")
        
        print(f"  Report saved to {report_path}")


def main():
    parser = argparse.ArgumentParser(description='Comprehensive graph analysis for epitope prediction')
    parser.add_argument('--data_dir', type=str, 
                       default='../../../../data/asep/m3epi/',
                       help='Directory containing graph datasets')
    parser.add_argument('--output_dir', type=str, 
                       default='../../../../results/hgraphepi/m3epi/figures/',
                       help='Output directory for analysis results')
    parser.add_argument('--max_complexes', type=int, default=None,
                       help='Maximum number of complexes to analyze per dataset')
    
    args = parser.parse_args()
    
    # Dataset paths
    proj_dir = os.path.join(os.getcwd(), '../../../../')
    data_dir = os.path.join(proj_dir, "data/asep/m3epi/")
    
    dataset_paths = {
        'base': os.path.join(data_dir, 'base_dataset.pkl'),
        'simple': os.path.join(data_dir, 'simple_epiformer_dataset.pkl'), 
        'gearnet': os.path.join(data_dir, 'gearnet_epiformer_dataset.pkl'),
        'raad': os.path.join(data_dir, 'epiformer_dataset_test.pkl')
    }
    
    # Initialize analyzer
    analyzer = GraphAnalyzer(args.output_dir)
    
    # Load datasets
    datasets = analyzer.load_datasets(dataset_paths)
    
    if not datasets:
        print("No datasets found. Please ensure graph datasets are created first.")
        return
    
    # Limit dataset size if specified
    if args.max_complexes:
        for graph_type in datasets:
            datasets[graph_type] = datasets[graph_type][:args.max_complexes]
            print(f"Limited {graph_type} to {len(datasets[graph_type])} complexes")
    
    # Perform analyses
    basic_stats = analyzer.compute_basic_statistics(datasets)
    epitope_stats = analyzer.analyze_epitope_patterns(datasets)
    edge_stats = analyzer.compare_edge_types(datasets)
    epiformer_stats = analyzer.analyze_epiformer_levels(datasets)
    
    # Generate visualizations
    analyzer.create_comprehensive_plots(basic_stats, epitope_stats, edge_stats, epiformer_stats)
    
    # Generate report
    analyzer.generate_report(basic_stats, epitope_stats, edge_stats)
    
    print(f"\n✅ Analysis complete! Results saved to: {analyzer.output_dir}")
    print("📊 Key outputs:")
    print("   - comprehensive_analysis.png: Main visualization")
    print("   - comprehensive_analysis.pdf: Publication-quality plots")
    print("   - epiformer_analysis.png: Epiformer level analysis")
    print("   - epiformer_analysis.pdf: Epiformer plots PDF")
    print("   - graph_analysis_summary.csv: Summary statistics table")
    print("   - epiformer_analysis_summary.csv: Epiformer statistics table")
    print("   - graph_analysis_report.md: Detailed analysis report")


if __name__ == '__main__':
    main()