from pathlib import Path
from demos.ipc.src.household.household_environment_state import HouseholdEnvironmentState
from demos.ipc.src.logistics.logistics_environment_state import LogisticsEnvironmentState
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from pddl.core import Constant
import json
from typing import Generic, Tuple, TypeVar, Union

llm_wm_dir = Path(__file__).parent.parent.parent.parent / "3rdparty/LLMs-World-Models-for-Planning"

T = TypeVar("T")


class BenchmarkSampleGenerator(Generic[T]):

    def generate(self, idx: int) -> Tuple[str, T, PDDLProblem]:
        raise NotImplementedError()

    def __len__(self):
        raise NotImplementedError()


class PlanningBenchmarkSampleGenerator(BenchmarkSampleGenerator):

    def __init__(self, task: str, use_hl_types: bool = True):
        super().__init__()
        self.task = task
        self.use_hl_types = use_hl_types
        resources_dir = llm_wm_dir / "experiments/planning_tasks"

        if self.task == "logistics":
            self.planning_tasks = json.loads((resources_dir / "logistics_planning_tasks.json").read_text())
        elif self.task == "household":
            self.planning_tasks = json.loads((resources_dir / "household_planning_tasks.json").read_text())

        self.instructions = list(self.planning_tasks.keys())

    def __len__(self):
        return len(self.instructions)

    def generate(
        self, idx: int
    ) -> Tuple[str, Union[HouseholdEnvironmentState, LogisticsEnvironmentState], PDDLProblem]:
        instruction = self.instructions[idx]
        task_info = self.planning_tasks[instruction]

        # one instruction may have one or more test cases, in which object states may vary
        task_cases = list(task_info.keys())
        # random_case = random.choice(task_cases)
        random_case = task_cases[0]

        # one test case may have one or more scenes, in which states of task-irrelevant objects may vary
        task_scenes = list(task_info[random_case].keys())
        # random_scene = random.choice(task_scenes)
        random_scene = task_scenes[0]

        object_info = task_info[random_case][random_scene]["objects"]
        object_list = object_info.keys()

        if self.task == "logistics":
            # assert not self.use_hl_types, "High-level types are not supported in logistics domain."
            pddl_objects = {
                obj_name: Constant(name=obj_name, type_tag=object_info[obj_name]["type"]["family"])
                for obj_name in object_list
            }
        else:
            if self.use_hl_types:
                key = "family"
                obj_mapping = lambda o: {
                    "furniture_appliance": "furniture_appliance",
                    "small_items": "household_object",
                    "small_receptacle": "household_object",
                    "agent": "agent",
                }[o]
            else:
                key = "name"
                obj_mapping = lambda o: o

            pddl_objects = {
                obj_name: Constant(
                    name=obj_name,
                    type_tag=obj_mapping(object_info[obj_name]["type"][key]),
                )
                for obj_name in object_list
            }

        if self.task == "logistics":
            env_state = LogisticsEnvironmentState.parse_environment_state(object_info)
        elif self.task == "household":
            env_state = HouseholdEnvironmentState.parse_environment_state(object_info)

        problem_skeleton = PDDLProblem(objects=list(pddl_objects.values()))

        return instruction, env_state, problem_skeleton
