import os
import json
import random

from tqdm import tqdm
from termcolor import colored

import textworld
import textworld.agents
import textworld.gym

from alfworld.agents.utils.misc import Demangler
from alfworld.agents.expert import HandCodedTWAgent, HandCodedAgentTimeout


TASK_TYPES = {1: "pick_and_place_simple",
              2: "look_at_obj_in_light",
              3: "pick_clean_then_place_in_recep",
              4: "pick_heat_then_place_in_recep",
              5: "pick_cool_then_place_in_recep",
              6: "pick_two_obj_and_place"}


class AlfredDemangler(textworld.core.Wrapper):

    def __init__(self, *args, shuffle=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.shuffle = shuffle

    def load(self, *args, **kwargs):
        super().load(*args, **kwargs)

        demangler = Demangler(game_infos=self._entity_infos, shuffle=self.shuffle)
        for info in self._entity_infos.values():
            info.name = demangler.demangle_alfred_name(info.id)


class AlfredInfos(textworld.core.Wrapper):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._gamefile = None

    def load(self, *args, **kwargs):
        super().load(*args, **kwargs)
        self._gamefile = args[0]

    def reset(self, *args, **kwargs):
        state = super().reset(*args, **kwargs)
        state["extra.gamefile"] = self._gamefile
        return state


# Enum for the supported types of AlfredExpert.
class AlfredExpertType:
    HANDCODED = "handcoded"
    PLANNER = "planner"


class AlfredTWEnv(object):
    '''
    Interface for Textworld Env
    '''
    def __init__(self, config, train_eval="train"):
        print("Initializing AlfredTWEnv...")
        self.config = config
        self.train_eval = train_eval

        if config["env"]["goal_desc_human_anns_prob"] > 0:
            msg = ("Warning! Changing `goal_desc_human_anns_prob` should be done with"
                   " the script `alfworld-generate`. Ignoring it and loading games as they are.")
            print(colored(msg, "yellow"))

        self.collect_game_files()

    def collect_game_files(self, verbose=False):
        def log(info):
            if verbose:
                print(info)

        self.game_files = []

        if self.train_eval == "train":
            data_path = os.path.expandvars(self.config['dataset']['data_path'])
        elif self.train_eval == "eval_in_distribution":
            data_path = os.path.expandvars(self.config['dataset']['eval_id_data_path'])
        elif self.train_eval == "eval_out_of_distribution":
            data_path = os.path.expandvars(self.config['dataset']['eval_ood_data_path'])

        log("Collecting solvable games...")

        # get task types
        assert len(self.config['env']['task_types']) > 0
        task_types = []
        for tt_id in self.config['env']['task_types']:
            if tt_id in TASK_TYPES:
                task_types.append(TASK_TYPES[tt_id])

        count = 0
        for root, dirs, files in tqdm(list(os.walk(data_path, topdown=False))):
            if 'traj_data.json' in files:
                count += 1

                # Filenames
                json_path = os.path.join(root, 'traj_data.json')
                game_file_path = os.path.join(root, "game.tw-pddl")

                if 'movable' in root or 'Sliced' in root:
                    log("Movable & slice trajs not supported %s" % (root))
                    continue

                # Get goal description
                with open(json_path, 'r') as f:
                    traj_data = json.load(f)

                # Check for any task_type constraints
                if not traj_data['task_type'] in task_types:
                    log("Skipping task type")
                    continue

                # Check if a game file exists
                if not os.path.exists(game_file_path):
                    log(f"Skipping missing game! {game_file_path}")
                    continue

                with open(game_file_path, 'r') as f:
                    gamedata = json.load(f)

                # Check if previously checked if solvable
                if 'solvable' not in gamedata:
                    print(f"-> Skipping missing solvable key! {game_file_path}")
                    continue

                if not gamedata['solvable']:
                    log("Skipping known %s, unsolvable game!" % game_file_path)
                    continue

                # Add to game file list
                self.game_files.append(game_file_path)

        print(f"Overall we have {len(self.game_files)} games in split={self.train_eval}")
        self.num_games = len(self.game_files)

        if self.train_eval == "train":
            num_train_games = self.config['dataset']['num_train_games'] if self.config['dataset']['num_train_games'] > 0 else len(self.game_files)
            self.game_files = self.game_files[:num_train_games]
            self.num_games = len(self.game_files)
            print("Training with %d games" % (len(self.game_files)))
        else:
            num_eval_games = self.config['dataset']['num_eval_games'] if self.config['dataset']['num_eval_games'] > 0 else len(self.game_files)
            self.game_files = self.game_files[:num_eval_games]
            self.num_games = len(self.game_files)
            print("Evaluating with %d games" % (len(self.game_files)))

    def get_game_logic(self):
        self.game_logic = {
            "pddl_domain": open(os.path.expandvars(self.config['logic']['domain'])).read(),
            "grammar": open(os.path.expandvars(self.config['logic']['grammar'])).read()
        }

    # use expert to check the game is solvable
    def is_solvable(self, env, game_file_path,
                    random_perturb=True, random_start=10, random_prob_after_state=0.15):
        done = False
        steps = 0
        trajectory = []
        try:
            env.load(game_file_path)
            game_state = env.reset()
            if env.expert_type == AlfredExpertType.PLANNER:
                return game_state["extra.expert_plan"]

            while not done:
                expert_action = game_state['extra.expert_plan'][0]
                random_action = random.choice(game_state.admissible_commands)

                command = expert_action
                if random_perturb:
                    if steps <= random_start or random.random() < random_prob_after_state:
                        command = random_action

                game_state, _, done = env.step(command)
                trajectory.append(command)
                steps += 1
        except Exception as e:
            print("Unsolvable: %s (%s)" % (str(e), game_file_path))
            return None

        return trajectory


from verl.environments.base import BaseEnv

ACTION_SPACE = """
- go to {recep}
- take {obj} from {recep}
- put {obj} in/on {recep}
- open {recep}
- close {recep}
- clean {obj} with {recep}
- heat {obj} with {recep}
- cool {obj} with {recep}
- use {obj}
- look
- inventory
where {obj} and {recep} correspond to objects and receptacles.
"""

import re
import json
def read_json(path):
    with open(path, 'r', encoding='utf8') as f:
        return json.loads(f.read())
    
class AlfworldEnv(BaseEnv):

    def __init__(self, env_config, special_settings, add_examples=False):
        super().__init__(special_settings)
        self.env_config = env_config
        self.add_examples = add_examples
        self.max_step = 50
        
        # inital env by env_config
        self.env = self._load_env_from_game_file(self.env_config["game_file"])
  
        # init env
        init_obs_and_task, info = self.env.reset()
        _, init_obs, task_str = init_obs_and_task.split('\n\n')

        # init info
        self.task_id = self.env_config["task_id"]
        self.task_name = '/'.join(info['extra.gamefile'].split('/')[-3:-1])
        self.task_gamefile = info['extra.gamefile']
        self.task_description = task_str.split("Your task is to:")[-1].strip()
        self.init_obs = init_obs.strip()

        ############################
        #### prompts for action ####
        ############################
        # examples in prompt
        self.task_type, _, _ = self.task_to_task_type(self.task_description)

        self.user_prompt = self.user_prompt.replace("<action_space>", ACTION_SPACE)
        self.user_prompt = self.user_prompt.replace("<Task>", self.task_description)
        self.user_prompt = self.user_prompt.replace("<Init Obs>", self.init_obs)
        
        if self.add_examples:
            self.user_prompt = self.user_prompt.replace("<examples>", "no example.")
        else:
            self.examples = read_json("./verl/environments/alfworld/alfworld_base.json")
            self.user_prompt = self.user_prompt.replace("<examples>", ''.join(self.examples["examples"][self.task_type]))
      
        ############################
        #### prompts for action ####
        ############################
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<action_space>", ACTION_SPACE)
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<Task>", self.task_description)
        self.user_prompt_for_deepthink = self.user_prompt_for_deepthink.replace("<Init Obs>", self.init_obs)

        ####################################
        #### Execute History Trajectory ####
        ####################################
        self.init_history_traj = []
        history_traj = env_config.get("history_traj", [])
        for step, traj in enumerate(history_traj):
            action = traj["action"]
            if action is not None:
                observation, reward, score, done = self.step(action)
                self.init_history_traj.append({
                    "original_response": traj["original_response"],
                    "think": traj["thought"],
                    "action": traj["action"],
                    "score": score,
                    "format_score": traj["format_score"],
                    "observation": observation,
                })
                if score != traj["score"]:
                    print("step", step)
                    print("Score Not Equal!")
                    print(traj["score"])
                    print(score)



    def task_to_task_type(self, task_str):
        task_to_task_type = {
            "put (?:a|some) ([a-z]+) (?:in|on) ([a-z]+).": "put",
            "clean (?:a|some) ([a-z]+) and put it in ([a-z]+).": "clean",
            "put (?:a|some) clean ([a-z]+) in ([a-z]+).": "clean",
            "heat (?:a|some) ([a-z]+) and put it in ([a-z]+).": "heat",
            "put (?:a|some) hot ([a-z]+) in ([a-z]+).": "heat",
            "cool (?:a|some) ([a-z]+) and put it in ([a-z]+).": "cool",
            "put (?:a|some) cool ([a-z]+) in ([a-z]+).": "cool",
            "look at ([a-z]+) under the ([a-z]+).": "examine",
            "examine the ([a-z]+) with the ([a-z]+).": "examine",
            "put two ([a-z]+) (?:in|on) ([a-z]+).": "puttwo",
            "find two ([a-z]+) and put them (?:in|on) ([a-z]+).": "puttwo",
        }
        for pattern, task_type in task_to_task_type.items():
            match = re.search(pattern, task_str)
            if match:
                return task_type, match.group(1), match.group(2)

    def _load_env_from_game_file(self, game_file):
        alfred_demangler = AlfredDemangler(shuffle=False)
        wrappers = [alfred_demangler, AlfredInfos]
        # Register a new Gym environment.
        request_infos = textworld.EnvInfos(won=True, admissible_commands=True, extras=["gamefile"])
        
        env_id = textworld.gym.register_games([game_file], request_infos, 
                                                max_episode_steps=100, 
                                                asynchronous=False,
                                                wrappers=wrappers)
        env = textworld.gym.make(env_id)
        return env
    
    def process_ob(self, ob):
        if ob.startswith('You arrive at loc '):
            ob = ob[ob.find('. ')+2:]    
        return ob
    
    def step(self, action):
        observation, score, done, info = self.env.step(action)
        observation = self.process_ob(observation)
        score = score * 10  # score 0~10

        # update stage
        self.current_step += 1
        if self.current_step >= self.max_step:
            self.gameDone = True
        elif done:
            self.gameDone = True

        self.reward = score - self.score
        self.score = score

        # judge task state
        if self.gameDone:
            if score > 0:
                self.done = True
            else:
                self.over = True

        return observation, self.reward, self.score, self.gameDone

class AlfworldEnv_AgentBoard(AlfworldEnv):
    
    def __init__(self, env_config, special_settings):
        self.sub_goal = env_config["subgoals"]
        self.finished_sub_goal = [0 for _ in range(len(self.sub_goal) + 1)]
        if env_config.get("no_example", False):
            super().__init__(env_config, special_settings, False)
        else:
            super().__init__(env_config, special_settings, True)

    def _check_temperature_string(self, s, selected_obs):
        for i, pattern in enumerate(selected_obs):
            match = re.search(pattern, s)
            if match:
                self.finished_sub_goal[i] = 1.0

    def step(self, action):
        if action == "check valid actions":
            valid_actions = ", ".join(self.valid_actions)
            observation = f"Choose an action from these valid actions: {valid_actions}"
            score, done = self.score, self.done
        else:
            observation, score, done, info = self.env.step(action)
            self.valid_actions = info["admissible_commands"]
            observation = self.process_ob(observation)

        # update stage
        self.current_step += 1
        if self.current_step >= self.max_step:
            self.gameDone = True
        elif done:
            self.gameDone = True

        self._check_temperature_string(observation, self.sub_goal)
        if done:
            score = 1.0
        else:
            score =  sum(self.finished_sub_goal) * 1.0 / len(self.finished_sub_goal)
        score = score * 10

        self.reward = score - self.score
        self.score = score

        if self.gameDone:
            if score > 0:
                self.done = True
            else:
                self.over = True

        return observation, self.reward, self.score, self.gameDone

    def reset(self):
        super().reset()
        _, info = self.env.reset()
        self.valid_actions = info["admissible_commands"]
        self.finished_sub_goal = [0 for _ in range(len(self.sub_goal) + 1)]


if __name__ == "__main__":
    alfred_demangler = AlfredDemangler(shuffle=False)
    wrappers = [alfred_demangler, AlfredInfos]
    # Register a new Gym environment.
    request_infos = textworld.EnvInfos(won=True, admissible_commands=True, extras=["gamefile"])
    
    env_id = textworld.gym.register_games([game_file], request_infos, 
                                            max_episode_steps=100, 
                                            asynchronous=False,
                                            wrappers=wrappers)
    env = textworld.gym.make(env_id)

    
    

    

    