from util.logger import logger

from typing import List, Tuple, Dict

from omegaconf import DictConfig

import gc

from pathlib import Path

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import (
    load_yaml, save_yaml, 
    convert_numpy_type_to_native_type
)


def display_result_ours_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    concurrent_max_worker = get_global_variable("concurrent_max_worker")

    # ---------= [Reward] =---------
    logger(f"[Reward] Loading started. ")

    reward_type = get_true_value(cfg["task"]["reward"]["reward_type"])
    original = get_true_value(cfg["task"]["reward"]["original"])

    logger(f"    reward_type: {reward_type}")
    logger(f"    original: {original}")

    logger(
        f"[Reward] Loading finished. "
        "\n"
    )

    # ---------= [Exp Root Path] =---------
    logger(f"[Exp Root Path] Loading started. ")

    exp_root_path = get_true_value(cfg["task"]["path"]["exp_root_path"])
    exp_name_list = get_true_value(cfg["task"]["path"]["exp_name_list"])
    folder_name_list = get_true_value(cfg["task"]["path"]["folder_name_list"])
    num_sample = get_true_value(cfg["task"]["path"]["num_sample"])

    logger(f"    exp_root_path: {exp_root_path}")
    logger(f"    exp_name_list: {exp_name_list}")
    logger(f"    folder_name_list: {folder_name_list}")
    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Exp Root Path] Loading finished. "
        "\n"
    )

    # ---------= [Use Merged Reward] =---------
    logger(f"[Use Merged Reward] Loading started. ")

    use_merged_reward = get_true_value(cfg["task"]["use_merged_reward"])

    logger(f"    use_merged_reward: {use_merged_reward}")

    logger(
        f"[Use Merged Reward] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    exp_root_path = Path(exp_root_path)

    # ---------= [Display Quantitative Result] =---------
    # (best_reward, nfe_dynamics, nfe_intermediate, nfe_final, time_cost)
    sample_avg_best_reward_list_dict = {}
    

    def implement_folder(
        exp_path: Path, 
        folder_name: str
    ):
        for sample_idx in range(num_sample):
            folder_path = exp_path / folder_name
            res_yaml_path = folder_path / "result" / f"{sample_idx}.yaml"

            if not res_yaml_path.is_file():
                logger(
                    f"File `{res_yaml_path}` not exists. ", 
                    log_type = "error"
                )

                breakpoint()
            
            res_dict = load_yaml(res_yaml_path)

            best_trajectory_updated_nfe_cal_dynamics_list = res_dict["best_trajectory_updated_nfe_cal_dynamics_list"]
            best_trajectory_updated_nfe_cal_intermediate_reward_list = res_dict["best_trajectory_updated_nfe_cal_intermediate_reward_list"]
            best_trajectory_updated_nfe_cal_final_reward_list = res_dict["best_trajectory_updated_nfe_cal_final_reward_list"]

            best_trajectory_updated_wall_clock_time_list = res_dict["best_trajectory_updated_wall_clock_time_list"]

            if original:
                if use_merged_reward:
                    best_merged_reward_list = res_dict["best_merged_reward_list"]
                    best_reward = best_merged_reward_list[-1]
                else:
                    last_final_reward_list = res_dict["last_final_reward_list"]
                    best_reward = last_final_reward_list[-1]
            else:
                del res_dict
                    
                res_yaml_path = exp_path / folder_name / "_metric" / reward_type / f"{sample_idx}.yaml"
                res_dict = load_yaml(res_yaml_path)

                reward_list = res_dict[reward_type]
                
                best_reward = reward_list[-1]

                del reward_list

            nfe_dynamics = best_trajectory_updated_nfe_cal_dynamics_list[-1]
            nfe_intermediate = best_trajectory_updated_nfe_cal_intermediate_reward_list[-1]
            nfe_final = best_trajectory_updated_nfe_cal_final_reward_list[-1]

            time_cost = best_trajectory_updated_wall_clock_time_list[-1]

            tmp_folder_name = f"{sample_idx}-{folder_name}"

            if tmp_folder_name not in sample_avg_best_reward_list_dict.keys():
                sample_avg_best_reward_list_dict[tmp_folder_name] = []
            
            sample_avg_best_reward_list_dict[tmp_folder_name].append(
                (
                    best_reward, 
                    nfe_dynamics, nfe_intermediate, nfe_final, 
                    time_cost
                )
            )

            # ---------= [Clean Up] =---------
            del res_dict

            # goto `for sample_idx`
            pass

        # `implement_folder()` done
        pass


    for exp_name in exp_name_list:
        exp_path = exp_root_path / exp_name
        
        if folder_name_list == "all":
            tmp_folder_name_list = [
                folder_path.stem \
                    for folder_path in exp_path.iterdir() \
                        if folder_path.stem != "_metric"
            ]
        else:
            tmp_folder_name_list = folder_name_list

        for folder_name in tmp_folder_name_list:
            implement_folder(
                exp_path = exp_path, 
                folder_name = folder_name
            )
            
            # goto `for folder_name`
            pass

        # ---------= [Clean Up] =---------
        del tmp_folder_name_list
        gc.collect()

        # goto `for exp_name`
        pass

    
    tot_num_sample = 0

    avg_best_reward = 0.0

    avg_nfe_dynamics = 0.0
    avg_nfe_intermediate = 0.0
    avg_nfe_final = 0.0

    avg_time_cost = 0.0
    
    for folder_name, best_reward_tuple_list in sample_avg_best_reward_list_dict.items():
        best_reward_tuple_list.sort(
            key = lambda x: (-x[0], x[1], x[3], x[4])
        )

        (
            best_reward, 
            nfe_dynamics, nfe_intermediate, nfe_final, 
            time_cost
        ) = best_reward_tuple_list[0]
        
        # print(folder_name, best_reward)

        tot_num_sample += 1

        avg_best_reward += best_reward

        avg_nfe_dynamics += nfe_dynamics
        avg_nfe_intermediate += nfe_intermediate
        avg_nfe_final += nfe_final

        avg_time_cost += time_cost

        # goto `for folder_name, best_reward_tuple_list`
        pass

    avg_best_reward /= tot_num_sample

    avg_nfe_dynamics /= tot_num_sample
    avg_nfe_intermediate /= tot_num_sample
    avg_nfe_final /= tot_num_sample

    avg_time_cost /= tot_num_sample
    avg_time_cost_per_sample = avg_time_cost / tot_num_sample

    logger(f"[Path]")
    logger(f"    exp_root_path: {exp_root_path}")
    logger(f"    folder_name_list: {folder_name_list}")

    print(f"[Result]")
    print(f"avg_best_reward: {avg_best_reward:.4f}")
    print(f"avg_nfe_dynamics: {avg_nfe_dynamics:.2f}")
    print(f"avg_nfe_intermediate: {avg_nfe_intermediate:.2f}")
    print(f"avg_nfe_final: {avg_nfe_final:.2f}")
    print(f"avg_time_cost_per_sample: {round(avg_time_cost_per_sample)}")

    # `display_result_ours_implement()` done
    pass


def display_result_ours(
    cfg: DictConfig
):
    display_result_ours_implement(cfg)

    # `display_result_ours()` done
    pass
