import alfworld.agents.environment as environment
import alfworld.agents.modules.generic as generic
from os.path import join as pjoin
from alfworld.info import ALFWORLD_DATA
import numpy as np
import pdb
import re, os, json
from copic_wrapper import Agent

RED = "\033[31m"
BLUE = "\033[34m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
RESET = "\033[0m"

config = generic.load_config()

task_num = {
    1: 24, 
    2: 18,
    3: 31,
    4: 23,
    5: 21,
    6: 17
}
sr_threshold = {
    1: 0.9, 
    2: 0.9,
    3: 0.9,
    4: 0.9,
    5: 0.9,
    6: 0.8
}

for task_id in range(1, 7, 1):
    config["env"]["task_types"] = [task_id] # only pick and place, 24 tasks in eval_out_of_distribution
    env_type = config["env"]["type"]
    
    
    print(RED + "=" * 20 + "CoPiC Learning" + "=" * 20 + RESET)
    train_eval = "train"
    env = getattr(environment, env_type)(config, train_eval=train_eval)
    env = env.init_env(batch_size=1)
    env.seed(240803)
    agent = Agent(
        env, task_id, 
        # llm_model_name="deepseek-chat",
        # llm_model_name="deepseek-coder",
        # llm_model_name="llama2-70B",
        # llm_model_name="qwen2.5-72B-Instruct"
        # llm_model_name="Meta-Llama-3-8B-Instruct"
        # llm_model_name="Qwen2.5-Coder-14B-Instruct"
        llm_model_name="Qwen2.5-14B-Instruct"
    )
    while True:
        success_rate = agent.solve_goal(num_episodes=20, threshold=sr_threshold[task_id])
        if success_rate >= sr_threshold[task_id]:
            print(RED + f"Task Success Rate: {success_rate} >= 0.9, so I will stop the training." + RESET)
            break
        else:
            env = getattr(environment, env_type)(config, train_eval=train_eval)
            env = env.init_env(batch_size=1)
            env.seed(240803)
            agent.env = env
    
    
    print(GREEN + "=" * 20 + "CoPiC Testing" + "=" * 20 + RESET)
    train_eval = "eval_out_of_distribution"
    t_env = getattr(environment, env_type)(config, train_eval=train_eval)
    t_env = t_env.init_env(batch_size=1)
    t_env.seed(240803)
    agent.env = t_env
    success_rate = agent.solve_goal(num_episodes=task_num[task_id], threshold=-1, learn=False)
    print(GREEN + "=" * 20 + "CoPiC Testing SR: " + f"{success_rate}" + "=" * 20 + RESET)
    
    # save test_interaction_steps: {eps: steps}
    test_interaction_steps_path = os.path.join(agent.llm2planner.planners_path, f"test_interaction_steps_{task_id}.json")
    json.dump(agent.test_interaction_steps, open(test_interaction_steps_path, "w"))