#!/usr/bin/env python3
"""
Batch Centrality Analysis Runner
Runs centrality analysis across multiple configurations and aggregates results.
"""

import subprocess
import time
import argparse
import os
import csv
from datetime import datetime
from pathlib import Path
import pandas as pd
import numpy as np

class CentralityBatchRunner:
    def __init__(self, output_dir="centrality_results"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
        self.batch_log = self.output_dir / "batch_centrality_log.txt"
        self.summary_file = self.output_dir / "centrality_summary.csv"
        
        # Initialize log
        with open(self.batch_log, 'w') as f:
            f.write(f"Batch centrality analysis started at {datetime.now().isoformat()}\n")
            f.write("=" * 60 + "\n")

    def log_message(self, message):
        """Log message to both console and file."""
        timestamp = datetime.now().strftime("%H:%M:%S")
        log_line = f"[{timestamp}] {message}"
        print(log_line)
        with open(self.batch_log, 'a') as f:
            f.write(log_line + "\n")

    def run_single_centrality_analysis(self, seed, share_budget, a_true=5, b_false=2, max_rounds=30):
        """Run centrality analysis for single configuration."""
        csv_filename = f"agents_s{seed}_share{share_budget}_regime{a_true}_{b_false}.csv"
        csv_path = self.output_dir / csv_filename
        
        cmd = [
            "python", "centrality_analyzer.py",
            "--seed", str(seed),
            "--share_budget", str(share_budget),
            "--initial_true_facts", str(a_true),
            "--initial_false_facts", str(b_false),
            "--max_rounds", str(max_rounds),
            "--output_csv", csv_filename,
            "--output_dir", str(self.output_dir)
        ]
        
        start_time = time.time()
        try:
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
            elapsed = time.time() - start_time
            
            if result.returncode == 0:
                # Parse key results from output
                output_lines = result.stdout.strip().split('\n')
                
                # Extract winner info
                winner_id = None
                convergence_round = None
                winner_centralities = {}
                
                # Extract correlations
                correlations = {}
                
                in_correlations = False
                for line in output_lines:
                    if "Winner: Agent" in line:
                        parts = line.split()
                        if len(parts) >= 4:
                            winner_id = parts[3].rstrip(',')
                            convergence_round = parts[5]
                    
                    if "Centralities: " in line:
                        # Parse winner centralities
                        cent_part = line.split("Centralities: ")[1]
                        for cent_pair in cent_part.split(", "):
                            key, value = cent_pair.split("=")
                            winner_centralities[key] = float(value)
                    
                    if "Spearman Correlations" in line:
                        in_correlations = True
                        continue
                    
                    if in_correlations and "🎊" in line:
                        in_correlations = False
                        continue
                    
                    if in_correlations and len(line.strip().split()) >= 3:
                        parts = line.strip().split()
                        if len(parts) >= 3 and parts[0] in ['degree', 'betweenness', 'closeness', 'eigenvector']:
                            try:
                                corr_val = float(parts[1])
                                p_val = float(parts[2]) if parts[2] != 'N/A' else 1.0
                                correlations[parts[0]] = {'correlation': corr_val, 'p_value': p_val}
                            except ValueError:
                                pass
                
                return {
                    "success": True,
                    "elapsed": elapsed,
                    "winner_id": winner_id,
                    "convergence_round": convergence_round,
                    "winner_centralities": winner_centralities,
                    "correlations": correlations,
                    "csv_file": csv_path,
                    "error": None
                }
            else:
                return {
                    "success": False,
                    "elapsed": elapsed,
                    "error": result.stderr
                }
        except subprocess.TimeoutExpired:
            return {
                "success": False,
                "elapsed": 600,
                "error": "Timeout after 10 minutes"
            }
        except Exception as e:
            return {
                "success": False,
                "elapsed": time.time() - start_time,
                "error": str(e)
            }

    def run_batch_analysis(self, share_budgets=None, seeds=None, regimes=None, max_rounds=30):
        """Run batch centrality analysis across configurations."""
        
        if share_budgets is None:
            share_budgets = [3, 5, 7, 10]
        if seeds is None:
            seeds = list(range(42, 47))  # 5 seeds for faster testing
        if regimes is None:
            regimes = [(5, 2)]  # Start with baseline regime
        
        total_runs = len(share_budgets) * len(seeds) * len(regimes)
        completed = 0
        start_time = time.time()
        
        self.log_message(f"🚀 Starting batch centrality analysis:")
        self.log_message(f"   Share budgets: {share_budgets}")
        self.log_message(f"   Seeds: {seeds}")
        self.log_message(f"   Regimes: {regimes}")
        self.log_message(f"   Total runs: {total_runs}")
        
        batch_results = []
        
        for a_true, b_false in regimes:
            regime_name = f"({a_true},{b_false})"
            self.log_message(f"\n🎯 Starting regime {regime_name}...")
            
            for share in share_budgets:
                for seed in seeds:
                    completed += 1
                    elapsed_total = time.time() - start_time
                    eta = (elapsed_total / completed) * (total_runs - completed) if completed > 0 else 0
                    
                    self.log_message(f"🔄 Run {completed}/{total_runs}: "
                                   f"{regime_name} share={share} seed={seed} "
                                   f"(ETA: {eta/60:.1f}min)")
                    
                    result = self.run_single_centrality_analysis(seed, share, a_true, b_false, max_rounds)
                    
                    # Store result with metadata
                    result_record = {
                        'seed': seed,
                        'share_budget': share,
                        'a_true': a_true,
                        'b_false': b_false,
                        'regime': regime_name,
                        'success': result['success'],
                        'elapsed': result['elapsed']
                    }
                    
                    if result['success']:
                        result_record.update({
                            'winner_id': result.get('winner_id'),
                            'convergence_round': result.get('convergence_round'),
                            'csv_file': str(result.get('csv_file')),
                        })
                        
                        # Add winner centralities
                        for cent, value in result.get('winner_centralities', {}).items():
                            result_record[f'winner_{cent}'] = value
                        
                        # Add correlations
                        for cent, corr_data in result.get('correlations', {}).items():
                            result_record[f'corr_{cent}'] = corr_data['correlation']
                            result_record[f'pval_{cent}'] = corr_data['p_value']
                        
                        self.log_message(f"   ✅ Success: winner={result.get('winner_id')}, "
                                       f"round={result.get('convergence_round')}, "
                                       f"{result['elapsed']:.1f}s")
                    else:
                        result_record['error'] = result.get('error')
                        self.log_message(f"   ❌ Failed: {result.get('error')}")
                    
                    batch_results.append(result_record)
        
        # Save batch summary
        self.save_batch_summary(batch_results)
        
        total_elapsed = time.time() - start_time
        success_rate = sum(1 for r in batch_results if r['success']) / len(batch_results)
        
        self.log_message(f"\n✅ Batch analysis complete!")
        self.log_message(f"   Total time: {total_elapsed/60:.1f} minutes")
        self.log_message(f"   Success rate: {success_rate*100:.1f}%")
        self.log_message(f"   Results saved to {self.summary_file}")
        
        return batch_results

    def save_batch_summary(self, batch_results):
        """Save batch results to CSV."""
        if not batch_results:
            return
        
        # Get all possible fieldnames
        fieldnames = set()
        for result in batch_results:
            fieldnames.update(result.keys())
        
        fieldnames = sorted(list(fieldnames))
        
        with open(self.summary_file, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(batch_results)

def main():
    parser = argparse.ArgumentParser(description='Run batch centrality analysis')
    parser.add_argument("--output_dir", default="centrality_results", help="Output directory")
    parser.add_argument("--share_budgets", nargs='+', type=int, default=[3, 5, 7, 10], 
                       help="Share budgets to test")
    parser.add_argument("--seeds", nargs='+', type=int, default=list(range(42, 47)), 
                       help="Seeds to test")
    parser.add_argument("--max_rounds", type=int, default=30, help="Maximum rounds")
    parser.add_argument("--regime", default="5,2", help="Knowledge regime (a_true,b_false)")
    args = parser.parse_args()
    
    # Parse regime
    a_true, b_false = map(int, args.regime.split(','))
    regimes = [(a_true, b_false)]
    
    runner = CentralityBatchRunner(args.output_dir)
    
    batch_results = runner.run_batch_analysis(
        share_budgets=args.share_budgets,
        seeds=args.seeds,
        regimes=regimes,
        max_rounds=args.max_rounds
    )
    
    print(f"\n🎊 Batch centrality analysis complete! Check {args.output_dir}/ for results")

if __name__ == "__main__":
    main()
