from vh_dataset.dataset.virtualhome import KG
from embodied_cd.environments.default.virtualhome import VirtualHomeEnv


COMPLEX_TASKS_SET = [
    # context_object
    "Turn on device that allows watching the movie",  # Turn on tv
    "Place something to read on desk",  # Place book on desk
    "Place item that can hold coffee on desk",  # Place mug on desk
    "Put item for cleaning spilled liquid in closet",  # Put towel in closet
    "Open appliance that keeps food fresh",  # Open fridge
    # context_task
    "Warm the mug until it’s hot",  # Put mug in microwave
    "I’ll be starting remote work soon. Please get things set up",  # Turn on computer
    "Return the book to where it belongs",  # Put book in bookshelf
    "This apple is for later, store it fresh",  # Put apple in fridge
    "Can you check if stove is working?",  # Turn on stove
    # composite_task
    "Open cabinet and open dishwasher",  # Open cabinet & Open dishwasher
    "Turn on tv and place chips on sofa",  # Turn on tv & Place chips on sofa
    "Place paper on desk and turn on computer",  # Place paper on desk & Turn on computer
    "Put mug in microwave and place plate on microwave",  # Put mug in microwave & Place plate on microwave
    "Turn on microwave and turn on radio and turn on tv",  # Turn on microwave & Turn on radio & Turn on tv
    # composite_abstract_task
    "Turn on two electronic devices",  # Turn on Any2(radio, tv, computer, dishwasher, microwave, stove)
    "Open two electronic devices",  # Open Any2(dishwasher, microwave, stove)
    "Put all edible items in fridge",  # Put apple in fridge & Put chips in fridge
    "Place any edible item on desk",  # Place Any1(apple, chips) on desk
    "Put all dishes in dishwasher",  # Put plate in dishwasher & Put mug in dishwasher
    # noisy_instruction
    "On sofa, please gently place the book",  # Place book on sofa
    "It is storming outside. Switch on the radio so I can hear the news",  # Turn on radio
    "Could you set up the TV for me? I want to watch something",  # Turn on tv
    "Please switch on the stove!",  # Turn on stove
    "The cabinet must be kept open",  # Open cabinet
]


COMPLEX_SUCCESS_CONDITIONS = [
    # context_object
    ("tv", "is", "on"),
    ("book", "on", "desk"),
    ("mug", "on", "desk"),
    ("towel", "inside", "closet"),
    ("fridge", "is", "open"),
    # context_task
    ("mug", "inside", "microwave"),
    ("computer", "is", "on"),
    ("book", "inside", "bookshelf"),
    ("apple", "inside", "fridge"),
    ("stove", "is", "on"),
    # composite_task
    [("cabinet", "is", "open"), ("dishwasher", "is", "open")],
    [("tv", "is", "on"), ("chips", "on", "sofa")],
    [("paper", "on", "desk"), ("computer", "is", "on")],
    [("mug", "inside", "microwave"), ("plate", "on", "microwave")],
    [("microwave", "is", "on"), ("radio", "is", "on"), ("tv", "is", "on")],
    # composite_abstract_task
    {
        "task_set": [
            ("radio", "is", "on"),
            ("tv", "is", "on"),
            ("computer", "is", "on"),
            ("dishwasher", "is", "on"),
            ("microwave", "is", "on"),
            ("stove", "is", "on"),
            ("washingmachine", "is", "on"),
        ],
        "min_success": 2,
    },
    {
        "task_set": [
            ("dishwasher", "is", "open"),
            ("microwave", "is", "open"),
            ("stove", "is", "open"),
        ],
        "min_success": 2,
    },
    [("apple", "inside", "fridge"), ("chips", "inside", "fridge")],
    {
        "task_set": [("apple", "on", "desk"), ("chips", "on", "desk")],
        "min_success": 1,
    },
    [("plate", "inside", "dishwasher"), ("mug", "inside", "dishwasher")],
    # noisy_instruction
    ("book", "on", "sofa"),
    ("radio", "is", "on"),
    ("tv", "is", "on"),
    ("stove", "is", "on"),
    ("cabinet", "is", "open"),
]


class VirtualComplexEnv(VirtualHomeEnv):

    name = "virtualhome_complex"

    def step(self, action):
        action = action.lower().strip()
        ## replace switch to switchon
        action = action.replace("switch ", "switchon ")
        ########
        if not self._is_valid_action(action):
            obs = f"{action} is not a valid action."
            # print(obs)
            done = True
            info = {
                "success": False,
                "task": self._parse_task_type(COMPLEX_TASKS_SET[self.task_id]),
            }
            return obs, 0, done, info

        # To make agent easier to navigate to target room without knowing the room connections
        room_navigation = False
        if self.easier_room_navigation and action.startswith("walk"):
            target = action.split(" ")[-1]
            if target in ["livingroom", "kitchen", "bedroom", "bathroom"]:
                room_navigation = True
                self._navigate_adjacent_room(target)

        self.history.append(str((f"step {self.timestep+1}", action)))

        action_history = [h.split("', '")[1][:-2] for h in self.history][
            -self.repetitive_action_patience :
        ]
        if (
            len(action_history) == self.repetitive_action_patience
            and len(set(action_history)) == 1
        ):
            obs = f"Lost in repetitive action {action}."
            done = True
            info = {
                "success": False,
                "task": self._parse_task_type(COMPLEX_TASKS_SET[self.task_id]),
            }
            return obs, 0, done, info

        obs, reward, done, info = self.env.step(action)
        self.timestep += 1

        self.kg.add(obs["visible_graph"], self.timestep, use_refinement=True)
        self.kg.add(obs["agent_graph"], self.timestep, use_refinement=True)

        if not room_navigation and not info["success"]:
            obs = "Nothing seems to be happened."
            done = False
        else:
            obs = self._parse_obs(
                self.kg, COMPLEX_TASKS_SET[self.task_id], prev_action=action
            )

        if self.task_id == 5 and sum(info["condwise_success"]) >= 2:
            info["success"] = True
            done = True

        info["task_type"] = self._parse_task_type(COMPLEX_TASKS_SET[self.task_id])
        info["task"] = COMPLEX_TASKS_SET[self.task_id]
        info["kg"] = self.kg.clone()
        info["state"] = self.kg.retrieve(
            [COMPLEX_TASKS_SET[self.task_id]],
            embedding_fns=self.emb_fn,
            num_edges=self.num_topk_edge,
        )
        info["history"] = ", ".join(self.history).replace("'", "")
        return obs, reward, done, info

    def reset(self, task_id=0, env_id=None, init_rooms=None):
        self.task_id = task_id
        self.env_id = env_id
        self.timestep = 0
        self._holding = None
        self.history = []

        if env_id is None:
            env_id = self.env.valid_env_id[-1]
        if not isinstance(env_id, int):
            env_id = int(env_id)

        required_condition = COMPLEX_SUCCESS_CONDITIONS[task_id]
        if not isinstance(required_condition, list):
            required_condition = [required_condition]
        self.env.set_task(
            {
                "required_condition": required_condition,
                "prohibited_condition": [],
            }
        )
        obs = self.env.reset(environment_id=env_id, init_rooms=init_rooms)
        if not isinstance(obs, dict):
            return None, None

        self.kg = KG(self.env.get_position_graph())
        self.kg.add(obs["visible_graph"], 0, use_refinement=True)
        self.kg.add(obs["agent_graph"], 0, use_refinement=True)

        obs = self._parse_obs(
            self.kg, COMPLEX_TASKS_SET[task_id], prev_action="reset env"
        )
        info = {
            "task_type": self._parse_task_type(COMPLEX_TASKS_SET[task_id]),
            "task": COMPLEX_TASKS_SET[task_id],
            "kg": self.kg.clone(),
            "state": self.kg.retrieve(
                [COMPLEX_TASKS_SET[task_id]],
                embedding_fns=self.emb_fn,
                num_edges=self.num_topk_edge,
            ),
            "history": "No action histroy.",
        }
        return obs, info

    def _parse_task_type(self, raw_task):
        task = raw_task.lower().replace(" ", "_")
        prefixes = {
            # context_object
            "Turn on device that allows watching the movie": "turn_on",
            "Put something to read on desk": "put",
            # context_task
            "Warm the mug until it’s hot": "put_in",
            "Make sure the towel goes in with the laundry": "put_in",
            # composite_task
            "Open cabinet and open dishwasher": "open",
            "Turn on two electronic devices": "turn_on",
            # noisy_instruction
            "On sofa, please gently place the book": "put",
            "It is storming outside. Switch on the radio so I can hear the news": "turn_on",
        }
        return prefixes[task]
