#!/usr/bin/env python3
"""
Hypergraph-specific evaluation framework for DOSAGE algorithm.
Evaluates DOSAGE using metrics appropriate for hypergraph construction.
"""

import networkx as nx
import numpy as np
from subgraphs.top_k import top_k_overlapping_densest_subgraphs
from utils.edge_reader import read_edges_from_file
from hypergraph.hypergraph import graph_to_hypergraph
import time

def calculate_hypergraph_metrics(hyperedges, G):
    """
    Calculate hypergraph-specific metrics.
    
    Args:
        hyperedges: List of hyperedges (sets of nodes)
        G: Original graph
    
    Returns:
        Dictionary of hypergraph metrics
    """
    metrics = {}
    
    # 1. Coverage: percentage of nodes included in at least one hyperedge
    all_nodes = set()
    for edge in hyperedges:
        all_nodes.update(edge)
    
    coverage = len(all_nodes) / G.number_of_nodes()
    metrics['coverage'] = coverage
    
    # 2. Average hyperedge size
    avg_size = np.mean([len(edge) for edge in hyperedges])
    metrics['avg_hyperedge_size'] = avg_size
    
    # 3. Hyperedge size variance
    size_variance = np.var([len(edge) for edge in hyperedges])
    metrics['size_variance'] = size_variance
    
    # 4. Overlap ratio: average number of hyperedges per node
    node_hyperedge_count = {}
    for edge in hyperedges:
        for node in edge:
            node_hyperedge_count[node] = node_hyperedge_count.get(node, 0) + 1
    
    avg_overlap = np.mean(list(node_hyperedge_count.values())) if node_hyperedge_count else 0
    metrics['avg_overlap'] = avg_overlap
    
    # 5. Density preservation: how well hyperedges preserve graph density
    total_density = 0
    for edge in hyperedges:
        if len(edge) > 1:
            subgraph = G.subgraph(edge)
            edge_density = subgraph.number_of_edges() / subgraph.number_of_nodes()
            total_density += edge_density
    
    avg_density = total_density / len(hyperedges) if hyperedges else 0
    metrics['avg_density'] = avg_density
    
    # 6. Connectivity: percentage of hyperedges that are connected subgraphs
    connected_count = 0
    for edge in hyperedges:
        if len(edge) > 1:
            subgraph = G.subgraph(edge)
            if nx.is_connected(subgraph):
                connected_count += 1
    
    connectivity_ratio = connected_count / len(hyperedges) if hyperedges else 0
    metrics['connectivity_ratio'] = connectivity_ratio
    
    # 7. Modularity of hypergraph structure
    # Calculate modularity using the hypergraph structure
    if len(hyperedges) > 1:
        # Create a graph where nodes are hyperedges and edges represent overlap
        hyperedge_graph = nx.Graph()
        for i, edge1 in enumerate(hyperedges):
            for j, edge2 in enumerate(hyperedges):
                if i != j:
                    overlap = len(edge1 & edge2)
                    if overlap > 0:
                        hyperedge_graph.add_edge(i, j, weight=overlap)
        
        if hyperedge_graph.number_of_edges() > 0:
            modularity = nx.community.modularity(hyperedge_graph, 
                                                [list(hyperedge_graph.nodes())])
            metrics['hypergraph_modularity'] = modularity
        else:
            metrics['hypergraph_modularity'] = 0.0
    else:
        metrics['hypergraph_modularity'] = 0.0
    
    return metrics

def evaluate_dosage_hypergraph(G, lambda_param, k, min_size, max_size):
    """
    Evaluate DOSAGE algorithm using hypergraph-specific metrics.
    
    Args:
        G: NetworkX graph
        lambda_param: Trade-off parameter
        k: Number of subgraphs to find
        min_size: Minimum subgraph size
        max_size: Maximum subgraph size
    
    Returns:
        Dictionary with evaluation results
    """
    start_time = time.time()
    
    # Run DOSAGE
    subgraphs = top_k_overlapping_densest_subgraphs(
        G, k, lambda_param, min_size, max_size, k_hop=2
    )
    
    end_time = time.time()
    runtime = end_time - start_time
    
    # Convert subgraphs to hyperedges
    hyperedges = [set(sg.nodes()) for sg in subgraphs]
    
    # Calculate hypergraph metrics
    metrics = calculate_hypergraph_metrics(hyperedges, G)
    
    # Add runtime and basic info
    metrics['runtime'] = runtime
    metrics['num_hyperedges'] = len(hyperedges)
    metrics['subgraphs'] = subgraphs
    metrics['hyperedges'] = hyperedges
    
    return metrics

def compare_dosage_configurations(G, lambda_values, k_values, min_size, max_size):
    """
    Compare different DOSAGE configurations.
    
    Args:
        G: NetworkX graph
        lambda_values: List of lambda values to test
        k_values: List of k values to test
        min_size: Minimum subgraph size
        max_size: Maximum subgraph size
    
    Returns:
        Dictionary with comparison results
    """
    results = {}
    
    for k in k_values:
        for lambda_param in lambda_values:
            config_name = f"DOSAGE_k={k}_λ={lambda_param}"
            print(f"Testing {config_name}...")
            
            try:
                metrics = evaluate_dosage_hypergraph(G, lambda_param, k, min_size, max_size)
                results[config_name] = metrics
                
                print(f"  Coverage: {metrics['coverage']:.3f}")
                print(f"  Avg hyperedge size: {metrics['avg_hyperedge_size']:.2f}")
                print(f"  Avg overlap: {metrics['avg_overlap']:.2f}")
                print(f"  Connectivity: {metrics['connectivity_ratio']:.3f}")
                print(f"  Runtime: {metrics['runtime']:.3f}s")
                
            except Exception as e:
                print(f"  Error: {e}")
                results[config_name] = {'error': str(e)}
    
    return results

def analyze_karate_club_hypergraph():
    """Analyze DOSAGE on Karate Club using hypergraph metrics."""
    print("HYPERGRAPH EVALUATION: Karate Club Dataset")
    print("=" * 60)
    
    # Load graph
    edges = read_edges_from_file('data/edges-karate.txt')
    G = nx.Graph()
    G.add_edges_from(edges)
    
    print(f"Graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    
    # Test different configurations
    lambda_values = [0.5, 1.0, 2.0, 5.0]
    k_values = [2, 3, 4]
    min_size = 5
    max_size = 25
    
    results = compare_dosage_configurations(G, lambda_values, k_values, min_size, max_size)
    
    # Find best configuration for each metric
    print(f"\n{'='*60}")
    print("BEST CONFIGURATIONS BY METRIC:")
    print(f"{'='*60}")
    
    metrics_to_optimize = ['coverage', 'avg_density', 'connectivity_ratio', 'hypergraph_modularity']
    
    for metric in metrics_to_optimize:
        valid_results = {k: v for k, v in results.items() if 'error' not in v and metric in v}
        if valid_results:
            best_config = max(valid_results.keys(), key=lambda k: valid_results[k][metric])
            best_value = valid_results[best_config][metric]
            print(f"Best {metric}: {best_config} = {best_value:.4f}")
    
    return results

def analyze_simple_communities_hypergraph():
    """Analyze DOSAGE on Simple Communities using hypergraph metrics."""
    print("\n\nHYPERGRAPH EVALUATION: Simple Communities Dataset")
    print("=" * 60)
    
    # Create graph
    G = nx.Graph()
    
    # Community 1: nodes 0-9
    for i in range(10):
        for j in range(i + 1, 10):
            if np.random.random() < 0.6:
                G.add_edge(i, j)
    
    # Community 2: nodes 10-19
    for i in range(10, 20):
        for j in range(i + 1, 20):
            if np.random.random() < 0.6:
                G.add_edge(i, j)
    
    # Community 3: nodes 20-29
    for i in range(20, 30):
        for j in range(i + 1, 30):
            if np.random.random() < 0.6:
                G.add_edge(i, j)
    
    # Add a few inter-community edges
    G.add_edge(5, 15)
    G.add_edge(15, 25)
    G.add_edge(8, 18)
    
    print(f"Graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    
    # Test different configurations
    lambda_values = [0.5, 1.0, 2.0, 5.0]
    k_values = [2, 3, 4]
    min_size = 5
    max_size = 20
    
    results = compare_dosage_configurations(G, lambda_values, k_values, min_size, max_size)
    
    # Find best configuration for each metric
    print(f"\n{'='*60}")
    print("BEST CONFIGURATIONS BY METRIC:")
    print(f"{'='*60}")
    
    metrics_to_optimize = ['coverage', 'avg_density', 'connectivity_ratio', 'hypergraph_modularity']
    
    for metric in metrics_to_optimize:
        valid_results = {k: v for k, v in results.items() if 'error' not in v and metric in v}
        if valid_results:
            best_config = max(valid_results.keys(), key=lambda k: valid_results[k][metric])
            best_value = valid_results[best_config][metric]
            print(f"Best {metric}: {best_config} = {best_value:.4f}")
    
    return results

if __name__ == "__main__":
    np.random.seed(42)
    
    karate_results = analyze_karate_club_hypergraph()
    simple_results = analyze_simple_communities_hypergraph()
    
    print(f"\n{'='*80}")
    print("HYPERGRAPH EVALUATION COMPLETE")
    print(f"{'='*80}")
    print("DOSAGE is now evaluated using hypergraph-specific metrics:")
    print("• Coverage: How many nodes are included in hyperedges")
    print("• Average hyperedge size: Typical size of hyperedges")
    print("• Overlap: How much nodes participate in multiple hyperedges")
    print("• Density: How well hyperedges preserve graph density")
    print("• Connectivity: How many hyperedges are connected subgraphs")
    print("• Modularity: Structure quality of the hypergraph")
