from util.logger import logger

from typing import List, Tuple, Dict

from omegaconf import DictConfig

import gc

from pathlib import Path

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import (
    load_yaml, save_yaml, 
    convert_numpy_type_to_native_type
)
from util.file_util import copy_file


def collect_result_ours_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    # ---------= [Path] =---------
    logger(f"[Path] Loading started. ")

    exp_root_path = get_true_value(cfg["task"]["path"]["exp_root_path"])
    exp_name_list = get_true_value(cfg["task"]["path"]["exp_name_list"])
    folder_name_list = get_true_value(cfg["task"]["path"]["folder_name_list"])
    num_sample = get_true_value(cfg["task"]["path"]["num_sample"])

    logger(f"    exp_root_path: {exp_root_path}")
    logger(f"    exp_name_list: {exp_name_list}")
    logger(f"    folder_name_list: {folder_name_list}")
    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Path] Loading finished. "
        "\n"
    )

    # ---------= [Reward] =---------
    logger(f"[Reward] Loading started. ")

    reward_type = get_true_value(cfg["task"]["reward"]["reward_type"])
    original = get_true_value(cfg["task"]["reward"]["original"])

    logger(f"    reward_type: {reward_type}")
    logger(f"    original: {original}")

    logger(
        f"[Reward] Loading finished. "
        "\n"
    )

    # ---------= [Collection] =---------
    logger(f"[Collection] Loading started. ")

    collect = get_true_value(cfg["task"]["collection"]["collect"])
    only_best = get_true_value(cfg["task"]["collection"]["only_best"])
    save_collection_root_path = get_true_value(cfg["task"]["collection"]["save_collection_root_path"])
    sd_type = get_true_value(cfg["task"]["collection"]["sd_type"])
    dataset = get_true_value(cfg["task"]["collection"]["dataset"])
    task_name = get_true_value(cfg["task"]["collection"]["task"])
    num_inference_step = get_true_value(cfg["task"]["collection"]["num_inference_step"])

    logger(f"    collect: {collect}")
    logger(f"    only_best: {only_best}")
    logger(f"    save_collection_root_path: {save_collection_root_path}")
    logger(f"    sd_type: {sd_type}")
    logger(f"    dataset: {dataset}")
    logger(f"    task_name: {task_name}")
    logger(f"    num_inference_step: {num_inference_step}")

    logger(
        f"[Collection] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    exp_root_path = Path(exp_root_path)

    if collect:
        save_collection_root_path = Path(save_collection_root_path)
        sd_root_path = save_collection_root_path / sd_type
        dataset_root_path = sd_root_path / dataset
        task_root_path = dataset_root_path / task_name
        inference_step_root_path = task_root_path / f"{num_inference_step}"

    # ---------= [Display Quantitative Result] =---------
    save_res_dict = {}

    res_dict_path = inference_step_root_path / "_metric.yaml"
    if res_dict_path.is_file():
        save_res_dict = load_yaml(res_dict_path)

    sample_reward_tuple_list_dict = {}  # (reward, exp_name)

    for exp_name in exp_name_list:
        exp_path = exp_root_path / exp_name
        
        if folder_name_list == "all":
            tmp_folder_name_list = [
                folder_path.name \
                    for folder_path in exp_path.iterdir() \
                        if folder_path.name != "_metric"
            ]
        else:
            tmp_folder_name_list = folder_name_list

        for folder_name in tmp_folder_name_list:
            folder_path = exp_path / folder_name
            
            for sample_idx in range(num_sample):
                res_yaml_path = folder_path / "result" / f"{sample_idx}.yaml"

                if not res_yaml_path.is_file():
                    logger(
                        f"File `{res_yaml_path}` not exists. ", 
                        log_type = "error"
                    )

                    breakpoint()
                
                res_dict = load_yaml(res_yaml_path)

                best_merged_reward_list = res_dict["best_merged_reward_list"]
                best_sample_idx = len(best_merged_reward_list) - 1

                if original:
                    reward = best_merged_reward_list[-1]
                else:
                    raise NotImplementedError(
                        f"Only support `original = True`. "
                    )
                
                if only_best:
                    sample_tuple = (folder_name, sample_idx)

                    if sample_tuple not in sample_reward_tuple_list_dict.keys():
                        sample_reward_tuple_list_dict[sample_tuple] = []

                    sample_reward_tuple_list_dict[sample_tuple].append(
                        (
                            reward, 

                            exp_name, 
                            best_sample_idx
                        )
                    )

                    # ---------= [Clean Up] =---------
                    del best_merged_reward_list
                else:
                    sample_idx_str = f"{sample_idx}"

                    # ---------= [Collect Samples] =---------
                    if collect:
                        png_root_path = folder_path / "png"
                        png_path = png_root_path / sample_idx_str / f"{sample_idx_str}_{best_sample_idx}.png"

                        dst_root_path = inference_step_root_path / folder_name / sample_idx_str
                        filename = f"{exp_name}.png"

                        copy_file(
                            src_path = png_path, 

                            dst_root_path = dst_root_path, 
                            filename = filename
                        )

                    # ---------= [Record Reward] =---------
                    if folder_name not in save_res_dict.keys():
                        save_res_dict[folder_name] = {}

                    if sample_idx_str not in save_res_dict[folder_name].keys():
                        save_res_dict[folder_name][sample_idx_str] = {}

                    save_res_dict[folder_name][sample_idx_str][exp_name] = reward

                # ---------= [Clean Up] =---------
                del res_dict
                gc.collect()

                # goto `for sample_idx`
                pass

            # goto `for folder_name`
            pass

        # ---------= [Clean Up] =---------
        del tmp_folder_name_list
        gc.collect()

        # goto `for exp_name`
        pass
    
    if only_best:
        for (folder_name, sample_idx), best_reward_list in sample_reward_tuple_list_dict.items():
            best_reward_list.sort(reverse = True)

            reward, exp_name, best_sample_idx = best_reward_list[0]

            sample_idx_str = f"{sample_idx}"

            # ---------= [Collect Samples] =---------
            if collect:
                exp_path = exp_root_path / exp_name
                folder_path = exp_path / folder_name
                png_root_path = folder_path / "png"
                png_path = png_root_path / sample_idx_str / f"{sample_idx_str}_{best_sample_idx}.png"

                dst_root_path = inference_step_root_path / folder_name / sample_idx_str
                filename = f"{exp_name}.png"

                copy_file(
                    src_path = png_path, 

                    dst_root_path = dst_root_path, 
                    filename = filename
                )

            # ---------= [Record Reward] =---------
            if folder_name not in save_res_dict.keys():
                save_res_dict[folder_name] = {}

            if sample_idx_str not in save_res_dict[folder_name].keys():
                save_res_dict[folder_name][sample_idx_str] = {}

            save_res_dict[folder_name][sample_idx_str][exp_name] = reward

            # goto `for folder_name, best_reward_list`
            pass

    logger(f"save_res_dict: {save_res_dict}")

    save_res_dict = convert_numpy_type_to_native_type(save_res_dict)
    
    save_yaml(
        save_res_dict, 

        yaml_root_path = inference_step_root_path, 
        yaml_filename = "_metric.yaml"
    )

    # `collect_result_ours_implement()` done
    pass


def collect_result_ours(
    cfg: DictConfig
):
    collect_result_ours_implement(cfg)

    # `collect_result_ours()` done
    pass
