import glob
import json
import os
from typing import Any, Dict, List, Optional, Tuple


def extract_planner_plan(history_data: Dict[str, List]) -> List[str]:
    """Extract the latest planner plan from the history data."""
    planner_history = history_data.get("planner", [])
    plans = []
    
    # Get assistant responses (plans)
    for idx, entry in enumerate(planner_history):
        if entry[0] == "assistant" and idx > 2:  # Skip initial examples
            plans.append(entry[1])
    
    return plans

def extract_agent_actions(result_data: Dict[str, Any], num_agents: int) -> List[Tuple[str, ...]]:
    """Extract agent actions from the result data."""
    # Extract the action history from the first experiment (indexed by "1.0")
    if "1.0" in result_data and result_data["1.0"]:
        action_history = result_data["1.0"][0].get("action_history", [[]])
        
        # Format: [(agent0_act1, agent1_act1, ...), (agent0_act2, agent1_act2, ...), ...]
        agent_actions = []
        for step in action_history[0]:
            actions = list(step[:num_agents])  # Take only the actions for the number of agents we want
            # Pad with empty strings if we don't have enough actions
            while len(actions) < num_agents:
                actions.append("")
            agent_actions.append(tuple(actions))
        
        return agent_actions
    return []

def extract_completion_data(result_data: Dict[str, Any]) -> Tuple[List[bool], List[List[str]], List[str]]:
    """Extract dish completion and order data from the result data."""
    if "1.0" in result_data and result_data["1.0"]:
        # Get the dish completion status for each step
        dish_completion = result_data["1.0"][0].get("dish_completion", [])
        
        # Get all orders list - contains orders active at each step
        all_orders_list = result_data["1.0"][0].get("all_orders_list", [])
        
        # Get accomplished task list - all completed tasks by the end
        accomplished_tasks = result_data["1.0"][0].get("acomplished_task_list", [])
        
        return dish_completion, all_orders_list, accomplished_tasks
    return [], [], []

def process_level_data(level: int, num_agents: int) -> List[Tuple[str, str, str, str, List[str], bool]]:
    """Process data for a specific level."""
    # Find the result file with the timestamp
    result_pattern = f"results/gpt-4o-v2/planner/{num_agents}/result_level_{level}_*.json"
    result_files = glob.glob(result_pattern)
    
    if not result_files:
        print(f"Result file not found for level {level} with {num_agents} agents")
        return []
    
    result_file = result_files[0]
    history_file = f"logs/gpt-4o-v2/planner/{num_agents}/history_level_{level}.json"
    
    # Load data
    try:
        with open(result_file, 'r') as f:
            result_data = json.load(f)
        
        with open(history_file, 'r') as f:
            history_data = json.load(f)
    except Exception as e:
        print(f"Error loading data for level {level} with {num_agents} agents: {e}")
        return []
    
    # Extract data
    plans = extract_planner_plan(history_data)
    agent_actions = extract_agent_actions(result_data, num_agents)
    dish_completion, all_orders_list, accomplished_tasks = extract_completion_data(result_data)
    
    # Create history log
    history_log = []
    
    # Track completed orders at each step
    completed_orders_so_far = []
    
    # Match plans with actions and dish completions
    current_plan_idx = 0
    last_plan = plans[0] if plans else ""
    
    for i, actions in enumerate(agent_actions):
        # Update the plan based on steps (assuming new plans come roughly every 3 steps)
        if i > 0 and i % 3 == 0 and current_plan_idx < len(plans) - 1:
            current_plan_idx += 1
            last_plan = plans[current_plan_idx]
        
        # Check if an order was completed at this step
        completed_this_step = dish_completion[i] if i < len(dish_completion) else False
        
        # If this step completed an order, we need to determine which order was completed
        if completed_this_step:
            # Look at active orders before and after this step to determine what was completed
            if i > 0 and i < len(all_orders_list) - 1:
                # Compare orders before and after to see what was completed
                before_orders = set(all_orders_list[i-1])
                after_orders = set(all_orders_list[i])
                
                # Find orders that disappeared (were completed)
                for dish in before_orders:
                    if dish in accomplished_tasks and dish not in after_orders and dish not in completed_orders_so_far:
                        completed_orders_so_far.append(dish)
        
        # Make a copy of the current completed orders to store in this step's record
        current_completed_orders = completed_orders_so_far.copy()
        
        # Format: (latest_planner_plan, agent0_act, agent1_act, agent2_act, completed_orders, completed_this_step)
        entry = [last_plan]
        # Add each agent's action individually
        for j in range(num_agents):
            entry.append(actions[j] if j < len(actions) else "")
        # Add completed orders and completion status
        entry.append(current_completed_orders)
        entry.append(completed_this_step)
        
        history_log.append(tuple(entry))
    
    return history_log

def generate_granular_history(num_agents: int) -> Dict[int, List[Tuple[str, str, str, str, List[str], bool]]]:
    """Generate granular history logs for all levels for a specific number of agents."""
    levels = range(13)  # 0-12 based on the files
    granular_history = {}
    
    for level in levels:
        history_log = process_level_data(level, num_agents)
        if history_log:
            granular_history[level] = history_log
    
    return granular_history

def save_granular_history(granular_history: Dict[int, List[Tuple[str, str, str, str, List[str], bool]]], 
                         num_agents: int):
    """Save the granular history to a JSON file."""
    output_file = f"granular_history_{num_agents}agent.json"
    with open(output_file, 'w') as f:
        json.dump(granular_history, f, indent=2)
    
    print(f"Granular history for {num_agents} agents saved to {output_file}")

def print_sample(granular_history: Dict[int, List[Tuple[str, str, str, str, List[str], bool]]], num_agents: int):
    """Print a sample from level 1 for the given number of agents."""
    if 1 in granular_history and granular_history[1]:
        print(f"\nSample from Level 1 ({num_agents} agents):")
        for i, entry in enumerate(granular_history[1][:5]):
            print(f"Step {i+1}:")
            print(f"  Plan: {entry[0][:100]}...")
            for j in range(num_agents):
                print(f"  Agent {j}: {entry[j+1]}")
            print(f"  Completed Orders So Far: {entry[-2]}")
            print(f"  Order Completed This Step: {entry[-1]}")
            print()
        
        # Also show a completed order step if available
        for i, entry in enumerate(granular_history[1]):
            if entry[-1] and i >= 5:  # Check if completed_this_step is True
                print(f"\nCompleted Order Step (Step {i+1}):")
                print(f"  Plan: {entry[0][:100]}...")
                for j in range(num_agents):
                    print(f"  Agent {j}: {entry[j+1]}")
                print(f"  Completed Orders So Far: {entry[-2]}")
                print(f"  Order Completed This Step: {entry[-1]}")
                break

if __name__ == "__main__":
    # Generate and save histories for 1, 2, and 3 agents
    for num_agents in [1, 2, 3]:
        print(f"\nProcessing {num_agents}-agent data...")
        granular_history = generate_granular_history(num_agents)
        save_granular_history(granular_history, num_agents)
        print_sample(granular_history, num_agents) 