import json
import os
import time
from typing import Any, Dict, List, Optional, Tuple

from openai import AzureOpenAI

# Set up the Azure OpenAI client

client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT", "https://llm-co-ncus.openai.azure.com/")
)

# Cost tracking
models_cost = {
    "gpt-4o-v2": {
        "input": 2.5,
        "output": 10
    }
}

def load_trajectory_data(file_path: str) -> Dict[str, List[Tuple[str, str, str, str]]]:
    """
    Load the granular trajectory data from the JSON file.
    
    Args:
        file_path: Path to the granular_trajectory.json file
        
    Returns:
        Dictionary containing trajectory data for each level
    """
    with open(file_path, 'r') as f:
        return json.load(f)

def prepare_prompt_for_level(level_name: str, trajectory: List[Tuple[str, str, str, str]], num_agents: int) -> str:
    """
    Prepare a prompt for GPT-4o to analyze agent behaviors for a specific level.
    
    Args:
        level_name: The name of the level (e.g., 'level_0')
        trajectory: List of tuples containing planner output and agent actions
        num_agents: Number of agents in the run
        
    Returns:
        Formatted prompt for GPT-4o
    """
    prompt = f"# Analysis for {level_name} ({num_agents} agents)\n\n"
    prompt += "You are an expert in analyzing multi-agent coordination and behavior in a kitchen simulation environment.\n\n"
    prompt += "Below is a trajectory of planner outputs and agent actions for a kitchen simulation level.\n"
    prompt += "Each entry contains:\n"
    prompt += "1. The planner's instructions\n"
    for i in range(num_agents):
        prompt += f"{i+1}. Agent {i}'s action\n"
    prompt += "\nTrajectory (limited to first 20 and last 10 steps for brevity):\n\n"
    
    # Include first 20 steps
    for i, step_data in enumerate(trajectory[:20]):
        planner_output, *agent_actions = step_data
        prompt += f"Step {i}:\n"
        prompt += f"Planner: {planner_output[:200]}{'...' if len(planner_output) > 200 else ''}\n"
        for agent_idx, action in enumerate(agent_actions[:num_agents]):
            prompt += f"Agent {agent_idx}: {action}\n"
        prompt += "\n"
    
    prompt += "...\n\n"
    
    # Include last 10 steps
    for i, step_data in enumerate(trajectory[-10:]):
        planner_output, *agent_actions = step_data
        prompt += f"Step {len(trajectory) - 10 + i}:\n"
        prompt += f"Planner: {planner_output[:200]}{'...' if len(planner_output) > 200 else ''}\n"
        for agent_idx, action in enumerate(agent_actions[:num_agents]):
            prompt += f"Agent {agent_idx}: {action}\n"
        prompt += "\n"
    
    prompt += "Based on this trajectory, provide a VERY CONCISE analysis. For each agent, write 1-2 sentences on:\n"
    prompt += "1. Behavior/Strategy: How did they respond to the planner's instructions?\n"
    prompt += "2. Ability: How well did they perform their assigned tasks?\n"
    prompt += "3. Performance: How efficient were they?\n\n"
    
    prompt += "Format your response as:\n"
    for i in range(num_agents):
        prompt += f"AGENT {i}: [concise summary of behavior, ability, and performance]\n"
    prompt += "OVERALL: [1-2 sentence team assessment]\n\n"
    
    prompt += "Keep your entire response under 200 words. Be direct and specific."
    
    return prompt

def query_gpt4o(prompt: str, model: str = "gpt-4o-v2", max_tokens: int = 300) -> Tuple[str, float]:
    """
    Query GPT-4o with the given prompt.
    
    Args:
        prompt: The prompt for GPT-4o
        model: The model to use
        max_tokens: Maximum number of tokens in the response
        
    Returns:
        Tuple containing the response and the price of the API call
    """
    chat_history = [{"role": "user", "content": prompt}]
    
    response = client.chat.completions.create(
        model='gpt-4o-v2',
        messages=chat_history,
        temperature=0.2,
        max_tokens=max_tokens
    )
    
    price = (response.usage.prompt_tokens * models_cost[model]['input'] + 
             response.usage.completion_tokens * models_cost[model]['output']) / 1e6
    
    return response.choices[0].message.content, price

def analyze_levels(trajectory_data: Dict[str, List[Tuple[str, str, str, str]]], 
                  num_agents: int,
                  levels_to_analyze: List[str] = None) -> Dict[str, str]:
    """
    Analyze specific levels or all levels in the trajectory data.
    
    Args:
        trajectory_data: Dictionary containing trajectory data for each level
        num_agents: Number of agents in the run
        levels_to_analyze: List of level names to analyze, or None to analyze all
        
    Returns:
        Dictionary containing the analysis for each level
    """
    analyses = {}
    total_cost = 0
    
    if levels_to_analyze is None:
        levels_to_analyze = list(trajectory_data.keys())
    
    for level_name in levels_to_analyze:
        if level_name not in trajectory_data:
            print(f"Level {level_name} not found in trajectory data, skipping...")
            continue
            
        trajectory = trajectory_data[level_name]
        print(f"Analyzing {level_name}...")
        prompt = prepare_prompt_for_level(level_name, trajectory, num_agents)
        
        response, cost = query_gpt4o(prompt)
        total_cost += cost
        
        analyses[level_name] = response
        print(f"Completed analysis for {level_name}. Cost: ${cost:.6f}")
        
        # Sleep to avoid hitting rate limits
        time.sleep(1)
    
    print(f"Total cost: ${total_cost:.6f}")
    return analyses

def save_analyses(analyses: Dict[str, str], output_file: str) -> None:
    """
    Save the analyses to a JSON file.
    
    Args:
        analyses: Dictionary containing the analysis for each level
        output_file: Path to the output file
    """
    with open(output_file, 'w') as f:
        json.dump(analyses, f, indent=2)
    print(f"Analysis saved to {output_file}")

def generate_summary_report(analyses: Dict[str, str], num_agents: int, output_file: str) -> None:
    """
    Generate a concise summary report of all analyses.
    
    Args:
        analyses: Dictionary containing the analysis for each level
        num_agents: Number of agents in the run
        output_file: Path to the output file
    """
    prompt = f"Create a BRIEF summary of {num_agents}-agent performance across multiple kitchen simulation levels.\n\n"
    prompt += "Here are the individual agent analyses for each level:\n\n"
    
    for level_name, analysis in analyses.items():
        prompt += f"{level_name}: {analysis}\n\n"
    
    prompt += "Provide a concise summary (max 250 words) that covers:\n"
    prompt += "1. Overall patterns in agent behavior and performance\n"
    prompt += "2. Key strengths and weaknesses of each agent\n"
    prompt += "3. How agent performance varied across levels\n"
    prompt += "4. Brief recommendations for improvement\n\n"
    prompt += "Format as short paragraphs with clear headings."
    
    print("Generating summary report...")
    response, cost = query_gpt4o(prompt, max_tokens=400)
    print(f"Summary report generated. Cost: ${cost:.6f}")
    
    with open(output_file, 'w') as f:
        f.write(response)
    print(f"Summary report saved to {output_file}")

def main() -> None:
    """
    Main function to analyze agent behavior, ability, and performance for both 1-agent and 2-agent runs.
    """
    for num_agents in [1, 2]:
        # Load trajectory data
        trajectory_file = f'granular_trajectory_{num_agents}agent.json'
        if not os.path.exists(trajectory_file):
            print(f"Trajectory file {trajectory_file} not found, skipping...")
            continue
            
        print(f"\nAnalyzing {num_agents}-agent runs...")
        trajectory_data = load_trajectory_data(trajectory_file)
        
        # Analyze all levels
        analyses = analyze_levels(trajectory_data, num_agents)
        
        # Save the analyses
        analyses_file = f'agent_analyses_{num_agents}agent.json'
        save_analyses(analyses, analyses_file)
        
        # Generate summary report
        report_file = f'summary_report_{num_agents}agent.md'
        generate_summary_report(analyses, num_agents, report_file)
        
        print(f"{num_agents}-agent analysis complete!")

if __name__ == "__main__":
    main() 