import os
import os.path as osp

import yaml

from embodied_cd.environments.base import BaseEnvironment, CustomAlfredTWEnv


class AlfredWorldEnv(BaseEnvironment):
    name = "alfred_world"

    def __init__(self, split="train"):
        super().__init__()

        config_path = osp.join("externals/alfworld/configs/base_config.yaml")
        with open(config_path, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        train_eval = {
            "train": "train",
            "valid_seen": "eval_in_distribution",
            "valid_unseen": "eval_out_of_distribution",
        }.get(split, split)

        self.env = CustomAlfredTWEnv(config, train_eval=train_eval).init_env(
            batch_size=1
        )

    def step(self, action):
        if action.startswith("put") and "in/on" not in action.split(" "):
            action_tokens = action.split(" ")
            if "in" in action_tokens:
                action_tokens[action_tokens.index("in")] = "in/on"
            elif "on" in action_tokens:
                action_tokens[action_tokens.index("on")] = "in/on"
            action = " ".join(action_tokens)

        if isinstance(action, str):
            action = [action]

        obs, reward, done, info = self.env.step(action)
        obs = self._parse_obs(obs[0])

        reward = int(info["won"][0])
        done = done[0]

        return obs, reward, done, info

    def reset(self):
        obs, info = self.env.reset()
        self.goal = obs[0].split("Your task is to:")[-1].strip(" .")

        obs = "\n".join(obs[0].split("\n\n")[1:])
        info["task_type"] = self._parse_task_type(info["extra.gamefile"][0])
        info["task"] = self.goal

        return obs, info

    def _parse_task_type(self, raw_task):
        name = "/".join(raw_task.split("/")[-3:-1])
        prefixes = {
            "pick_and_place": "put",
            "pick_clean_then_place": "clean",
            "pick_heat_then_place": "heat",
            "pick_cool_then_place": "cool",
            "look_at_obj": "examine",
            "pick_two_obj": "puttwo",
        }
        for prefix in prefixes:
            if name.startswith(prefix):
                return prefixes[prefix]

    def _parse_obs(self, obs):
        if obs.startswith("You arrive at loc "):
            obs = obs[obs.find(". ") + 2 :]
        return obs
