import json
import os
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def analyze_results(n_agents, results_subdir):
    # Path to results
    results_path = f"results/gpt-4o-v2/planner/{results_subdir}"
    
    # Create output directory for agent-specific figures
    output_dir = Path('figures') / f'{n_agents}agentruns'
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize lists to store data
    data = []
    
    # Read each result file
    for level in range(13):
        # Find the file for this level
        level_files = [f for f in os.listdir(results_path) if f.startswith(f"result_level_{level}_")]
        if not level_files:
            continue
            
        level_file = level_files[0]
        with open(os.path.join(results_path, level_file), 'r') as f:
            result = json.load(f)
            
        # Get the first (and only) key-value pair
        run_data = list(result.values())[0][0]
        
        # Extract data
        orders_completed = len(run_data['acomplished_task_list'])
        final_cost = run_data['price']
        
        # Get final individual LLM costs
        planner_cost = sum(run_data['planner_costs'])
        executor_costs = [sum(cost[i] for cost in run_data['executor_costs']) for i in range(n_agents)]
        
        data_row = {
            'Level': level,
            'Orders Completed': orders_completed,
            'Total Cost ($)': round(final_cost, 4),
            'Planner Cost (gpt-4o-v2) ($)': round(planner_cost, 4),
        }
        
        # Add executor costs based on number of agents
        if n_agents == 1:
            executor_names = ['Llama-3.1-70B']  # For 1 agent, only Llama
        else:
            executor_names = ['gpt-4o-mini', 'Llama-3.1-70B', 'Qwen2.5-32B-Instruct']  # Keep original order for 2+ agents
        
        for i in range(n_agents):
            data_row[f'Executor {i+1} ({executor_names[i]}) ($)'] = round(executor_costs[i], 4)
            
        data.append(data_row)

    # Create DataFrame
    df = pd.DataFrame(data)
    
    # 1. Create table figure - increased size significantly
    fig_table = plt.figure(figsize=(24, len(data)*0.5 + 2))  # Made wider and taller
    ax_table = fig_table.add_subplot(111)
    
    # Hide axes
    ax_table.set_frame_on(False)
    ax_table.set_xticks([])
    ax_table.set_yticks([])
    
    # Create table
    column_labels = ['Step', 'Orders Completed', 'Total Cost ($)', 
                    'Total Planner (gpt-4o-v2) Cost ($)']
    for i in range(n_agents):
        column_labels.append(f'Total Executor {i+1} ({executor_names[i]}) Cost ($)')
    
    # Create table data
    table_data = []
    for _, row in df.iterrows():
        row_data = [
            f"Level {row['Level']}", 
            str(row['Orders Completed']),
            f"{row['Total Cost ($)']:.4f}",
            f"{row['Planner Cost (gpt-4o-v2) ($)']:.4f}",
        ]
        for i in range(n_agents):
            row_data.append(f"{row[f'Executor {i+1} ({executor_names[i]}) ($)']:.4f}")
        table_data.append(row_data)
    
    # Calculate column widths - made wider for better text fit
    n_cols = len(column_labels)
    col_widths = [0.08]  # Step
    col_widths.extend([0.08])  # Orders Completed
    col_widths.extend([0.12])  # Total Cost
    col_widths.extend([0.18])  # Planner Cost
    col_widths.extend([0.18] * (n_cols - 4))  # Executor columns
    
    # Create table
    table = ax_table.table(
        cellText=table_data,
        colLabels=column_labels,
        cellLoc='center',
        loc='center',
        colWidths=col_widths
    )
    
    # Set cell colors
    for i in range(len(column_labels)):  # Header row
        table[(0, i)].set_facecolor('#e6e6e6')
    for i in range(1, len(table_data) + 1):  # Data rows
        for j in range(len(column_labels)):
            table[(i, j)].set_facecolor('#f2f2f2')
    
    # Adjust table properties - increased font size and scaling
    table.auto_set_font_size(False)
    table.set_fontsize(10)  # Increased font size
    table.scale(1.4, 1.8)  # Increased scaling
    
    # Save table figure in agent-specific directory
    table_output = output_dir / f'level_comparison_table_{n_agents}agents_gpt4ov2.png'
    plt.savefig(table_output, bbox_inches='tight', dpi=300, pad_inches=0.5)
    plt.close()
    
    # 2. Create scatter plot - increased size
    plt.figure(figsize=(16, 12))  # Made figure bigger
    plt.scatter(df['Orders Completed'], df['Total Cost ($)'], alpha=0.6, s=150)  # Increased point size
    
    # Add level labels to each point - increased font size and offset
    for i, row in df.iterrows():
        plt.annotate(f"Level {row['Level']}", 
                    (row['Orders Completed'], row['Total Cost ($)']),
                    xytext=(10, 10), textcoords='offset points',
                    fontsize=12)  # Increased font size
    
    plt.xlabel('Orders Completed', fontsize=14)  # Increased font size
    plt.ylabel('Total Cost ($)', fontsize=14)  # Increased font size
    plt.title(f'Cost vs Performance Across Levels ({n_agents} Agents, GPT-4o-v2)', fontsize=16, pad=20)  # Increased font size
    plt.grid(True, alpha=0.3)
    
    # Add more padding to the axes and adjust limits
    plt.margins(0.2)  # Increased margins
    x_min, x_max = plt.xlim()
    y_min, y_max = plt.ylim()
    plt.xlim(x_min - 0.5, x_max + 0.5)  # Increased padding
    plt.ylim(y_min - 0.05, y_max + 0.05)  # Increased padding
    
    # Save scatter plot in agent-specific directory
    scatter_output = output_dir / f'cost_vs_orders_{n_agents}agents_gpt4ov2.png'
    plt.savefig(scatter_output, bbox_inches='tight', dpi=300, pad_inches=0.5)
    plt.close()
    
    return df

# Create necessary directories
Path('figures').mkdir(exist_ok=True)
for n in [1, 2, 3]:
    (Path('figures') / f'{n}agentruns').mkdir(exist_ok=True)

# Analyze 1, 2, and 3-agent results
df_1agent = analyze_results(1, "1")
print("\nGenerated 1-agent figures in figures/1agentruns/:")
print("1. Table figure: level_comparison_table_1agent_gpt4ov2.png")
print("2. Scatter plot: cost_vs_orders_1agent_gpt4ov2.png")

df_2agents = analyze_results(2, "2")
print("\nGenerated 2-agent figures in figures/2agentruns/:")
print("1. Table figure: level_comparison_table_2agents_gpt4ov2.png")
print("2. Scatter plot: cost_vs_orders_2agents_gpt4ov2.png")

df_3agents = analyze_results(3, "3")
print("\nGenerated 3-agent figures in figures/3agentruns/:")
print("1. Table figure: level_comparison_table_3agents_gpt4ov2.png")
print("2. Scatter plot: cost_vs_orders_3agents_gpt4ov2.png")

# Create comparison plot - increased size significantly
plt.figure(figsize=(24, 18))  # Made figure even bigger
plt.scatter(df_1agent['Orders Completed'], df_1agent['Total Cost ($)'], alpha=0.7, s=250, label='1 Agent (Llama-only)')  # Increased point size and opacity
plt.scatter(df_2agents['Orders Completed'], df_2agents['Total Cost ($)'], alpha=0.7, s=250, label='2 Agents (4o-mini + Llama)')  # Increased point size and opacity
plt.scatter(df_3agents['Orders Completed'], df_3agents['Total Cost ($)'], alpha=0.7, s=250, label='3 Agents (4o-mini + Llama + Qwen)')  # Increased point size and opacity

# Add level labels to each point with more spacing and larger font
for df, prefix, offset in [(df_1agent, '1A-L', (12, 12)), 
                          (df_2agents, '2A-L', (12, -12)), 
                          (df_3agents, '3A-L', (-12, 12))]:  # Different offsets for different agent counts
    for i, row in df.iterrows():
        plt.annotate(f"{prefix}{row['Level']}", 
                    (row['Orders Completed'], row['Total Cost ($)']),
                    xytext=offset, textcoords='offset points',
                    fontsize=14,  # Increased font size
                    bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))  # Added white background to labels

plt.xlabel('Orders Completed', fontsize=16)  # Increased font size
plt.ylabel('Total Cost ($)', fontsize=16)  # Increased font size
plt.title('Cost vs Performance Comparison (1 vs 2 vs 3 Agents, GPT-4o-v2)', fontsize=18, pad=20)  # Increased font size

# Customize grid
plt.grid(True, alpha=0.3, linestyle='--')

# Move legend outside and make it larger
plt.legend(fontsize=14, bbox_to_anchor=(1.05, 1), loc='upper left')

# Add more padding to the axes and adjust limits
plt.margins(0.25)  # Increased margins further
x_min, x_max = plt.xlim()
y_min, y_max = plt.ylim()
plt.xlim(x_min - 1, x_max + 1)  # Increased padding
plt.ylim(y_min - 0.1, y_max + 0.1)  # Increased padding

# Make tick labels larger
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

# Save comparison plot in root figures directory
plt.savefig('figures/cost_vs_orders_comparison_all_gpt4ov2.png', bbox_inches='tight', dpi=300, pad_inches=0.7)
plt.close()

print("\nGenerated comparison figure in figures/:")
print("1. Comparison plot: cost_vs_orders_comparison_all_gpt4ov2.png") 