from util.logger import logger

from typing import List, Tuple, Dict

from omegaconf import DictConfig

from pathlib import Path

import gc

from tqdm.auto import tqdm

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import (
    load_yaml, save_yaml, 
    convert_numpy_type_to_native_type
)


def get_baseline_scaling_line_chart_str_implement(
    cfg: DictConfig
):
    # ---------= [Basic Global Variables] =---------
    exp_name = get_global_variable("exp_name")
    start_time = get_global_variable("start_time")
    device = get_global_variable("device")
    seed = get_global_variable("seed")
    exp_time_str = f"{exp_name}_{start_time}"

    # ---------= [Inference Step List] =---------
    logger(f"[Inference Step List] Loading started. ")

    inference_step_list = get_true_value(cfg["task"]["inference_step_list"])

    logger(f"    inference_step_list: {inference_step_list}")

    logger(
        f"[Inference Step List] Loading finished. "
        "\n"
    )

    # ---------= [Path] =---------
    logger(f"[Path] Loading started. ")

    exp_root_pre_path = get_true_value(cfg["task"]["path"]["exp_root_pre_path"])
    exp_name_list = get_true_value(cfg["task"]["path"]["exp_name_list"])
    folder_name_list = get_true_value(cfg["task"]["path"]["folder_name_list"])
    num_sample = get_true_value(cfg["task"]["path"]["num_sample"])

    logger(f"    exp_root_pre_path: {exp_root_pre_path}")
    logger(f"    exp_name_list: {exp_name_list}")
    logger(f"    folder_name_list: {folder_name_list}")
    logger(f"    num_sample: {num_sample}")

    logger(
        f"[Path] Loading finished. "
        "\n"
    )

    # ---------= [Reward Model Type] =---------
    logger(f"[Reward Model Type] Loading started. ")

    reward_model_type = get_true_value(cfg["task"]["reward_model_type"])

    logger(f"    reward_model_type: {reward_model_type}")

    logger(
        f"[Reward Model Type] Loading finished. "
        "\n"
    )

    # ---------= [All Components Loaded] =---------
    logger(
        f"All components loaded. "
        "\n"
    )

    # ---------= [Prepare Everything] =---------
    exp_root_pre_path = Path(exp_root_pre_path)

    # ---------= [Display Quantitative Result] =---------
    def _implement_exp(
        inference_step: int
    ) -> float:
        exp_root_path = exp_root_pre_path / f"{inference_step}"

        sample_avg_reward_list_dict = {}

        for exp_name in exp_name_list:
            exp_path = exp_root_path / exp_name
            
            if folder_name_list == "all":
                tmp_folder_name_list = [
                    folder_path.stem \
                        for folder_path in exp_path.iterdir() \
                            if folder_path.stem != "_metric"
                ]
            else:
                tmp_folder_name_list = folder_name_list

            for folder_name in tmp_folder_name_list:
                folder_path = exp_path / folder_name
                result_yaml_path = folder_path / "_metric" / f"{reward_model_type}.yaml"

                if not result_yaml_path.is_file():
                    logger(
                        f"result_yaml_path: {result_yaml_path}", 
                        log_type = "error"
                    )

                    breakpoint()

                result_dict = load_yaml(result_yaml_path)
                result_list = result_dict[reward_model_type]

                for sample_idx in range(num_sample):
                    tmp_folder_name = f"{sample_idx}-{folder_name}"

                    if tmp_folder_name not in sample_avg_reward_list_dict.keys():
                        sample_avg_reward_list_dict[tmp_folder_name] = []

                    sample_avg_reward_list_dict[tmp_folder_name].append(
                        result_list[sample_idx]
                    )

                    # goto `for sample_idx`
                    pass

                # ---------= [Clean Up] =---------
                del result_dict
                gc.collect()

                # goto `for folder_name`
                pass

            # ---------= [Clean Up] =---------
            del tmp_folder_name_list
            gc.collect()

            # goto `for exp_name`
            pass

        tot_num_sample = 0
        avg_reward = 0.0
        
        for folder_name, best_reward_list in sample_avg_reward_list_dict.items():
            best_reward_list.sort(reverse = True)

            tot_num_sample += 1
            avg_reward += best_reward_list[0]

            # goto `for folder_name, best_reward_list`
            pass

        avg_reward /= tot_num_sample

        # ---------= [Clean Up] =---------
        del sample_avg_reward_list_dict
        gc.collect()

        # `_implement_exp()` done
        return avg_reward


    res_pair_str_list = []

    for inference_step in tqdm(
        inference_step_list
    ):
        res = _implement_exp(inference_step = inference_step)

        res_pair_str_list.append(
            f"({inference_step}, {res})"
        )

        # goto `for inference_step`
        pass

    res_pair_str = ", ".join(res_pair_str_list)

    res_pair_str = f"[(0, 0.0), {res_pair_str}]"

    logger("res_pair_str: ")
    logger(res_pair_str)

    # ---------= [Clean Up] =---------
    del res_pair_str_list
    gc.collect()

    # `get_baseline_scaling_line_chart_str_implement()` done
    pass


def get_baseline_scaling_line_chart_str(
    cfg: DictConfig
):
    get_baseline_scaling_line_chart_str_implement(cfg)

    # `get_baseline_scaling_line_chart_str()` done
    pass
