#!/usr/bin/env python3



import json
import os

import hydra
import omegaconf
from hydra.utils import instantiate
from omegaconf import OmegaConf

from dataset_generation.benchmark_generation.filter_instructions import filter_dataset
from dataset_generation.benchmark_generation.generate_episodes import (
    run_generation_over_proposals,
)
from dataset_generation.benchmark_generation.llmgen2episodes_pipeline import (
    get_state_dict,
    get_state_dict_fromperscenerun,
)
from dataset_generation.benchmark_generation.parse_generated_instructions import (
    InstructionParser,
)


@hydra.main(
    version_base=None,
    config_path="../conf/",
    config_name="benchmark_gen_llama2.yaml",
)
def main(cfg: omegaconf.DictConfig):
    ## book-keeping
    init_state_dicts = "init_state_dicts.json"
    output_dataset = "dataset.json.gz"
    failure_log = "generator_failure_log.json"
    validated_instructions = "valid_instructions.csv"
    final_instructions = "final_instructions.csv"
    folder_names = []

    ## configs
    inst_gen_config = OmegaConf.create(cfg)
    rootfolder = inst_gen_config.generator.output_path
    gen_config = {
        "scene_dataset": "data/fpss/hssd-hab-partnr.scene_dataset_config.json",
        "additional_object_paths": [
            "data/objects/ycb/configs/",
            # "data/objects_ovmm/train_val/ai2thorhab/configs/objects",
            "data/objects_ovmm/train_val/amazon_berkeley/configs",
            "data/objects_ovmm/train_val/google_scanned/configs",
            "data/objects_ovmm/train_val/hssd/configs/objects",
        ],
        "ep_dataset_output": os.path.join(rootfolder, output_dataset),
        "failure_log_output": os.path.join(rootfolder, failure_log),
        "enable_check_obj_stability": False,
    }

    metadata_dict = {
        "metadata_folder": "data/fpss/metadata",
        "obj_metadata": "object_categories_filtered_no_thor.csv",
        # "obj_metadata": "object_categories_filtered.csv",
        "room_objects_json": "room_objects.json",
        "staticobj_metadata": "fpmodels-with-decomposed.csv",
    }

    ## generate instructions and object initialization proposals
    run_id = inst_gen_config.generator.repeat_run_id
    run_per_scene = inst_gen_config.generator.run_per_scene
    save_out_folder = rootfolder
    if run_per_scene:
        print(
            "inst_gen_config.generator.scene_ids",
            inst_gen_config.generator.scene_ids[
                inst_gen_config.generator.scene_id_torun
            ],
        )
        inst_gen_config.generator.output_path = os.path.join(
            rootfolder,
            str(
                inst_gen_config.generator.scene_ids[
                    inst_gen_config.generator.scene_id_torun
                ]
            ),
        )
        rootfolder_per_scene = inst_gen_config.generator.output_path
        gen_config["ep_dataset_output"] = os.path.join(
            rootfolder_per_scene, output_dataset
        )
        gen_config["failure_log_output"] = os.path.join(
            rootfolder_per_scene, failure_log
        )
        save_out_folder = rootfolder_per_scene
    inst_gen_config.generator.output_path = os.path.join(
        inst_gen_config.generator.output_path, str(run_id)
    )
    print("new path for inst gen", inst_gen_config.generator.output_path)
    inst_gen = instantiate(inst_gen_config.generator)
    inst_gen.generate_instructions()

    ## parse and filter instructions
    instr_parser = InstructionParser()
    res_stats, html = instr_parser.parse_instructions(
        inst_gen_config.generator.output_path,
        save_output=True,
        generate_html=True,
        run_per_scene=run_per_scene,
        per_call_generation=inst_gen_config.generator.generations_per_call,
        add_clutter=inst_gen_config.generator.add_clutter,
    )
    print(
        "\nParsed instructions from:",
        inst_gen_config.generator.output_path,
        "to remove hallucinations. Overview:",
        res_stats,
    )

    print("\nLoading and filtering init state dicts from parsed output..")
    if run_per_scene:
        for i in range(run_id + 1):
            folder_names.append(f"{rootfolder_per_scene}/{i}")
        scene_name = rootfolder_per_scene.split("/")[-1]
        print("Reading from folders", folder_names, "for scene:", scene_name)
        all_dicts = get_state_dict_fromperscenerun(folder_names, scene_name)
    else:
        all_dicts = get_state_dict(rootfolder)

    filtered_dict, validated_instructions_df = filter_dataset(all_dicts)
    with open(os.path.join(save_out_folder, init_state_dicts), "w") as f:
        json.dump(filtered_dict, f)
    validated_instructions_df.to_csv(
        os.path.join(save_out_folder, validated_instructions), index=False
    )

    ## generate episodes
    print("\nGenerating episodes for total", len(filtered_dict), "episodes")
    if len(filtered_dict) > 0:
        final_generation_df, gen_eps_count = run_generation_over_proposals(
            gen_config, metadata_dict, filtered_dict, validated_instructions_df
        )
        final_generation_df.to_csv(
            os.path.join(save_out_folder, final_instructions), index=False
        )
    else:
        gen_eps_count = 0
    print(
        "\nGenerated", gen_eps_count, "out of total", len(filtered_dict), "episodes.\n"
    )
    print("Valid episodes", gen_eps_count)


if __name__ == "__main__":
    main()
