import argparse
import json
import os
import re
import sys
from pathlib import Path
from datetime import datetime

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from agent_manager import AgentManager
from levels.utils import convert_to_prompt
from overcooked import World

AVAILABLE_MODELS = [
    "mistral-large-latest",
    "mistral-small-latest", 
    "meta-llama/Llama-3.1-70B-Instruct",
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "command-a-03-2025",
    "command-r7b-12-2024",
    "gpt-4o-v2",
    "gpt-4o-mini",
    "Qwen/Qwen2.5-Coder-32B-Instruct"
]

LEVELS = [f'level_{i}' for i in range(13)]

AGENT_TYPES = [
    'orchestrator', 
    'planner',
]

def main(args):
    model = args.model_name
    
    # Handle multiple executor models
    if args.executor_model_names:
        executor_models = args.executor_model_names.split(',')
        # Ensure we have exactly the right number of models
        if len(executor_models) != args.num_agents:
            raise ValueError(f"Number of executor models ({len(executor_models)}) must match number of agents ({args.num_agents})")
    else:
        executor_models = [args.executor_model_name] * args.num_agents  # Use the same model for all agents
    
    with_feedback = not (args.without_feedback)
    with_notes = not (args.without_notes)
    num_agents = int(args.num_agents)
    level_to_run = args.level
    max_episode = args.max_episode
    max_steps = args.max_steps
    verbose = args.verbose
    agent_type = args.type
    cost_threshold = args.cost_threshold

    alphas = [1.0]
    look_ahead_steps = 20

    if level_to_run == "all": 
        levels = LEVELS
    else:
        levels = [level_to_run]
    
    print("---------- Starting ----------")
    print("Model: ", model)
    print("Executor Models: ", executor_models)
    print("Agent Type: ", agent_type)
    print("Cost Threshold: $", cost_threshold)
    
    for level in levels: 
        if with_feedback is False:
            save_file_name = f'result_{level}_{num_agents}_wo_feedback_{model}.json'
        elif with_notes is False:
            save_file_name = f'result_{level}_{num_agents}_wo_notes_{model}.json'
        else:
            save_file_name = f'result_{level}_{num_agents}_{model}.json'

        if os.path.exists(save_file_name):
            with open(save_file_name, 'r') as f:
                table = json.load(f)
        else:
            table = {}

        for alpha in alphas:
            table[alpha] = []

            if str(alpha) in table.keys():
                continue

            env = World(recipe_filename='experiment2/assets/recipe.json', task_filename='experiment2/assets/tasks_level_final.json',
                        level=level, use_task_lifetime_interval_oracle=True,
                        alpha=alpha, beta=2.5, num_agents=num_agents, override_agent=True)
            
            for eps_id in range(max_episode):
                num_success = 0
                total = 0
                success = 0
                failed = 0

                total_action_histories = []
                total_action_success_histories = []
                total_prompts = []
                plans = []

                step = 0
                obs = env.reset()

                agent_manager = AgentManager(env, model, num_agents, with_feedback, with_notes, look_ahead_steps, agent_type, executor_models)

                action_histories = []
                action_success_histories = []
                prompt_history = []

                while True: 
                    if step >= max_steps:
                        break

                    plan = agent_manager.step(obs, step, verbose=verbose) 

                    print('-'*25)
                    print(convert_to_prompt(obs))
                    print("feedback: ", env.feedback)
                    print("STEP: ", step)
                    print("PLAN: ", plan)
                    print("Current tasks: ", env.task_manager._current_task_list)
                    print("Task lifetime: ", env.task_manager._current_task_lifetime_list)
                    print(f"step info: new_task: {env.task_manager.just_new_task}, failed: {obs.just_failed}, success: {obs.task_just_success}")
                    total_cost = sum(agent_manager.individual_costs.values())
                    print(f"Total Cost: ${total_cost:.4f}")
                    print("Individual Costs:")
                    for agent, cost in agent_manager.individual_costs.items():
                        print(f"  {agent}: ${cost:.4f}")
                    print('-'*25)

                    if plan:
                        obs, done, info = env.step(plan)
                    action_histories.append(plan)
                    
                    action_success_histories.append(env.action_success_history)
                    plans.append(plan)
                    step += 1

                    # Check if we've hit the cost threshold (total cost including planner and executors)
                    total_cost = sum(agent_manager.individual_costs.values())
                    if total_cost >= cost_threshold:
                        print(f"\nReached cost threshold of ${cost_threshold} at step {step}")
                        print(f"Total cost: ${total_cost:.4f}")
                        break
                    
                total += env.success_count + env.failed_count + len(env.task_manager._current_task_list)
                success += env.success_count
                failed += env.failed_count
                total_action_histories.append(action_histories)
                total_action_success_histories.append(env.action_success_history)
                total_prompts.append(prompt_history)

                table[alpha].append({
                    'total': total,
                    'success': success,
                    'failed': failed,
                    'alpha': alpha,
                    'noop_count': env.noop_count,
                    'action_history': total_action_histories,
                    'action_success_history': total_action_success_histories,
                    'prompt_history': total_prompts,
                    'all_tasks': env.task_manager._all_tasks,
                    'all_orders_list': env.task_manager._all_orders_list,
                    'order_times': env.task_manager.order_times,
                    'acomplished_task_list': env.task_manager.accomplished_tasks(),
                    'task_interval': env.task_manager._task_interval,
                    'task_lifetime': env.task_manager._task_lifetime,
                    'max_num_tasks': env.task_manager._max_num_tasks,
                    'dish_completion': env.task_manager._dish_completion,
                    'all_orders_lifetime': env.task_manager._all_orders_lifetime,
                    'stop_step': step,
                    'plans': plan,
                    'predictions': agent_manager.predictions,
                    'price': agent_manager.price,
                    'individual_costs': agent_manager.individual_costs
                })

            # Save Results
            model2write = model.split('/')[-1] if '/' in model else model
            results_dir = Path(args.results_dir, model2write, agent_type, str(num_agents))
            results_dir.mkdir(parents=True, exist_ok=True)
            
            # Add timestamp to filename to avoid overwriting
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            result_file = Path(results_dir, f'result_{level}_{timestamp}.json')
            
            with open(result_file, 'w') as fp:
                    json.dump(table, fp)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="invoking GPT")

    # Add the arguments
    parser.add_argument('-m','--model_name', metavar='model_name', type=str, choices=AVAILABLE_MODELS, help='model to use', default=AVAILABLE_MODELS[2])
    parser.add_argument('-ex','--executor_model_name', metavar='executor_model_name', type=str, choices=AVAILABLE_MODELS, help='model to use for all agents', default=AVAILABLE_MODELS[3])
    parser.add_argument('-exs','--executor_model_names', metavar='executor_model_names', type=str, help='comma-separated list of models for each agent (e.g. "model1,model2")', default=None)
    parser.add_argument('--without_feedback', action='store_true', help='without feedback')
    parser.add_argument('--without_notes', action='store_true', help='without notes')
    parser.add_argument('-n', '--num_agents', metavar='num_agents', type=int, help='number of agents', default=1)
    parser.add_argument('-l', '--level', metavar='level', type=str, help='level of the game', default='level_0')
    parser.add_argument('-e', '--max_episode', metavar='max_episode', type=int, help='max episode', default=1)
    parser.add_argument('-s', '--max_steps', metavar='max_steps', type=int, help='max steps', default=1000)  # Increased default
    parser.add_argument('-rd', '--results_dir', metavar='results_dir', type=str, help='results directory', default='experiment2/results/')
    parser.add_argument('-t', '--type', metavar='type', type=str, help='type of the agent', choices=AGENT_TYPES, default='individual')
    parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
    parser.add_argument('-dr', '--debete_rounds', metavar='debate_rounds', type=int, help='number of debate rounds', default=3)
    parser.add_argument('-ct', '--cost_threshold', metavar='cost_threshold', type=float, help='stop when cost reaches this threshold', default=4.0)
    # Parse the arguments
    args = parser.parse_args()

    total = int(args.without_feedback) + int(args.without_notes)
    assert total <= 1

    main(args)