import json
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import glob

def load_experiment_data(exp_dir, num_agents):
    """Load and aggregate data from an experiment directory for specific number of agents."""
    base_dir = Path(exp_dir)
    total_completed = 0
    num_levels = 0
    
    # Look in the results/gpt-4o-v2/planner directory for level results
    results_dir = base_dir / "results" / "gpt-4o-v2" / "planner" / str(num_agents)
    if not results_dir.exists():
        print(f"No results directory found at {results_dir}")
        return 0
        
    # Get all result files for all levels
    result_files = glob.glob(str(results_dir / "result_level_*.json"))
    if not result_files:
        print(f"No result files found in {results_dir}")
        return 0
        
    for result_file in result_files:
        try:
            with open(result_file, 'r') as f:
                data = json.load(f)
            # Count completed orders from dish_completion list
            completed = sum(1 for completion in data["1.0"][0]["dish_completion"] if completion)
            total_completed += completed
            num_levels += 1
        except Exception as e:
            print(f"Error processing {result_file}: {str(e)}")
            continue
    
    # Return average completed orders per level
    return total_completed / num_levels if num_levels > 0 else 0

def create_comparison_plot():
    # Load data from all experiments
    experiments = {
        'Baseline': ('experiment2_v2_baseline', 'blue'),
        'Success Rates': ('experiment2_v2_success_rates', 'green'),
        'Historical': ('experiment2_v2_historical', 'red')
    }
    
    plt.figure(figsize=(10, 6))
    
    for exp_name, (exp_dir, color) in experiments.items():
        # Collect data for all agent numbers
        means = []
        for num_agents in [1, 2, 3]:
            mean_orders = load_experiment_data(exp_dir, num_agents)
            means.append(mean_orders)
            print(f"{exp_name} - {num_agents} agents: {mean_orders:.2f} orders per level")
        
        # Plot line for this experiment with specified color
        plt.plot([1, 2, 3], means, marker='o', label=exp_name, linewidth=2, markersize=8, color=color)
    
    plt.title('Average Orders Completed vs Number of Agents', fontsize=14)
    plt.xlabel('Number of Agents', fontsize=12)
    plt.ylabel('Average Orders Completed per Level', fontsize=12)
    plt.legend()
    plt.grid(True)
    plt.xticks([1, 2, 3])
    
    plt.tight_layout()
    plt.savefig('experiment_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

if __name__ == '__main__':
    create_comparison_plot() 