import json
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

def analyze_trajectory(result_file):
    # Read the results file
    with open(result_file, 'r') as f:
        data = json.load(f)
    
    # Get the first (and only) alpha value's data
    alpha_data = data['1.0'][0]
    
    # Extract relevant data
    all_orders = alpha_data['all_orders_list']
    prices = alpha_data['prices']  # List of cumulative costs at each step
    planner_costs = alpha_data['planner_costs']  # List of planner costs at each step
    executor_costs = alpha_data['executor_costs']  # List of executor costs at each step
    dish_completion = alpha_data['dish_completion']
    total_steps = alpha_data['stop_step']
    
    # Get level number from the results file
    result_path = Path(result_file)
    level_num = result_path.stem.split('level_')[-1].split('_')[0]  # Extract level number correctly
    num_agents = len(executor_costs[0]) if executor_costs else 0  # Number of executors only
    
    # Create trajectory data
    trajectory_data = []
    completed_so_far = 0
    
    # Sample points to show in table (every 5 steps)
    sample_steps = list(range(0, total_steps, 5))
    if total_steps not in sample_steps:
        sample_steps.append(total_steps)
    
    # Track cumulative costs
    cum_planner_cost = 0
    cum_executor_costs = [0] * len(executor_costs[0]) if executor_costs else []
    
    # Get model names from the data
    planner_model = "gpt-4o-v2"  # Planner always uses gpt-4o-v2
    executor_models = []
    if 'executor_models' in alpha_data:
        executor_models = [model.split('/')[-1] if '/' in model else model for model in alpha_data['executor_models']]
    else:
        # Default model names if not found - make sure this matches the actual number of executors
        if num_agents == 2:
            executor_models = ['gpt-4o-mini', 'Llama-3.1-70B']
        else:
            executor_models = ['gpt-4o-mini', 'Llama-3.1-70B', 'Qwen2.5-Coder-32B-Instruct']
    
    # Ensure we have the right number of executor models
    executor_models = executor_models[:num_agents]
    
    for step in sample_steps:
        if step >= len(prices):
            break
            
        # Count orders in queue at this step
        current_orders = len(all_orders[step]) if step < len(all_orders) else 0
        
        # Update completed orders
        if step < len(dish_completion):
            completed_so_far += sum(1 for completed in dish_completion[max(0, step-5):step] if completed)
        
        # Calculate costs up to this step
        total_cost = prices[step] if step < len(prices) else 0
        
        # Sum planner costs up to this step
        planner_total = sum(planner_costs[:step+1]) if step < len(planner_costs) else 0
        
        # Sum executor costs up to this step
        executor_totals = [0] * num_agents
        if step < len(executor_costs):
            for i in range(num_agents):
                if i < len(executor_costs[0]):  # Safety check for number of executors
                    executor_totals[i] = sum(step_costs[i] for step_costs in executor_costs[:step+1])
        
        row_data = {
            'Step': step,
            'Orders in Queue': current_orders,
            'Orders Completed': completed_so_far,
            'Total Cost ($)': round(total_cost, 4),
            f'Total Planner (gpt-4o-v2) Cost ($)': round(planner_total, 4)
        }
        
        # Add executor costs with model names
        for i, (model, cost) in enumerate(zip(executor_models, executor_totals)):
            row_data[f'Total Executor {i+1} ({model}) Cost ($)'] = round(cost, 4)
            
        trajectory_data.append(row_data)
    
    # Add final results row
    final_row = {
        'Step': 'Final',
        'Orders in Queue': len(all_orders[-1]) if all_orders else 0,
        'Orders Completed': sum(1 for completed in dish_completion if completed),
        'Total Cost ($)': round(prices[-1], 4),
        f'Total Planner (gpt-4o-v2) Cost ($)': round(sum(planner_costs), 4)
    }
    
    # Add final executor costs
    for i, model in enumerate(executor_models):
        if i < len(executor_costs[0]):  # Safety check for number of executors
            final_executor_cost = sum(step_costs[i] for step_costs in executor_costs)
            final_row[f'Total Executor {i+1} ({model}) Cost ($)'] = round(final_executor_cost, 4)
    
    trajectory_data.append(final_row)
    
    # Create DataFrame
    df = pd.DataFrame(trajectory_data)
    
    # Create figure - increased width from 20 to 24
    fig, ax = plt.subplots(figsize=(24, 10))
    ax.axis('tight')
    ax.axis('off')
    
    # Create table with more space for text
    col_widths = [0.07]  # Step column (slightly narrower)
    # Make columns wider, especially for executor costs which have longer names
    col_widths.extend([0.09])  # Orders in Queue
    col_widths.extend([0.09])  # Orders Completed
    col_widths.extend([0.11])  # Total Cost
    col_widths.extend([0.15])  # Planner Cost
    # Wider columns for executors
    for _ in range(len(executor_models)):
        col_widths.extend([0.15])  # Executor columns
    
    table = ax.table(
        cellText=df.values,
        colLabels=df.columns,
        cellLoc='center',
        loc='center',
        colWidths=col_widths
    )
    
    # Format table
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1.2, 1.5)
    
    # Add title
    plt.title(f'Order Completion vs Cost Analysis - Level {level_num} ({num_agents} Agents)', pad=20)
    
    # Generate timestamp for filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_filename = f'figures/cost_analysis_level{level_num}_{num_agents}agents_{timestamp}.png'
    
    # Save figure
    plt.savefig(output_filename, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Analysis saved to {output_filename}")

if __name__ == '__main__':
    # Create figures directory if it doesn't exist
    Path('figures').mkdir(exist_ok=True)

    # Find the most recent result file
    results_dir = Path('results')
    result_files = list(results_dir.rglob('result_level_*.json'))
    if not result_files:
        print("No result files found")
        exit(1)
    
    # Sort by modification time and get the most recent
    latest_file = max(result_files, key=lambda x: x.stat().st_mtime)
    print(f"Analyzing latest file: {latest_file}")
    analyze_trajectory(str(latest_file)) 