import babyai_text
import babyai.utils as utils

import numpy as np
import openai
from env import BabyAIEnv

def load_text(fpaths, by_lines=False):
    with open(fpaths, "r") as fp:
        if by_lines:
            return fp.readlines()
        else:
            return fp.read()

def deepcopy_list(list_msgs):
    list_msgs_copy = []
    for msg in list_msgs:
        list_msgs_copy.append({"role": msg["role"], "content": msg["content"]})
    return list_msgs_copy


# import make_env function from /zfsauton2/home/wentsec/incontext_RL/BALROG/balrog/environments/crafter/crafter_env.py
import sys
sys.path.append("/zfsauton2/home/wentsec/incontext_RL/BALROG/balrog/environments/crafter")
from crafter_env import make_crafter_env

env = make_crafter_env("crafter", "go_to_obj", config, render_mode=None)


openai.api_key = "REPLACE_THIS_WITH_YOUR_API_KEY"

map_name = 'BabyAI-GoToObj-v0'
env = BabyAIEnv(map_name)
agent = utils.load_agent(env.env, 'BOT', None, 'agent', False, map_name)

reflect_open_dir = "/zfsauton2/home/wentsec/incontext_RL/collect/reflect2.txt"
reflect_close_dir = "/zfsauton2/home/wentsec/incontext_RL/collect/reflect3.txt"
reflect_open = load_text(reflect_open_dir)
reflect_close = load_text(reflect_close_dir)

stat = {
    "correct2correct": 0,
    "correct2incorrect": 0,
    "incorrect2correct": 0,
    "incorrect2incorrect": 0,
    "succ_traj": 0,
    "fail_traj": 0
}

for episode in range(100):
    
    obs, info = env.reset()
    agent.on_reset()
    done = False
    
    traj_obs_buffer = [info["obs_text"]]
    traj_action_buffer = []
    traj_expert_a_buffer = []
    traj_prompt_buffer = []
    episode_return = 0
    sys_txt = info["prompt"][0]["content"]
    sys_txt = sys_txt.split("\nIn a moment I will present you an observation.")[0].strip()
    
    while not done:
        
        expert_action = agent.act(obs, info)["action"]
        
        for retry_time in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-4o-mini", 
                    messages=info["prompt"],
                    temperature=0.7  
                )
                response = response["choices"][0]["message"]["content"]
                for valid_action in env.all_actions:
                    if valid_action in response:
                        action = env.all_actions.index(valid_action)
                        break
                assert action < len(env.all_actions)
                break
            except:
                action = np.random.randint(len(env.all_actions))
            if retry_time == 2:
                exit()
        
        obs, reward, done, info = env.step(action)
        
        for valid_action in env.all_actions:
            if expert_action.name.lower() in valid_action.lower():
                expert_action = env.all_actions.index(valid_action)
                break
        
        traj_prompt_buffer.append(deepcopy_list(info["prompt"]))
        traj_obs_buffer.append(info["obs_text"])
        traj_action_buffer.append(env.all_actions[action])
        traj_expert_a_buffer.append(expert_action)
        episode_return += reward
        
    if episode_return > 0:
        episode_result = "You have completed the task successfully!"
        stat["succ_traj"] += 1
    else:
        episode_result = "You have failed the task."
        stat["fail_traj"] += 1
    print("Episode result:", episode_result)
    
    traj_len = len(traj_action_buffer)
    for t in range(traj_len):
        
        future = ""
        for idx in range(t+1, traj_len):
            if traj_obs_buffer[idx-1] == traj_obs_buffer[idx]:
                future += f"At step t+{idx-t}, you observed: State unchanged.\n"
            else:
                future += f"At step t+{idx-t}, you observed: " + traj_obs_buffer[idx] + "\n"
            future += "You took action: " + traj_action_buffer[idx] + "\n"
        future += f"At step t+{traj_len-t}, you observed: " + traj_obs_buffer[traj_len] + "\n"
        future += episode_result + "\n"
        
        o_t = traj_obs_buffer[t]
        a_t = traj_action_buffer[t]
        user_txt = reflect_open.format(o_t=o_t, a_t=a_t, traj=future)
        user_txt += reflect_close
        msg = [
            {"role": "system", "content": sys_txt},
            {"role": "user", "content": user_txt}
        ]
        
        for retry_time in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-4o-mini", 
                    messages=msg,
                    temperature=0.7  
                )
                response = response["choices"][0]["message"]["content"]
                assert "Feedback: " in response
                verbal_fb = response.split("Feedback: ")[1].strip()
                break
            except:
                verbal_fb = ""
            if retry_time == 2:
                exit()
        
        agent_prompt = traj_prompt_buffer[t]
        agent_prompt[-1]["content"] += "Advice you should follow: " + verbal_fb

        for retry_time in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-4o-mini", 
                    messages=agent_prompt,
                    temperature=0.7  
                )
                response = response["choices"][0]["message"]["content"]
                for valid_action in env.all_actions:
                    if valid_action in response:
                        new_action = env.all_actions.index(valid_action)
                        break
                assert new_action in range(len(env.all_actions))
                break
            except:
                if retry_time == 2:
                    exit()

        expert_action = traj_expert_a_buffer[t]
        action = env.all_actions.index(traj_action_buffer[t])
        if new_action == expert_action and action == expert_action:
            stat["correct2correct"] += 1
        elif new_action == expert_action and action != expert_action:
            stat["incorrect2correct"] += 1
        elif new_action != expert_action and action == expert_action:
            stat["correct2incorrect"] += 1
        else:
            stat["incorrect2incorrect"] += 1
            
    print("stat:", stat)
        
        