import babyai_text
import babyai.utils as utils

import numpy as np
import openai
from openai import OpenAI
from env import BabyAIEnv

def load_text(fpaths, by_lines=False):
    with open(fpaths, "r") as fp:
        if by_lines:
            return fp.readlines()
        else:
            return fp.read()

def deepcopy_list(list_msgs):
    list_msgs_copy = []
    for msg in list_msgs:
        list_msgs_copy.append({"role": msg["role"], "content": msg["content"]})
    return list_msgs_copy

map_name = 'BabyAI-GoToObj-v0'
env = BabyAIEnv(map_name)
agent = utils.load_agent(env.env, 'BOT', None, 'agent', False, map_name)

reflect_open_dir = "/zfsauton2/home/wentsec/incontext_RL/collect/reflect2.txt"
reflect_close_dir = "/zfsauton2/home/wentsec/incontext_RL/collect/reflect3.txt"
reflect_open = load_text(reflect_open_dir)
reflect_close = load_text(reflect_close_dir)
reflect_test_dir = "/zfsauton2/home/wentsec/incontext_RL/collect/reflect2_test.txt"
reflect_test = load_text(reflect_test_dir)

stat = {
    "correct2correct": 0,
    "correct2incorrect": 0,
    "incorrect2correct": 0,
    "incorrect2incorrect": 0,
    "succ_traj": 0,
    "fail_traj": 0
}

api_key = "REPLACE_THIS_WITH_YOUR_API_KEY"
openai_client = OpenAI(api_key=api_key)
def gpt(model, msg, temperature):
    response = openai_client.chat.completions.create(
        model=model,
        messages=msg,
        temperature=temperature
    )
    return response.choices[0].message.content

api_key = "Empty"
api_base = "http://localhost:8000/v1"
llama_client = OpenAI(api_key=api_key, base_url=api_base)
def llama(model, msg, temperature):
    result = llama_client.chat.completions.create(
        messages=msg,
        model=model,
        temperature=temperature,
    )
    return result.choices[0].message.content
    

for episode in range(100):
    
    obs, info = env.reset()
    agent.on_reset()
    done = False
    
    traj_obs_buffer = [info["obs_text"]]
    traj_action_buffer = []
    traj_expert_a_buffer = []
    traj_prompt_buffer = []
    episode_return = 0
    sys_txt = info["prompt"][0]["content"]
    sys_txt = sys_txt.split("\nIn a moment I will present you an observation.")[0].strip()
    
    while not done:
        
        expert_action = agent.act(obs, info)["action"]
        
        for retry_time in range(3):
            try:
                # response = gpt("gpt-4o-mini", info["prompt"], 0.7)
                response = llama("meta-llama/Llama-3.2-3B-Instruct", info["prompt"], 0.7)
                for valid_action in env.all_actions:
                    if valid_action in response:
                        action = env.all_actions.index(valid_action)
                        break
                assert action < len(env.all_actions)
                break
            except:
                action = np.random.randint(len(env.all_actions))
            if retry_time == 2:
                exit()
        
        obs, reward, done, info = env.step(action)
        
        for valid_action in env.all_actions:
            if expert_action.name.lower() in valid_action.lower():
                expert_action = env.all_actions.index(valid_action)
                break
        
        traj_prompt_buffer.append(deepcopy_list(info["prompt"]))
        traj_obs_buffer.append(info["obs_text"])
        traj_action_buffer.append(env.all_actions[action])
        traj_expert_a_buffer.append(expert_action)
        episode_return += reward
        
    if episode_return > 0:
        episode_result = "You have completed the task successfully!"
        stat["succ_traj"] += 1
    else:
        episode_result = "You have failed the task."
        stat["fail_traj"] += 1
    print("Episode result:", episode_result)
    
    traj_len = len(traj_action_buffer)
    start_idx = 0 #max(0, traj_len-5)
    
    # for idx in range(len(traj_obs_buffer)):
    #     old_obs = traj_obs_buffer[idx]
    #     obs = "\n"
    #     for o in old_obs.split("\n"):
    #         if "wall" in o and ("2" in o or "3" in o or "4" in o or "5" in o or "6" in o):
    #             continue
    #         elif obs == "":
    #             continue
    #         else:
    #             obs += o + "\n"
    #     if obs == "":
    #         obs = "You see nothing."
    #     traj_obs_buffer[idx] = obs
            
    
    for t in range(start_idx, traj_len):
        
        # all_empty = True
        # for idx in range(t, traj_len):
        #     if traj_obs_buffer[idx] != "You see nothing.":
        #         all_empty = False
        #         break
        # if all_empty:
        #     continue
        
        future = ""
        for idx in range(t+1, traj_len):
            if traj_obs_buffer[idx-1] == traj_obs_buffer[idx]:
                future += f"At step t+{idx-t}, you observed:\nState unchanged. You should try a different action.\n"
            else:
                future += f"At step t+{idx-t}, you observed:\n" + traj_obs_buffer[idx] + "\n"
            future += "You took action: " + traj_action_buffer[idx] + "\n"
        future += f"At step t+{traj_len-t}, you observed:\n" + traj_obs_buffer[traj_len] + "\n"
        future += episode_result + "\n"
        
        o_t = traj_obs_buffer[t]
        a_t = traj_action_buffer[t]
        user_txt = reflect_test.format(o_t=o_t, a_t=a_t, traj=future)
        # user_txt = reflect_open.format(o_t=o_t, a_t=a_t, traj=future)
        # user_txt += reflect_close
        msg = [
            {"role": "system", "content": sys_txt},
            {"role": "user", "content": user_txt}
        ]
        
        for retry_time in range(3):
            try:
                fb_response = gpt("gpt-4o-mini", msg, 0.7)
                # fb_response = llama("meta-llama/Llama-3.2-3B-Instruct", msg, 0.7)
                verbal_fb = fb_response.split("Conclusion:")[1].strip()
                break
            except:
                print("verbal_fb error", fb_response)
                verbal_fb = ""
            if retry_time == 2:
                exit()
        
        agent_prompt = traj_prompt_buffer[t]
        agent_prompt[-1]["content"] += "Advice you should follow: " + verbal_fb

        for retry_time in range(3):
            try:
                # response = gpt("gpt-4o-mini", agent_prompt, 0.7)
                response = llama("meta-llama/Llama-3.2-3B-Instruct", agent_prompt, 0.7)
                for valid_action in env.all_actions:
                    if valid_action in response:
                        new_action = env.all_actions.index(valid_action)
                        break
                assert new_action in range(len(env.all_actions))
                break
            except:
                if retry_time == 2:
                    exit()

        expert_action = traj_expert_a_buffer[t]
        action = env.all_actions.index(traj_action_buffer[t])
        if new_action == expert_action and action == expert_action:
            stat["correct2correct"] += 1
        elif new_action == expert_action and action != expert_action:
            stat["incorrect2correct"] += 1
        elif new_action != expert_action and action == expert_action:
            stat["correct2incorrect"] += 1
            print("usr:", user_txt)
            print("----"*20)
            print("fb:", fb_response)
            print("----"*20)
            print("action:", env.all_actions[new_action], "expert:", env.all_actions[expert_action])
            print("====="*20)
        else:
            stat["incorrect2incorrect"] += 1
            
    print("stat:", stat)
        
        