#!/usr/bin/env python3



import copy
import glob
import html
import json
import os
import random
import shutil
from json.decoder import JSONDecodeError

conversion_dict = {
    "plant": "plant_container",
    "watering_can": "pitcher",
    "reading_lamp": "lamp",
    "phone": "cellphone",
    "ball": "basketball",
    "towel": "hand_towel",
    "washcloth": "hand_towel",
    "laundry": "bath_towel",
    "laundry_basket": "basket",
    "hamper": "basket",
    "toy_car": "toy_vehicle",
    "stuffed_animal": "stuffed_toy",
    "dish": "plate",
    "shoes": "shoe",
    "cutlery": "fork",
    "dishes": "plate",
    "mug": "cup",
    "cups": "cup",
    "bat": "baseballbat",
    "dirty_dishes": "plate",
    "clean_dishes": "plate",
    "dirty_laundry": "bath_towel",
}


def html_sanitize(content_str):
    res = html.escape(content_str)

    content_str = res.replace("\n", "<br>")
    return content_str


class InstructionParser:
    def parse_instruction_folders(
        self,
        folder_names,
        save_output=False,
        generate_html=False,
        run_per_scene=False,
        per_call_generation=5,
        add_clutter=False,
    ):
        dict_total = {
            "total_theoretically_possible": 0,
            "total_obtained": 0,
            "good_parsing": 0,
        }
        all_html = []
        for folder_name in folder_names:
            dict_res, other_info = self.parse_instructions(
                folder_name,
                save_output,
                generate_html,
                run_per_scene,
                per_call_generation,
                add_clutter,
            )
            all_html += other_info["html_output"]
            for elem in dict_res:
                dict_total[elem] += dict_res[elem]
        if generate_html:
            html_head = """
                <!DOCTYPE html>
                    <html>
                    <head>
                        <style>
                            .code-block {
                                background-color: #f0f0f0;
                                border: 1px solid #ddd;
                                padding: 10px;
                                font-family: monospace;
                                white-space: pre;
                                overflow-x: auto;
                            }
                            .bad-code-block {
                                background-color: #fde0e0;
                                border: 1px solid #ddd;
                                padding: 10px;
                                font-family: monospace;
                                white-space: pre;
                                overflow-x: auto;
                            }
                            .good-code-block {
                                background-color: #C1E1C1;
                                border: 1px solid #ddd;
                                padding: 10px;
                                font-family: monospace;
                                white-space: pre;
                                overflow-x: auto;
                            }
                        </style>
                    </head>
                    <input type="checkbox" id="toggleButton1" onclick="toggleCodeBlock('toggleButton1', 'ccode-block')" checked>Show Bad Objects</input><br>
                    <input type="checkbox" id="toggleButton2" onclick="toggleCodeBlock('toggleButton2', 'cbad-code-block')" checked>Show Bad JSON</input> Show hallucinated<br>
                    <input type="checkbox" id="toggleButton3" onclick="toggleCodeBlock('toggleButton3', 'cgood-code-block')" checked>Show good</input>

                    <body>
            """
            script = """
                <script>
                function toggleCodeBlock(button_str, class_name) {
                    var checkBox = document.getElementById(button_str);
                    var codeBlocks = document.getElementsByClassName(class_name);
                    if (checkBox.checked == true){
                        for (codeBlock of codeBlocks){
                            codeBlock.style.display = "block";
                        }
                    } else {
                        for (codeBlock of codeBlocks){
                            codeBlock.style.display = "none";
                        }
                    }
                }
            </script>
            """
            html_str = html_head + "<br>".join(all_html) + script + "</body>"
            with open("result.html", "w+") as f:
                f.write(html_str)
            return dict_total
        else:
            return dict_total

    def parse_instructions(
        self,
        folder_name,
        save_output=False,
        generate_html=False,
        run_per_scene=False,
        per_call_generation=5,
        add_clutter=False,
    ):
        """
        Read a folder of episodes generated by the LLM, parses them and makes sure that
        objects exist
        """
        # TODO: this should be generated somewhere else
        html_out = []
        good_parsing = 0
        total = 0
        missing_objects = 0
        missing_furniture = 0
        all_missing = []

        if not run_per_scene:
            scenes = [
                scene
                for scene in glob.glob(f"{folder_name}/*")
                if "yaml" not in scene and "json" not in scene and "csv" not in scene
            ]
        else:
            scenes = [folder_name]
        for scene_path in scenes:
            parsed_folder = f"{scene_path}/output_parsed"
            gen_folder = f"{scene_path}/output_gen"
            file_for_eval = f"{scene_path}/scene_info.json"
            out_folder = "/".join(folder_name.split("/")[0:-1])
            # out_folder = f"{folder_name}/"
            scene_info = self.load_scene_info(file_for_eval)
            shutil.copy2(file_for_eval, out_folder)

            if save_output and not os.path.isdir(parsed_folder):
                os.makedirs(parsed_folder)
            files = glob.glob(f"{gen_folder}/*")
            total_theo = len(files) * per_call_generation

            for file_path in files:
                with open(file_path, "r") as f:
                    content = f.read()
                fi = content.find("[")
                ei = content.rfind("]")

                content = content[fi : ei + 1]
                content_parsed = self.parse_to_json(content)

                html_file = ""
                if len(content_parsed) == 0 and generate_html:
                    # html_file += "<p> Bad JSON Parsing </p><br>"
                    html_file += (
                        "<div class='cbad-code-block'>"
                        + f"<p>{file_path}</p><br>"
                        + "<div class='bad-code-block'>"
                    )
                    html_file += html_sanitize(content)
                    html_file += "</div></div>"

                    html_out.append(html_file)

                for ind, episode_init in enumerate(content_parsed):
                    total += 1
                    dest_file = file_path.replace("output_gen", "output_parsed")
                    dest_file = dest_file.replace(".json", f"_{ind}.json")

                    (
                        is_valid,
                        content_parsed,
                        missing_objects,
                        missing_furniture,
                        missing_room,
                    ) = self.episode_init_valid(episode_init, scene_info, add_clutter)

                    print(
                        "is_valid:",
                        is_valid,
                        "missing_objects:",
                        missing_objects,
                        "missing_furniture:",
                        missing_furniture,
                        "missing_room:",
                        missing_room,
                    )
                    # breakpoint()

                    html_out.append(html_file)
                    if is_valid:
                        html_file += (
                            "<div class='ccode-block'>"
                            + f"<p>{file_path}</p><br>"
                            + "<div class='code-block'>"
                        )

                        good_parsing += 1

                        if save_output:
                            with open(dest_file, "w+") as f:
                                f.write(json.dumps(content_parsed, indent=4))
                    else:
                        html_file += (
                            "<div class='cgood-code-block'>"
                            + f"<p>{file_path}</p><br>"
                            + "<div class='good-code-block'>"
                        )

                        all_missing += missing_objects

                    html_file += html_sanitize(json.dumps(content_parsed, indent=4))
                    html_file += "</div></div>"
        return {
            "total_theoretically_possible": total_theo,
            "total_obtained": total,
            "good_parsing": good_parsing,
        }, {"html_output": html_out}

    def episode_init_valid(self, init_episode, scene_info, add_clutter=False):
        """
        Check if the episode initialization is valid
        """
        # scene_info = self.load_scene_info(scene_path)
        missing_objects = []
        missing_furniture = []
        missing_room = []
        missing_spatial_anchor = []
        parsed_init_state = []
        parsed_fin_state = []
        task_relevant_objects = []
        necessary_fields = ["object_type", "furniture_name", "region", "how_many"]

        # check if initializations exist
        if "initial state" in init_episode:
            init_state_key = "initial state"
        elif "inital state" in init_episode:
            # NOTE: catch misspelled "initial" also
            init_state_key = "inital state"
        else:
            print("initial or final state not in the input!!")
            return False, [], [], [], []

        if not all(x in init_episode for x in [init_state_key, "final state"]):
            print("initial or final state not in the input!!")
            return False, [], [], [], []

        # check hallucinations in initial state
        for init_obj in init_episode[init_state_key]:
            # ensure all init fields are present
            if not all(x in init_obj for x in necessary_fields):
                continue

            # Convert object
            init_obj["object_type"] = (
                init_obj["object_type"].lower().strip().replace(" ", "_")
            )

            if init_obj["object_type"] in conversion_dict:
                init_obj["object_type"] = conversion_dict[init_obj["object_type"]]

            if init_obj["object_type"] not in scene_info["objects"]:
                missing_objects.append(init_obj["object_type"])
            else:
                task_relevant_objects.append(init_obj["object_type"])

            if (
                init_obj["furniture_name"] not in scene_info["all_furniture"]
                and init_obj["furniture_name"] != "floor"
            ):
                missing_furniture.append(init_obj["furniture_name"])

            if (
                init_obj["region"] not in scene_info["furniture"].keys()
                and init_obj["region"] not in scene_info["all_rooms"]
            ):
                missing_room.append(init_obj["region"])

            parsed_init_state.append(
                {
                    "number": init_obj["how_many"],
                    "object_classes": [init_obj["object_type"]],
                    "furniture_names": [init_obj["furniture_name"]],
                    "allowed_regions": [init_obj["region"]],
                }
            )

        if add_clutter:
            clutter_num = random.randint(0, 5)
            clutter_num = str(clutter_num)
            ##use the task_relevant_objects list above to control clutter gen
            parsed_init_state.append(
                {
                    "name": "common sense",
                    "excluded_object_classes": task_relevant_objects,
                    "exclude_existing_objects": True,
                    "number": clutter_num,
                    "common_sense_object_classes": True,  # this specifies region->object metadata is used for sampling
                    "location": "on",
                    "furniture_names": [],
                },
            )
        new_init = copy.deepcopy(init_episode)
        del new_init[init_state_key]
        new_init["initial_state"] = parsed_init_state

        # check hallucinations in final state
        for fin_obj in init_episode["final state"]:
            # ensure all init fields are present
            if not all(x in fin_obj for x in necessary_fields):
                break

            if fin_obj["furniture_name"] in conversion_dict:
                fin_obj["furniture_name"] = conversion_dict[fin_obj["furniture_name"]]

            if (
                fin_obj["furniture_name"] not in scene_info["all_furniture"]
                and fin_obj["furniture_name"] != "floor"
            ):
                missing_furniture.append(fin_obj["furniture_name"])

            if (
                fin_obj["region"] not in scene_info["furniture"].keys()
                and fin_obj["region"] not in scene_info["all_rooms"]
            ):
                missing_room.append(fin_obj["region"])
            if "spatial_anchor" in fin_obj:
                for anchor in fin_obj["spatial_anchor"]:
                    if (
                        anchor not in scene_info["furniture"].keys()
                        and anchor not in scene_info["objects"]
                    ):
                        missing_spatial_anchor.append(fin_obj["spatial_anchor"])

            parsed_fin_state.append(
                {
                    "number": fin_obj["how_many"],
                    "object_classes": [fin_obj["object_type"]],
                    "furniture_names": [fin_obj["furniture_name"]],
                    "allowed_regions": [fin_obj["region"]],
                }
            )
        del new_init["final state"]
        new_init["final_state"] = parsed_fin_state

        if (
            len(missing_objects) > 0
            or len(missing_furniture) > 0
            or len(missing_spatial_anchor) > 0
        ):
            return False, new_init, missing_objects, missing_furniture, missing_room
        else:
            return True, new_init, [], [], []

    def parse_to_json(self, content_parsed):
        """
        Modifies json string so that it can be parsed
        """
        content_parsed = content_parsed.strip()
        content_parsed = content_parsed.replace(",,", ",")
        try:
            return json.loads(content_parsed)
        except Exception:
            content_parsed = content_parsed.replace('\n[]"reason', ',\n"reason')
        try:
            return json.loads(content_parsed)
        except Exception:
            content_parsed = content_parsed.replace("\n]\n]", "\n]")
        try:
            content_parsed = content_parsed.replace(",,", ",")

            return json.loads(content_parsed)
        except JSONDecodeError as e:
            print(e)
            return {}

    def load_scene_info(self, sceneinfo_file):
        """
        Loads scene information
        """
        # sceneinfo_file = f"{scenepath}/scene_info.json"
        with open(sceneinfo_file, "r") as f:
            scene_info = json.load(f)
        scene_info["all_furniture"] = []
        scene_info["all_rooms"] = []
        for room, furniture_room in scene_info["furniture"].items():
            scene_info["all_furniture"] += furniture_room
            if room not in scene_info["all_rooms"]:
                scene_info["all_rooms"].append(room)
        scene_info["all_furniture"] = list(set(scene_info["all_furniture"]))
        return scene_info


if __name__ == "__main__":
    ## Add this if you would like to run this script as a standalone!
    root_path = "/"
    folder_names = [f"{root_path}/folder_name"]
    dict_total = {
        "total_theoretically_possible": 0,
        "total_obtained": 0,
        "good_parsing": 0,
    }
    instr_parser = InstructionParser()
    res = instr_parser.parse_instruction_folders(
        folder_names,
        save_output=True,
        generate_html=True,
        run_per_scene=False,
        per_call_generation=5,
    )
    print(res)
