import os
import os.path as osp
import json

from vh_dataset.dataset.virtualhome import KG, make_embedding_fn
from virtualhome.simulation.environment.unity_environment import UnityEnvironment
from virtualhome.simulation.environment.resources import TASKS_SET, SUCCESS_CONDITIONS

from embodied_cd.environments.base import BaseEnvironment


class VirtualHomeWorldEnv(BaseEnvironment):
    name = "virtualhome_world"

    def __init__(self, num_topk_edge: int = 12):
        super().__init__()

        self.env = UnityEnvironment(url="127.0.0.1", seed=0)
        self.valid_env_id = self.env.valid_env_id

        self.kg = None
        self.task_id = None
        self.env_id = None
        self.timestep = 0
        self._holding = None

        self.emb_fn = make_embedding_fn()
        self.num_topk_edge = num_topk_edge

    def step(self, action):
        action = action.lower().strip()
        ## replace switch to switchon
        action = action.replace("switch ", "switchon ")
        ########
        if not self._is_valid_action(action):
            obs = f"{action} is not a valid action."
            done = True
            info = {
                "success": False,
                "task": self._parse_task_type(TASKS_SET[self.task_id]),
            }
            return obs, 0, done, info
        self.history.append(str((f"step {self.timestep+1}", action)))

        obs, reward, done, info = self.env.step(action)
        self.timestep += 1

        self.kg.add(obs["visible_graph"], self.timestep, use_refinement=True)
        self.kg.add(obs["agent_graph"], self.timestep, use_refinement=True)

        if not info["success"]:
            obs = "Nothing seems to be happened."
            done = False
        else:
            obs = self._parse_obs(self.kg, TASKS_SET[self.task_id], prev_action=action)
        info["task_type"] = self._parse_task_type(TASKS_SET[self.task_id])
        info["task"] = TASKS_SET[self.task_id]
        info["kg"] = self.kg.clone()
        info["state"] = self.kg.retrieve(
            [TASKS_SET[self.task_id]],
            embedding_fns=self.emb_fn,
            num_edges=self.num_topk_edge,
        )
        info["history"] = ", ".join(self.history).replace("'", "")
        return obs, reward, done, info

    def reset(self, task_id=0, env_id=None, init_rooms=None):
        self.task_id = task_id
        self.env_id = env_id
        self.timestep = 0
        self._holding = None
        self.history = []

        if env_id is None:
            env_id = self.env.valid_env_id[-1]

        self.env.set_task(
            {
                "required_condition": [SUCCESS_CONDITIONS[task_id]],
                "prohibited_condition": [],
            }
        )
        obs = self.env.reset(environment_id=env_id, init_rooms=init_rooms)
        if not isinstance(obs, dict):
            return None, None

        self.kg = KG(self.env.get_position_graph())
        self.kg.add(obs["visible_graph"], 0, use_refinement=True)
        self.kg.add(obs["agent_graph"], 0, use_refinement=True)

        obs = self._parse_obs(self.kg, TASKS_SET[task_id], prev_action="reset env")
        info = {
            "task_type": self._parse_task_type(TASKS_SET[task_id]),
            "task": TASKS_SET[task_id],
            "kg": self.kg.clone(),
            "state": self.kg.retrieve(
                [TASKS_SET[task_id]],
                embedding_fns=self.emb_fn,
                num_edges=self.num_topk_edge,
            ),
            "history": "No action histroy.",
        }
        return obs, info

    def _parse_task_type(self, raw_task):
        task = raw_task.lower().replace(" ", "_")
        prefixes = {
            "turn_on_tv": "turn_on",
            "turn_on_radio": "turn_on",
            "turn_on_microwave": "turn_on",
            "turn_on_stove": "turn_on",
            "turn_on_computer": "turn_on",
            "open_cabinet": "open",
            "open_dishwasher": "open",
            "open_microwave": "open",
            "open_stove": "open",
            "put_mug_to_coffeetable": "put",
            "put_apple_on_desk": "put",
            "put_book_on_sofa": "put",
            "place_book_in_bookshelf": "put_in",
            "place_towel_in_closet": "put_in",
            "put_towel_on_washingmachine": "put",
            "place_paper_in_cabinet": "open_put",
            "place_mug_in_microwave": "open_put",
            "place_plate_in_dishwasher": "open_put",
            "put_plate_on_microwave": "open_put",
        }
        return prefixes[task]

    def _parse_obs(self, obs, goal, prev_action):
        ROOMS = ["kitchen", "bedroom", "livingroom", "bathroom"]

        edges = []
        for edge in obs.edges:
            source, relation, target = str(edge).strip("()").split(", ")
            edges.append((source, relation, target))

        action, target = prev_action.split(" ")
        # reset environment
        if action == "reset":
            player_room = self._get_player_room(edges)

            new_obs = f"You are in the middle of a {player_room}.\n"
            new_obs += self.kg.retrieve([goal], embedding_fns=self.emb_fn)
            new_obs += f"\nYour task is to: {goal}."

        elif action == "walk":
            # walk to a room
            if target in ROOMS:
                player_room = self._get_player_room(edges)
                objects = self._get_objects_in_room(edges, player_room)
                adjacent_rooms = self._get_adjacent_rooms(edges, player_room)

                new_obs = f"You are in the middle of a {player_room}. "
                new_obs += f"In {player_room} you see {', '.join(objects)}. "
                new_obs += f"Adjacent to {player_room} is {', '.join(adjacent_rooms)}. "

            # walk to an object
            else:
                object_status = self._get_object_status(edges, target)

                new_obs = f"You walk to {target}. "
                new_obs += f"{target} is {', '.join({', '.join(object_status)})}. "

        # open/close/switch an object
        elif action in ["open", "close", "switchon"]:
            object_status = self._get_object_status(edges, target)

            new_obs = f"You {action} {target}. "
            new_obs += f"{target} is {', '.join({', '.join(object_status)})}. "

        # grab an object
        elif action == "grab":
            object_status = self._get_object_status(edges, target)
            self._holding = target

            new_obs = f"You grab {target}. "

        elif action == "put":
            object_status = self._get_object_status(edges, target)

            new_obs = (
                f"You put {self._holding if self._holding else 'nothing'} on {target}. "
            )
            new_obs += f"{target} is {', '.join(object_status)}. "

        elif action == "putin":
            object_status = self._get_object_status(edges, target)

            new_obs = (
                f"You put {self._holding if self._holding else 'nothing'} in {target}. "
            )
            new_obs += f"{target} is {', '.join(object_status)}. "

        return new_obs

    def _get_player_room(self, edges):
        player_room = None
        for edge in edges:
            source, relation, target = edge
            if source == "character" and relation == "inside":
                player_room = target
                break
        return player_room

    def _get_objects_in_room(self, edges, room):
        objects = set()
        for edge in edges:
            source, relation, target = edge
            if source == "character":
                continue
            if target == room and relation == "inside":
                objects.add(source)
        return list(objects)

    def _get_adjacent_rooms(self, edges, room):
        adjacent_rooms = set()
        for edge in edges:
            source, relation, target = edge
            if source == room and relation == "adjacent":
                adjacent_rooms.add(target)
            elif target == room and relation == "adjacent":
                adjacent_rooms.add(source)

        if room in adjacent_rooms:
            adjacent_rooms.remove(room)
        return list(adjacent_rooms)

    def _get_object_status(self, edges, obj):
        status = set()
        for edge in edges:
            source, relation, target = edge
            if source == obj:
                status.add(f"{relation} {target}")
        return list(status)

    def _is_valid_action(self, action):
        avail_actions = ["walk", "open", "close", "switchon", "grab", "put", "putin"]
        with open(
            "externals/virtualhome/virtualhome/resources/properties_data_unity.json",
            "r",
        ) as f:
            avail_targets = set(
                [
                    "livingroom",
                    "kitchen",
                    "bedroom",
                    "bathroom",
                ]
            )
            for k in json.load(f).keys():
                avail_targets.add(k)
        with open(
            "externals/virtualhome/virtualhome/resources/class_name_equivalence.json",
            "r",
        ) as f:
            for k, v in json.load(f).items():
                avail_targets.add("".join(k.split("_")))
                avail_targets.update(["".join(e.split("_")) for e in v])

        tokens = action.split(" ")
        if len(tokens) != 2:
            return False

        action, target = tokens
        if action not in avail_actions:
            return False
        if target not in avail_targets:
            return False

        return True
