#!/usr/bin/env python3



# isort: skip_file

import time
import os
import json
import glob

from omegaconf import OmegaConf, open_dict

import habitat
from hydra import compose, initialize
import numpy as np
from torch import multiprocessing as mp
import traceback


from partnr.utils import cprint


from habitat.config.default_structured_configs import (
    HabitatConfigPlugin,
    register_hydra_plugin,
)
from habitat_baselines.config.default_structured_configs import (
    HabitatBaselinesConfigPlugin,
)

from partnr.agent.env import (
    register_measures,
    register_sensors,
)

from partnr.agent.env.dataset import CollaborationDatasetV0


def summarize_verifications(save_results_dir: str):
    """
    Aggregates the metrics (percent complete and success) and exceptions
    unpon episode initialization. Displays in text and saves to file.
    """
    summary_json = os.path.join(save_results_dir, "summary.json")
    files = glob.glob(os.path.join(save_results_dir, "*"))
    total = len(files)
    avg_stats = {"task_percent_complete": 0, "task_state_success": 0}
    no_breaks = 0
    summary_content = {}
    for file in files:
        if file.endswith("summary.json"):
            continue
        with open(file, "r") as f:
            content = json.load(f)

        no_breaks += content["success_init"]
        if content["success_init"]:
            for key in ["task_percent_complete", "task_state_success"]:
                avg_stats[key] += content["info"][key]
        ep_name = file.split("/")[-1]
        summary_content[ep_name] = content

    print(f"No exceptions: {no_breaks*100/total:.2f}%")
    for key, val in avg_stats.items():
        print(f"{key}: {val/no_breaks}")

    with open(summary_json, "w+") as f:
        f.write(json.dumps(summary_content))


def get_output_file(save_results_dir, eid):
    os.makedirs(save_results_dir, exist_ok=True)
    output_file = os.path.join(save_results_dir, f"{eid}.json")
    return output_file


def process_success_at_init(env, save_results_dir, eid):
    output_file = get_output_file(save_results_dir, eid)
    metrics = env.get_metrics()
    success_dict = {
        "success_init": True,
        "info": {
            metric_name: metrics[metric_name]
            for metric_name in ["task_percent_complete", "task_state_success"]
        },
    }
    with open(output_file, "w+") as f:
        f.write(json.dumps(success_dict))


def process_error_at_init(save_results_dir, eid):
    output_file = get_output_file(save_results_dir, eid)
    exc_string = traceback.format_exc()
    failure_dict = {"success_init": False, "info": str(exc_string)}
    with open(output_file, "w+") as f:
        f.write(json.dumps(failure_dict))


def run_verifier(
    config,
    save_results_dir,
    dataset: CollaborationDatasetV0 = None,
    conn=None,
):
    if config == None:
        cprint("Failed to setup config. Exiting", "red")
        return

    os.environ["GLOG_minloglevel"] = "3"  # noqa: SIM112
    os.environ["MAGNUM_LOG"] = "quiet"
    os.environ["HABITAT_SIM_LOG"] = "quiet"

    register_hydra_plugin(HabitatBaselinesConfigPlugin)
    register_hydra_plugin(HabitatConfigPlugin)

    with open_dict(config):
        config.habitat.simulator.agents_order = sorted(
            config.habitat.simulator.agents.keys()
        )

        # Propagate the metadata folder config to the simulator
        config_dict = OmegaConf.create(
            OmegaConf.to_container(config.habitat, resolve=True)
        )
        # TODO: refactor this. We shouldn't need to copy configs into other subconfigs to pass information. This is done now because CollaborationSim needs metadata paths for init.
        config_dict.simulator.metadata = config.habitat.dataset.metadata
        config.habitat = config_dict

    register_sensors(config)
    register_measures(config)

    env = habitat.Env(config=config, dataset=dataset)
    env.sim.dynamic_target = np.zeros(3)

    for i in range(len(env.episodes)):
        eid = env.episodes[i].episode_id
        try:
            env.reset()
            process_success_at_init(env, save_results_dir, eid)
        except Exception:
            process_error_at_init(save_results_dir, eid)

    # aggregate metrics across all runs.
    if conn is not None:
        conn.send([0])

    env.close()
    del env

    if conn is not None:
        # Potentially we may want to send something
        conn.close()

    return


def verify_dataset_parallel(
    dataset_path: str, save_results_dir: str, num_proc: int = 1
):
    # Set up hydra config
    with initialize(version_base=None, config_path="../../partnr/conf"):
        config = compose(
            config_name="benchmark_gen/evaluation_validation.yaml",
            overrides=[
                "+habitat.dataset.metadata.metadata_folder=data/fpss/metadata",
                "+habitat.dataset.metadata.obj_metadata=object_categories_filtered.csv",
                "+habitat.dataset.metadata.staticobj_metadata=fpmodels-with-decomposed.csv",
                f"habitat.dataset.data_path={dataset_path}",
            ],
        )

    config = OmegaConf.create(config)
    t0 = time.time()
    dataset = CollaborationDatasetV0(config.habitat.dataset)
    num_episodes = len(dataset.episodes)

    if num_proc == 1:
        episode_subset = dataset.episodes
        new_dataset = CollaborationDatasetV0(
            config=config.habitat.dataset, episodes=episode_subset
        )
        run_verifier(config, save_results_dir, new_dataset)
    else:
        # Process episodes in parallel
        mp_ctx = mp.get_context("forkserver")
        proc_infos = []
        ochunk_size = num_episodes // num_proc
        chunked_datasets = []
        start = 0
        for i in range(num_proc):
            chunk_size = ochunk_size
            if i < (num_episodes % num_proc):
                chunk_size += 1
            end = min(start + chunk_size, num_episodes)
            indices = slice(start, end)
            chunked_datasets.append(indices)
            start += chunk_size

        for episode_index_chunk in chunked_datasets:
            episode_subset = dataset.episodes[episode_index_chunk]
            new_dataset = CollaborationDatasetV0(
                config=config.habitat.dataset, episodes=episode_subset
            )

            parent_conn, child_conn = mp_ctx.Pipe()
            proc_args = (config, save_results_dir, new_dataset, child_conn)
            p = mp_ctx.Process(target=run_verifier, args=proc_args)
            p.start()
            proc_infos.append((parent_conn, p))

        # Get back info
        for conn, proc in proc_infos:
            try:
                conn.recv()
            except Exception:
                pass
            proc.join()

    e_t = time.time() - t0
    print("Elapsed Time: ", e_t)

    summarize_verifications(save_results_dir)


if __name__ == "__main__":
    """
    A script that verifies episodes successful load and eval measures don't crash.

    To run:
    python dataset_generation/benchmark_generation/verify_dataset.py \
        --dataset-path=path/to/dataset_name.json.gz \
        --save-results-dir=data/episode_checks/dataset_name \
        --num-proc=5
    """
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-path", type=str)
    parser.add_argument("--save-results-dir", type=str, default="")
    parser.add_argument("--num-proc", type=int, default=1)
    args = parser.parse_args()

    save_results_dir = args.save_results_dir
    if save_results_dir == "":
        dset_name = args.dataset_path.split("/")[-1].split(".")[0]
        save_results_dir = f"data/episode_checks/{dset_name}"

    verify_dataset_parallel(args.dataset_path, save_results_dir, args.num_proc)
