from util.logger import logger

from typing import Optional, Union, List, Tuple

from pathlib import Path

import numpy as np

import gc

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import load_yaml


def get_reward_scaling_list_from_result(
    setting_root_path: Union[str, Path], 
    setting_name_list: Optional[List[str]] = None, 

    folder_name_list: Optional[List[str]] = None, 

    original: Optional[bool] = True, 
    reward_type: Optional[str] = "hps_v2", 

    num_sample: Optional[int] = 1, 

    use_merged_reward: Optional[bool] = False, 

    nfe_limit: Optional[int] = 999
) -> List[float]:
    if not isinstance(setting_root_path, Path):
        setting_root_path = Path(setting_root_path)

    if setting_name_list is None:
        setting_name_list = [
            setting_path.stem \
                for setting_path in setting_root_path.iterdir()
                    if setting_path.stem != "_metric"
        ]

    sample_res_list_dict = {}

    for setting_name in setting_name_list:
        setting_path = setting_root_path / setting_name

        if folder_name_list is None:
            folder_name_list = [
                folder_path.stem \
                    for folder_path in setting_path.iterdir()
                        if folder_path.stem != "_metric"
            ]

        for folder_name in folder_name_list:
            for sample_idx in range(num_sample):
                res_yaml_path = setting_path / folder_name / "result" / f"{sample_idx}.yaml"
                res_dict = load_yaml(res_yaml_path)

                best_trajectory_updated_nfe_cal_dynamics_list \
                    = res_dict["best_trajectory_updated_nfe_cal_dynamics_list"]
                    
                if original:
                    if use_merged_reward:
                        reward_list = res_dict["best_merged_reward_list"]
                    else:
                        reward_list = res_dict["last_final_reward_list"]
                else:
                    del res_dict
                    
                    res_yaml_path = setting_path / folder_name / "_metric" / reward_type / f"{sample_idx}.yaml"
                    res_dict = load_yaml(res_yaml_path)

                    reward_list = res_dict[reward_type]

                reward_scaling_list = [0.0] * (nfe_limit + 1)
                cur_idx = 0

                for (
                    nfe_dynamics, 
                    reward
                ) in zip(
                    best_trajectory_updated_nfe_cal_dynamics_list, 
                    reward_list
                ):
                    while cur_idx + 1 <= nfe_dynamics:
                        cur_idx += 1
                        reward_scaling_list[cur_idx] = reward

                        # goto `while cur_idx + 1 <= nfe_dynamics`
                        pass

                    # goto `for nfe_dynamics`
                    pass

                while cur_idx <= nfe_limit:
                    reward_scaling_list[cur_idx] = reward_scaling_list[cur_idx - 1]
                    cur_idx += 1

                    # goto `while cur_idx < nfe_limit`
                    pass

                tmp_folder_name = f"{sample_idx}-{folder_name}"

                if tmp_folder_name in sample_res_list_dict.keys():
                    for nfe in range(nfe_limit + 1):
                        sample_res_list_dict[tmp_folder_name][nfe] = max(
                            sample_res_list_dict[tmp_folder_name][nfe], 
                            reward_scaling_list[nfe]
                        )

                        # goto `for nfe`
                        pass
                else:
                    sample_res_list_dict[tmp_folder_name] = reward_scaling_list

                # ---------= [Clean Up] =---------
                del res_dict

                # goto `for sample_idx`
                pass

            # goto `for folder_name`
            pass

        # goto `for setting_name`
        pass

    avg_reward_scaling_list = np.array(
        [0.0] * (nfe_limit + 1)
    )
    num_sample = 0

    for tmp_folder_name, reward_scaling_list in sample_res_list_dict.items():
        avg_reward_scaling_list += reward_scaling_list
        num_sample += 1

        # goto `for tmp_folder_name, reward_scaling_list`
        pass

    assert num_sample > 0
    avg_reward_scaling_list /= num_sample

    avg_reward_scaling_list = avg_reward_scaling_list.tolist()
    
    # ---------= [Clean Up] =---------
    gc.collect()

    # `get_scaling_list_from_result()` done
    return avg_reward_scaling_list


def get_turning_point_tuple_list(
    reward_scaling_list: List[float], 

    eps: Optional[float] = 1e-8
) -> List[Tuple[int, float]]:
    turning_point_tuple_list = []

    nfe_limit = len(reward_scaling_list) - 1

    for nfe in range(nfe_limit + 1):
        if nfe == 0:
            turning_point_tuple_list.append(
                (0, 0.0)
            )
            
            continue
        
        cur_reward = reward_scaling_list[nfe]

        if abs(cur_reward - reward_scaling_list[nfe - 1]) > eps:
            turning_point_tuple_list.append(
                (nfe, cur_reward)
            )

        # goto `for nfe`
        pass

    if turning_point_tuple_list[-1][0] != nfe_limit:
        turning_point_tuple_list.append(
            (nfe_limit, turning_point_tuple_list[-1][1])
        )

    # `get_turning_point_tuple_list()` done
    return turning_point_tuple_list
