#!/usr/bin/env python3
"""
Visual Comparison of DOSAGE vs Other Community Detection Methods
"""

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from subgraphs.top_k import top_k_overlapping_densest_subgraphs
from utils.edge_reader import read_edges_from_file
import time
import community
from sklearn.cluster import SpectralClustering

def louvain_method(G):
    """Louvain method for community detection."""
    try:
        partition = community.best_partition(G)
        communities = {}
        for node, comm_id in partition.items():
            if comm_id not in communities:
                communities[comm_id] = set()
            communities[comm_id].add(node)
        return list(communities.values())
    except:
        return [set(G.nodes())]

def girvan_newman_method(G, k=2):
    """Girvan-Newman algorithm."""
    try:
        communities = list(nx.community.girvan_newman(G))
        if len(communities) >= k:
            return list(communities[k-1])
        else:
            return [set(G.nodes())]
    except:
        return [set(G.nodes())]

def label_propagation_method(G):
    """Label propagation algorithm."""
    try:
        communities = list(nx.community.label_propagation_communities(G))
        return communities
    except:
        return [set(G.nodes())]

def greedy_modularity_method(G):
    """Greedy modularity optimization."""
    try:
        communities = list(nx.community.greedy_modularity_communities(G))
        return communities
    except:
        return [set(G.nodes())]

def dosage_method(G, lambda_param=1.0, k=2):
    """DOSAGE method."""
    try:
        # Use appropriate max_size based on graph size
        max_size = min(50, G.number_of_nodes() // 2)  # Use half the graph size to ensure diversity
        subgraphs = top_k_overlapping_densest_subgraphs(
            G, k, lambda_param, min_subset_size=5, max_subset_size=max_size, k_hop=2
        )
        communities = [set(sg.nodes()) for sg in subgraphs]
        return communities
    except Exception as e:
        print(f"DOSAGE error: {e}")
        return [set(G.nodes())]

def evaluate_communities(predicted_communities, ground_truth):
    """Evaluate community detection quality."""
    predicted_labels = {}
    # For overlapping communities, assign each node to its first appearance
    for i, community in enumerate(predicted_communities):
        for node in community:
            if node not in predicted_labels:  # Only assign if not already assigned
                predicted_labels[node] = i
    
    true_labels = {}
    for comm_id, nodes in ground_truth.items():
        for node in nodes:
            true_labels[node] = comm_id
    
    all_nodes = set(predicted_labels.keys()) | set(true_labels.keys())
    
    pred_array = []
    true_array = []
    
    for node in sorted(all_nodes):
        pred_array.append(predicted_labels.get(node, -1))
        true_array.append(true_labels.get(node, -1))
    
    try:
        ari = adjusted_rand_score(true_array, pred_array)
        nmi = normalized_mutual_info_score(true_array, pred_array)
    except:
        ari = 0.0
        nmi = 0.0
    
    assigned_nodes = len([x for x in pred_array if x != -1])
    coverage = assigned_nodes / len(pred_array) if len(pred_array) > 0 else 0.0
    
    return {
        'ari': ari,
        'nmi': nmi,
        'coverage': coverage,
        'num_communities': len(predicted_communities)
    }

def create_datasets():
    """Create test datasets with ground truth."""
    datasets = {}
    
    # Karate Club
    edges = read_edges_from_file('data/edges-karate.txt')
    G_karate = nx.Graph()
    G_karate.add_edges_from(edges)
    
    community_1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, 18, 20, 22}
    community_2 = {15, 16, 19, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}
    ground_truth_karate = {0: community_1, 1: community_2}
    
    datasets['Karate Club'] = (G_karate, ground_truth_karate)
    
    # Simple Communities
    G_simple = nx.Graph()
    
    # Community 1: nodes 0-9
    for i in range(10):
        for j in range(i + 1, 10):
            if np.random.random() < 0.6:
                G_simple.add_edge(i, j)
    
    # Community 2: nodes 10-19
    for i in range(10, 20):
        for j in range(i + 1, 20):
            if np.random.random() < 0.6:
                G_simple.add_edge(i, j)
    
    # Community 3: nodes 20-29
    for i in range(20, 30):
        for j in range(i + 1, 30):
            if np.random.random() < 0.6:
                G_simple.add_edge(i, j)
    
    # Add a few inter-community edges
    G_simple.add_edge(5, 15)
    G_simple.add_edge(15, 25)
    G_simple.add_edge(8, 18)
    
    ground_truth_simple = {
        0: set(range(10)),
        1: set(range(10, 20)),
        2: set(range(20, 30))
    }
    
    datasets['Simple Communities'] = (G_simple, ground_truth_simple)
    
    return datasets

def run_comparison(dataset_name, G, ground_truth):
    """Run comparison and return results."""
    methods = {
        'Louvain': lambda: louvain_method(G),
        'Girvan-Newman': lambda: girvan_newman_method(G, k=len(ground_truth)),
        'Label Propagation': lambda: label_propagation_method(G),
        'Greedy Modularity': lambda: greedy_modularity_method(G),
    }
    
    # Add DOSAGE with dataset-specific lambda values
    if dataset_name == 'Simple Communities':
        # Simple Communities needs higher lambda for distinct communities
        lambda_values = [0.5, 1.0, 2.0, 5.0, 10.0]
        k_value = 3
    else:
        # Karate Club works well with lower lambda values
        lambda_values = [0.5, 1.0, 2.0, 5.0]
        k_value = len(ground_truth)
    
    for lambda_param in lambda_values:
        methods[f'DOSAGE (λ={lambda_param})'] = lambda: dosage_method(G, lambda_param, k_value)
    
    results = {}
    
    for method_name, method_func in methods.items():
        try:
            start_time = time.time()
            communities = method_func()
            end_time = time.time()
            
            metrics = evaluate_communities(communities, ground_truth)
            
            results[method_name] = {
                'communities': communities,
                'metrics': metrics,
                'runtime': end_time - start_time
            }
            
        except Exception as e:
            print(f"Error in {method_name}: {e}")
            results[method_name] = {
                'communities': [set(G.nodes())],
                'metrics': {'ari': 0.0, 'nmi': 0.0, 'coverage': 1.0},
                'runtime': 0.0
            }
    
    return results

def create_comparison_plots(dataset_name, results):
    """Create comprehensive comparison plots."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f'DOSAGE vs Traditional Methods: {dataset_name}', fontsize=16, fontweight='bold')
    
    # Extract data
    method_names = list(results.keys())
    ari_values = [results[method]['metrics']['ari'] for method in method_names]
    nmi_values = [results[method]['metrics']['nmi'] for method in method_names]
    runtime_values = [results[method]['runtime'] for method in method_names]
    num_communities = [results[method]['metrics']['num_communities'] for method in method_names]
    
    # Color coding
    colors = []
    for method in method_names:
        if 'DOSAGE' in method:
            colors.append('red')
        else:
            colors.append('blue')
    
    # Plot 1: Adjusted Rand Index
    ax1 = axes[0, 0]
    bars1 = ax1.bar(range(len(method_names)), ari_values, color=colors, alpha=0.7)
    ax1.set_title('Community Detection Quality (ARI)', fontweight='bold')
    ax1.set_ylabel('Adjusted Rand Index')
    ax1.set_xticks(range(len(method_names)))
    ax1.set_xticklabels(method_names, rotation=45, ha='right', fontsize=9)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, 1.1)
    
    # Add value labels
    for bar, value in zip(bars1, ari_values):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{value:.3f}', ha='center', va='bottom', fontsize=8, fontweight='bold')
    
    # Plot 2: Normalized Mutual Information
    ax2 = axes[0, 1]
    bars2 = ax2.bar(range(len(method_names)), nmi_values, color=colors, alpha=0.7)
    ax2.set_title('Community Detection Quality (NMI)', fontweight='bold')
    ax2.set_ylabel('Normalized Mutual Information')
    ax2.set_xticks(range(len(method_names)))
    ax2.set_xticklabels(method_names, rotation=45, ha='right', fontsize=9)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1.1)
    
    for bar, value in zip(bars2, nmi_values):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{value:.3f}', ha='center', va='bottom', fontsize=8, fontweight='bold')
    
    # Plot 3: Runtime
    ax3 = axes[0, 2]
    bars3 = ax3.bar(range(len(method_names)), runtime_values, color=colors, alpha=0.7)
    ax3.set_title('Algorithm Runtime', fontweight='bold')
    ax3.set_ylabel('Runtime (seconds)')
    ax3.set_xticks(range(len(method_names)))
    ax3.set_xticklabels(method_names, rotation=45, ha='right', fontsize=9)
    ax3.grid(True, alpha=0.3)
    
    for bar, value in zip(bars3, runtime_values):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, 
                f'{value:.3f}s', ha='center', va='bottom', fontsize=8, fontweight='bold')
    
    # Plot 4: Number of Communities
    ax4 = axes[1, 0]
    bars4 = ax4.bar(range(len(method_names)), num_communities, color=colors, alpha=0.7)
    ax4.set_title('Number of Communities Found', fontweight='bold')
    ax4.set_ylabel('Number of Communities')
    ax4.set_xticks(range(len(method_names)))
    ax4.set_xticklabels(method_names, rotation=45, ha='right', fontsize=9)
    ax4.grid(True, alpha=0.3)
    
    for bar, value in zip(bars4, num_communities):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                f'{value}', ha='center', va='bottom', fontsize=8, fontweight='bold')
    
    # Plot 5: ARI vs Runtime Scatter
    ax5 = axes[1, 1]
    traditional_ari = [ari_values[i] for i, method in enumerate(method_names) if 'DOSAGE' not in method]
    traditional_runtime = [runtime_values[i] for i, method in enumerate(method_names) if 'DOSAGE' not in method]
    dosage_ari = [ari_values[i] for i, method in enumerate(method_names) if 'DOSAGE' in method]
    dosage_runtime = [runtime_values[i] for i, method in enumerate(method_names) if 'DOSAGE' in method]
    
    ax5.scatter(traditional_runtime, traditional_ari, c='blue', s=100, alpha=0.7, label='Traditional Methods')
    ax5.scatter(dosage_runtime, dosage_ari, c='red', s=100, alpha=0.7, label='DOSAGE Methods')
    ax5.set_title('Quality vs Speed Trade-off', fontweight='bold')
    ax5.set_xlabel('Runtime (seconds)')
    ax5.set_ylabel('Adjusted Rand Index')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    # Plot 6: Summary Statistics
    ax6 = axes[1, 2]
    ax6.axis('off')
    
    # Calculate summary statistics
    traditional_methods = [m for m in method_names if 'DOSAGE' not in m]
    dosage_methods = [m for m in method_names if 'DOSAGE' in m]
    
    best_traditional = max(traditional_methods, key=lambda m: results[m]['metrics']['ari'])
    best_dosage = max(dosage_methods, key=lambda m: results[m]['metrics']['ari'])
    
    summary_text = f"""
    SUMMARY STATISTICS
    
    Best Traditional Method:
    • {best_traditional}
    • ARI: {results[best_traditional]['metrics']['ari']:.4f}
    • NMI: {results[best_traditional]['metrics']['nmi']:.4f}
    • Runtime: {results[best_traditional]['runtime']:.4f}s
    
    Best DOSAGE Method:
    • {best_dosage}
    • ARI: {results[best_dosage]['metrics']['ari']:.4f}
    • NMI: {results[best_dosage]['metrics']['nmi']:.4f}
    • Runtime: {results[best_dosage]['runtime']:.4f}s
    
    Performance Gap:
    • ARI Difference: {results[best_traditional]['metrics']['ari'] - results[best_dosage]['metrics']['ari']:.4f}
    • NMI Difference: {results[best_traditional]['metrics']['nmi'] - results[best_dosage]['metrics']['nmi']:.4f}
    """
    
    ax6.text(0.1, 0.9, summary_text, transform=ax6.transAxes, fontsize=10, 
             verticalalignment='top', fontfamily='monospace', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(f'visual_comparison_{dataset_name.lower().replace(" ", "_")}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

def main():
    """Main function to run visual comparison."""
    np.random.seed(42)
    
    print("Running Visual Comparison of DOSAGE vs Traditional Methods")
    print("=" * 60)
    
    datasets = create_datasets()
    
    for dataset_name, (G, ground_truth) in datasets.items():
        print(f"\nProcessing {dataset_name}...")
        print(f"Graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
        print(f"Ground truth: {len(ground_truth)} communities")
        
        results = run_comparison(dataset_name, G, ground_truth)
        create_comparison_plots(dataset_name, results)
        
        print(f"Visual comparison saved as: visual_comparison_{dataset_name.lower().replace(' ', '_')}.png")
    
    print("\n" + "=" * 60)
    print("Visual comparison complete!")
    print("Generated plots show:")
    print("• Community detection quality (ARI & NMI)")
    print("• Algorithm runtime")
    print("• Number of communities found")
    print("• Quality vs speed trade-off")
    print("• Summary statistics")

if __name__ == "__main__":
    main()
