import os
import os.path as osp
import json

import numpy as np
from virtualhome.simulation.environment.resources import TASKS_SET

from embodied_cd.environments.base import ContinualEnvironment
from embodied_cd.environments.default import VirtualHomeEnv


class ContinualVirtualHomeEnv(ContinualEnvironment):
    name = "continual_virtualhome"

    def __init__(self, cl_type, split="train_1"):
        super().__init__()

        self.cl_type = cl_type
        self.split = split

        self.stage_id = 0
        self.max_stage_id = 5 if cl_type == "behavior" else 8

        stage_seq_id = int(split.split("_")[-1]) if "train" in split else 0
        self.rng = np.random.default_rng(stage_seq_id)

        self.env = VirtualHomeEnv()
        if "train" in split:
            self.task_list = ["turn_on", "put", "open", "putin", "open_put"]
            self.env_list = self.env.valid_env_id[-8:]

            if cl_type == "behavior":
                self.task_list = self.rng.choice(
                    self.task_list, size=self.max_stage_id, replace=False
                )
                print(
                    f"Incremental type: {cl_type} | Stage sequence: {[self.task_list[i] for i in range(self.max_stage_id)]}"
                )
            elif cl_type == "environment":
                self.env_list = self.rng.choice(
                    self.env_list, size=self.max_stage_id, replace=False
                )
                print(
                    f"Incremental type: {cl_type} | Stage sequence: {[self.env_list[i] for i in range(self.max_stage_id)]}"
                )
            else:
                raise ValueError("Invalid continual learning type")

    def reset(self):
        if self.cl_type == "behavior":
            task_type = self.task_list[self.stage_id]

            task_id = self._sample_task_id(task_type)
            env_id = self._sample_env_id()
        elif self.cl_type == "environment":
            task_type = self.rng.choice(self.task_list, 1)[0]

            task_id = self._sample_task_id(task_type)
            env_id = self.env_list[self.stage_id]

        try:
            obs, info = self.env.reset(task_id=task_id, env_id=env_id)
        except Exception as e:
            obs = None

        while not obs:
            if self.cl_type == "behavior":
                task_type = self.task_list[self.stage_id]

                task_id = self._sample_task_id(task_type)
                env_id = self._sample_env_id()
            elif self.cl_type == "environment":
                task_type = self.rng.choice(self.task_list, 1)[0]

                task_id = self._sample_task_id(task_type)
                env_id = self.env_list[self.stage_id]

            try:
                obs, info = self.env.reset(task_id=task_id, env_id=env_id)
            except Exception as e:
                obs = None

        print(f"Task: {task_type} | Environment: {env_id}")
        return obs, info

    def step(self, action):
        return self.env.step(action)

    def increment(self):
        self.stage_id += 1

    def _sample_task_id(self, task_type):
        task_candidates = {
            "turn_on": [
                "Turn on tv",
                "Turn on radio",
                "Turn on microwave",
                "Turn on stove",
                "Turn on computer",
            ],
            "put": [
                "Put apple on desk",
                "Put book on sofa",
                "Put mug to coffeetable",
                "Put plate on microwave",
                "Put towel on washingmachine",
            ],
            "open": [
                "Open cabinet",
                "Open dishwasher",
                "Open microwave",
                "Open stove",
            ],
            "putin": [
                "Place towel in closet",
                "Place book in bookshelf",
            ],
            "open_put": [
                "Place paper in cabinet",
                "Place mug in microwave",
                "Place plate in dishwasher",
            ],
        }
        task_name = self.rng.choice(task_candidates[task_type])
        return TASKS_SET.index(task_name)

    def _sample_env_id(self):
        return self.rng.choice(self.env_list)
