#!/usr/bin/env python3



import ast
import os
import os.path as osp
import warnings

import habitat_sim
import omegaconf
from habitat.sims.habitat_simulator.sim_utilities import object_shortname_from_handle
from hydra.utils import instantiate

from partnr.sims.metadata_interface import MetadataInterface

warnings.filterwarnings("ignore")
import json

import hydra
import pandas as pd
from habitat.datasets.rearrange.run_episode_generator import get_config_defaults
from habitat_sim.nav import NavMeshSettings
from omegaconf import OmegaConf

from partnr.utils.sim import find_receptacles


# Generate episodes given a json specification
class InstructionGeneratorHSSD:
    def __init__(self, **kwargs):
        # Here we should put the episode
        self.config = OmegaConf.create(kwargs)
        self.sim = None
        # self.init_sim()
        self.cfg = get_config_defaults()

        # Init LLM
        llm_config = self.config.llm
        self.llm = llm_config.llm(llm_config)

    def make_output_path(self):
        # now = datetime.datetime.now()
        # formatted_time = now.strftime("%Y-%m-%d/%H-%M-%S")
        # output_path = osp.join(self.config.output_path, formatted_time)
        output_path = self.config.output_path
        return output_path

    def initialize_sim(self, scene_name: str, dataset_path: str) -> None:
        """
        Initialize a new Simulator object with a selected scene and dataset.
        """
        # Setup a camera coincident with the agent body node.
        # For debugging visualizations place the default agent where you want the camera with local -Z oriented toward the point of focus.
        camera_resolution = [540, 720]
        sensors = {
            "rgb": {
                "sensor_type": habitat_sim.SensorType.COLOR,
                "resolution": camera_resolution,
                "position": [0.0, 0.0, 0.0],
                "orientation": [0.0, 0.0, 0.0],
            }
        }

        backend_cfg = habitat_sim.SimulatorConfiguration()
        backend_cfg.scene_dataset_config_file = dataset_path
        backend_cfg.scene_id = scene_name
        backend_cfg.enable_physics = True
        backend_cfg.gpu_device_id = self.cfg.gpu_device_id

        sensor_specs = []
        for sensor_uuid, sensor_params in sensors.items():
            # sensor_spec = habitat_sim.EquirectangularSensorSpec()
            sensor_spec = habitat_sim.CameraSensorSpec()
            sensor_spec.uuid = sensor_uuid
            sensor_spec.sensor_type = sensor_params["sensor_type"]
            sensor_spec.resolution = sensor_params["resolution"]
            sensor_spec.position = sensor_params["position"]
            sensor_spec.orientation = sensor_params["orientation"]
            sensor_spec.sensor_subtype = habitat_sim.SensorSubType.EQUIRECTANGULAR
            sensor_spec.sensor_subtype = habitat_sim.SensorSubType.PINHOLE
            sensor_specs.append(sensor_spec)

        agent_cfg = habitat_sim.agent.AgentConfiguration()
        agent_cfg.sensor_specifications = sensor_specs

        hab_cfg = habitat_sim.Configuration(backend_cfg, [agent_cfg])
        if self.sim is None:
            self.sim = habitat_sim.Simulator(hab_cfg)

            object_attr_mgr = self.sim.get_object_template_manager()
            for object_path in self.config.additional_object_paths:
                object_attr_mgr.load_configs(osp.abspath(object_path))
        else:
            if self.sim.config.sim_cfg.scene_id != scene_name:
                self.sim.close(destroy=True)
            if self.sim.config.sim_cfg.scene_id == scene_name:
                # we need to force a reset, so reload the NONE scene
                # TODO: we should fix this to provide an appropriate reset method
                proxy_backend_cfg = habitat_sim.SimulatorConfiguration()
                proxy_backend_cfg.scene_id = "NONE"
                proxy_backend_cfg.gpu_device_id = self.cfg.gpu_device_id
                proxy_hab_cfg = habitat_sim.Configuration(
                    proxy_backend_cfg, [agent_cfg]
                )
                self.sim.reconfigure(proxy_hab_cfg)
            self.sim.reconfigure(hab_cfg)

        # setup the debug camera state to the center of the scene bounding box
        scene_bb = self.sim.get_active_scene_graph().get_root_node().cumulative_bb
        self.sim.agents[0].scene_node.translation = scene_bb.center()

        # initialize the debug visualizer
        # output_path = (
        #     "rearrange_ep_gen_output/"
        #     if self.vdb is None
        #     else self.vdb.output_path
        # )
        # self.vdb = DebugVisualizer(self.sim, output_path=output_path)

    def initialize_fresh_scene(self, scene_id: str):
        """
        Set the scene id and initialize a fresh Simulator instance with the specified scene.
        """
        # self.reset_episode_state()
        # self._reset_samplers()
        # self._scene_sampler.scene = scene_id
        self.initialize_sim(scene_id, self.config.scene_dataset)

        self.receptacles = find_receptacles(self.sim)

        # generate the navmesh from the config parameters
        navmesh_settings = NavMeshSettings()
        navmesh_settings.set_defaults()
        navmesh_settings.agent_radius = self.cfg.agent_radius
        navmesh_settings.agent_height = self.cfg.agent_height
        navmesh_settings.include_static_objects = True
        navmesh_settings.agent_max_climb = self.cfg.agent_max_climb
        navmesh_settings.agent_max_slope = self.cfg.agent_max_slope
        self.sim.recompute_navmesh(
            self.sim.pathfinder,
            navmesh_settings,
        )

    def generate_instructions(self):
        output_path = self.make_output_path()
        if self.config.scene_id_torun == -1:
            all_scene_ids = self.config.scene_ids
        else:
            all_scene_ids = [self.config.scene_ids[self.config.scene_id_torun]]
        for scene_id in all_scene_ids:
            config_path = f"{output_path}/config.yaml"
            if self.config.run_per_scene:
                output_path_scene = output_path
            else:
                output_path_scene = osp.join(output_path, scene_id)
            # Initialize the scene
            self.initialize_fresh_scene(scene_id)
            self.scene_info = self.obtain_scene_info()

            # Scene info dictionary
            scene_info_path = f"{output_path_scene}/scene_info.json"
            print(f"Parsing Scene: {scene_id}")
            for iter_call in range(self.config.calls_per_scene):
                if not osp.isdir(output_path_scene):
                    os.makedirs(output_path_scene)

                if not os.path.exists(scene_info_path):
                    with open(scene_info_path, "w+") as f:
                        f.write(json.dumps(self.scene_info, indent=4))

                if not os.path.exists(config_path):
                    with open(config_path, "w+") as f:
                        OmegaConf.save(self.config, f)

                (
                    current_instructions,
                    output_dict,
                ) = self.generate_instructions_on_scene()

                with open(f"{output_path_scene}/input_prompt.json", "w+") as f:
                    f.write(json.dumps(output_dict))

                output_path_scene_gen = f"{output_path_scene}/output_gen/"
                if not osp.isdir(output_path_scene_gen):
                    os.makedirs(output_path_scene_gen)

                full_file = f"{output_path_scene_gen}/gen_{iter_call}.json"
                with open(full_file, "w+") as f:
                    if type(current_instructions) == str:
                        f.write(current_instructions)
                    else:
                        f.write(json.dumps(current_instructions))

    def validate_scene(self):
        region_furn = self.mi.get_region_rec_contents(self.sim)
        print("All regions and their furniture")
        print(region_furn)

        print("\nAll articulated objs")
        aom = self.sim.get_articulated_object_manager()
        all_ao = list(aom.get_objects_by_handle_substring().values())
        all_ao_info = []
        for obj in all_ao:
            obj_hash = object_shortname_from_handle(obj.handle)
            obj_cat = self.mi.get_object_category(obj_hash)
            all_ao_info.append(obj_cat)
        print(all_ao_info)

        print("\n Scene lexicon")
        scene_lexicon = self.mi.get_scene_lexicon(self.sim)
        print(scene_lexicon)

    def obtain_scene_info(self):
        metadata_dict = {
            "metadata_folder": "data/fpss/metadata/",
            "obj_metadata": "object_categories_filtered_no_thor.csv",
            "room_objects_json": "room_objects.json",
            "staticobj_metadata": "fpmodels-with-decomposed.csv",
        }

        self.mi = MetadataInterface(metadata_dict)
        self.mi.refresh_scene_caches(self.sim)
        self.validate_scene()

        # Get the objects
        object_names = []
        ovmm_metadata = pd.read_csv(
            os.path.join(
                metadata_dict["metadata_folder"], metadata_dict["obj_metadata"]
            )
        )
        for index in range(ovmm_metadata.shape[0]):
            cat = ovmm_metadata.at[index, "clean_category"]
            object_names.append(cat)
        object_names = list(set(object_names))

        # Get room to id mapping
        room_to_id = {}
        for k, v in self.mi.region_semname_to_id.items():
            room_to_id[k] = self.sim.semantic_scene.regions[v].id

        return {
            "objects": object_names,
            "furniture": self.mi.get_region_rec_contents(self.sim),  # furniture,
            # "room_info": room_info,
            # "furniture_dict": furniture_dict,
            "receptacle_to_handle": self.mi.recobj_semname_to_handle,
            "room_to_id": room_to_id,
        }


class AllAtOnceGenerator(InstructionGeneratorHSSD):
    # Will generate the task and json all at once
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.init_prompt_template()
        self.output_length = self.config.output_length

    def init_prompt_template(self):
        prompt_file_task = self.config.prompt_file_task
        prompt_file_init = self.config.prompt_file_init
        with open(prompt_file_task, "r") as f:
            self.prompt_text_task = f.read()

        if len(prompt_file_init) > 1:
            with open(prompt_file_init, "r") as f:
                self.prompt_text_init = f.read()
        else:
            self.prompt_text_init = None

    def generate_instructions_on_scene(self):
        k = self.config.generations_per_call
        # receptacles_str = "\n".join(self.scene_info["furniture"])
        receptacles_str = ""
        for room, furniture in self.scene_info["furniture"].items():
            if len(furniture) == 0:
                continue
            receptacles_str += f"{room}: " + ", ".join(furniture) + "\n"
        objs_str = "\n".join(self.scene_info["objects"])
        formatted_prompt = self.prompt_text_task.format(
            house_furniture=receptacles_str, objects_list=objs_str, k=k
        )
        import time

        t1 = time.time()
        llm_answer = self.llm.generate(formatted_prompt, max_length=self.output_length)
        print(time.time() - t1)
        instructions = llm_answer

        if self.prompt_text_init is not None:
            fi = instructions.find("[")
            ei = instructions.find("]")
            instructions = instructions[fi + 1 : ei]
            instructions = instructions.split("\n")
            all_init_instructions = []
            for inst in instructions:
                if len(inst) > 0:
                    formatted_prompt_init = self.prompt_text_init.format(
                        house_furniture=receptacles_str,
                        objects_list=objs_str,
                        task_instruction=inst,
                    )
                    llm_answer_init = self.llm.generate(
                        formatted_prompt_init, max_length=self.output_length
                    )
                    inits = llm_answer_init
                    if llm_answer_init[-1] == "]":
                        fi_init = inits.find("[")
                        ei_init = inits.rfind("]")
                        inits = inits[fi_init + 1 : ei_init]
                    try:
                        inits = inits.lstrip()
                        inits_dict = ast.literal_eval(inits)
                        init_instruction = {
                            "initial state": inits_dict["initial state"],
                            "final state": inits_dict["final state"],
                            "instruction": inst[:-1],
                        }
                        all_init_instructions.append(init_instruction)
                    except Exception:
                        print("error")
                        continue
            return all_init_instructions, {
                "formatted_prompt_task": formatted_prompt,
                "formatted_prompt_init": formatted_prompt_init,
            }
        else:
            return instructions, {"formatted_prompt": formatted_prompt}


@hydra.main(
    version_base=None,
    config_path="../conf/",
    config_name="benchmark_gen_llama2.yaml",
)
def main(cfg: omegaconf.DictConfig):
    config = OmegaConf.create(cfg)

    inst_gen = instantiate(config.generator)

    ## scene validation
    for scene_id in inst_gen.config.scene_ids:
        print("Scene id:", scene_id)
        # inst_gen.initialize_fresh_scene(scene_id)
        # inst_gen.scene_info = inst_gen.obtain_scene_info()

    inst_gen.generate_instructions()


if __name__ == "__main__":
    main()
