import babyai_text
import babyai.utils as utils

import numpy as np
import openai
from env import BabyAIEnv

def load_text(fpaths, by_lines=False):
    with open(fpaths, "r") as fp:
        if by_lines:
            return fp.readlines()
        else:
            return fp.read()

def deepcopy_list(list_msgs):
    list_msgs_copy = []
    for msg in list_msgs:
        list_msgs_copy.append({"role": msg["role"], "content": msg["content"]})
    return list_msgs_copy

openai.api_key = "REPLACE_THIS_WITH_YOUR_API_KEY"

map_name = 'BabyAI-GoToObj-v0'
env = BabyAIEnv(map_name)
agent = utils.load_agent(env.env, 'BOT', None, 'agent', False, map_name)

reflect_dir = "/zfsauton2/home/wentsec/incontext_RL/collect/reflect4.txt"
reflect = load_text(reflect_dir)

stat = {
    "origin_correct": 0,
    "origin_incorrect": 0,
    "new_correct": 0,
    "new_incorrect": 0,
    "origin_succ_traj": 0,
    "origin_fail_traj": 0,
    "new_succ_traj": 0,
    "new_fail_traj": 0
}

for episode in range(100):
    
    obs, info = env.reset(seed=episode*42+17)
    agent.on_reset()
    done = False
    
    traj_obs_buffer = [info["obs_text"]]
    traj_action_buffer = []
    episode_return = 0
    sys_txt = info["prompt"][0]["content"]
    sys_txt = sys_txt.split("\nIn a moment I will present you an observation.")[0].strip()
    
    while not done:
        
        expert_action = agent.act(obs, info)["action"]
        
        for retry_time in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-4o-mini", 
                    messages=info["prompt"],
                    temperature=0.7  
                )
                response = response["choices"][0]["message"]["content"]
                for valid_action in env.all_actions:
                    if valid_action in response:
                        action = env.all_actions.index(valid_action)
                        break
                assert action < len(env.all_actions)
                break
            except:
                action = np.random.randint(len(env.all_actions))
            if retry_time == 2:
                exit()
        
        obs, reward, done, info = env.step(action)
        
        for valid_action in env.all_actions:
            if expert_action.name.lower() in valid_action.lower():
                expert_action = env.all_actions.index(valid_action)
                break
        
        traj_obs_buffer.append(info["obs_text"])
        traj_action_buffer.append(env.all_actions[action])
        episode_return += reward
        
        if action == expert_action:
            stat["origin_correct"] += 1
        else:
            stat["origin_incorrect"] += 1
        
    if episode_return > 0:
        episode_result = "You have completed the task successfully!"
        stat["origin_succ_traj"] += 1
    else:
        episode_result = "You have failed the task."
        stat["origin_fail_traj"] += 1
    
    traj = "At state t you observed: " + traj_obs_buffer[0] + "\n"
    traj += "You took action: " + traj_action_buffer[0] + "\n"
    traj_len = len(traj_action_buffer)
    for idx in range(1, traj_len):
        if traj_obs_buffer[idx-1] == traj_obs_buffer[idx]:
            traj += f"At step t+{idx-1}, you observed: State unchanged.\n"
        else:
            traj += f"At step t+{idx-1}, you observed: " + traj_obs_buffer[idx] + "\n"
        traj += "You took action: " + traj_action_buffer[idx] + "\n"
    traj += f"At step t+{traj_len-1}, you observed: " + traj_obs_buffer[traj_len] + "\n"
    traj += episode_result + "\n"
    
    user_txt = reflect.format(traj=traj)
    msg = [
        {"role": "system", "content": sys_txt},
        {"role": "user", "content": user_txt}
    ]
    
    for retry_time in range(3):
        try:
            response = openai.ChatCompletion.create(
                model="gpt-4o-mini", 
                messages=msg,
                temperature=0.2 
            )
            response = response["choices"][0]["message"]["content"]
            assert "Feedback: " in response
            verbal_fb = response.split("Feedback: ")[1].strip()
            break
        except:
            verbal_fb = ""
        if retry_time == 2:
            exit()
            
    print("verbal feedback:\n", verbal_fb)
    print("====================================")
    
    obs, info = env.reset(seed=episode*42+17)
    agent.on_reset()
    done = False
    
    while not done:
        
        expert_action = agent.act(obs, info)["action"]
        msg = deepcopy_list(info["prompt"])
        msg[-1]["content"] += "Here is the advice you should follow: " + verbal_fb + "\n"
        
        for retry_time in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-4o-mini", 
                    messages=msg,
                    temperature=0.7  
                )
                response = response["choices"][0]["message"]["content"]
                for valid_action in env.all_actions:
                    if valid_action in response:
                        action = env.all_actions.index(valid_action)
                        break
                assert action < len(env.all_actions)
                break
            except:
                action = np.random.randint(len(env.all_actions))
            if retry_time == 2:
                exit()
        
        obs, reward, done, info = env.step(action)
        
        for valid_action in env.all_actions:
            if expert_action.name.lower() in valid_action.lower():
                expert_action = env.all_actions.index(valid_action)
                break
            
        if action == expert_action:
            stat["new_correct"] += 1
        else:
            stat["new_incorrect"] += 1
            
        episode_return += reward

    if episode_return > 0:
        stat["new_succ_traj"] += 1
    else:
        stat["new_fail_traj"] += 1
            
    print("stat:", stat)
        
        