#!/usr/bin/env python3
"""
Centrality Analysis Tool for Network-Constrained Truth Recovery
Extracts agent scores, centralities, and their relationships per run.
"""

import random
import argparse
import csv
import networkx as nx
from collections import defaultdict
import time
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path

def generate_dummy_universe(rng, size=20):
    """Generate dummy universe without LLM calls for fast experimentation."""
    pairs = []
    for i in range(1, size + 1):
        fact = f"F{i} relates to G{i}"
        neg = f"NOT({fact})"
        pairs.append((fact, neg))
    return pairs

def sample_true_set(pairs, rng):
    """From each pair, pick exactly one (fact or negation)."""
    return {rng.choice([fact, neg]) for fact, neg in pairs}

def assign_agent_knowledge(G, T, pairs, rng, a_true=5, b_false=2):
    """Assign a_true truths + b_false false facts to each agent."""
    universe = {f for pair in pairs for f in pair}
    false_pool = universe - T
    assignments = {}
    for node in G.nodes():
        truths = rng.sample(list(T), min(a_true, len(T)))
        falses = rng.sample(list(false_pool), min(b_false, len(false_pool)))
        assignments[node] = set(truths + falses)
    return assignments

def resolve_contradictions_simple(agent_store, agent_scores):
    """Simplified contradiction resolution using majority vote."""
    resolved_store = {}
    total_contradictions = 0
    
    for agent_id, facts in agent_store.items():
        resolved_facts = set()
        contradictions_resolved = 0
        
        # Group facts by their base proposition
        fact_groups = defaultdict(lambda: {'fact_votes': 0, 'negation_votes': 0, 'fact_name': None, 'negation_name': None})
        
        for fact in facts:
            if fact.startswith('NOT(') and fact.endswith(')'):
                # This is a negation
                base = fact[4:-1]  # Remove NOT( and )
                fact_groups[base]['negation_votes'] += agent_scores[agent_id][fact]
                fact_groups[base]['negation_name'] = fact
            else:
                # This is a fact
                base = fact
                fact_groups[base]['fact_votes'] += agent_scores[agent_id][fact]
                fact_groups[base]['fact_name'] = fact
        
        # Resolve each group
        for base, votes in fact_groups.items():
            fact_votes = votes['fact_votes']
            negation_votes = votes['negation_votes']
            
            if fact_votes > 0 and negation_votes > 0:
                # CONTRADICTION DETECTED!
                contradictions_resolved += 1
                total_contradictions += 1
                
                if fact_votes > negation_votes:
                    # Fact wins
                    resolved_facts.add(votes['fact_name'])
                elif negation_votes > fact_votes:
                    # Negation wins
                    resolved_facts.add(votes['negation_name'])
                else:
                    # Tie - keep the fact (arbitrary but consistent)
                    resolved_facts.add(votes['fact_name'])
            elif fact_votes > 0:
                # Only fact exists
                resolved_facts.add(votes['fact_name'])
            elif negation_votes > 0:
                # Only negation exists
                resolved_facts.add(votes['negation_name'])
        
        resolved_store[agent_id] = resolved_facts
    
    return resolved_store, total_contradictions

def select_top(score_dict, k):
    """Return top-k facts by score (deterministic tie-breaker)."""
    return [f for f, _ in sorted(score_dict.items(), key=lambda kv: (-kv[1], kv[0]))[:k]]

def has_full_recovery(agent_facts, T):
    """Check if agent has exactly all truths and no contradictions."""
    return set(agent_facts) >= T

def compute_centralities(G):
    """Compute all centrality measures once for efficiency."""
    print("🔍 Computing centralities...")
    start_time = time.time()
    
    centralities = {}
    
    # Degree centrality (fastest)
    centralities['degree'] = nx.degree_centrality(G)
    
    # Betweenness centrality (slowest for large graphs)
    centralities['betweenness'] = nx.betweenness_centrality(G)
    
    # Closeness centrality
    centralities['closeness'] = nx.closeness_centrality(G)
    
    # Eigenvector centrality
    try:
        centralities['eigenvector'] = nx.eigenvector_centrality(G, max_iter=1000)
    except nx.PowerIterationFailedConvergence:
        print("⚠️  Eigenvector centrality failed to converge, using degree as fallback")
        centralities['eigenvector'] = centralities['degree']
    
    elapsed = time.time() - start_time
    print(f"✅ Centralities computed in {elapsed:.1f}s")
    
    return centralities

def analyze_agent_performance(G, T, resolved_store, centralities):
    """Analyze agent performance and centrality relationships."""
    
    # Compute scores for all agents
    agent_scores = []
    for agent_id in G.nodes():
        correct = len(resolved_store[agent_id] & T)
        incorrect = len(resolved_store[agent_id] - T)
        score = correct - incorrect
        
        agent_data = {
            'agent_id': agent_id,
            'score': score,
            'correct': correct,
            'incorrect': incorrect,
            'degree': centralities['degree'][agent_id],
            'betweenness': centralities['betweenness'][agent_id],
            'closeness': centralities['closeness'][agent_id],
            'eigenvector': centralities['eigenvector'][agent_id]
        }
        agent_scores.append(agent_data)
    
    # Sort by score (descending)
    agent_scores.sort(key=lambda x: x['score'], reverse=True)
    
    return agent_scores

def print_top_k_table(agent_scores, k=20):
    """Print top-K agents table."""
    print(f"\n📊 Top-{min(k, len(agent_scores))} Agents by Score:")
    print("=" * 100)
    print(f"{'Agent':>8} {'Score':>6} {'Correct':>8} {'Incorrect':>10} {'Degree':>8} {'Between':>8} {'Close':>8} {'Eigen':>8}")
    print("-" * 100)
    
    for i, agent in enumerate(agent_scores[:k]):
        print(f"{agent['agent_id']:>8} {agent['score']:>6} {agent['correct']:>8} {agent['incorrect']:>10} "
              f"{agent['degree']:>8.4f} {agent['betweenness']:>8.4f} {agent['closeness']:>8.4f} {agent['eigenvector']:>8.4f}")

def print_winner_line(winner_agent, convergence_round):
    """Print winner information."""
    if winner_agent:
        print(f"\n🏆 Winner: Agent {winner_agent['agent_id']}, Round {convergence_round}")
        print(f"   Centralities: degree={winner_agent['degree']:.4f}, betweenness={winner_agent['betweenness']:.4f}, "
              f"closeness={winner_agent['closeness']:.4f}, eigenvector={winner_agent['eigenvector']:.4f}")
    else:
        print("\n❌ No winner (no convergence)")

def compute_group_statistics(agent_scores):
    """Compute statistics for top-decile vs others."""
    n_agents = len(agent_scores)
    top_decile_size = max(1, n_agents // 10)
    
    top_decile = agent_scores[:top_decile_size]
    others = agent_scores[top_decile_size:]
    
    centrality_cols = ['degree', 'betweenness', 'closeness', 'eigenvector']
    
    print(f"\n📈 Group Statistics (Top-{top_decile_size} vs Others):")
    print("=" * 80)
    print(f"{'Centrality':>12} {'Top Mean':>10} {'Others Mean':>12} {'Difference':>12} {'Effect Size':>12}")
    print("-" * 80)
    
    group_stats = {}
    
    for cent in centrality_cols:
        top_values = [agent[cent] for agent in top_decile]
        other_values = [agent[cent] for agent in others] if others else [0]
        
        top_mean = np.mean(top_values)
        other_mean = np.mean(other_values)
        difference = top_mean - other_mean
        
        # Effect size (Cohen's d)
        if len(others) > 0:
            pooled_std = np.sqrt(((len(top_values)-1)*np.var(top_values, ddof=1) + 
                                (len(other_values)-1)*np.var(other_values, ddof=1)) / 
                               (len(top_values) + len(other_values) - 2))
            effect_size = difference / pooled_std if pooled_std > 0 else 0
        else:
            effect_size = 0
        
        group_stats[cent] = {
            'top_mean': top_mean,
            'other_mean': other_mean,
            'difference': difference,
            'effect_size': effect_size
        }
        
        print(f"{cent:>12} {top_mean:>10.4f} {other_mean:>12.4f} {difference:>12.4f} {effect_size:>12.2f}")
    
    return group_stats

def compute_rank_correlations(agent_scores):
    """Compute Spearman correlations between score and centralities."""
    from scipy.stats import spearmanr
    
    scores = [agent['score'] for agent in agent_scores]
    centrality_cols = ['degree', 'betweenness', 'closeness', 'eigenvector']
    
    print(f"\n🔗 Spearman Correlations (Score vs Centrality):")
    print("=" * 50)
    print(f"{'Centrality':>12} {'Correlation':>12} {'P-value':>10}")
    print("-" * 50)
    
    correlations = {}
    
    for cent in centrality_cols:
        cent_values = [agent[cent] for agent in agent_scores]
        
        if len(set(scores)) > 1 and len(set(cent_values)) > 1:
            corr, p_value = spearmanr(scores, cent_values)
            correlations[cent] = {'correlation': corr, 'p_value': p_value}
            print(f"{cent:>12} {corr:>12.4f} {p_value:>10.4f}")
        else:
            correlations[cent] = {'correlation': 0, 'p_value': 1.0}
            print(f"{cent:>12} {'N/A':>12} {'N/A':>10}")
    
    return correlations

def run_centrality_experiment(G, T, assignments, max_rounds=30, share_budget=3, centralities=None):
    """Run experiment with centrality analysis."""
    agent_store = {u: set(facts) for u, facts in assignments.items()}
    agent_scores = {u: defaultdict(int) for u in G.nodes()}
    
    convergence_round = None
    winner_id = None
    
    for t in range(1, max_rounds + 1):
        receipts = {u: defaultdict(int) for u in G.nodes()}
        
        # --- Agents share facts based on current knowledge ---
        for u in G.nodes():
            # Build score dict from current facts
            scores = {}
            for f in agent_store[u]:
                scores[f] = 1 + agent_scores[u][f]
            to_send = select_top(scores, share_budget)

            for v in G.neighbors(u):
                for f in to_send:
                    receipts[v][f] += 1

        # --- Update stores with received facts ---
        new_facts_count = 0
        for u in G.nodes():
            for f, cnt in receipts[u].items():
                if f not in agent_store[u]:
                    new_facts_count += 1
                agent_store[u].add(f)
                agent_scores[u][f] += cnt

        # --- Resolve contradictions after updating knowledge ---
        resolved_store, total_contradictions = resolve_contradictions_simple(agent_store, agent_scores)
        
        # Check for agents with full recovery
        for u in G.nodes():
            if has_full_recovery(resolved_store[u], T):
                convergence_round = t
                winner_id = u
                break
        
        if convergence_round:
            break

        # Early termination if no learning
        if new_facts_count == 0 and t > 5:
            break

    # Final contradiction resolution
    if not convergence_round:
        resolved_store, total_contradictions = resolve_contradictions_simple(agent_store, agent_scores)
    
    return resolved_store, convergence_round, winner_id

def save_agent_csv(agent_scores, output_file, metadata):
    """Save all agent data to CSV."""
    fieldnames = ['agent_id', 'score', 'correct', 'incorrect', 'degree', 'betweenness', 
                  'closeness', 'eigenvector', 'seed', 'share_budget', 'a_true', 'b_false', 'tmax']
    
    with open(output_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        
        for agent in agent_scores:
            row = agent.copy()
            row.update(metadata)
            writer.writerow(row)

def main():
    parser = argparse.ArgumentParser(description='Analyze centrality vs performance in truth recovery')
    parser.add_argument("--graph", default="grqc_undirected.edgelist", help="Network file")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--max_rounds", type=int, default=30, help="Maximum rounds")
    parser.add_argument("--share_budget", type=int, default=3, help="Facts shared per neighbor")
    parser.add_argument("--initial_true_facts", type=int, default=5, help="True facts per agent")
    parser.add_argument("--initial_false_facts", type=int, default=2, help="False facts per agent")
    parser.add_argument("--truth_size", type=int, default=20, help="Size of truth set")
    parser.add_argument("--top_k", type=int, default=20, help="Number of top agents to show")
    parser.add_argument("--output_csv", help="Save all agent data to CSV")
    parser.add_argument("--output_dir", default="centrality_results", help="Output directory")
    args = parser.parse_args()

    print(f"🎯 Centrality Analysis: seed={args.seed}, share_budget={args.share_budget}, "
          f"knowledge=({args.initial_true_facts},{args.initial_false_facts})")
    
    rng = random.Random(args.seed)

    # Load graph
    try:
        G = nx.read_edgelist(args.graph, nodetype=int)
        print(f"📊 Loaded graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    except Exception as e:
        print(f"❌ Error loading graph: {e}")
        return

    # Compute centralities once
    centralities = compute_centralities(G)

    # Generate dummy universe
    pairs = generate_dummy_universe(rng, args.truth_size)
    T = sample_true_set(pairs, rng)
    print(f"🎯 Generated truth set with {len(T)} facts")

    # Assign initial knowledge
    assignments = assign_agent_knowledge(G, T, pairs, rng, args.initial_true_facts, args.initial_false_facts)

    # Run experiment
    print(f"\n🚀 Running experiment...")
    start_time = time.time()
    resolved_store, convergence_round, winner_id = run_centrality_experiment(
        G, T, assignments, args.max_rounds, args.share_budget, centralities)
    
    elapsed = time.time() - start_time
    print(f"✅ Experiment completed in {elapsed:.1f}s")
    
    if convergence_round:
        print(f"🎯 Converged in round {convergence_round}")
    else:
        print("❌ No convergence")

    # Analyze agent performance
    agent_scores = analyze_agent_performance(G, T, resolved_store, centralities)
    
    # Find winner in agent_scores
    winner_agent = None
    if winner_id is not None:
        winner_agent = next((agent for agent in agent_scores if agent['agent_id'] == winner_id), None)
    
    # Print results
    print_top_k_table(agent_scores, args.top_k)
    print_winner_line(winner_agent, convergence_round)
    group_stats = compute_group_statistics(agent_scores)
    correlations = compute_rank_correlations(agent_scores)

    # Save to CSV if requested
    if args.output_csv:
        Path(args.output_dir).mkdir(exist_ok=True)
        output_file = Path(args.output_dir) / args.output_csv
        
        metadata = {
            'seed': args.seed,
            'share_budget': args.share_budget,
            'a_true': args.initial_true_facts,
            'b_false': args.initial_false_facts,
            'tmax': args.max_rounds
        }
        
        save_agent_csv(agent_scores, output_file, metadata)
        print(f"\n💾 Agent data saved to {output_file}")

    print(f"\n🎊 Centrality analysis complete!")

if __name__ == "__main__":
    main()
