#!/usr/bin/env python3
"""
Comprehensive Analysis of Parallel vs Sequential Test-Time Scaling Results
ICLR 2026 Submission Code

This script analyzes the experimental results from all benchmarks and generates
publication-ready figures and statistics.
"""

import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set style for publication-quality plots
try:
    plt.style.use('seaborn-v0_8')
except:
    plt.style.use('seaborn')
sns.set_palette("husl")

def extract_json_metrics(filepath):
    """Extract key metrics from JSON report files"""
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        metrics = {}
        
        # Extract accuracy metrics
        if 'final_accuracy_percent' in data:
            accuracy_data = data['final_accuracy_percent']
            if isinstance(accuracy_data, dict):
                if 'simple_majority' in accuracy_data:
                    metrics['accuracy'] = accuracy_data['simple_majority']
                else:
                    # Take the first available accuracy metric
                    metrics['accuracy'] = list(accuracy_data.values())[0]
            else:
                metrics['accuracy'] = accuracy_data
        
        elif 'final_accuracies_percent' in data:
            accuracy_data = data['final_accuracies_percent']
            if 'simple_majority' in accuracy_data:
                metrics['accuracy'] = float(accuracy_data['simple_majority'].replace('%', ''))
            elif 'entropy_weighted' in accuracy_data:
                metrics['accuracy'] = float(accuracy_data['entropy_weighted'].replace('%', ''))
            else:
                # Take the first available accuracy metric
                first_key = list(accuracy_data.keys())[0]
                metrics['accuracy'] = float(accuracy_data[first_key].replace('%', ''))
        
        # Extract experiment summary
        if 'experiment_summary' in data:
            summary = data['experiment_summary']
            metrics['total_questions'] = summary.get('total_questions_processed', 0)
            metrics['total_tokens'] = summary.get('total_tokens_used', {}).get('total', 0)
            metrics['avg_tokens_per_question'] = summary.get('average_tokens_per_question', 0)
            metrics['avg_time_per_question'] = summary.get('average_time_per_question_sec', 0)
            metrics['total_api_calls'] = summary.get('total_api_calls', 0)
            metrics['model'] = summary.get('model', 'unknown')
            metrics['strategy'] = summary.get('strategy', 'unknown')
            metrics['max_chains'] = summary.get('max_chains', summary.get('max_steps', 0))
            
        return metrics
        
    except Exception as e:
        print(f"Error processing {filepath}: {e}")
        return {}

def parse_filename_info(filename):
    """Extract benchmark, strategy, and chain info from filename"""
    info = {
        'benchmark': 'unknown',
        'strategy': 'unknown',
        'chains': 0,
        'model': 'unknown'
    }
    
    filename_lower = filename.lower()
    
    # Extract benchmark
    if 'aime24' in filename_lower or 'aime_2024' in filename_lower:
        info['benchmark'] = 'AIME 2024'
    elif 'aime25' in filename_lower or 'aime_2025' in filename_lower:
        info['benchmark'] = 'AIME 2025'
    elif 'gpqa' in filename_lower:
        info['benchmark'] = 'GPQA Diamond'
    
    # Extract strategy
    if 'parallel' in filename_lower or 'par_' in filename_lower:
        info['strategy'] = 'Parallel'
    elif 'sequential' in filename_lower or 'seq_' in filename_lower:
        info['strategy'] = 'Sequential'
    
    # Extract number of chains/steps
    import re
    chain_match = re.search(r'(\d+)(?:chain|step)', filename_lower)
    if chain_match:
        info['chains'] = int(chain_match.group(1))
    
    # Extract model info
    for model_name in ['gpt-oss-120b', 'gpt-oss-20b', 'qwen3-30b', 'qwen3-235b', 'kimi-k2']:
        if model_name in filename_lower:
            info['model'] = model_name
            break
    
    return info

def collect_all_results(results_dir="."):
    """Collect all experimental results from JSON files"""
    results = []
    
    # Find all JSON result files
    for filepath in Path(results_dir).rglob("*.json"):
        if any(keyword in str(filepath).lower() for keyword in ['report', 'result', 'aime', 'gpqa']):
            metrics = extract_json_metrics(filepath)
            if metrics:
                file_info = parse_filename_info(filepath.name)
                combined_info = {**metrics, **file_info, 'filepath': str(filepath)}
                results.append(combined_info)
    
    return pd.DataFrame(results)

def create_accuracy_comparison_plot(df):
    """Create accuracy comparison plot across benchmarks"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    benchmarks = ['AIME 2024', 'AIME 2025', 'GPQA Diamond']
    
    for i, benchmark in enumerate(benchmarks):
        benchmark_data = df[df['benchmark'] == benchmark]
        if benchmark_data.empty:
            continue
            
        # Group by strategy and chains
        pivot_data = benchmark_data.pivot_table(
            values='accuracy', 
            index='chains', 
            columns='strategy', 
            aggfunc='mean'
        )
        
        if not pivot_data.empty:
            pivot_data.plot(kind='bar', ax=axes[i], width=0.7)
            axes[i].set_title(f'{benchmark}', fontsize=14, fontweight='bold')
            axes[i].set_xlabel('Number of Chains/Steps', fontsize=12)
            axes[i].set_ylabel('Accuracy (%)', fontsize=12)
            axes[i].legend(title='Strategy', fontsize=10)
            axes[i].grid(True, alpha=0.3)
            axes[i].tick_params(axis='x', rotation=0)
    
    plt.tight_layout()
    plt.savefig('accuracy_comparison_by_benchmark.png', dpi=300, bbox_inches='tight')
    plt.savefig('accuracy_comparison_by_benchmark.pdf', bbox_inches='tight')
    print("Saved: accuracy_comparison_by_benchmark.png/pdf")

def create_token_efficiency_plot(df):
    """Create token efficiency analysis plot"""
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Calculate tokens per correct answer
    df['tokens_per_correct'] = df['avg_tokens_per_question'] / (df['accuracy'] / 100)
    
    # Create scatter plot
    strategies = df['strategy'].unique()
    colors = ['#1f77b4', '#ff7f0e']
    
    for i, strategy in enumerate(strategies):
        strategy_data = df[df['strategy'] == strategy]
        ax.scatter(
            strategy_data['accuracy'], 
            strategy_data['tokens_per_correct'],
            label=strategy,
            alpha=0.7,
            s=100,
            c=colors[i % len(colors)]
        )
    
    ax.set_xlabel('Accuracy (%)', fontsize=12)
    ax.set_ylabel('Tokens per Correct Answer', fontsize=12)
    ax.set_title('Token Efficiency: Accuracy vs Tokens per Correct Answer', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('token_efficiency_analysis.png', dpi=300, bbox_inches='tight')
    plt.savefig('token_efficiency_analysis.pdf', bbox_inches='tight')
    print("Saved: token_efficiency_analysis.png/pdf")

def create_model_comparison_plot(df):
    """Create model comparison plot"""
    if 'model' not in df.columns or df['model'].nunique() < 2:
        print("Insufficient model data for comparison plot")
        return
    
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    
    # Group by model and strategy
    model_strategy_accuracy = df.groupby(['model', 'strategy'])['accuracy'].mean().reset_index()
    
    # Create grouped bar plot
    models = model_strategy_accuracy['model'].unique()
    strategies = model_strategy_accuracy['strategy'].unique()
    
    x = np.arange(len(models))
    width = 0.35
    
    for i, strategy in enumerate(strategies):
        strategy_data = model_strategy_accuracy[model_strategy_accuracy['strategy'] == strategy]
        accuracy_values = [strategy_data[strategy_data['model'] == model]['accuracy'].iloc[0] 
                          if not strategy_data[strategy_data['model'] == model].empty else 0 
                          for model in models]
        
        ax.bar(x + i*width, accuracy_values, width, label=strategy, alpha=0.8)
    
    ax.set_xlabel('Model', fontsize=12)
    ax.set_ylabel('Average Accuracy (%)', fontsize=12)
    ax.set_title('Model Performance Comparison by Strategy', fontsize=14, fontweight='bold')
    ax.set_xticks(x + width/2)
    ax.set_xticklabels(models, rotation=45, ha='right')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig('model_comparison.pdf', bbox_inches='tight')
    print("Saved: model_comparison.png/pdf")

def generate_summary_statistics(df):
    """Generate summary statistics table"""
    print("\\n" + "="*80)
    print("COMPREHENSIVE EXPERIMENTAL RESULTS SUMMARY")
    print("="*80)
    
    # Overall statistics
    print(f"Total Experiments: {len(df)}")
    print(f"Benchmarks: {df['benchmark'].unique()}")
    print(f"Strategies: {df['strategy'].unique()}")
    print(f"Models: {df['model'].unique()}")
    
    print("\\n" + "-"*80)
    print("ACCURACY BY BENCHMARK AND STRATEGY")
    print("-"*80)
    
    # Accuracy by benchmark and strategy
    accuracy_summary = df.groupby(['benchmark', 'strategy'])['accuracy'].agg(['mean', 'std', 'count']).round(2)
    print(accuracy_summary)
    
    print("\\n" + "-"*80)
    print("TOKEN EFFICIENCY BY STRATEGY")
    print("-"*80)
    
    # Token efficiency by strategy
    df['tokens_per_correct'] = df['avg_tokens_per_question'] / (df['accuracy'] / 100)
    efficiency_summary = df.groupby('strategy')['tokens_per_correct'].agg(['mean', 'std']).round(2)
    print(efficiency_summary)
    
    print("\\n" + "-"*80)
    print("SCALING ANALYSIS")
    print("-"*80)
    
    # Scaling analysis
    for benchmark in df['benchmark'].unique():
        benchmark_data = df[df['benchmark'] == benchmark]
        if len(benchmark_data) > 1:
            print(f"\\n{benchmark}:")
            scaling_summary = benchmark_data.groupby(['strategy', 'chains'])['accuracy'].mean().round(2)
            print(scaling_summary)
    
    # Save summary to file
    with open('experimental_results_summary.txt', 'w') as f:
        f.write("COMPREHENSIVE EXPERIMENTAL RESULTS SUMMARY\\n")
        f.write("="*50 + "\\n\\n")
        f.write(f"Total Experiments: {len(df)}\\n")
        f.write(f"Benchmarks: {list(df['benchmark'].unique())}\\n")
        f.write(f"Strategies: {list(df['strategy'].unique())}\\n")
        f.write(f"Models: {list(df['model'].unique())}\\n\\n")
        
        f.write("ACCURACY BY BENCHMARK AND STRATEGY\\n")
        f.write("-"*40 + "\\n")
        f.write(str(accuracy_summary) + "\\n\\n")
        
        f.write("TOKEN EFFICIENCY BY STRATEGY\\n")
        f.write("-"*30 + "\\n")
        f.write(str(efficiency_summary) + "\\n")
    
    print("\\nSaved: experimental_results_summary.txt")

def main():
    """Main analysis function"""
    print("🔍 COMPREHENSIVE ANALYSIS OF TEST-TIME SCALING EXPERIMENTS")
    print("="*60)
    
    # Collect all results
    print("\\n📊 Collecting experimental results...")
    df = collect_all_results()
    
    if df.empty:
        print("❌ No experimental results found. Please ensure JSON result files are in the current directory or subdirectories.")
        return
    
    print(f"✅ Found {len(df)} experimental results")
    
    # Filter out invalid data
    df = df[df['accuracy'] > 0]  # Remove entries with 0 accuracy
    df = df[df['chains'] > 0]    # Remove entries with 0 chains
    
    print(f"📈 Generating analysis plots and statistics...")
    
    # Generate plots
    create_accuracy_comparison_plot(df)
    create_token_efficiency_plot(df)
    create_model_comparison_plot(df)
    
    # Generate summary statistics
    generate_summary_statistics(df)
    
    print("\\n🎉 Analysis complete! Generated files:")
    print("  - accuracy_comparison_by_benchmark.png/pdf")
    print("  - token_efficiency_analysis.png/pdf") 
    print("  - model_comparison.png/pdf")
    print("  - experimental_results_summary.txt")

if __name__ == "__main__":
    main()