#!/usr/bin/env python3



import glob
import json
import os

from dataset_generation.benchmark_generation.filter_instructions import filter_dataset
from dataset_generation.benchmark_generation.generate_episodes import (
    run_generation_over_proposals,
)
from dataset_generation.benchmark_generation.parse_generated_instructions import (
    InstructionParser,
)


def get_state_dict_fromperscenerun(folder_names, scene_name=""):
    episode_dict = []
    count = 0
    if scene_name == "":
        print("scene name is required when parsing outputs from per scene run")
        return []
    else:
        if type(folder_names) == str:
            folder_names = [folder_names]
        for folder_name in folder_names:
            file_paths = [
                file_path
                for file_path in glob.glob(f"{folder_name}/*")
                if "output_parsed" in file_path
            ]
            for file_path in file_paths:
                files = glob.glob(f"{file_path}/*.json")
                for file_path in files:
                    with open(file_path, "r") as f:
                        content = json.load(f)
                    content["scene_id"] = scene_name
                    content["episode_id"] = f"test|{scene_name}|{count}"
                    content["file_path"] = file_path
                    count += 1
                    episode_dict.append(content)
        print("Total episode generation proposals:", count)
        return episode_dict


def get_state_dict(folder_names):
    episode_dict = []
    count = 0
    if type(folder_names) == str:
        folder_names = [folder_names]
    for folder_name in folder_names:
        scenes = [
            scene
            for scene in glob.glob(f"{folder_name}/*")
            if "yaml" not in scene and "json" not in scene and "csv" not in scene
        ]
        for scene in scenes:
            scene_name = scene.split("/")[-1]
            files = glob.glob(f"{scene}/output_parsed/*.json")
            for file_path in files:
                with open(file_path, "r") as f:
                    content = json.load(f)
                content["scene_id"] = scene_name
                content["episode_id"] = f"test|{scene_name}|{count}"
                content["file_path"] = file_path
                count += 1
                episode_dict.append(content)
    print("Total episode generation proposals:", count)
    return episode_dict


if __name__ == "__main__":
    # TODO: combine with end2end_pipeline script in the future to reduce code duplication

    rootfolder = "/"
    folder_names = None
    scene_name = None
    scene_name = None
    # folder_names = [f"{rootfolder}/folder_1",
    #                 f"{rootfolder}/folder_2"]

    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--rootfolder",
        type=str,
        default=rootfolder,
        help="Absolute path to the folder containing instructions and scene inits generated from LLMs.",
    )
    parser.add_argument(
        "--genpercall",
        type=int,
        default=5,
        help="Number of generations made per LLM call, Must match param used during benchmark gen as defined in benchmark gen config usually.",
    )
    parser.add_argument(
        "--addclutter",
        type=bool,
        default=False,
        help="Option to add clutter during episode generation, automatically excludes task relevant objects",
    )

    args, _ = parser.parse_known_args()
    rootfolder = args.rootfolder
    init_state_dicts = "init_state_dicts.json"
    output_dataset = "dataset.json.gz"
    failure_log = "generator_failure_log.json"
    validated_instructions = "valid_instructions.csv"
    final_instructions = "final_instructions.csv"

    gen_config = {
        "scene_dataset": "data/fpss/hssd-hab-partnr.scene_dataset_config.json",
        "additional_object_paths": [
            "data/objects/ycb/configs/",
            # "data/objects_ovmm/train_val/ai2thorhab/configs/objects",
            "data/objects_ovmm/train_val/amazon_berkeley/configs",
            "data/objects_ovmm/train_val/google_scanned/configs",
            "data/objects_ovmm/train_val/hssd/configs/objects",
        ],
        "ep_dataset_output": os.path.join(rootfolder, output_dataset),
        "failure_log_output": os.path.join(rootfolder, failure_log),
        "enable_check_obj_stability": False,
    }

    metadata_dict = {
        "metadata_folder": "data/fpss/metadata",
        "obj_metadata": "object_categories_filtered_no_thor.csv",
        # "obj_metadata": "object_categories_filtered.csv",
        "room_objects_json": "room_objects.json",
        "staticobj_metadata": "fpmodels-with-decomposed.csv",
    }

    instr_parser = InstructionParser()
    if folder_names != None:
        print("reading data from", folder_names)
        if not folder_names:
            print("Missing folders info!!")
        res_stats = instr_parser.parse_instruction_folders(
            folder_names,
            save_output=True,
            generate_html=True,
            run_per_scene=True,
            per_call_generation=args.genpercall,
            add_clutter=args.addclutter,
        )
    else:
        print("reading data from", rootfolder)
        res_stats, html = instr_parser.parse_instructions(
            rootfolder,
            save_output=True,
            generate_html=True,
            run_per_scene=False,
            per_call_generation=args.genpercall,
            add_clutter=args.addclutter,
        )
    print("\nParsed instructions to remove hallucinations, Overview:", res_stats)

    print("\nLoading and filtering init state dicts from parsed output..")
    if folder_names != None:
        assert scene_name != None
        all_dicts = get_state_dict_fromperscenerun(folder_names, scene_name)
    else:
        all_dicts = get_state_dict(args.rootfolder)

    filtered_dict, validated_instructions_df = filter_dataset(all_dicts)
    with open(os.path.join(rootfolder, init_state_dicts), "w") as f:
        json.dump(filtered_dict, f)
    validated_instructions_df.to_csv(
        os.path.join(rootfolder, validated_instructions), index=False
    )

    print("\nGenerating episodes for total", len(filtered_dict), "episodes")
    final_generation_df, gen_eps_count = run_generation_over_proposals(
        gen_config, metadata_dict, filtered_dict, validated_instructions_df
    )
    final_generation_df.to_csv(
        os.path.join(rootfolder, final_instructions), index=False
    )
    print("\nGenerated", gen_eps_count, "out of total", len(filtered_dict), "episodes")
