import os
import re
import sys
import json
import argparse
from pathlib import Path

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from agent_manager import AgentManager
from overcooked import World

from cost_success_matrices import create_cost_matrix_seeded, get_agent_success_percentages 

AVAILABLE_MODELS = [
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.1-70B-Instruct",
    "meta-llama/Meta-Llama-3.1-405B-Instruct-FP8",
    "Qwen/Qwen2.5-Coder-32B-Instruct",
    "command-r-plus-08-2024",
    "gpt-4o",
    "mistral-large-latest"
]

LEVELS = [f'level_{i}' for i in range(13)]

AGENT_TYPES = [
    'individual', 
    'debate',
    'orchestrator',
    'planner',
    'random',
    'auction'
    ]

def main(args):
    model = args.model_name
    with_feedback= not (args.without_feedback)
    with_notes=not (args.without_notes)
    num_agents = int(args.num_agents)
    level_to_run = args.level
    max_episode = args.max_episode
    max_steps = args.max_steps
    verbose = args.verbose
    agent_type = args.type
    budget = args.budget
    num_debate_rounds = args.debete_rounds if args.type == 'debate' else None

    if budget != None:
        actions = ["noop", "goto", "put", "get", "activate"]
        cost_matrix = create_cost_matrix_seeded(num_agents, actions)
        success_matrix = get_agent_success_percentages(num_agents, actions, cost_matrix)
    else:
        cost_matrix = None
        success_matrix = None

    alphas = [1.0]
    look_ahead_steps = 20

    if level_to_run == "all": 
        levels = LEVELS
    else:
        levels = [level_to_run]
    
    print("---------- Starting ----------")
    print("Model: ", model)
    print("Agent Type: ", agent_type)
    print("Budget: ", budget)
    # print cost matrix 
    print("\nCost Matrix:")
    print("─" * 62)
    # Header row
    print(f"{'Agent':6} │", end="")
    for action in actions:
        print(f" {action:8} │", end="")
    print("\n" + "─" * 62)
    
    # Data rows
    for i in range(num_agents):
        print(f"{i:6} │", end="")
        for action in actions:
            print(f" {cost_matrix[i][action]:8.2f} │", end="")
        print("\n" + "─" * 62)
    print()

    # Print success rate matrix
    print("\nSuccess Rate Matrix:")
    print("─" * 62)
    # Header row
    print(f"{'Agent':6} │", end="")
    for action in actions:
        print(f" {action:8} │", end="")
    print("\n" + "─" * 62)
    
    # Data rows
    for i in range(num_agents):
        print(f"{i:6} │", end="")
        for action in actions:
            print(f" {success_matrix[i][action]:8.2f} │", end="")
        print("\n" + "─" * 62)
    print()

    
    for level in levels: 

        if with_feedback is False:
            save_file_name = f'result_{level}_{num_agents}_wo_feedback_{model}.json'
        elif with_notes is False:
            save_file_name = f'result_{level}_{num_agents}_wo_notes_{model}.json'
        else:
            save_file_name = f'result_{level}_{num_agents}_{model}.json'

        if os.path.exists(save_file_name):
            with open(save_file_name, 'r') as f:
                table = json.load(f)
        else:
            table = {}

        for alpha in alphas:
            table[alpha] = []

            if str(alpha) in table.keys():
                continue

            env = World(recipe_filename='./assets/recipe.json', task_filename='./assets/tasks_level_final.json',
                        level=level, use_task_lifetime_interval_oracle=True,
                        alpha=alpha, beta=2.5, num_agents=num_agents, override_agent=True, 
                        success_rates=success_matrix)
            

            for eps_id in range(max_episode):
                num_success = 0
                total = 0
                success = 0
                failed = 0

                total_action_histories = []
                total_action_success_histories = []
                total_prompts = []
                plans = []

                step = 0
                obs = env.reset()

                agent_manager = AgentManager(env, model, num_agents, with_feedback, with_notes, look_ahead_steps, agent_type, budget, num_debate_rounds, cost_matrix)

                action_histories = []
                action_success_histories = []
                prompt_history = []

                while True: 
                    print(f"STEP: {step}/{max_steps}")
                    if step >= max_steps:
                        break

                    plan = agent_manager.step(obs, step, verbose=verbose) 

                    if agent_manager.cost > budget:
                        break

                    print('-'*25)
                    print("PLAN: ", plan)
                    print("Completed tasks: ", env.task_manager.accomplished_tasks())
                    print('-'*25)

                    if plan:
                        obs, done, info = env.step(plan)
                    action_histories.append(plan)
                    
                    action_success_histories.append(env.action_success_history)
                    plans.append(plan)
                    step += 1
                    


                # print('-'*25)
                # print("Feedback: ", feedback)
                # print("Suggestions: ", suggestions)
                # print("PROMPT: ", history)
                # print("PLAN: ", plan)
                # print('-'*25)

                total += env.success_count + env.failed_count + len(env.task_manager._current_task_list)
                success += env.success_count
                failed += env.failed_count
                total_action_histories.append(action_histories)
                total_action_success_histories.append(env.action_success_history)
                total_prompts.append(prompt_history)

                table[alpha].append({
                    'total' : total,
                    'success' : success,
                    'failed': failed,
                    'alpha': alpha,
                    'noop_count': env.noop_count,
                    'action_history': total_action_histories,
                    'action_success_history': total_action_success_histories,
                    'prompt_history': total_prompts,
                    'all_tasks': env.task_manager._all_tasks,
                    'all_orders_list': env.task_manager._all_orders_list,
                    'order_times': env.task_manager.order_times,
                    'acomplished_task_list': env.task_manager.accomplished_tasks(),
                    'task_interval': env.task_manager._task_interval,
                    'task_lifetime': env.task_manager._task_lifetime,
                    'max_num_tasks': env.task_manager._max_num_tasks,
                    'dish_completion': env.task_manager._dish_completion,
                    'all_orders_lifetime': env.task_manager._all_orders_lifetime,
                    'stop_step': step,
                    'cost': agent_manager.cost,
                    'cost_matrix': cost_matrix,
                    'success_matrix': success_matrix,
                    'plans': plan,
                    'predictions': agent_manager.predictions
                })

            # Save Results
            model2write = model.split('/')[-1] if '/' in model else model
            results_dir = Path(args.results_dir, model2write, agent_type, str(num_agents))
            results_dir.mkdir(parents=True, exist_ok=True)
            if budget is not None:
                with open(Path(results_dir, f'result_{level}_{budget}.json'), 'w') as fp:
                    json.dump(table, fp)
            else:
                with open(Path(results_dir, f'result_{level}.json'), 'w') as fp:
                    json.dump(table, fp)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="invoking GPT")

    # Add the arguments
    parser.add_argument('-m','--model_name', metavar='model_name', type=str, choices=AVAILABLE_MODELS, help='model to use', default=AVAILABLE_MODELS[0])
    parser.add_argument('--without_feedback', action='store_true', help='without feedback')
    parser.add_argument('--without_notes', action='store_true', help='without notes')
    parser.add_argument('-n', '--num_agents', metavar='num_agents', type=int, help='number of agents', default=1)
    parser.add_argument('-l', '--level', metavar='level', type=str, help='level of the game', default='level_0')
    parser.add_argument('-e', '--max_episode', metavar='max_episode', type=int, help='max episode', default=1)
    parser.add_argument('-s', '--max_steps', metavar='max_steps', type=int, help='max steps', default=60)
    parser.add_argument('-rd', '--results_dir', metavar='results_dir', type=str, help='results directory', default='results/')
    parser.add_argument('-t', '--type', metavar='type', type=str, help='type of the agent', choices=AGENT_TYPES, default='individual')
    parser.add_argument('-v', '--verbose', action='store_true', help='verbose')
    parser.add_argument('-dr', '--debete_rounds', metavar='debate_rounds', type=int, help='number of debate rounds', default=3)
    parser.add_argument('-b', '--budget', metavar='budget', type=int, help='budget', default=None)
    # Parse the arguments
    args = parser.parse_args()

    total = int(args.without_feedback) + int(args.without_notes)
    assert total <= 1

    main(args)