
#!/usr/bin/env python3
"""
Systematic Experiment Runner for Network-Constrained Truth Recovery
Implements the comprehensive experimental design with dummy universe generation.
"""

import random
import argparse
import csv
import networkx as nx
from collections import defaultdict
import time
from datetime import datetime

def generate_dummy_universe(rng, size=20):
    """Generate dummy universe without LLM calls for fast experimentation."""
    pairs = []
    for i in range(1, size + 1):
        fact = f"F{i} relates to G{i}"
        neg = f"NOT({fact})"
        pairs.append((fact, neg))
    return pairs

def sample_true_set(pairs, rng):
    """From each pair, pick exactly one (fact or negation)."""
    return {rng.choice([fact, neg]) for fact, neg in pairs}

def assign_agent_knowledge(G, T, pairs, rng, a_true=5, b_false=2):
    """Assign a_true truths + b_false false facts to each agent."""
    universe = {f for pair in pairs for f in pair}
    false_pool = universe - T
    assignments = {}
    for node in G.nodes():
        truths = rng.sample(list(T), min(a_true, len(T)))
        falses = rng.sample(list(false_pool), min(b_false, len(false_pool)))
        assignments[node] = set(truths + falses)
    return assignments

def resolve_contradictions_simple(agent_store, agent_scores):
    """Simplified contradiction resolution using majority vote."""
    resolved_store = {}
    total_contradictions = 0
    
    for agent_id, facts in agent_store.items():
        resolved_facts = set()
        contradictions_resolved = 0
        
        # Group facts by their base proposition
        fact_groups = defaultdict(lambda: {'fact_votes': 0, 'negation_votes': 0, 'fact_name': None, 'negation_name': None})
        
        for fact in facts:
            if fact.startswith('NOT(') and fact.endswith(')'):
                # This is a negation
                base = fact[4:-1]  # Remove NOT( and )
                fact_groups[base]['negation_votes'] += agent_scores[agent_id][fact]
                fact_groups[base]['negation_name'] = fact
            else:
                # This is a fact
                base = fact
                fact_groups[base]['fact_votes'] += agent_scores[agent_id][fact]
                fact_groups[base]['fact_name'] = fact
        
        # Resolve each group
        for base, votes in fact_groups.items():
            fact_votes = votes['fact_votes']
            negation_votes = votes['negation_votes']
            
            if fact_votes > 0 and negation_votes > 0:
                # CONTRADICTION DETECTED!
                contradictions_resolved += 1
                total_contradictions += 1
                
                if fact_votes > negation_votes:
                    # Fact wins
                    resolved_facts.add(votes['fact_name'])
                elif negation_votes > fact_votes:
                    # Negation wins
                    resolved_facts.add(votes['negation_name'])
                else:
                    # Tie - keep the fact (arbitrary but consistent)
                    resolved_facts.add(votes['fact_name'])
            elif fact_votes > 0:
                # Only fact exists
                resolved_facts.add(votes['fact_name'])
            elif negation_votes > 0:
                # Only negation exists
                resolved_facts.add(votes['negation_name'])
        
        resolved_store[agent_id] = resolved_facts
    
    return resolved_store, total_contradictions

def select_top(score_dict, k):
    """Return top-k facts by score (deterministic tie-breaker)."""
    return [f for f, _ in sorted(score_dict.items(), key=lambda kv: (-kv[1], kv[0]))[:k]]

def has_full_recovery(agent_facts, T):
    """Check if agent has exactly all truths and no contradictions."""
    return set(agent_facts) >= T

def run_experiment(G, T, assignments, max_rounds=30, share_budget=3, verbose=False):
    """
    Run single experiment with given parameters.
    Returns: dict with status, round, winner, correct, incorrect, score
    """
    agent_store = {u: set(facts) for u, facts in assignments.items()}
    agent_scores = {u: defaultdict(int) for u in G.nodes()}
    
    start_time = time.time()
    
    for t in range(1, max_rounds + 1):
        receipts = {u: defaultdict(int) for u in G.nodes()}
        
        # --- Agents share facts based on current knowledge ---
        facts_shared_this_round = 0
        for u in G.nodes():
            # Build score dict from current facts
            scores = {}
            for f in agent_store[u]:
                scores[f] = 1 + agent_scores[u][f]
            to_send = select_top(scores, share_budget)

            for v in G.neighbors(u):
                for f in to_send:
                    receipts[v][f] += 1
                    facts_shared_this_round += 1

        # --- Update stores with received facts ---
        new_facts_count = 0
        for u in G.nodes():
            for f, cnt in receipts[u].items():
                if f not in agent_store[u]:
                    new_facts_count += 1
                agent_store[u].add(f)
                agent_scores[u][f] += cnt

        # --- CRITICAL FIX: Resolve contradictions after updating knowledge ---
        resolved_store, total_contradictions = resolve_contradictions_simple(agent_store, agent_scores)
        
        if verbose and total_contradictions > 0:
            print(f"  🔧 Round {t}: Resolved {total_contradictions} contradictions")

        if verbose:
            print(f"  📤 Round {t}: {facts_shared_this_round} facts shared, {new_facts_count} new facts learned")
        
        # Check for agents with full recovery
        recovery_candidates = []
        for u in G.nodes():
            if has_full_recovery(resolved_store[u], T):
                recovery_candidates.append(u)
        
        if recovery_candidates:
            if verbose:
                print(f"  🎯 Found {len(recovery_candidates)} agents with full recovery!")
            
            # Return first successful agent
            winner = recovery_candidates[0]
            elapsed = time.time() - start_time
            return {
                "status": "converged",
                "round": t,
                "winner": winner,
                "correct": len(resolved_store[winner] & T),
                "incorrect": len(resolved_store[winner] - T),
                "score": len(resolved_store[winner] & T) - len(resolved_store[winner] - T),
                "elapsed_time": elapsed,
                "contradictions_final": total_contradictions
            }

        # Early termination if no learning
        if new_facts_count == 0 and t > 5:
            if verbose:
                print(f"  ⏹️  Early termination: no new facts learned in round {t}")
            break

    # If no one recovers by max_rounds, find best agent
    if verbose:
        print(f"❌ No convergence after {max_rounds} rounds")
    
    # Final contradiction resolution
    resolved_store, total_contradictions = resolve_contradictions_simple(agent_store, agent_scores)
    
    # Find best agent based on resolved knowledge
    best_agent = None
    best_score = -float('inf')
    for u in G.nodes():
        correct = len(resolved_store[u] & T)
        incorrect = len(resolved_store[u] - T)
        score = correct - incorrect
        if score > best_score:
            best_score = score
            best_agent = u
    
    elapsed = time.time() - start_time
    return {
        "status": "no convergence",
        "round": max_rounds,
        "winner": best_agent,
        "correct": len(resolved_store[best_agent] & T) if best_agent else 0,
        "incorrect": len(resolved_store[best_agent] - T) if best_agent else 0,
        "score": best_score,
        "elapsed_time": elapsed,
        "contradictions_final": total_contradictions
    }

def main():
    parser = argparse.ArgumentParser(description='Run systematic truth recovery experiments')
    parser.add_argument("--graph", default="grqc_undirected.edgelist", help="Network file")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--csv", default="experiment_results.csv", help="Output CSV file")
    parser.add_argument("--max_rounds", type=int, default=30, help="Maximum rounds")
    parser.add_argument("--share_budget", type=int, default=3, help="Facts shared per neighbor")
    parser.add_argument("--initial_true_facts", type=int, default=5, help="True facts per agent")
    parser.add_argument("--initial_false_facts", type=int, default=2, help="False facts per agent")
    parser.add_argument("--truth_size", type=int, default=20, help="Size of truth set")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    parser.add_argument("--timeout", type=int, default=600, help="Timeout in seconds (10 min)")
    args = parser.parse_args()

    print(f"🚀 Starting experiment: seed={args.seed}, share_budget={args.share_budget}, "
          f"knowledge=({args.initial_true_facts},{args.initial_false_facts})")
    
    start_time = time.time()
    rng = random.Random(args.seed)

    # Load graph
    try:
        G = nx.read_edgelist(args.graph, nodetype=int)
        if args.verbose:
            print(f"📊 Loaded graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    except Exception as e:
        print(f"❌ Error loading graph: {e}")
        return

    # Generate dummy universe (no LLM calls)
    pairs = generate_dummy_universe(rng, args.truth_size)
    
    # Sample hidden truth set
    T = sample_true_set(pairs, rng)
    if args.verbose:
        print(f"🎯 Generated truth set with {len(T)} facts")

    # Assign initial knowledge
    assignments = assign_agent_knowledge(G, T, pairs, rng, args.initial_true_facts, args.initial_false_facts)

    # Run experiment with timeout
    try:
        result = run_experiment(G, T, assignments, args.max_rounds, args.share_budget, args.verbose)
    except KeyboardInterrupt:
        print("⏹️  Experiment interrupted")
        return
    except Exception as e:
        print(f"❌ Experiment failed: {e}")
        return

    # Add metadata to result
    result.update({
        "seed": args.seed,
        "a_true": args.initial_true_facts,
        "b_false": args.initial_false_facts,
        "share": args.share_budget,
        "tmax": args.max_rounds,
        "n_nodes": G.number_of_nodes(),
        "timestamp": datetime.now().isoformat()
    })

    # Write result to CSV
    fieldnames = ["seed", "status", "round", "winner", "correct", "incorrect", "score", 
                  "a_true", "b_false", "share", "tmax", "n_nodes", "elapsed_time", 
                  "contradictions_final", "timestamp"]
    
    # Check if file exists to write header
    import os
    write_header = not os.path.exists(args.csv)
    
    with open(args.csv, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if write_header:
            writer.writeheader()
        writer.writerow(result)

    # Print summary
    elapsed = time.time() - start_time
    print(f"✅ Experiment complete ({elapsed:.1f}s)")
    print(f"   Status: {result['status']}")
    print(f"   Performance: {result['correct']}/{args.truth_size} correct ({result['correct']/args.truth_size*100:.1f}%)")
    print(f"   Rounds: {result['round']}")
    print(f"   Result written to {args.csv}")

if __name__ == "__main__":
    main()
