import os
import re
import sys
import json
import argparse
from pathlib import Path
from datetime import datetime

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from agent_manager import AgentManager
from overcooked import World
from levels.utils import convert_to_prompt

AVAILABLE_MODELS = [
    "mistral-large-latest",
    "mistral-small-latest", 
    "meta-llama/Llama-3.1-70B-Instruct",
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "command-a-03-2025",
    "command-r7b-12-2024",
    "gpt-4o-v2",
    "gpt-4o-mini"

]

LEVELS = [f'level_{i}' for i in range(13)]

AGENT_TYPES = [
    'orchestrator', 
    'planner',
    'individual'
    ]

def main(args):
    model = args.model_name
    if args.executor_model_name:
        executor_models = args.executor_model_name.split(',')
    else:
        executor_models = [model] * (int(args.num_agents) - 1)  # Default to same as planner
    with_feedback = not (args.without_feedback)
    with_notes = not (args.without_notes)
    num_agents = int(args.num_agents)
    level_to_run = args.level
    max_episode = args.max_episode
    max_steps = args.max_steps
    verbose = args.verbose
    agent_type = args.type

    alphas = [1.0]
    look_ahead_steps = 20

    if level_to_run == "all": 
        levels = LEVELS
    else:
        levels = [level_to_run]
    
    print("---------- Starting ----------")
    print("Model: ", model)
    print("Executor Models: ", executor_models)
    print("Agent Type: ", agent_type)

    
    for level in levels: 

        if with_feedback is False:
            save_file_name = f'result_{level}_{num_agents}_wo_feedback_{model}.json'
        elif with_notes is False:
            save_file_name = f'result_{level}_{num_agents}_wo_notes_{model}.json'
        else:
            save_file_name = f'result_{level}_{num_agents}_{model}.json'

        if os.path.exists(save_file_name):
            with open(save_file_name, 'r') as f:
                table = json.load(f)
        else:
            table = {}

        for alpha in alphas:
            table[alpha] = []

            if str(alpha) in table.keys():
                continue

            env = World(recipe_filename='./assets/recipe.json', task_filename='./assets/tasks_level_final.json',
                        level=level, use_task_lifetime_interval_oracle=True,
                        alpha=alpha, beta=2.5, num_agents=num_agents, override_agent=True)
            

            for eps_id in range(max_episode):
                num_success = 0
                total = 0
                success = 0
                failed = 0

                total_action_histories = []
                total_action_success_histories = []
                total_prompts = []
                plans = []
                prices = []
                tokens = []
                planner_costs = []  # Track planner costs over time
                executor_costs = []  # Track executor costs over time

                step = 0
                obs = env.reset()

                agent_manager = AgentManager(env, model, num_agents, with_feedback, with_notes, look_ahead_steps, agent_type, executor_models)

                action_histories = []
                action_success_histories = []
                prompt_history = []

                while True: 
                    if step >= max_steps:
                        break

                    # Get current total cost before this step
                    prev_total_cost = agent_manager.price if prices else 0
                    
                    plan = agent_manager.step(obs, step, verbose=verbose) 

                    # Calculate costs for this step
                    current_total_cost = agent_manager.price
                    step_total_cost = current_total_cost - prev_total_cost

                    # Store costs from this step
                    if hasattr(agent_manager, 'step_costs'):
                        planner_costs.append(agent_manager.step_costs['planner'])
                        executor_costs.append(agent_manager.step_costs['executors'])
                    else:
                        planner_costs.append(0)
                        executor_costs.append([0] * (num_agents - 1))

                    # Append price before printing
                    prices.append(current_total_cost)

                    print('-'*25)
                    print("STEP: ", step)
                    print("PLAN: ", plan)
                    print("Current tasks: ", env.task_manager._current_task_list)
                    print("Task lifetime: ", env.task_manager._current_task_lifetime_list)
                    print(f"step info: new_task: {env.task_manager.just_new_task}, failed: {obs.just_failed}, success: {obs.task_just_success}")
                    print("Price: ", current_total_cost)
                    print('-'*25)

                    if plan:
                        obs, done, info = env.step(plan)
                    action_histories.append(plan)
                    
                    action_success_histories.append(env.action_success_history)
                    plans.append(plan)
                    tokens.append(agent_manager.all_tokens)
                    step += 1


 
                # print('-'*25)
                # print("Feedback: ", feedback)
                # print("Suggestions: ", suggestions)
                # print("PROMPT: ", history)
                # print("PLAN: ", plan)
                # print('-'*25)

                total += env.success_count + env.failed_count + len(env.task_manager._current_task_list)
                success += env.success_count
                failed += env.failed_count
                total_action_histories.append(action_histories)
                total_action_success_histories.append(env.action_success_history)
                total_prompts.append(prompt_history)

                table[alpha].append({
                    'total' : total,
                    'success' : success,
                    'failed': failed,
                    'alpha': alpha,
                    'noop_count': env.noop_count,
                    'action_history': total_action_histories,
                    'action_success_history': total_action_success_histories,
                    'prompt_history': total_prompts,
                    'all_tasks': env.task_manager._all_tasks,
                    'all_orders_list': env.task_manager._all_orders_list,
                    'order_times': env.task_manager.order_times,
                    'acomplished_task_list': env.task_manager.accomplished_tasks(),
                    'task_interval': env.task_manager._task_interval,
                    'task_lifetime': env.task_manager._task_lifetime,
                    'max_num_tasks': env.task_manager._max_num_tasks,
                    'dish_completion': env.task_manager._dish_completion,
                    'all_orders_lifetime': env.task_manager._all_orders_lifetime,
                    'stop_step': step,
                    'plans': plan,
                    'predictions': agent_manager.predictions,
                    'price': agent_manager.price,
                    'prices': prices,
                    'tokens': tokens,
                    'planner_costs': planner_costs,  # Add planner costs over time
                    'executor_costs': executor_costs  # Add executor costs over time
                })

            # Save Results
            model2write = model.split('/')[-1] if '/' in model else model
            if agent_type == 'planner':
                # For planner, organize by planner model and number of agents
                results_dir = Path(args.results_dir, model2write, agent_type, str(num_agents))
            else:
                results_dir = Path(args.results_dir, model2write, agent_type, str(num_agents))
            results_dir.mkdir(parents=True, exist_ok=True)
            logs_dir = Path(args.logs_dir, model2write, agent_type, str(num_agents))
            logs_dir.mkdir(parents=True, exist_ok=True)
            
            # Add timestamp to the results
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            table['timestamp'] = timestamp
            
            with open(Path(results_dir, f'result_{level}_{timestamp}.json'), 'w') as fp:
                    json.dump(table, fp)

            agent_manager.store_history(Path(logs_dir, f'history_{level}.json'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="invoking GPT")

    # Add the arguments
    parser.add_argument('-m','--model_name', metavar='model_name', type=str, choices=AVAILABLE_MODELS, help='model to use', default=AVAILABLE_MODELS[2])
    parser.add_argument('-ex','--executor_model_name', metavar='executor_model_name', type=str, help='Comma-separated list of executor models')
    parser.add_argument('--without_feedback', action='store_true', help='without feedback')
    parser.add_argument('--without_notes', action='store_true', help='without notes')
    parser.add_argument('-n', '--num_agents', metavar='num_agents', type=int, help='number of agents', default=2)
    parser.add_argument('-l', '--level', metavar='level', type=str, help='level of the game', default='level_8')
    parser.add_argument('-e', '--max_episode', metavar='max_episode', type=int, help='max episode', default=1)
    parser.add_argument('-s', '--max_steps', metavar='max_steps', type=int, help='max steps', default=1000)
    parser.add_argument('-rd', '--results_dir', metavar='results_dir', type=str, help='results directory', default='results/')
    parser.add_argument('-ld', '--logs_dir', metavar='logs_dir', type=str, help='logs directory', default='logs/')
    parser.add_argument('-t', '--type', metavar='type', type=str, help='type of the agent', choices=AGENT_TYPES, default='orchestrator')
    parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
    parser.add_argument('-dr', '--debete_rounds', metavar='debate_rounds', type=int, help='number of debate rounds', default=3)
    parser.add_argument('-ct', '--cost_threshold', metavar='cost_threshold', type=float, help='stop when total cost reaches this value (in dollars)', default=None)
    # parser.add_argument('-b', '--budget', metavar='budget', type=int, help='budget', default=None)
    # Parse the arguments
    args = parser.parse_args()

    total = int(args.without_feedback) + int(args.without_notes)
    assert total <= 1

    main(args)