from jericho import FrotzEnv
import os
import re
from verl.environments.base import BaseEnv

import json
def read_json(path):
    with open(path, 'r', encoding='utf8') as f:
        return json.loads(f.read())
        
ACTION_SPACE = """
- Inventory: check things you are carrying
- Look: check your surroundings
- Examine {place/obj}: check the details of something
- Take {obj}: pickup obj
- Put down {obj}: leave a obj at your current place.
- Drop {obj}
- South: go south
- North: go north
- East: go east
- West: go west
- Up: go up
- Down: go down
"""

class JerichoEnv(BaseEnv):

    def __init__(self, env_config, special_settings, add_examples=True):       
        super().__init__(special_settings)
        self.env_config = env_config
        self.env_name = "jericho"
        self.max_step = 50

        # inital env by env_config
        self.env = FrotzEnv(env_config["game_file"])

        # reward fn
        self.points = 0  # 得分点
        self.obs_to_reward = env_config["obs_to_reward"]
        self.num_obs_to_reward = len(self.obs_to_reward)

        # init env
        init_obs, info = self.env.reset()

        # init info
        self.task_id = env_config["game_id"]
        self.task_name = env_config["game_name"]
        self.task_gamefile = env_config["game_file"]
        self.difficulty = env_config["difficulty"]
        self.task_description = env_config["goal"]
        self.init_obs = self.get_init_obs(init_obs)

        # prepare prompt
        self.user_prompt = self.user_prompt.replace("<action_space>", ACTION_SPACE)
        self.user_prompt = self.user_prompt.replace("<Task>", self.task_description)
        self.user_prompt = self.user_prompt.replace("<Init Obs>", self.init_obs)

        ######################
        #### add examples ####
        ######################
        if add_examples:
            data = read_json("./verl/environments/jericho/jericho_vanilla_prompt.json")
            examples = data["examples"]
            self.user_prompt = self.user_prompt.replace("<examples>", ''.join(examples))
        else:
            self.user_prompt = self.user_prompt.replace("<examples>", "no example.")

        ############################
        #### prompts for action ####
        ############################
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<action_space>", ACTION_SPACE)
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<Task>", self.task_description)
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<Init Obs>", self.init_obs)


    def get_init_obs(self, init_obs): # this step needs case specific checking for each game
        text =  re.sub(r'-----([\s\S]*?)-----', '', init_obs)  # remove the copyright line
        # Remove all \n   
        text = re.sub(r'\n', ' ', text)
        # Remove more than one space  
        text = re.sub(r' {2,}', ' ', text)
        return text.strip()

    def clean_up_text(self, observation):
        cleaned_text = re.sub(r'\n', ' ', observation)
        cleaned_text = re.sub(r' {2,}', ' ', cleaned_text) 
        cleaned_text = re.sub(r'\[Your score .*?\]', '', cleaned_text)  
        return cleaned_text
    
    def _match_style(self, obs, pattern):
        # remove all non-alphanumeric characters, but keep spaces
        obs = re.sub(r'[^a-zA-Z0-9\s]', '', obs)
        pattern = re.sub(r'[^a-zA-Z0-9\s]', '', pattern)
        if pattern in obs:
            return True
        else:
            return False
    
    def update_points(self, obs):
        if self.obs_to_reward is None:
            return
        if len(self.obs_to_reward) == 0:
            return
        
        for pattern in self.obs_to_reward:
            if self._match_style(obs, pattern):
                self.points += 1
                self.obs_to_reward.remove(pattern)
                break

    def _get_action_space(self):
        return self.env.get_valid_actions() + ["check valid actions"]

    def validate_check_valid_actions(self, action):
        action = action.strip().lower()
        if "check" in action and "action" in action:
            return True
        else:
            return False

    def step(self, action):
        if self.validate_check_valid_actions(action):
            # print(1111)
            observation = "You can take the following actions: " + ", ".join(self._get_action_space())
        else:
            try:
                observation, _reward, _done, _info = self.env.step(action)
                observation = self.clean_up_text(observation)
            except:
                observation = "illegal action!"
            
            self.update_points(observation)

        score = self.points / self.num_obs_to_reward
        done = (score == 1)
            
        score = score * 10  # 0~10

        # update stage
        self.current_step += 1
        if self.current_step >= self.max_step:
            self.gameDone = True
        elif done:
            self.gameDone = True

        self.reward = score - self.score
        self.score = score

        # judge task state
        if self.gameDone:
            if score > 0:
                self.done = True
            else:
                self.over = True

        return observation, self.reward, self.score, self.gameDone

    def reset(self):
        self.points = 0  # 得分点
        self.obs_to_reward = self.env_config["obs_to_reward"]
        super().reset()