#!/usr/bin/env python3



import csv
import gzip
import json
import os
import shutil
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union

import habitat
import hydra
import omegaconf
import tqdm
from hydra.utils import instantiate
from omegaconf import OmegaConf

from dataset_generation.benchmark_generation.verify_dataset import (
    verify_dataset_parallel,
)
from dataset_generation.eval_gen_parsing import (
    DependencyParser,
    InstructionParser,
    LLMGenerationError,
    PropositionParser,
    SkipParser,
    TemporalParser,
    TerminalSatisfactionParser,
    TieParser,
    metadata_to_state_string,
    proposition_to_llm_output_str,
    temporal_words_in_str,
)
from partnr.agent.env.dataset import CollaborationDatasetV0, CollaborationEpisode
from partnr.agent.env.evaluation.evaluation_functions import (
    DifferentArgConstraint,
    EvaluationConstraint,
    EvaluationProposition,
    EvaluationPropositionDependency,
    SameArgConstraint,
    TemporalConstraint,
)


class LLMEvaluationGenerator:
    def __init__(
        self,
        dataset_file_in: str,
        dataset_file_out: str,
        plain_text_eval_dir: str,
        metadata_dir: str,
        log_dir: str,
        scene_info_file: str,
        filter_out_clutter: bool,
        recep_desc_in_prompt: bool,
        metadata_csv: str,
        affordances_csv: str,
        proposition_prompt_file: str,
        dag_prompt_file: str,
        tie_prompt_file: str,
        predicate_vocabulary_file: str,
        llm: Any,
        max_tokens_proposition_call: int,
        max_tokens_dag_call: int,
        max_tokens_tie_call: int,
        skip_temporal_prediction: bool,
        skip_tie_prediction: bool,
        skip_episode_default: bool,
        is_from_templates: bool,
        prompt_template_strs: Optional[Dict[int, str]] = None,
        plain_text_eval_dir_copy: str = "",
    ) -> None:
        if not os.path.exists(dataset_file_in):
            raise IOError(f"Dataset file `{dataset_file_in}` does not exist.")
        os.makedirs(plain_text_eval_dir, exist_ok=True)
        os.makedirs(metadata_dir, exist_ok=True)
        os.makedirs(log_dir, exist_ok=True)
        if plain_text_eval_dir_copy != "":
            os.makedirs(plain_text_eval_dir_copy, exist_ok=True)

        with gzip.open(dataset_file_in, "rt") as f:
            self.dataset_in = json.load(f)

        self.plain_text_eval_dir = plain_text_eval_dir
        self.plain_text_eval_dir_copy = plain_text_eval_dir_copy
        self.metadata_dir = metadata_dir
        self.log_dir = log_dir
        self.dataset_file_out = dataset_file_out
        self.scene_info_file = scene_info_file
        self.filter_out_clutter = filter_out_clutter
        self.recep_desc_in_prompt = recep_desc_in_prompt
        self.metadata_csv = metadata_csv
        with open(predicate_vocabulary_file) as f:
            self.predicate_vocabulary: Dict[str, Any] = json.load(f)
        self.affordances = self._load_affordances_dict(affordances_csv)

        with open(proposition_prompt_file, "r") as f:
            self.proposition_prompt_template = f.read()

        with open(dag_prompt_file, "r") as f:
            self.dag_prompt_template = f.read()

        with open(tie_prompt_file, "r") as f:
            self.tie_prompt_template = f.read()

        self.llm = llm
        self.max_tokens_proposition_call = max_tokens_proposition_call
        self.max_tokens_dag_call = max_tokens_dag_call
        self.max_tokens_tie_call = max_tokens_tie_call
        self.skip_temporal_prediction = skip_temporal_prediction
        self.skip_tie_prediction = skip_tie_prediction
        self.skip_episode_default = skip_episode_default
        self.is_from_templates = is_from_templates
        self.prompt_template_strs = prompt_template_strs
        if is_from_templates and prompt_template_strs is None:
            raise AssertionError(
                "`prompt_template_strs` must be provided if `is_from_templates==True`."
            )

        predicates: Dict[str, List[Dict[str, Any]]] = self.predicate_vocabulary[
            "predicates"
        ]
        self.prop_parser = PropositionParser(predicates, self.affordances)
        self.dag_parser = TemporalParser(predicates)
        self.tie_parser = TieParser(predicates)
        self.dep_parser = DependencyParser(predicates)

    @staticmethod
    def extract_template_task_number(episode) -> int:
        """
        The template task number is an ID linking an episode to the index of the
        template episode used as basis during LLM generation. It has been stored in
        a couple inconsistent places, thus the additional logic here to extract it.
        """
        extra_info = episode["info"]["extra_info"]
        if "template_task_number" in extra_info:
            return int(extra_info["template_task_number"])
        if "template_task_number" in episode["info"]["extra_info"]["initial_state"][-1]:
            return int(extra_info["initial_state"][-1]["template_task_number"])
        raise ValueError("Cannot find `template_task_number` in episode object.")

    def generate_plain_text_evaluations(self):
        """Produces evaluation functions for episodes in self.dataset_in. The entities in
        evaluation functions are references via semantic names (table_1) instead of sim
        handles. Metadata is saved that maps from semantic names to handles.

        For each episode, saves:
          - plaintext_evals/episode_{i}.py       * plaintext evaluation data
          - plaintext_evals_orig/episode_{i}.py  * plaintext evaluation data (copy)
          - logs/episode_{i}.log                 * prompts, raw+parsed outputs, failures
          - metadata/episode_{i}.json            * all relevant episode+scene metadata
        """
        for episode in tqdm.tqdm(self.dataset_in["episodes"]):
            eid = self._eid_from_episode(episode)

            prop_prompt_template_ex = ""
            if self.is_from_templates:
                ttn = self.extract_template_task_number(episode)
                prop_prompt_template_ex = self.prompt_template_strs[ttn]

            eval_function_file = os.path.join(
                self.plain_text_eval_dir, f"episode_{eid}.py"
            )
            if os.path.exists(eval_function_file):
                continue

            metadata = self._generate_metadata_mappings(episode)
            plaintext_eval_str = self.generate_plaintext_eval(
                metadata, eid, prop_prompt_template_ex
            )

            # save results to files
            with open(eval_function_file, "w") as f:
                f.write(plaintext_eval_str)
            with open(os.path.join(self.metadata_dir, f"episode_{eid}.json"), "w") as f:
                json.dump(metadata, f, indent=2)

            if self.plain_text_eval_dir_copy == "":
                return

            eval_function_file_copy = os.path.join(
                self.plain_text_eval_dir_copy, f"episode_{eid}.py"
            )
            with open(eval_function_file_copy, "w") as f:
                f.write(plaintext_eval_str)

    def compile_plaintext(
        self,
        instruction: str,
        propositions: List[EvaluationProposition],
        tc_constraint: TemporalConstraint,
        tie_constraints: List[Union[SameArgConstraint, DifferentArgConstraint]],
        dependencies: List[EvaluationPropositionDependency],
    ) -> str:
        """
        Produces a string for the plain text file containing all generated evaluation data.
        This string is later saved for manual verification and correction.
        """
        return (
            "# type: ignore\n"
            + InstructionParser.to_plaintext(instruction)
            + self.prop_parser.to_plaintext(propositions)
            + self.dag_parser.to_plaintext(tc_constraint)
            + self.tie_parser.to_plaintext(tie_constraints, propositions)
            + self.dep_parser.to_plaintext(dependencies)
            + TerminalSatisfactionParser.to_plaintext()
            + SkipParser.to_plaintext(self.skip_episode_default)
        )

    def generate_plaintext_eval(
        self, metadata: Dict, eid: int, prop_prompt_template_ex: str = ""
    ) -> str:
        """
        Generates a plaintext evaluation for a single episode consisting of a list
        of propositions, a temporal constraint, and a list of tie constraints. Each
        is generated from a separate LLM call. If an error occurs during generation,
        the plaintext file is produced with empty components.

        prop_prompt_template_ex, if provided, will inject a final example into the
        proposition generation prompt. See the default prompt for the required format.
        """
        state_str = metadata_to_state_string(
            metadata, self.predicate_vocabulary["object_state_negations"]
        )
        inst = metadata["instruction"]

        # generate propositions
        try:
            propositions = self.generate_propositions(
                eid, inst, state_str, metadata, prop_prompt_template_ex
            )
        except LLMGenerationError as e:
            self._log_results(eid, f"Failure in Proposition Generation. Reason:\n{e}")
            return self.compile_plaintext(inst, [], None, [], [])

        # generate the temporal constraint
        try:
            tc_constraint = self.generate_temporal_constraint(eid, inst, propositions)
        except LLMGenerationError as e:
            self._log_results(eid, f"Failure in DAG Generation. Reason:\n{e}")
            return self.compile_plaintext(inst, propositions, None, [], [])

        # generate the tied constraints
        try:
            tie_constraints = self.generate_ties(eid, inst, propositions)
        except LLMGenerationError as e:
            self._log_results(eid, f"Failure in Tie Generation. Reason:\n{e}")
            return self.compile_plaintext(inst, propositions, tc_constraint, [], [])

        self._log_results(eid, "Success.")
        return self.compile_plaintext(
            inst, propositions, tc_constraint, tie_constraints, []
        )

    def generate_propositions(
        self,
        eid: int,
        instruction: str,
        state_str: str,
        metadata: Dict,
        prop_prompt_template_ex: str = "",
    ) -> List[EvaluationProposition]:
        """
        Prompts an LLM to produce evaluation propositions.
        Parses the output into List[EvaluationProposition].
        """
        prompt = self.proposition_prompt_template
        prompt = prompt.replace("{INSTRUCTION}", instruction)
        prompt = prompt.replace("{INIT_STATE}", state_str)
        template_key = "{TEMPLATE_EXAMPLE}"
        if prop_prompt_template_ex == "":
            # trim whitespace to match the other few-shot examples
            template_key += "\n\n"
        prompt = prompt.replace(template_key, prop_prompt_template_ex)

        self._log_results(eid, f"[prop call] LLM Prompt: \n{prompt}")

        propositions_str = self.llm.generate(
            prompt=prompt,
            stop="[END]",
            max_length=self.max_tokens_proposition_call,
        )
        self._log_results(eid, f"[prop call] Raw LLM Output: \n{propositions_str}")

        propositions = self.prop_parser.from_llm(propositions_str, metadata)
        self._log_results(
            eid,
            f"[prop call] Parsed LLM Output: \n{self.prop_parser.to_plaintext(propositions)}",
        )
        return propositions

    def generate_temporal_constraint(
        self,
        eid: int,
        instruction: str,
        propositions: List[EvaluationProposition],
    ) -> TemporalConstraint:
        """
        Infers the temporal order of a list propositions using LLM.
        Parses this result into DAG proposition groups.
        """
        empty_constraint = TemporalConstraint([], len(propositions))
        if self.skip_temporal_prediction:
            self._log_results(eid, "[dag call] Skipping.")
            return empty_constraint
        if not temporal_words_in_str(instruction):
            self._log_results(
                eid, "[dag call] Skipping: temporal words not in the instruction."
            )
            return empty_constraint

        dag_prompt = self.dag_prompt_template
        dag_prompt = dag_prompt.replace("{INSTRUCTION}", instruction)
        dag_prompt = dag_prompt.replace(
            "{PROPOSITIONS}", self.prop_parser.to_plaintext(propositions)
        )
        self._log_results(eid, f"[dag call] LLM Prompt: \n{dag_prompt}")

        dag_str = self.llm.generate(
            prompt=dag_prompt, stop="\n\n", max_length=self.max_tokens_dag_call
        )
        self._log_results(eid, f"[dag call] Raw LLM Output: \n{dag_str}")

        tc_constraint = self.dag_parser.from_llm(dag_str, n_props=len(propositions))
        self._log_results(
            eid,
            f"[dag call] Parsed LLM Output: \n{self.dag_parser.to_plaintext(tc_constraint)}",
        )
        return tc_constraint

    def generate_ties(
        self,
        eid: int,
        instruction: str,
        propositions: List[EvaluationProposition],
    ) -> List[Union[SameArgConstraint, DifferentArgConstraint]]:
        if self.skip_tie_prediction:
            self._log_results(eid, "[tie call] Skipping.")
            return []

        tie_prompt = self.tie_prompt_template
        tie_prompt = tie_prompt.replace("{INSTRUCTION}", instruction)
        tie_prompt = tie_prompt.replace(
            "{PROPOSITIONS}", self.prop_parser.to_plaintext(propositions)
        )
        self._log_results(eid, f"[tie call] LLM Prompt: \n{tie_prompt}")

        tie_str = self.llm.generate(
            prompt=tie_prompt, stop="\n\n", max_length=self.max_tokens_dag_call
        )
        self._log_results(eid, f"[tie call] Raw LLM Output: \n{tie_str}")

        try:
            constraints = self.tie_parser.from_llm(tie_str, propositions)
        except Exception as e:
            # don't crash, this call isn't critical
            self._log_results(eid, f"[tie call] Unexpected error: {str(e)}")
            constraints = []

        self._log_results(
            eid,
            f"[tie call] Parsed LLM Output: \n{self.tie_parser.to_plaintext(constraints, propositions)}",
        )

        filtered_constraints = self.filter_generated_ties(
            eid, constraints, propositions
        )
        n_before = len(constraints)
        n_after = len(filtered_constraints)
        if n_before != n_after:
            self._log_results(
                eid,
                f"[tie call] Filtered ties from {n_before} to {n_after} ties.",
            )
        return filtered_constraints

    def filter_generated_ties(
        self,
        eid: int,
        ties: List[Union[SameArgConstraint, DifferentArgConstraint]],
        propositions: List[EvaluationProposition],
    ) -> List[Union[SameArgConstraint, DifferentArgConstraint]]:
        """
        Keep ties that apply to 2 or more propositions
        and have more than one satisfying value.
        """
        valid_ties = []
        for tie in ties:
            if len(tie.proposition_indices) < 2:
                continue

            for prop_idx, arg_name in zip(
                tie.proposition_indices, tie.args["arg_names"]
            ):
                try:
                    prop = propositions[prop_idx]
                    matched_arg = prop.args[arg_name]
                except (IndexError, KeyError) as e:
                    self._log_results(
                        eid, f"[tie call] Improper indices generated. Error: {str(e)}"
                    )
                    continue

                if not isinstance(matched_arg, list):
                    continue
                # if a matched arg has more than one possible value, keep this tie.
                if len(matched_arg) > 1:
                    valid_ties.append(tie)
                    break

        return valid_ties

    def parse_plaintext_eval(
        self, plaintext_str: str, metadata: Dict
    ) -> Tuple[
        str,
        List[EvaluationProposition],
        List[EvaluationConstraint],
        List[EvaluationPropositionDependency],
    ]:
        """
        Takes a string of the plain text eval for an episode and converts it to
        propositions and constraints.
        """

        # check if the annotator marked skip
        skipped, reason = SkipParser.from_plaintext(plaintext_str)
        if skipped:
            raise LLMGenerationError("Episode skipped. " + reason)

        instruction = InstructionParser.from_plaintext(plaintext_str)
        propositions = self.prop_parser.from_plaintext(plaintext_str, metadata)
        tc_constraint = self.dag_parser.from_plaintext(plaintext_str, len(propositions))
        tie_constraints = self.tie_parser.from_plaintext(plaintext_str, propositions)
        ts_constraint = TerminalSatisfactionParser.from_plaintext(
            plaintext_str, propositions
        )
        all_constraints: List[EvaluationConstraint] = [
            tc_constraint,
            *tie_constraints,
            ts_constraint,
        ]
        dependencies = self.dep_parser.from_plaintext(plaintext_str)
        return instruction, propositions, all_constraints, dependencies

    def plaintext_evals_to_dataset(
        self, eids_to_skip: Optional[Set[int]] = None
    ) -> None:
        """
        Converts semantic names in evaluation functions to handles, resulting in
        evaluation functions that can be loaded into partnr.
        """
        if eids_to_skip is None:
            eids_to_skip = set()

        def eid_from_fname(fname):
            return int(fname.split(".")[0].split("_")[-1])

        # regenerate metadata in case handles changed.
        new_metadata = {}  # eid to metadata
        for ep in self.dataset_in["episodes"]:
            eid = self._eid_from_episode(ep)
            new_metadata[eid] = self._generate_metadata_mappings(ep)

        new_data = {}  # maps eid to dict containing propositions, constraints
        failure_summary = defaultdict(list)
        for fname in sorted(
            os.listdir(self.plain_text_eval_dir), key=lambda s: eid_from_fname(s)
        ):
            eid = eid_from_fname(fname)
            if eid in eids_to_skip:
                continue

            with open(os.path.join(self.plain_text_eval_dir, fname)) as f:
                eval_fn_plain_text_str = f.read()
            try:
                metadata = new_metadata[eid]
                (
                    instruction,
                    propositions,
                    constraints,
                    dependencies,
                ) = self.parse_plaintext_eval(eval_fn_plain_text_str, metadata)
            except LLMGenerationError as e:
                print(f"Failed to pack EID {eid} of file {fname}.\n\tReason: {str(e)}")
                failure_summary[str(e)].append(eid)
                continue
            except KeyError:
                print(f"eid {eid} no longer in the orig dataset?")
                failure_summary["missing eid"].append(eid)
                continue

            new_data[eid] = {
                "instruction": instruction,
                "propositions": propositions,
                "constraints": constraints,
                "dependencies": dependencies,
            }

        # log packing failures for future analysis and summary stats
        failure_file = os.path.join(
            os.path.dirname(self.dataset_file_out), "packing_failures.json"
        )
        with open(failure_file, "w") as f:
            json.dump(failure_summary, f, indent=2)

        self.save_new_dataset(new_data)

    def save_new_dataset(self, new_data: Dict[int, Any]) -> None:
        """Compile a CollaborationDataset and save it to disk."""
        new_dataset = CollaborationDatasetV0()

        new_dataset.episodes = []
        for ep in self.dataset_in["episodes"]:
            eid = self._eid_from_episode(ep)

            # skip if we don't have evaluation data for this episode
            if eid not in new_data:
                continue

            instruction: str = new_data[eid]["instruction"]

            episode = CollaborationEpisode(  # type: ignore
                episode_id=ep["info"]["extra_info"]["episode_id"],
                scene_id=ep["scene_id"],
                scene_dataset_config=ep["scene_dataset_config"],
                additional_obj_config_paths=ep["additional_obj_config_paths"],
                start_position=ep["start_position"],
                start_rotation=ep["start_rotation"],
                ao_states=ep["ao_states"],
                rigid_objs=ep["rigid_objs"],
                targets=ep["targets"],
                markers=ep["markers"],
                name_to_receptacle=ep["name_to_receptacle"],
                instruction=instruction,
                info={
                    "object_labels": ep["info"]["object_labels"],
                    "initial_state": ep["info"]["extra_info"]["initial_state"],
                },
                evaluation_propositions=new_data[eid]["propositions"],
                evaluation_proposition_dependencies=new_data[eid]["dependencies"],
                evaluation_constraints=new_data[eid]["constraints"],
                object_states=ep["object_states"],
            )
            new_dataset.episodes.append(episode)

        if not self.dataset_file_out.endswith(".json.gz"):
            raise AssertionError(
                f"Dataset file out should end with .json.gz. Found: `{self.dataset_file_out}`"
            )

        with gzip.open(self.dataset_file_out, "wt") as f:
            s = new_dataset.to_json()
            f.write(s)
        print(
            f"packed {len(new_dataset.episodes)} episodes"
            f" out of {len(self.dataset_in['episodes'])}."
        )

    def _log_results(self, eid: int, msg: str) -> None:
        """Appends the log message to the episode log."""
        with open(os.path.join(self.log_dir, f"episode_{eid}.log"), "a") as f:
            f.write(msg)
            f.write("\n\n")

    @staticmethod
    def _eid_from_episode(episode: CollaborationEpisode) -> int:
        """Extracts the episode ID integer from an episode object."""
        eid = episode["info"]["extra_info"]["episode_id"]
        if isinstance(eid, int):
            return eid
        if eid.isdigit():
            return int(eid)
        return int(eid.split("|")[-1].split(".")[0])

    @staticmethod
    def map_to_handle(name: str, entity_type: str, metadata: Dict[str, Any]) -> str:
        """Maps the name of an entity to its sim handle, guided by entity_type."""
        map_key = {
            "object": "object_to_handle",
            "receptacle": "recep_to_handle",
            "room": "room_to_id",
        }[entity_type]

        valid_entities = metadata[map_key]

        # allow objects to act as receptacles
        if entity_type == "receptacle":
            valid_entities = {**metadata["object_to_handle"], **valid_entities}

        if name not in valid_entities:
            raise LLMGenerationError(f"`{name}` is not a valid {entity_type}.")
        return valid_entities[name]

    @staticmethod
    def _generate_hash_to_text(
        metadata_csv: str,
        entity_name_to_handle: Dict,
    ) -> Dict:
        """get a mapping from object/receptacle hash to text description."""
        description_map = defaultdict(str)
        description_map_culled = {}
        with open(metadata_csv, "r") as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                description_map[row[0]] = row[1]
        for recep, handle in entity_name_to_handle.items():
            description_map_culled[recep] = description_map[handle.split("_:")[0]]
        return description_map_culled

    def _generate_metadata_mappings(self, episode: Any) -> Dict[str, Any]:
        """
        Derives essential metadata mappings using the contents of a scene info file
        and the episode data.
        """

        def sort_k_single(entity_name: str):
            """
            Takes an entity name and returns a key that affords
            secondary sorting on the post index if it exists.
            """
            idx_str = entity_name.split("_")[-1]
            try:
                idx = int(idx_str)
                entity_name = "_".join(entity_name.split("_")[:-1])
            except ValueError:
                idx = 0
            return (entity_name, idx)

        def sorted_dict(d, key):
            return dict(sorted(d.items(), key=key))

        with open(self.scene_info_file, "r") as f:
            metadata = json.load(f)

        rooms = list(metadata["room_to_id"].keys())
        recep_to_room = {}
        for _room, _receptacles in metadata["furniture"].items():
            for receptacle in _receptacles:
                recep_to_room[receptacle] = _room

        (
            objects,
            object_to_handle,
            object_to_room,
            object_to_recep,
        ) = self._object_instance_info_from_episode(episode)

        receptacle_to_handle = metadata["receptacle_to_handle"]
        room_to_id = metadata["room_to_id"]

        recep_to_description = self._generate_hash_to_text(
            self.metadata_csv, receptacle_to_handle
        )

        object_to_states = self._get_semantic_object_states(
            object_to_handle, receptacle_to_handle, episode["object_states"]
        )

        # sort items for fast visual pathing
        objects = sorted(objects, key=sort_k_single)
        recep_to_description = sorted_dict(
            recep_to_description, key=lambda x: sort_k_single(x[0])
        )
        rooms = sorted(rooms, key=sort_k_single)
        object_to_recep = sorted_dict(
            object_to_recep, key=lambda x: (sort_k_single(x[1]), sort_k_single(x[0]))
        )
        object_to_room = sorted_dict(
            object_to_room, key=lambda x: (sort_k_single(x[1]), sort_k_single(x[0]))
        )
        object_to_states = sorted_dict(
            object_to_states, key=lambda x: (sort_k_single(x[0]))
        )
        receptacle_to_handle = sorted_dict(
            receptacle_to_handle, key=lambda x: (sort_k_single(x[0]))
        )
        recep_to_room = sorted_dict(
            recep_to_room, key=lambda x: (sort_k_single(x[1]), sort_k_single(x[0]))
        )
        room_to_id = sorted_dict(room_to_id, key=lambda x: (sort_k_single(x[0])))
        return {
            "objects": objects,
            "rooms": rooms,
            "object_to_recep": object_to_recep,
            "object_to_room": object_to_room,
            "recep_to_room": recep_to_room,
            "recep_to_description": recep_to_description,
            "object_to_states": object_to_states,
            "object_to_handle": object_to_handle,
            "recep_to_handle": receptacle_to_handle,
            "room_to_id": room_to_id,
            "instruction": episode["info"]["extra_info"]["instruction"],
        }

    def _object_instance_info_from_episode(
        self, episode: CollaborationEpisode
    ) -> Tuple[List[str], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
        handle_template_to_obj_cat = {
            v.split("_:")[0]: k
            for k, v in episode["info"]["extra_info"]["obj_info"].items()
        }
        object_handles = list(episode["name_to_receptacle"].keys())

        objects = []
        object_to_room: Dict[str, str] = {}
        object_to_recep: Dict[str, str] = {}
        object_cat_to_count: DefaultDict[str, int] = defaultdict(int)
        for state_element in episode["info"]["extra_info"]["initial_state"]:
            if (
                "name" in state_element
                or "template_task_number" in state_element
                or len(state_element["object_classes"]) == 0
            ):  # skip clutter and template transfer state elements
                continue

            obj_name = state_element["object_classes"][0]
            for _ in range(state_element["number"]):
                o = f"{obj_name}_{object_cat_to_count[obj_name]}"
                object_cat_to_count[obj_name] += 1
                object_to_room[o] = state_element["allowed_regions"][0]
                object_to_recep[o] = state_element["furniture_names"][0]
                objects.append(o)

        # NOTE: this mapping is tenuous and relies on CPython dict order
        object_to_handle = {objects[i]: object_handles[i] for i in range(len(objects))}

        # add clutter objects if we aren't filtering them out.
        if not self.filter_out_clutter:
            n_objs = len(objects)
            for handle in list(episode["name_to_receptacle"].keys())[n_objs:]:
                handle_template = handle.split("_:")[0]
                obj_name = handle_template_to_obj_cat[handle_template]
                if obj_name not in object_cat_to_count:
                    object_cat_to_count[obj_name] = 0
                else:
                    object_cat_to_count[obj_name] += 1
                number = object_cat_to_count[obj_name]
                o = f"{obj_name}_{number}"
                objects.append(o)
                # TODO: room+receptacle info from 'name_to_receptacle'
                object_to_room[o] = "UNK"
                object_to_recep[o] = "UNK"
                object_to_handle[o] = handle_template + f"_:{format(number, '04d')}"

        return (
            objects,
            object_to_handle,
            object_to_room,
            object_to_recep,
        )

    def _get_semantic_object_states(
        self,
        object_to_handle: Dict[str, str],
        receptacle_to_handle: Dict[str, str],
        object_states: Dict[str, Dict[str, bool]],
    ) -> Dict[str, Dict[str, bool]]:
        """
        Maps an object state dictionary of the form:
            {"[affordance]": {"[handle]": bool, ...}, ...}
        to semantic names of the form:
            {"[semantic_name]": {"[affordance]": bool, ...}, ...}
        """
        handle_to_obj = {
            v: k for k, v in (object_to_handle | receptacle_to_handle).items()
        }

        object_to_states: Dict[str, Dict[str, bool]] = defaultdict(dict)
        for affordance, d in object_states.items():
            for handle, value in d.items():
                if handle not in handle_to_obj:
                    continue
                object_to_states[handle_to_obj[handle]][affordance] = value
        return object_to_states

    def _load_affordances_dict(self, affordances_csv: str) -> Dict[str, Set[str]]:
        """Maps object state predicates to a set of object/furniture classes that have an associated affordance"""
        with open(affordances_csv, "r") as f:
            reader = csv.reader(f)

            k_map = self.predicate_vocabulary["affordance_to_predicates"]
            affordances: Dict[str, Set[str]] = {
                p: set() for p in sorted({p for ps in k_map.values() for p in ps})
            }
            for row in reader:
                try:
                    k_map[row[0]]
                except KeyError as e:
                    print(
                        f"Affordance key `{row[0]}` missing from key map {k_map.keys()}."
                    )
                    raise e
                for predicate in k_map[row[0]]:
                    for aff in row[2:]:
                        aff = aff.lstrip().removeprefix("'").removeprefix("['")
                        aff = aff.removesuffix("'").removesuffix("']")
                        affordances[predicate].add(aff)
        return affordances


def scene_dirs_from_dir(d: str, full_path: bool = True) -> List[str]:
    scene_ids = [
        sid
        for sid in os.listdir(d)
        if os.path.isdir(os.path.join(d, sid)) and sid[0].isdigit()
    ]
    if not full_path:
        return scene_ids
    return [os.path.join(d, sid) for sid in scene_ids]


def setup_config(
    config: omegaconf.DictConfig,
) -> Tuple[omegaconf.DictConfig, List[str], Dict[int, str]]:
    """Initializes the job config."""
    config = OmegaConf.create(config).eval_gen

    prompt_template_strs = {}
    if config.generate:
        prompt_template_strs = load_template_prompt_examples(config)

    # setup the scene
    if config.scene_index == -1:
        scene_ids_to_run = config.scene_ids
    else:
        scene_ids_to_run = [config.scene_ids[config.scene_index]]

    print("--- Config --- ")
    print(OmegaConf.to_yaml(config))
    print("selected scenes:", scene_ids_to_run)
    return config, scene_ids_to_run, prompt_template_strs


def run_single_scene(
    scene_id: str, config: omegaconf.DictConfig, prompt_template_strs: Dict[int, str]
) -> None:
    outputs_dir = os.path.join(config.output_path, str(config.run_name), scene_id)
    eval_generator = LLMEvaluationGenerator(
        dataset_file_in=os.path.join(
            config.path_to_dataset_in, scene_id, "dataset.json.gz"
        ),
        dataset_file_out=os.path.join(outputs_dir, f"{config.run_name}.json.gz"),
        plain_text_eval_dir=os.path.join(outputs_dir, "plaintext_evals"),
        plain_text_eval_dir_copy=os.path.join(outputs_dir, "plaintext_evals_orig"),
        metadata_dir=os.path.join(outputs_dir, "metadata"),
        log_dir=os.path.join(outputs_dir, "logs"),
        scene_info_file=os.path.join(
            config.path_to_dataset_in, scene_id, "scene_info.json"
        ),
        filter_out_clutter=config.filter_out_clutter,
        recep_desc_in_prompt=config.recep_desc_in_prompt,
        metadata_csv=config.metadata_csv,
        affordances_csv=config.affordances_csv,
        proposition_prompt_file=config.proposition_prompt_file,
        dag_prompt_file=config.dag_prompt_file,
        tie_prompt_file=config.tie_prompt_file,
        predicate_vocabulary_file=config.predicate_vocabulary_file,
        llm=instantiate(config.llm.llm)(conf=config.llm),
        max_tokens_proposition_call=config.llm.max_tokens_proposition_call,
        max_tokens_dag_call=config.llm.max_tokens_dag_call,
        max_tokens_tie_call=config.llm.max_tokens_tie_call,
        skip_temporal_prediction=config.skip_temporal_prediction,
        skip_tie_prediction=config.skip_tie_prediction,
        skip_episode_default=config.skip_episode_default,
        is_from_templates=config.is_from_templates,
        prompt_template_strs=prompt_template_strs,
    )

    if config.generate:
        eval_generator.generate_plain_text_evaluations()
    if config.pack:
        eval_generator.plaintext_evals_to_dataset()


def display_packing_stats(source_path: str, n_eps: int) -> None:
    """
    Summarizes the stats of packing failure modes as contained in each scene directory.
    n_eps: the nubmer of episodes that were successfully packed.
    """

    def perc(n: int, m: int) -> str:
        return f"{round(100 * n / m, 2)}"

    failure_summary: DefaultDict[str, int] = defaultdict(int)
    for scene_dir in scene_dirs_from_dir(source_path):
        scene_packing_failures = os.path.join(scene_dir, "packing_failures.json")
        with open(scene_packing_failures, "r") as f:
            packing_failures = json.load(f)
            for k, v in packing_failures.items():
                failure_summary[k] += len(v)

    tot_failures = sum(failure_summary.values())
    tot_eps = n_eps + tot_failures

    with open(os.path.join(source_path, "packing_summary.json"), "w") as f:
        summary = {
            "episodes_before": tot_eps,
            "episodes_after": n_eps,
            "episodes_failed": tot_failures,
            "packed_percentage": float(perc(n_eps, tot_eps)),
            "failure_modes": failure_summary,
        }
        json.dump(summary, f, indent=2)

    print()
    print(" ------ Packing failures summary ------")
    print(f"Episodes after/before: {n_eps}/{tot_eps} ({perc(n_eps, tot_eps)}%)")
    print("failure modes:")
    for k, v in failure_summary.items():
        print(k, v, f"({perc(v, tot_failures)}%)")
    print()


def merge_datasets(source_path: str) -> None:
    """
    Merge scene-specific datasets into a single dataset and scene-specific
    metadata into a single metadata folder. Assigns new episode IDs.
    Args:
        - source_path: path to the directory containing scene directories of generated and packed episodes.
    Produces:
        - [source_path]/[run_name].json.gz
        - [source_path]/metadata/[episode_id].json
    """
    run_name = source_path.split("/")[-1]

    datasets_to_merge = []
    for scene_dir in scene_dirs_from_dir(source_path):
        dset_path = os.path.join(scene_dir, f"{run_name}.json.gz")
        metadata_path = os.path.join(scene_dir, "metadata/episode_{eid}.json")
        if not os.path.exists(dset_path):
            continue
        assert os.path.exists(
            os.path.dirname(metadata_path)
        ), f"{dset_path} missing {metadata_path}"
        datasets_to_merge.append((dset_path, metadata_path))

    assert len(datasets_to_merge), "no datasets to merge."

    new_metadata_dir = os.path.join(source_path, "metadata")
    os.makedirs(new_metadata_dir, exist_ok=True)

    new_episodes = []
    new_eid = 0
    for dset, metadata_template in tqdm.tqdm(datasets_to_merge):
        with gzip.open(dset, "rt") as f:
            eps = json.load(f)["episodes"]
        new_episodes.extend(eps)
        for ep in eps:
            # copy the episode metadata to src_path/metadata/episode_{eid}.json
            eid = str(ep["episode_id"])
            if not eid.isdigit():
                eid = eid.split("|")[-1]
            metadata_f = metadata_template.format(eid=eid)

            ep["info"]["episode_id"] = ep["episode_id"]
            ep["episode_id"] = str(new_eid)
            shutil.copy(
                metadata_f, os.path.join(new_metadata_dir, f"episode_{new_eid}.json")
            )
            new_eid += 1

    display_packing_stats(source_path, len(new_episodes))

    with gzip.open(os.path.join(source_path, f"{run_name}.json.gz"), "wt") as f:
        s = json.dumps({"config": None, "episodes": new_episodes})
        f.write(s)


def verify_inference_in_sim(
    output_path: str,
    run_name: str,
    scene_index: Optional[int] = -1,
    scene_ids: Optional[List[str]] = None,
    verification_num_proc: Optional[int] = 1,
) -> None:
    """
    Verifies that each episode can load and evaluation doesn't crash.
    Which dataset is verified? One of:
        (1) a scene dataset (if scene_index is provided)
        (2) the merged dataset (if it exists)
        (3) all scene dataset files
    """
    merged_dataset_path = os.path.join(output_path, run_name, f"{run_name}.json.gz")
    results_dir = os.path.join(output_path, run_name, "verification")

    to_verify: List[Tuple[str, str]] = []  # (dataset_path, results_path)
    if scene_index == -1:
        if os.path.exists(merged_dataset_path):
            to_verify = [(merged_dataset_path, results_dir)]
        else:
            for scene in scene_ids:
                dset = os.path.join(output_path, run_name, scene, f"{run_name}.json.gz")
                to_verify.append((dset, os.path.join(results_dir, scene)))
    else:
        dset = os.path.join(
            output_path, run_name, scene_ids[scene_index], f"{run_name}.json.gz"
        )
        to_verify = [(dset, results_dir)]

    for dataset_path, results_path in to_verify:
        hydra.core.global_hydra.GlobalHydra.instance().clear()
        verify_dataset_parallel(dataset_path, results_path, verification_num_proc)


def trim_failed_episodes(output_path: str, run_name: str) -> None:
    """
    Trims the merged dataset to episodes which passed simulator verification.
    Saves these episodes to a new file `[run_name]_verified.json.gz`.
    """
    summary_file = os.path.join(output_path, run_name, "verification", "summary.json")
    if not os.path.exists(summary_file):
        raise AssertionError(
            f"Summary file not found: {summary_file}. Run `verify` first."
        )

    with open(summary_file) as f:
        summary = json.load(f)

    dataset_file = os.path.join(output_path, run_name, f"{run_name}.json.gz")
    with gzip.open(dataset_file, "rt") as f:
        dataset = json.load(f)
    eps_before = len(dataset["episodes"])

    failed_eids = {
        int(k.removesuffix(".json"))
        for k, v in summary.items()
        if not v["success_init"]
    }
    dataset["episodes"] = list(
        filter(lambda e: int(e["episode_id"]) not in failed_eids, dataset["episodes"])
    )
    eps_after = len(dataset["episodes"])

    dataset_file_out = os.path.join(
        output_path, run_name, f"{run_name}_verified.json.gz"
    )
    with gzip.open(dataset_file_out, "wt") as f:
        s = json.dumps(dataset)
        f.write(s)

    perc = round(100 * eps_after / eps_before, 2)
    print()
    print(f"Episodes in {run_name} (before): {eps_before}")
    print(f"Episodes in {run_name}  (after): {eps_after} ({perc}%)")
    print(f"Saved trimmed dataset to: {dataset_file_out}")
    print()


def load_template_prompt_examples(config: omegaconf.DictConfig) -> Dict[int, str]:
    """
    Produce a mapping of template task number to string of the proposition prompt example.
    """
    if not config.is_from_templates:
        return {}

    with gzip.open(config.template_dataset, "rt") as f:
        template_episodes = json.load(f)["episodes"]

    prompt_example_strs = {}  # episode index (NOT eid) to prompt example string

    for i, ep in enumerate(template_episodes):
        eid = str(ep["info"]["episode_id"])
        if not eid.isdigit():
            eid = eid.split("|")[-1]

        metadata_f = os.path.join(
            config.template_dataset_dir,
            ep["scene_id"],
            "metadata",
            f"episode_{eid}.json",
        )
        with open(metadata_f) as f:
            metadata = json.load(f)

        with open(config.predicate_vocabulary_file) as f:
            object_state_negations = json.load(f)["object_state_negations"]
        metadata_str = metadata_to_state_string(metadata, object_state_negations)
        propositions_str = "".join(
            proposition_to_llm_output_str(p, metadata)
            for p in ep["evaluation_propositions"]
        ).lstrip("\n")
        ex_str = f"""<step> Source: user

The initial state is:
{metadata_str}

Instruction: "{ep["instruction"]}"

<step> Source: assistant
{propositions_str}
[END]
"""
        prompt_example_strs[i] = ex_str

    return prompt_example_strs


@hydra.main(
    version_base=None,
    config_path="../../partnr/conf/benchmark_gen",
    config_name="evaluation_gen_codellama.yaml",
)
def main(config: omegaconf.DictConfig) -> None:
    config, scene_ids_to_run, prompt_template_strs = setup_config(config)

    if config.generate or config.pack:
        for scene_id in scene_ids_to_run:
            run_single_scene(str(scene_id), config, prompt_template_strs)

    if config.merge:
        merge_datasets(os.path.join(config.output_path, config.run_name))
    if config.verify:
        verify_inference_in_sim(
            config.output_path,
            config.run_name,
            config.scene_index,
            config.scene_ids,
            config.verification_num_proc,
        )
        trim_failed_episodes(
            config.output_path,
            config.run_name,
        )


if __name__ == "__main__":
    """
    Workflow:
    
    1) Generate plain text evaluation functions via LLM.
        >>> python -m dataset_generation.benchmark_generation.generate_evaluations

        If doing batched generation, set all the desired parameters and run the following sbatch script instead.
        >>> sbatch dataset_generation/benchmark_generation/run_eval_gen.sh

    2) [Optional] Manually correct evaluation functions.
        Directly modify: [output_dir]/[run_name]/plaintext_evals/episode_[i].py
        Using reference: [output_dir]/[run_name]/metadata/episode_[i].json

    3) Finalize the dataset: pack, merge, verify.
        >>> python -m dataset_generation.benchmark_generation.generate_evaluations \
            eval_gen.generate=False \
            eval_gen.pack=True \
            eval_gen.merge=True \
            eval_gen.verify=True
        Recommendation: if manually correcting evaluation functions, run `eval_gen.pack=True` by itself. Use the results to fix your errors before continuing.
    """
    main()
