import numpy as np
import gym
import re

class BabyAIEnv:
    
    def __init__(self, map_name, max_steps=15):
        self.map_name = map_name
        self.max_steps = max_steps
        self.all_actions = ['Turn left', 'Turn right', 'Go forward', 'Pick up', 'Drop', 'Toggle', 'Done']
        self.env = gym.make(map_name)
        self.system_prompt = "You are an agent playing a simple navigation game. Your goal is to {goal}. The following are the possible actions you can take in the game, followed by a short description of each action:\n\nTurn left: turn to the left,\nTurn right: turn to the right,\nGo forward: take one step forward,\nPick up: pick up the object below you,\nDrop: drop the object that you are holding,\nToggle: manipulate the object in front of you.\n\nIn a moment I will present you an observation.\n\nTips:\n- Once the desired object you want to interact or pickup in front of you, you can use the 'toggle' action to interact with it.\n- It doesn't make sense to repeat the same action over and over if the observation doesn't change.\n- answer the alphanumerical action, not the description.\n\nPLAY!"
        self.user_prompt = "\n\nYou always have to output one of the above actions at a time and no other text. You always have to output an action until the episode terminates."
        
    def _get_prompt(self, obs, info):
        obs_text = ""
        goal_obj = obs["mission"].split("go to the ")[1]
        for o in info["descriptions"]:
            if goal_obj in o:
                number = re.findall(r"\d+", o)
                dist = sum([int(n) for n in number])
                o += f" ({dist} steps away from you)"
            if "wall" in o and "You see a wall 1 steps forward." != o:
                continue
            if o != "\n":
                obs_text += o + ".\n"
        if obs_text == "":
            obs_text = "You see nothing."
        if self.last_obs is None:
            prompt = [
                {"role": "system", "content": self.system_prompt.format(goal=obs["mission"])},
                {"role": "user", "content": "Current Observation: " + obs_text + self.user_prompt},
            ]
        elif self.last_obs == obs_text:
            prompt = [
                {"role": "system", "content": self.system_prompt.format(goal=obs["mission"])},
                {"role": "user", "content": self.last_obs},
                {"role": "assistant", "content": self.last_action},
                {"role": "user", "content": "Current Observation: " + obs_text + self.user_prompt + "\nYour previous action was invalid. Please try another action."},
            ]
        else:
            prompt = [
                {"role": "system", "content": self.system_prompt.format(goal=obs["mission"])},
                {"role": "user", "content": self.last_obs},
                {"role": "assistant", "content": self.last_action},
                {"role": "user", "content": "Current Observation: " + obs_text + self.user_prompt},
            ]
        if goal_obj not in obs_text:
            prompt[-1]["content"] += f" Turn left or turn right to find the {goal_obj}."
        info["prompt"] = prompt
        info["obs_text"] = obs_text
        return obs, info
        
    def reset(self, seed=None):
        self.last_obs, self.last_action = None, None
        if seed is not None:
            self.env.seed(seed)
            np.random.seed(seed)
        obs, info = self.env.reset()
        obs, info = self._get_prompt(obs, info)
        
        no_obj = True
        for des in info["descriptions"]:
            print("des:", des)
            if "wall" not in des:
                number = re.findall(r"\d+", des)
                print("Number:", number)
                dist = sum([int(n) for n in number])
                print("Dist:", dist)
                if dist > 3:
                    return self.reset(seed)
                no_obj = False
        if no_obj:
            return self.reset(seed)

        
        self.steps = 0
        self.total_rew = 0
        return obs, info
    
    def step(self, action):
        
        if action not in range(len(self.all_actions)):
            obs, info = self.env.reset()
            obs, info = self._get_prompt(obs, info)
            done = True
            reward = 0
            return obs, reward, done, info
        
        obs, reward, done, info = self.env.step(action)
        obs, info = self._get_prompt(obs, info)
        self.steps += 1
        self.total_rew += reward
        if self.steps >= self.max_steps:
            done = True
        self.last_obs = info["obs_text"]
        self.last_action = self.all_actions[action]
        return obs, reward, done, info
