import os
import os.path as osp
import json

from embodied_cd.environments.base import ContinualEnvironment
from embodied_cd.environments.default import AlfredEnv


class ContinualAlfredEnv(ContinualEnvironment):
    name = "continual_alfred"

    def __init__(self, cl_type, split="train_1"):
        super().__init__()

        self.cl_type = cl_type
        self.split = split

        self.stage_id = 0
        self.max_stage_id = 6 if cl_type == "behavior" else 4
        self.json_path = "externals/cl-alfred/embodied_split"
        self.alfred_path = "externals/alfworld"

        self.env = None
        self.increment()

    def reset(self):
        return self.env.reset()

    def step(self, action):
        return self.env.step(action)

    def increment(self):
        if self.stage_id >= self.max_stage_id:
            print("All tasks have been completed")
            return

        if "valid" in self.split:
            alfred_split = f"{self.json_path}/{self.cl_type}_il/{self.split}.json"
        else:
            rand_id = self.split.split("_")[-1]
            alfred_split = f"{self.json_path}/{self.cl_type}_il/embodied_data_disjoint_rand{rand_id}_cls1_task{self.stage_id}.json"

            with open(alfred_split) as f:
                sample = json.load(f)[0]["task"]["task"]

            if self.cl_type == "behavior" and "movable" in sample:
                print("Skipping unsupported movable task")
                self.stage_id = min(self.stage_id + 1, self.max_stage_id)
                alfred_split = f"{self.json_path}/{self.cl_type}_il/embodied_data_disjoint_rand{rand_id}_cls1_task{self.stage_id}.json"
        
        print(f"increment: {alfred_split}")

        self.env = AlfredEnv(split=alfred_split)
        self.stage_id += 1
