from scienceworld import ScienceWorldEnv
from verl.environments.base import BaseEnv

import re
import json
def read_json(path):
    with open(path, 'r', encoding='utf8') as f:
        return json.loads(f.read())

ACTION_SPACE = """
**Manipulation**: 
- open {OBJ} / close {OBJ}: Interact with a container.
- pick up {OBJ}: Add an object to your inventory.
- put down {OBJ}: Remove an object from your inventory.
- move {OBJ} to {OBJ}: Transfer an object.
- pour {OBJ} into {OBJ}: Pour a substance.
- dunk {OBJ} into {OBJ}: Immerse a container in a liquid.
- mix {OBJ}: Chemically combine contents.

**Inspection**:
- inventory: Display items you're carrying.
- look around: Survey your surroundings.
- look at {OBJ}: Examine an object closely.
- look in {OBJ}: Peek inside a container.
- read {OBJ}: Review written content.

**Device Operations**:
- activate {OBJ} / deactivate {OBJ}: Toggle a device.
- use {OBJ} [on {OBJ}]: Utilize a device or item.
- connect {OBJ} to {OBJ} / disconnect {OBJ} to {OBJ}: connect / disconnect electrical components

**Movement**:
- go to {LOC}: Relocate.

**Miscellaneous**:
- {Number}: When the Env feedback has "Please enter the number for the action you intended..."
- eat {OBJ}: Consume an edible item.
- flush {OBJ}: Activate a flushing mechanism.
- focus on {OBJ}: Direct attention to a particular object.
- wait: take no action for 10 steps
- wait1: take no action for a step
  
Where:
- {OBJ}: Object
- {LOC}: Location
- {Number}: Number, such as 0, 1, 2, ...
"""

class SciworldEnv(BaseEnv):

    def __init__(self, env_config, special_settings, add_examples=False, simpleStr=""):
        super().__init__(special_settings)
        self.max_step = 50
        
        # inital env by env_config
        self.env = ScienceWorldEnv("", None, envStepLimit=200)
        self.env.load(env_config["task_name"], env_config["var"], simpleStr)  # 'easy' to change the difficult

        # init env
        init_obs, info = self.env.reset()

        # init info
        self.task_type = info["taskName"]
        self.task_var = info["variationIdx"]
        self.task_id = "{}_{}".format(info["taskName"], info["variationIdx"])
        self.task_name = self.task_id
        self.task_gamefile = None
        self.task_description = self.env.get_task_description()
        self.init_obs = self._clean_obs(init_obs)
        
        # prompt
        self.user_prompt = self.user_prompt.replace("<action_space>", ACTION_SPACE)
        self.user_prompt = self.user_prompt.replace("<Task>", self.task_description)
        self.user_prompt = self.user_prompt.replace("<Init Obs>", self.init_obs)

        ######################
        #### add examples ####
        ######################
        if add_examples:
            data = read_json("./verl/environments/sciworld/scienceworld_base.json")
            examples = data["examples"]
            self.user_prompt = self.user_prompt.replace("<examples>", examples)
        else:
            self.user_prompt = self.user_prompt.replace("<examples>", "no example.")

        ############################
        #### prompts for action ####
        ############################
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<action_space>", ACTION_SPACE)
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<Task>", self.task_description)
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<Init Obs>", self.init_obs)

        ####################################
        #### Execute History Trajectory ####
        ####################################
        self.init_history_traj = []
        history_traj = env_config.get("history_traj", [])
        for step, traj in enumerate(history_traj):
            action = traj["action"]
            if action is not None:
                observation, reward, score, done = self.step(action)
                self.init_history_traj.append({
                    "original_response": traj["original_response"],
                    "think": traj["thought"],
                    "action": traj["action"],
                    "score": score,
                    "format_score": traj["format_score"],
                    "observation": observation,
                })
                if score != traj["score"]:
                    print("step", step)
                    print("Score Not Equal!")
                    print(traj["score"])
                    print(score)

                # if observation != traj["observation"]:
                #     print("step", step)
                #     print("Observation Not Equal!")
                #     print(traj["observation"])
                #     print(observation)
                # sciworld 物品拜访位置随机

    def _process_ob(self, obs):
        if "Ambiguous request: Please enter the number for the action you intended" in obs:
            obs += "\nPlease only give your action by a single number, such as 0, 1, 2, ..."
        return obs
    
    def _clean_obs(self, s):
        clean_toks = ['\n', '\t']
        for tok in clean_toks:
            s = s.replace(tok, ' ')
        return s
    
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        observation = self._clean_obs(observation)
        observation = self._process_ob(observation)

        # update stage
        self.current_step += 1
        if self.current_step >= self.max_step:
            self.gameDone = True
        elif done:
            self.gameDone = True

        self.reward = reward / 10
        self.score = info["score"] / 10 if info["score"] != -100 else self.score  # -100 -> last max score

        # judge task state
        if self.gameDone:
            if info["score"] > 0:
                self.done = True
            else:
                self.over = True

        return observation, self.reward, self.score, self.gameDone
    

class SciworldEnv_Agentboard(SciworldEnv):
    def __init__(self, env_config, special_settings):
        super().__init__(env_config, special_settings, True, "easy")
        self.sub_goal = env_config["subgoals"]
        self.finished_sub_goal = [0 for _ in range(len(self.sub_goal) + 1)]

    def _check_temperature_string(self, s, selected_obs):
        for i, pattern in enumerate(selected_obs):
            match = re.search(pattern, s)
            if match:
                self.finished_sub_goal[i] = 1.0

    def step(self, action):
        if action == "check valid actions":
            valid_actions = ", ".join(self.env.get_possible_actions())
            observation = f"Choose an action from these valid actions: {valid_actions}"
            score, done = self.score, self.done
        else:
            observation, score, done, info = self.env.step(action)
            observation = self._clean_obs(observation)
            observation = self._process_ob(observation)

        # update stage
        self.current_step += 1
        if self.current_step >= self.max_step:
            self.gameDone = True
        elif done:
            self.gameDone = True

        self._check_temperature_string(observation, self.sub_goal)
        if done:
            score = 1.0
        else:
            score =  sum(self.finished_sub_goal) * 1.0 / len(self.finished_sub_goal)
        score = score * 10

        self.reward = score - self.score
        self.score = score

        if self.gameDone:
            if score > 0:
                self.done = True
            else:
                self.over = True

        return observation, self.reward, self.score, self.gameDone

    def reset(self):
        super().reset()
        self.finished_sub_goal = [0 for _ in range(len(self.sub_goal) + 1)]