import json
import os
from typing import Any, Dict, List, Optional, Tuple


def process_level(level_num: int, num_agents: int) -> List[Tuple[str, str, str, str]]:
    """
    Process a single level's log and result files to create granular trajectory data.
    
    Args:
        level_num: The level number to process
        num_agents: Number of agents (1, 2, or 3)
        
    Returns:
        List of tuples containing (planner_output, agent0_act, agent1_act, agent2_act)
    """
    base_path = "/Users/ashwin/localdesktop/Economies-COLM2025/economy-of-minds/experiment2_pass3/experiment2_v2_baseline"
    
    # Load the log file
    log_path = f"{base_path}/logs/gpt-4o-v2/planner/{num_agents}/history_level_{level_num}.json"
    with open(log_path, 'r') as f:
        log_data = json.load(f)
    
    # Extract planner outputs - skip first entry which is for the initial state
    planner_outputs = []
    for i in range(len(log_data['planner'])):
        if log_data['planner'][i][0] == 'assistant':
            planner_outputs.append(log_data['planner'][i][1])
    
    # Find the result file
    result_dir = f"{base_path}/results/gpt-4o-v2/planner/{num_agents}"
    result_files = [f for f in os.listdir(result_dir) if f.startswith(f"result_level_{level_num}_")]
    if not result_files:
        print(f"No result file found for level {level_num}")
        return []
    
    result_path = os.path.join(result_dir, result_files[0])
    with open(result_path, 'r') as f:
        result_data = json.load(f)
    
    # Get action history
    action_history = result_data["1.0"][0]["action_history"][0]
    
    # Get planner costs to determine when the planner was called
    planner_costs = result_data["1.0"][0]["planner_costs"]
    planner_call_indices = [i for i, cost in enumerate(planner_costs) if cost > 0]
    
    # Create the granular trajectory
    trajectory = []
    current_planner_output = planner_outputs[0] if planner_outputs else ""
    planner_idx = 0
    
    for step_idx, actions in enumerate(action_history):
        # Check if this step had a planner call
        if step_idx in planner_call_indices and planner_idx < len(planner_outputs) - 1:
            planner_idx += 1
            current_planner_output = planner_outputs[planner_idx]
        
        # Get agent actions for this step, pad with "none" for missing agents
        agent_actions = actions + ["none"] * (3 - len(actions))
        trajectory.append((current_planner_output, *agent_actions))
    
    return trajectory

def main() -> None:
    """
    Process all levels for 1-agent, 2-agent, and 3-agent runs and create separate granular_trajectory files.
    """
    for num_agents in [1, 2, 3]:
        granular_data = {}
        
        for level_num in range(13):  # 0 to 12
            try:
                level_trajectory = process_level(level_num, num_agents)
                if level_trajectory:  # Only add if we got data
                    granular_data[f'level_{level_num}'] = level_trajectory
                    print(f"Processed level {level_num} ({num_agents} agents): {len(level_trajectory)} steps")
            except Exception as e:
                print(f"Error processing level {level_num} ({num_agents} agents): {str(e)}")
        
        # Save the granular trajectory data
        output_file = f'granular_trajectory_{num_agents}agent.json'
        with open(output_file, 'w') as f:
            json.dump(granular_data, f, indent=2)
        
        print(f"Successfully created {output_file}")

if __name__ == "__main__":
    main() 