from util.logger import logger

from typing import Optional, Tuple, Union, Dict, List, Set, Callable, TypeVar, Generic, NewType, Protocol

from omegaconf import OmegaConf, DictConfig

from pathlib import Path

import torch

import gc

from tqdm.auto import tqdm

from util.basic_util import pause, get_global_variable, is_none, get_true_value
from util.image_util import load_img_path, save_pil_as_png
from util.yaml_util import (
    load_yaml, save_yaml, 
    convert_numpy_type_to_native_type
)


def gen_random_list(
    n, 
    l, r
) -> List[float]:
    res_np = np.random.uniform(
        l, r, 
        n
    )
    
    res_list = res_np.tolist()

    return res_list

def test_MDP(
    cfg: DictConfig    
):
    from .toy.my_mdp import MyMDP

    mdp = MyMDP(
        state_space = None, 
        action_space = None, 
        time_horizon = None
    )

    init_state = np.random.rand(1, 2)

    action_list = gen_random_list(
        n = 5, 
        l = -1.0, r = 1.0
    )

    (
        state_list, 
        reward_list
    ) = mdp.simulate_step_list(
        state = init_state, 
        action_list = action_list, 

        verbose = True
    )
    
    # breakpoint()

    # `test_MDP()` done
    pass

def test_MCTS(
    cfg: DictConfig    
):
    n = 5

    from .toy.my_mdp import MyMDP
    from .toy.my_mcts import MyMCTS

    mdp = MyMDP(
        state_space = None, 
        action_space = None, 
        time_horizon = n
    )

    init_state = np.random.rand(2, 3)

    logger(f"init_state: {init_state}")

    nfe_cal_dynamics_lim = 200
    nfe_cal_intermediate_reward_lim = 20
    nfe_cal_final_reward_lim = 20

    mcts = MyMCTS(
        mdp = mdp, 
        init_state = init_state, 

        nfe_cal_dynamics_lim = nfe_cal_dynamics_lim, 
        nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim, 
        nfe_cal_final_reward_lim = nfe_cal_final_reward_lim
    )

    root = mcts.run()

    mcts.display_tree()

    mcts.display_result(display_trajectory = True)

    mcts.display_info()

    # breakpoint()

    # `test_MCTS()` done
    pass

def test_controllability_gramian(
    cfg: DictConfig    
):
    from .toy.my_mdp import MyMDP
    from OT_MCTS.src.optimal_control.controllability_matrix import (
        cal_controllability_gramian_list, 
        cal_spectral_decomposition
    )

    n = 5

    mdp = MyMDP(
        state_space = None, 
        action_space = None, 
        time_horizon = n
    )

    init_state = np.random.rand(2, 3)

    logger(f"init_state: {init_state}")

    A_list, B_list = mdp.locally_linearize(
        init_state = init_state, 
        action_list = 0.5
    )

    logger(f"A_list: {A_list}")
    logger(f"B_list: {B_list}")

    W_list = cal_controllability_gramian_list(
        A_list = A_list, 
        B_list = B_list
    )

    k_proj_z = 4

    Lambda_list = [] 
    U_list = []

    for W in W_list:
        Lambda, U = cal_spectral_decomposition(
            W = W, 
            k_proj_z = k_proj_z
        )

        Lambda_list.append(Lambda)
        U_list.append(U)

    for i, (W, Lambda, U) in enumerate(
        zip(
            W_list, 
            Lambda_list, U_list
        )
    ):
        logger(f"[Timestep {i}]")

        logger(f"    W: {W}")
        logger(f"    Lambda: {Lambda}")
        logger(f"    U: {U}")

    breakpoint()

def test_lqr(
    cfg: DictConfig    
):
    from util.linalg_util.definite_matrix_util import (
        is_positive_definite, 
        is_positive_semi_definite
    )

    from OT_MCTS.src.optimal_control.optimal_control_solver import OptimalControlSolver
    
    from .toy.state_space import StateSpace
    from .toy.action_space import ActionSpace
    from .toy.my_mdp import MyMDP

    state_space = StateSpace()
    action_space = ActionSpace(
        eta_low = 0.0, 
        eta_high = 1.0
    )
    n = 5

    mdp = MyMDP(
        state_space = state_space, 
        action_space = action_space, 
        time_horizon = n
    )

    oc_solver = OptimalControlSolver(
        mdp = mdp, 

        # cost Gamma
        omega_z = 0.5, 
        omega_eta = 0.01, 

        # finite difference
        finite_difference_accuracy_order = "SECOND",  # ["SECOND", "FOURTH", "SIXTH", "EIGHTH"]
        # finite_difference_accuracy_order = "FOURTH",  # ["SECOND", "FOURTH", "SIXTH", "EIGHTH"]
        # finite_difference_accuracy_order = "SIXTH",  # ["SECOND", "FOURTH", "SIXTH", "EIGHTH"]
        # finite_difference_accuracy_order = "EIGHTH",  # ["SECOND", "FOURTH", "SIXTH", "EIGHTH"]
        finite_difference_eps = 1e-8, 

        # force positive (semi-)definite
        force_positive_semi_definite_max_tolerance = 1e-8, 
        force_positive_definite_max_tolerance = 1e-8
    )

    init_state = np.random.rand(2, 3)

    # action_list = 0.5
    action_list = [0.1, 0.2, 0.3, 0.4, 0.5]
    action_list = [
        np.array([action]) \
            for action in action_list
    ]

    logger(f"init_state: {init_state}")
    logger(f"action_list: {action_list}")

    oc_solver.init_everything_around_a_trajectory(
        init_state = init_state, 
        action_list = action_list
    )
    oc_solver.update_cost_gamma()
    oc_solver.update_dare_P()

    def check_definite(
    ):
        Gamma_zf = oc_solver.Gamma_zf
        Gamma_z_list = oc_solver.Gamma_z_list
        Gamma_eta = oc_solver.Gamma_eta

        logger(f"Gamma_zf: {Gamma_zf}")
        logger(f"Gamma_z_list: {Gamma_z_list}")
        logger(f"Gamma_eta: {Gamma_eta}")

        Gamma_zf_is_positive_semi_definite = is_positive_semi_definite(Gamma_zf)
        Gamma_z_is_positive_semi_definite_list = [
            is_positive_semi_definite(Gamma_z) \
                for Gamma_z in Gamma_z_list
        ]
        Gamma_eta_is_positive_definite = is_positive_definite(Gamma_eta)

        logger(f"Gamma_zf_is_positive_semi_definite: {Gamma_zf_is_positive_semi_definite}")
        logger(f"Gamma_z_is_positive_semi_definite_list: {Gamma_z_is_positive_semi_definite_list}")
        logger(f"Gamma_eta_is_positive_definite: {Gamma_eta_is_positive_definite}")
    
    # check_definite()

    def display_eigenvalue_list(
    ):
        Gamma_zf_eigenvalue_list = np.linalg.eigvals(Gamma_zf)
        Gamma_z_eigenvalue_list_list = [
            np.linalg.eigvals(Gamma_z) \
                for Gamma_z in Gamma_z_list
        ]
        Gamma_eta_eigenvalue_list = np.linalg.eigvals(Gamma_eta)

        logger(f"Gamma_zf_eigenvalue_list: {Gamma_zf_eigenvalue_list}")
        logger(f"Gamma_z_eigenvalue_list_list: {Gamma_z_eigenvalue_list_list}")
        logger(f"Gamma_eta_eigenvalue_list: {Gamma_eta_eigenvalue_list}")

    # display_eigenvalue_list()

    def test_next_state_estimation(
    ):
        t = 3
        
        state = oc_solver.state_list[t]
        state = state + 0.1 * np.random.rand(*state.shape)

        action = action_list[t] + 0.0 * np.random.rand(1)
    
        next_state = mdp.cal_dynamics(
            state = state, 
            action = action
        )

        next_state_estimated = oc_solver.cal_estimated_dynamics(
            state = state, 
            action = action, 
            t = t
        )

        logger(f"next_state: {next_state}")
        logger(f"next_state_estimated: {next_state_estimated}")

    # test_next_state_estimation()

    def display_dare_P(
    ):
        P_list = oc_solver.P_list

        logger(f"P_list: {P_list}")

    # display_dare_P()

    def test_optimal_action(
    ):
        t = 3
        
        state = oc_solver.state_list[t]
        state = state + 0.1 * np.random.rand(*state.shape)

        optimal_action = mdp.cal_optimal_action(state = state)

        optimal_action_estimated = oc_solver.cal_optimal_action(
            state = state, 
            t = t
        )

        optimal_reward = mdp.cal_intermediate_reward(
            state = state, 
            action = optimal_action
        )

        optimal_reward_estimated = mdp.cal_intermediate_reward(
            state = state, 
            action = optimal_action_estimated
        )

        logger(
            f"optimal_action: {optimal_action}, optimal_reward: {optimal_reward}"
        )
        logger(
            f"optimal_action_estimated: {optimal_action_estimated}, optimal_reward_estimated: {optimal_reward_estimated}"
        )

    test_optimal_action()

    # breakpoint()

def test_batch_MCTS(
    cfg: DictConfig
):
    import numpy as np

    from .toy.state_space import StateSpace
    from .toy.action_space import ActionSpace
    from .toy.my_mdp import MyMDP

    from .toy.my_mcts import MyMCTS

    ver = "numpy"

    state_shape = (2, 3)
    action_shape = (1, )

    state_space = StateSpace(
        shape = state_shape, 

        ver = ver
    )
    action_space = ActionSpace(
        eta_low = 0.0, 
        eta_high = 1.0, 

        shape = action_shape, 

        ver = ver
    )

    n = 5

    cal_dynamics_batch_size = 100
    cal_intermediate_reward_batch_size = 100
    cal_final_reward_batch_size = 100

    reward_shape = (1, )
    
    mdp = MyMDP(
        state_space = state_space, 
        action_space = action_space, 
        time_horizon = n, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size = cal_dynamics_batch_size, 
        cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        reward_shape = reward_shape
    )

    init_state = np.random.rand(2, 3)

    def test_MDP(
    ):
        action_list = [
            action_space.sample_uniform_element() \
                for _ in range(n)
        ]

        (
            state_list, 
            reward_list
        ) = mdp.simulate_step_list(
            state = init_state, 
            action_list = action_list, 

            verbose = True
        )

        logger(f"state_list: {state_list}")
        logger(f"reward_list: {reward_list}")

        # `test_MDP()` done
        pass

    # test_MDP()

    # ---------= [Upper Confidence Bound (UCB)] =---------
    exploration_coef = 2.0
    depth_coef = 2.0

    # ---------= [Expansion Policy] =---------
    # expansion_enable_importance_sampling = False
    expansion_enable_importance_sampling = True
    expansion_importance_sampling_J_star_scaling_factor = 0.95
    # expansion_importance_sampling_J_star_scaling_factor = 0.96
    # expansion_importance_sampling_J_star_scaling_factor = 0.97
    # expansion_importance_sampling_J_star_scaling_factor = 0.98
    # expansion_importance_sampling_J_star_scaling_factor = 0.99
    num_per_iteration_selection = 1
    # per_iteration_expansion_lim = 1
    per_iteration_expansion_lim = 2
    # per_iteration_expansion_lim = 3

    # ---------= [NFE Limit] =---------
    # nfe_cal_dynamics_lim = 1000
    # nfe_cal_dynamics_lim = 2000
    # nfe_cal_dynamics_lim = 3000
    nfe_cal_dynamics_lim = 5000
    # nfe_cal_dynamics_lim = 10000
    # nfe_cal_intermediate_reward_lim = 500  # 每次都更新
    nfe_cal_intermediate_reward_lim = 1000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 1500  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 2000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 2500  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 3000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 5000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 6000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 7000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 8000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 9000  # 每次都更新
    # nfe_cal_intermediate_reward_lim = 10000  # 每次都更新
    # nfe_cal_final_reward_lim = 1000
    # nfe_cal_final_reward_lim = 2000
    # nfe_cal_final_reward_lim = 3000
    nfe_cal_final_reward_lim = 5000

    # ---------= [Optimal Control] =---------
    optimal_control_online_update = True
    # optimal_control_online_update = False
    # optimal_control_update_reward_threshold = 1e-8
    optimal_control_update_reward_threshold = 0.001
    optimal_control_omega_z = 0.5
    optimal_control_omega_eta = 0.01

    # ---------= [Optimal Control Beta] =---------
    optimal_control_beta_online_update = True
    # optimal_control_beta_online_update = False
    optimal_control_beta_zeta_list = 10

    mcts = MyMCTS(
        mdp = mdp, 
        init_state = init_state, 

        # ---------= [Upper Confidence Bound (UCB)] =---------
        exploration_coef = exploration_coef, 
        depth_coef = depth_coef, 

        # ---------= [Expansion Policy] =---------
        expansion_action_sampling_policy = "uniform", 
        # expansion_action_sampling_policy = "optimal_control", 
        # expansion_action_sampling_policy = "optimal_control_beta", 
        expansion_default_action_list = None, 
        expansion_enable_importance_sampling = expansion_enable_importance_sampling, 
        expansion_importance_sampling_J_star_scaling_factor = expansion_importance_sampling_J_star_scaling_factor, 
        expansion_importance_sampling_eps = 1e-8, 
        expansion_importance_sampling_verbose = True, 
        # expansion_importance_sampling_verbose = False, 
        num_per_iteration_selection = num_per_iteration_selection, 
        per_iteration_expansion_lim = per_iteration_expansion_lim, 

        # ---------= [Simulation Policy] =---------
        simulation_action_sampling_policy = "uniform", 
        simulation_default_action_list = None, 

        # ---------= [NFE Limit] =---------
        nfe_cal_dynamics_lim = nfe_cal_dynamics_lim, 
        nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim, 
        nfe_cal_final_reward_lim = nfe_cal_final_reward_lim, 

        # ---------= [Optimal Control] =---------
        optimal_control_online_update = optimal_control_online_update, 
        optimal_control_update_reward_threshold = optimal_control_update_reward_threshold, 
        optimal_control_omega_z = optimal_control_omega_z, 
        optimal_control_omega_eta = optimal_control_omega_eta, 
        optimal_control_finite_difference_accuracy_order = "SECOND", 
        optimal_control_finite_difference_eps = 1e-8, 
        optimal_control_force_positive_semi_definite_max_tolerance = 1e-8, 
        optimal_control_force_positive_definite_max_tolerance = 1e-8, 

        # ---------= [Optimal Control Beta] =---------
        optimal_control_beta_online_update = optimal_control_beta_online_update, 
        optimal_control_beta_zeta_list = optimal_control_beta_zeta_list, 
        optimal_control_clamp_eps = 1e-8
    )

    mcts.run()

    # mcts.display_tree()

    # mcts.display_result(display_trajectory = False)
    mcts.display_result(display_trajectory = True)

    # mcts.display_info()


def test_torch_MCTS(
    cfg: DictConfig
):
    import torch

    from .toy.state_space import StateSpace
    from .toy.action_space import ActionSpace
    from .toy.my_mdp import MyMDP

    from OT_MCTS.src.monte_carlo_tree_search.lru_cache import LRUCache
    from .toy.my_mcts import MyMCTS

    dtype = "float16"
    device = "cuda"

    ver = "torch"

    state_shape = (3, 4)
    action_shape = (1, )

    state_space = StateSpace(
        shape = state_shape, 

        dtype = dtype, 
        device = device, 

        ver = ver
    )
    action_space = ActionSpace(
        eta_low = 0.0, 
        eta_high = 1.0, 

        shape = action_shape, 

        dtype = dtype, 
        device = device, 

        ver = ver
    )

    n = 5

    cal_dynamics_batch_size = 100
    cal_intermediate_reward_batch_size = 100
    cal_final_reward_batch_size = 100

    reward_shape = (1, )
    
    mdp = MyMDP(
        state_space = state_space, 
        action_space = action_space, 
        time_horizon = n, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size = cal_dynamics_batch_size, 
        cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
        cal_final_reward_batch_size = cal_final_reward_batch_size, 

        reward_shape = reward_shape
    )
    
    batch_size = 2

    init_state_list = [
        torch.rand(
            size = state_shape, 
            dtype = torch.float16
        ).to(device) \
            for _ in range(batch_size)
    ]
    
    # init_state_list.shape = (batch_size, *state_shape)
    init_state_list = torch.stack(init_state_list)
    
    def test_MDP(
    ):
        # state_list.shape = (batch_size, *state_shape)
        state_list = init_state_list.clone()

        action_list = [
            action_space.sample_uniform_element() \
                for _ in range(n)
        ]
        action_list = torch.stack(action_list)

        # action_list_list.shape = (batch_size, num_inference_step, *reward_shape)
        action_list_list = [action_list] * batch_size
        action_list_list = torch.stack(action_list_list)
        
        (
            state_list_list, 
            reward_list_list
        ) = mdp.batch_simulate_step_list(
            state_list = state_list, 
            action_list_list = action_list_list, 

            verbose = True
        )

        # logger(f"state_list_list: {state_list_list}")
        # logger(f"reward_list_list: {reward_list_list}")

        # `test_MDP()` done
        pass

    # test_MDP()
    # return 

    # ---------= [LRU Cache] =---------
    lru_cache = LRUCache(
        num_gpu_resident_lim = 100, 

        device = device
    )

    # ---------= [Upper Confidence Bound (UCB)] =---------
    exploration_coef = 2.0
    depth_coef = 2.0

    # ---------= [Expansion Policy] =---------
    expansion_enable_importance_sampling = False
    # expansion_enable_importance_sampling = True
    # `expansion_importance_sampling_J_star_scaling_factor()` 可认为不敏感
    expansion_importance_sampling_J_star_scaling_factor = 0.95
    # expansion_importance_sampling_J_star_scaling_factor = 0.96
    # expansion_importance_sampling_J_star_scaling_factor = 0.97
    # expansion_importance_sampling_J_star_scaling_factor = 0.98
    # expansion_importance_sampling_J_star_scaling_factor = 0.99
    num_per_iteration_selection = 1
    per_iteration_expansion_lim = 1
    # per_iteration_expansion_lim = 2
    # per_iteration_expansion_lim = 3

    # ---------= [NFE Limit] =---------
    # nfe_cal_dynamics_lim = 1000
    # nfe_cal_dynamics_lim = 2000
    # nfe_cal_dynamics_lim = 3000
    nfe_cal_dynamics_lim = 5000
    # nfe_cal_dynamics_lim = 10000
    # nfe_cal_intermediate_reward_lim = 50  # 
    # nfe_cal_intermediate_reward_lim = 100  # 
    # nfe_cal_intermediate_reward_lim = 500  # 
    # nfe_cal_intermediate_reward_lim = 1000  # 
    # nfe_cal_intermediate_reward_lim = 1500  # 
    # nfe_cal_intermediate_reward_lim = 2000  # 
    # nfe_cal_intermediate_reward_lim = 2500  # 
    # nfe_cal_intermediate_reward_lim = 3000  # 
    # nfe_cal_intermediate_reward_lim = 5000  # 
    # nfe_cal_intermediate_reward_lim = 6000  # 
    # nfe_cal_intermediate_reward_lim = 7000  # 
    nfe_cal_intermediate_reward_lim = 8000  # 
    # nfe_cal_intermediate_reward_lim = 9000  # 
    # nfe_cal_intermediate_reward_lim = 10000  # 
    # nfe_cal_intermediate_reward_lim = 20000  # 
    # nfe_cal_final_reward_lim = 1000
    # nfe_cal_final_reward_lim = 2000
    # nfe_cal_final_reward_lim = 3000
    nfe_cal_final_reward_lim = 5000
    # nfe_cal_final_reward_lim = 10000

    # ---------= [Optimal Control] =---------
    optimal_control_online_update = True
    # optimal_control_online_update = False
    # optimal_control_update_reward_threshold = 1e-8
    optimal_control_update_reward_threshold = 0.001
    optimal_control_omega_z = 0.5
    optimal_control_omega_eta = 0.01

    # ---------= [Optimal Control Beta] =---------
    optimal_control_beta_online_update = True
    # optimal_control_beta_online_update = False
    # optimal_control_beta_zeta_list = 10
    # optimal_control_beta_zeta_list = 9
    # optimal_control_beta_zeta_list = 8
    # optimal_control_beta_zeta_list = 7
    # optimal_control_beta_zeta_list = 6
    optimal_control_beta_zeta_list = 5
    # optimal_control_beta_zeta_list = 4
    # optimal_control_beta_zeta_list = 3
    # optimal_control_beta_zeta_list = 2
    # optimal_control_beta_zeta_list = 1

    mcts = MyMCTS(
        mdp = mdp, 
        init_state_list = init_state_list, 

        # ---------= [Upper Confidence Bound (UCB)] =---------
        exploration_coef = exploration_coef, 
        depth_coef = depth_coef, 

        # ---------= [Expansion Policy] =---------
        # expansion_action_sampling_policy = "uniform", 
        # expansion_action_sampling_policy = "optimal_control", 
        expansion_action_sampling_policy = "optimal_control_beta", 
        expansion_default_action_list = None, 
        expansion_enable_importance_sampling = expansion_enable_importance_sampling, 
        expansion_importance_sampling_J_star_scaling_factor = expansion_importance_sampling_J_star_scaling_factor, 
        expansion_importance_sampling_eps = 1e-8, 
        expansion_importance_sampling_verbose = True, 
        # expansion_importance_sampling_verbose = False, 
        num_per_iteration_selection = num_per_iteration_selection, 
        per_iteration_expansion_lim = per_iteration_expansion_lim, 

        # ---------= [Simulation Policy] =---------
        simulation_action_sampling_policy = "uniform", 
        simulation_default_action_list = None, 

        # ---------= [NFE Limit] =---------
        nfe_cal_dynamics_lim = nfe_cal_dynamics_lim, 
        nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim, 
        nfe_cal_final_reward_lim = nfe_cal_final_reward_lim, 

        # ---------= [Optimal Control] =---------
        optimal_control_online_update = optimal_control_online_update, 
        optimal_control_update_reward_threshold = optimal_control_update_reward_threshold, 
        optimal_control_omega_z = optimal_control_omega_z, 
        optimal_control_omega_eta = optimal_control_omega_eta, 
        optimal_control_finite_difference_accuracy_order = "SECOND", 
        optimal_control_finite_difference_eps = 1e-8, 
        optimal_control_force_positive_semi_definite_max_tolerance = 1e-8, 
        optimal_control_force_positive_definite_max_tolerance = 1e-8, 

        # ---------= [Optimal Control Beta] =---------
        optimal_control_beta_online_update = optimal_control_beta_online_update, 
        optimal_control_beta_zeta_list = optimal_control_beta_zeta_list, 
        optimal_control_clamp_eps = 1e-8, 

        # ---------= [Optimal Control Beta] =---------
        lru_cache = lru_cache, 

        # ---------= [Dtype] =---------
        dtype = dtype
    )

    mcts.run(
        display_result = True, 

        display_trajectory = False, 
        # display_trajectory = True, 
        display_state = True, 
        display_action = True, 
        display_reward = True
    )

    # mcts.display_info()

    # `test_torch_MCTS()` done
    pass


def test_mscoco_2014(
    cfg: DictConfig
):
    from prompt_manager.util import get_prompt_manager

    # num_prompt_lim_per_img = 5
    num_prompt_lim_per_img = None

    prompt_manager_dict = {
        "prompt_manager_type": "MSCOCO_2014_5K_Test", 

        "num_prompt_lim_per_img": num_prompt_lim_per_img
    }

    prompt_manager = get_prompt_manager(
        prompt_manager_dict = prompt_manager_dict
    )

    prompt_manager.load_prompt_list()
    prompt_manager.prepare_everything(
        shuffle = False
    )

    prompt_list = prompt_manager.prompt_list
    folder_name_list = prompt_manager.folder_name_list

    num_prompt = len(prompt_list)

    logger(f"num_prompt: {num_prompt}")

    img_filename_prompt_list_tuple_list \
        = prompt_manager.img_filename_prompt_list_tuple_list

    max_num_prompt = max(
        [
            len(prompt_list) \
                for (img_filename, prompt_list) \
                    in img_filename_prompt_list_tuple_list
        ]
    )

    min_num_prompt = min(
        [
            len(prompt_list) \
                for (img_filename, prompt_list) \
                    in img_filename_prompt_list_tuple_list
        ]
    )

    logger(f"max_num_prompt: {max_num_prompt}")
    logger(f"min_num_prompt: {min_num_prompt}")

    breakpoint()

    # `test_mscoco_2014()` done
    pass


def test_hpsv2_reward_model(
    cfg: DictConfig
):
    from util.yaml_util import load_yaml
    from util.image_util import load_img_path

    from reward_model.hps_v2 import HumanPreferenceScore_v2_RewardModel

    prompt_list = []
    img_pil_list = []

    folder_path_list = [
        Path("./tmp/run_sample_scheduled/sd-turbo/no_eta_eps/3d_human_face_made"), 
        Path("./tmp/run_sample_scheduled/sd-turbo/no_eta_eps/a_3d_illustrated_chubby")
    ]

    for folder_path in folder_path_list:
        cfg_root_path = folder_path / "cfg"
        png_root_path = folder_path / "png"

        for sample_idx in range(5):
            cfg_path = cfg_root_path / f"{sample_idx}.yaml"
            png_path = png_root_path / f"{sample_idx}.png"

            prompt = load_yaml(cfg_path)["sample"]["prompt"]
            img_pil = load_img_path(png_path)

            prompt_list.append(prompt)
            img_pil_list.append(img_pil)

    reward_shape = (1, )
    reward_dtype = "float16"
    reward_device = "cuda"

    hpsv2_reward_model = HumanPreferenceScore_v2_RewardModel(
        reward_shape = reward_shape, 
        reward_dtype = reward_dtype, 
        reward_device = reward_device
    )

    # final_reward_list.shape = (num_img, 1)
    final_reward_list = hpsv2_reward_model.cal_final_reward(
        img_pil_list = img_pil_list, 
        prompt_list = prompt_list
    )

    logger(f"final_reward_list: {final_reward_list}")

    breakpoint()

    # `test_hpsv2_reward_model()` done
    pass


def cal_avg_mpd_lpips(
    cfg: DictConfig
):
    from util.yaml_util import load_yaml

    # sd_type = "sd-turbo"
    sd_type = "sd"

    pipeline_root_path = Path(f"./tmp/run_sample_scheduled/{sd_type}")

    style_list = [
        "no_eta_eps", 
        "fixed_eps-fixed_eta", 
        "fixed_eps-random_eta", 

        "fixed_init_noise-fixed_eta", 

        "fixed_init_noise-fixed_eps", 

        "fixed_init_noise-fixed_eps_1"
    ]

    for style in style_list:
        style_name_list = []

        if style == "no_eta_eps":
            style_name_list.append(style)
        elif style != "fixed_init_noise-fixed_eps_1":
            for i in range(3):
                style_name_list.append(f"{style}_{i}")
        else:
            for i in range(3, 6):
                style_name_list.append(f"fixed_init_noise-fixed_eps_{i}")

        num_sample = len(style_name_list)

        sum_mpd_prompt = 0.0
        sum_mpd_all = 0.0

        for style_name in style_name_list:
            setting_root_path = pipeline_root_path / style_name
            mpd_lpips_path = setting_root_path / "metric" / "mpd_lpips.yaml"

            mpd_prompt = 0.0
            mpd_all = 0.0

            if mpd_lpips_path.is_file():
                mpd_lpips_dict = load_yaml(mpd_lpips_path)

                mpd_prompt = mpd_lpips_dict["mpd_prompt"]
                mpd_all = mpd_lpips_dict["mpd_all"]

            logger(f"style_name: {style_name}")
            logger(f"    mpd_prompt: {mpd_prompt:.4f}")
            logger(f"    mpd_all: {mpd_all:.4f}")

            sum_mpd_prompt += mpd_prompt
            sum_mpd_all += mpd_all

            # goto `for style_name`
            pass

        avg_mpd_prompt = sum_mpd_prompt / num_sample
        avg_mpd_all = sum_mpd_all / num_sample

        logger(f"style: {style}")
        logger(f"    avg_mpd_prompt: {avg_mpd_prompt:.4f}")
        logger(f"    avg_mpd_all: {avg_mpd_all:.4f}")

        logger(f"\n")

        # goto `for style`
        pass

    # `cal_avg_mpd_lpips()` done
    pass


def test_polar(
    cfg: DictConfig
):
    from util.yaml_util import load_yaml

    yaml_path = "./tmp/run_optimal_control_mcts/sd/1.2_1.35/uniform/beta-4/immediate_posterior_mean/no_difference/an_english_woman_plays/action_list/1/1_4.yaml"
    eta_dict = load_yaml(yaml_path)

    eta_list = eta_dict["action_list"]
    eta_list = [
        eta[0] \
            for eta in eta_list
    ]

    print(eta_list)

    # `test_polar()` done
    pass


def test_hpsv2_prompt(
    cfg: DictConfig
):
    from util.yaml_util import load_yaml
    from util.json_util import load_json

    cfg_yaml_path = Path("./config/dataset/hpd_v2.yaml")
    category_path_dict = load_yaml(cfg_yaml_path)["category_path_dict"]

    prompt_list = []

    for category_name, prompt_json_path in category_path_dict.items():
        if category_name == "concept-art":
            prompt_list += load_json(prompt_json_path)[520: 524]
        
        # goto `for category_name, prompt_json_path`
        pass

    print(prompt_list)

    # `test_hpsv2_prompt()` done
    pass


def test_depth_reward(
    cfg: DictConfig
):
    from transformers import pipeline
    import numpy as np
    import torch.nn.functional as F
    from torchvision import transforms

    depth_anything_v2_pipeline_root_path = "/mnt/d/hytidel/model/depth-anything/Depth-Anything-V2-Small-hf"
    ref_img_root_path = Path("./tmp/ref_img")

    img_filename = "an_anime_girl.png"
    img_path = ref_img_root_path / img_filename

    pipe = pipeline(
        task = "depth-estimation", 
        model = depth_anything_v2_pipeline_root_path
    )

    img_pil = load_img_path(img_path)

    depth_pil = pipe(img_pil)["depth"]

    # norm to [0, 1]
    # depth_np = np.array(depth_pil) / 255.0

    save_pil_as_png(
        depth_pil, 
        png_root_path = ref_img_root_path, 
        png_filename = "depth.png"
    )

    tsfm = transforms.ToTensor()

    # depth_tensor_1.shape = (1, height, width)
    depth_tensor_1 = tsfm(depth_pil)
    depth_tensor_2 = tsfm(depth_pil)

    if depth_tensor_1.shape != depth_tensor_2.shape:
        raise ValueError(
            f"Depth maps must be provided with the same size. "
        )

    mae_loss = F.l1_loss(depth_tensor_1, depth_tensor_2) \
        .item()
    
    # `test_depth_reward()` done
    pass


def test_mpd(
    cfg: DictConfig
):
    from util.yaml_util import load_yaml

    # ---------= [2-step SD-Turbo] =---------
    # setting_root_path = Path("./tmp/run_sample_scheduled/sd-turbo/2")

    # a.0)
    # setting_name_list = [
    #     "no_eta_eps_2025-04-24_14-57-24"
    # ]

    # a.1)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_0.5_2025-04-25_09-29-27"
    # ]

    # a.2)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_1.0_2025-04-25_09-41-30"
    # ]

    # b.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_20-40-16", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_20-44-55", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_20-49-40"
    # ]

    # b.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_20-40-21", 
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_20-45-09", 
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_20-50-03"
    # ]

    # c.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.0_2025-04-25_10-41-00"
    # ]

    # c.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.5_2025-04-25_10-52-44"
    # ]

    # ---------= [15-step SD v1.4] =---------
    # setting_root_path = Path("./tmp/run_sample_scheduled/sd_v1_4/15")

    # a.0)
    # setting_name_list = [
    #     "no_eta_eps_2025-04-25_09-26-20"
    # ]

    # a.1)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_0.5_2025-04-25_10-13-52"
    # ]

    # a.2)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_1.0_2025-04-25_09-38-09"
    # ]

    # b.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_20-40-51", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_20-59-51", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_21-17-55"
    # ]

    # b.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_20-40-55", 
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_21-00-17", 
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_21-18-32"
    # ]

    # c.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.0_2025-04-26_11-53-53"
    # ]

    # c.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.5_2025-04-25_12-34-03"
    # ]

    # ---------= [20-step SD v1.4] =---------
    # setting_root_path = Path("./tmp/run_sample_scheduled/sd_v1_4/20")

    # a.0)
    # setting_name_list = [
    #     "no_eta_eps_2025-04-25_09-26-23"
    # ]

    # a.1)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_0.5_2025-04-25_10-25-45"
    # ]

    # a.2)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_1.0_2025-04-25_09-41-08"
    # ]

    # b.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_21-38-11", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_22-03-37", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_22-27-57"
    # ]

    # b.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_22-42-17", 
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_23-08-03", 
    #     "fixed_init_noise-fixed_eta_1.0_2025-05-23_23-32-58"
    # ]

    # c.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.0_2025-04-26_11-53-56"
    # ]

    # c.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.5_2025-04-25_13-20-43"
    # ]

    # ---------= [25-step SD v1.4] =---------
    setting_root_path = Path("./tmp/run_sample_scheduled/sd_v1_4/25")

    # a.0)
    # setting_name_list = [
    #     "no_eta_eps_2025-04-25_09-26-26"
    # ]

    # a.1)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_0.5_2025-04-25_10-36-22"
    # ]

    # a.2)
    # setting_name_list = [
    #     "fixed_eps-fixed_eta_1.0_2025-04-25_09-43-55"
    # ]

    # b.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_20-56-05", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_21-27-16", 
    #     "fixed_init_noise-fixed_eta_0.5_2025-05-23_21-57-31"
    # ]

    # b.1)
    setting_name_list = [
        "fixed_init_noise-fixed_eta_1.0_2025-05-23_22-41-58", 
        "fixed_init_noise-fixed_eta_1.0_2025-05-23_23-11-36", 
        "fixed_init_noise-fixed_eta_1.0_2025-05-23_23-41-32"
    ]

    # c.0)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.0_2025-04-26_11-53-59"
    # ]

    # c.1)
    # setting_name_list = [
    #     "fixed_init_noise-fixed_eps_0.5_2025-04-25_14-02-53"
    # ]

    mpd_dict = {}

    for setting_name in setting_name_list:
        setting_path = setting_root_path / setting_name

        res_yaml_path = setting_path / "_metric" / "mpd_lpips.yaml"
        res_dict = load_yaml(res_yaml_path)

        for folder_name, mpd in res_dict.items():
            if folder_name == "mpd_prompt":
                continue

            if folder_name not in mpd_dict.keys():
                mpd_dict[folder_name] = mpd
            else:
                mpd_dict[folder_name] = max(mpd_dict[folder_name], mpd)

            # goto `for folder_name, mpd` pass
            pass

        # goto `for setting_name`
        pass

    avg_mpd = 0.0
    num_folder = 0

    for folder_name, mpd in mpd_dict.items():
        avg_mpd += mpd
        num_folder += 1

        # goto `for folder_name, mpd`
        pass

    avg_mpd /= num_folder
    
    print(f"avg_mpd: {avg_mpd:.4f}")

    # `test_mpd()` done
    pass


def test_scaling_line_chart(
    cfg: DictConfig
):
    from .util.result_util import (
        get_reward_scaling_list_from_result, 
        get_turning_point_tuple_list
    )

    setting_root_path = Path(
        # "./tmp/run_optimal_control_mcts/sd_v1_4/HumanPreferenceDataset_v2/hps_v2/latent_reward/max_reward_max/pseudo_latent_as_final/15/2.0/12/beta-soft-9/immediate_posterior_mean", 
        "./tmp/run_optimal_control_mcts/sd_v1_4/HumanPreferenceDataset_v2/hps_v2/latent_reward/max_reward_max/pseudo_latent_as_final-5/15/2.0/12/beta-soft-3/immediate_posterior_mean"
    )
    
    setting_name_list = [
        # NFE 500
        # "max_max_2025-05-22_23-08-43", 
        # "max_max_2025-05-23_03-40-09", 
        # "max_max_2025-05-23_07-25-44"

        # NFE 999
        # "max_max_2025-05-23_17-52-22", 
        # "max_max_2025-05-23_17-52-21", 
        # "max_max_2025-05-23_17-52-19"
        "max_max_999_2025-06-06_08-13-36", 
        "max_max_999_2025-06-06_08-13-38", 
        "max_max_999_2025-06-06_08-13-41"
    ]

    num_sample = 2

    use_merged_reward = True

    avg_reward_scaling_list = get_reward_scaling_list_from_result(
        setting_root_path = setting_root_path, 
        setting_name_list = setting_name_list, 

        num_sample = num_sample, 

        use_merged_reward = use_merged_reward        
    )

    turning_point_tuple_list = get_turning_point_tuple_list(
        reward_scaling_list = avg_reward_scaling_list
    )

    print(f"avg_reward_scaling_list: {avg_reward_scaling_list}")
    print(f"turning_point_tuple_list: {turning_point_tuple_list}")

    # `test_scaling_line_chart()` done
    pass


def plot_scaling_line_chart(
    cfg: DictConfig
):
    import matplotlib.pyplot as plt
    
    from util.plot_util import (
        get_line_chart, 
        save_plot
    )

    # beta_soft_9_500_turning_point_tuple_list = eval(
    #     f"[(0, 0.0), (1, 22.1703125), (2, 28.59453125), (16, 29.0375), (30, 29.1640625), (43, 29.18671875), (44, 29.21640625), (55, 29.23828125), (57, 29.2765625), (58, 29.28046875), (68, 29.28984375), (69, 29.30078125), (70, 29.3046875), (77, 29.33984375), (79, 29.34453125), (80, 29.346875), (81, 29.35078125), (82, 29.3546875), (83, 29.35859375), (87, 29.3640625), (100, 29.37265625), (104, 29.3921875), (106, 29.396875), (109, 29.40390625), (113, 29.40859375), (114, 29.41015625), (118, 29.41640625), (122, 29.42109375), (126, 29.42265625), (130, 29.42890625), (131, 29.4328125), (132, 29.4421875), (137, 29.44609375), (147, 29.496875), (151, 29.5125), (152, 29.51875), (158, 29.534375), (164, 29.53515625), (179, 29.5375), (184, 29.53984375), (188, 29.54375), (203, 29.5609375), (208, 29.56328125), (209, 29.57421875), (211, 29.59140625), (219, 29.59375), (224, 29.59765625), (260, 29.6046875), (261, 29.6078125), (275, 29.609375), (304, 29.6125), (305, 29.62734375), (316, 29.62890625), (325, 29.63203125), (345, 29.6359375), (364, 29.6375), (385, 29.65234375), (386, 29.65546875), (394, 29.6671875), (397, 29.6703125), (402, 29.67421875), (408, 29.67578125), (999, 29.67578125)]"
    # )

    beta_soft_9_turning_point_tuple_list = eval(
        f"[(0, 0.0), (1, 22.1703125), (2, 28.59453125), (16, 29.0375), (30, 29.1640625), (43, 29.18671875), (44, 29.21640625), (55, 29.23828125), (56, 29.240625), (57, 29.27890625), (58, 29.2828125), (68, 29.2921875), (69, 29.303125), (70, 29.30703125), (77, 29.3421875), (79, 29.34765625), (80, 29.35), (81, 29.35390625), (82, 29.3578125), (83, 29.36171875), (87, 29.371875), (100, 29.38046875), (104, 29.4), (106, 29.4046875), (109, 29.4109375), (113, 29.415625), (114, 29.4171875), (118, 29.4234375), (122, 29.4265625), (126, 29.428125), (130, 29.434375), (131, 29.43828125), (132, 29.44765625), (137, 29.4515625), (147, 29.50234375), (151, 29.51796875), (152, 29.52421875), (158, 29.53984375), (164, 29.540625), (179, 29.54375), (184, 29.54609375), (188, 29.55), (195, 29.5546875), (196, 29.55703125), (199, 29.5578125), (203, 29.575), (208, 29.5765625), (209, 29.5875), (211, 29.6046875), (219, 29.609375), (239, 29.61328125), (260, 29.6203125), (261, 29.62109375), (275, 29.62265625), (293, 29.62578125), (304, 29.628125), (305, 29.64296875), (306, 29.64921875), (316, 29.65078125), (325, 29.65703125), (345, 29.659375), (364, 29.665625), (374, 29.6703125), (397, 29.67109375), (413, 29.67265625), (439, 29.675), (461, 29.69375), (467, 29.69453125), (484, 29.6953125), (508, 29.69765625), (519, 29.70546875), (557, 29.70859375), (562, 29.71015625), (567, 29.7109375), (584, 29.721875), (588, 29.72578125), (616, 29.72890625), (632, 29.73046875), (670, 29.73203125), (695, 29.734375), (807, 29.73515625), (999, 29.73515625)]"
    )

    ddpm_turning_point_tuple_list = eval(
        f"[(0, 0.0), (15, 28.0461), (20, 28.1703), (25, 28.3680), (30, 28.7742), (50, 28.9844), (100, 28.7711), (150, 28.6812), (200, 28.9469), (250, 29.0656), (300, 28.7219), (350, 28.1695), (400, 28.5727), (450, 28.7125), (500, 29.3750), (550, 25.4492), (600, 26.8977), (650, 27.7992), (700, 27.9320), (750, 28.1438), (800, 28.4672), (850, 28.8313), (900, 28.9578), (950, 28.9867), (999, 29.0266)]"
    )

    ddim_turning_point_tuple_list = eval(
        f"[(0, 0.0), (15, 27.6383), (20, 27.7133), (25, 28.0758), (30, 28.1234), (50, 28.1695), (100, 28.1094), (150, 28.0031), (200, 28.1687), (250, 28.1461), (300, 28.0883), (350, 27.4187), (400, 27.6812), (450, 28.0859), (500, 27.9586), (550, 23.6062), (600, 24.5219), (650, 26.6187), (700, 27.3766), (750, 27.0867), (800, 27.6711), (850, 27.5352), (900, 27.5352), (950, 28.1164), (999, 27.9312)]"
    )

    # marker_list = [
    #     'o', 
    #     'o', 
    #     'o'
    # ]
    marker_list = [
        None, 
        None, 
        None
    ]

    label_list = [
        "Ours", 
        "DDPM", 
        "DDIM"
    ]
    color_list = [
        "#fda99f",  # 浅红
        "#a0acfe",  # 浅蓝 
        "#9be3aa"  # 浅绿
    ]

    figsize = (14, 8)
    fig, ax = plt.subplots(figsize = figsize)


    def plot_with_turning_point_tuple_list(
        idx: int, 
        turning_point_tuple_list: List[Tuple[int, float]]
    ):
        x_list = []
        y_list = []

        for (x, y) in turning_point_tuple_list:
            x_list.append(x)
            y_list.append(y)

            # goto `for (x, y)`
            pass

        ax.plot(
            x_list, y_list, 

            marker = marker_list[idx], 
            label = label_list[idx], 
            color = color_list[idx], 

            linewidth = 3
        )

        # `plot_with_turning_point_tuple_list()` done
        pass

    plot_with_turning_point_tuple_list(
        idx = 0, 
        turning_point_tuple_list = beta_soft_9_turning_point_tuple_list
    )
    plot_with_turning_point_tuple_list(
        idx = 1, 
        turning_point_tuple_list = ddpm_turning_point_tuple_list
    )
    plot_with_turning_point_tuple_list(
        idx = 2, 
        turning_point_tuple_list = ddim_turning_point_tuple_list
    )

    ax.set_xlabel(
        "NFE Dynamics", 
        fontsize = 20
    )
    ax.set_ylabel(
        "HPS v2", 
        fontsize = 20
    )

    ax.set_ylim(22, 30)

    ax.legend(
        loc = "lower right", 
        fontsize = 20
    )

    ax.grid(True)

    fig.tight_layout()

    save_plot_root_path = "./tmp/scaling_line_chart"
    save_plot_filename = f"ours_ddpm_ddim.png"

    save_plot(
        fig = fig, 

        save_plot_root_path = save_plot_root_path, 
        save_plot_filename = save_plot_filename
    )

    # `plot_scaling_line_chart()` done
    pass


def plot_polar(
    cfg: DictConfig
):
    import matplotlib.pyplot as plt
    
    from util.plot_util import save_plot

    import numpy as np

    n = 15

    ddpm_eta_list = [1.0] * n
    
    deterministic_ddim_eta_list = [0.0] * n

    # 1_1
    ours_eta_list_1_1 = [
        0.38064852356910706, 
        0.6287875771522522, 
        0.5687167644500732,
        0.536539614200592,
        0.2381202131509781,
        0.43918555974960327,
        0.48655834794044495,
        0.11110340803861618,
        0.6663500666618347,
        0.47652822732925415,
        0.46543702483177185,
        0.40869998931884766,
        0.8437999486923218,
        0.22041979432106018,
        0.7386772632598877
    ]

    # 1_9
    ours_eta_list_1_9 = [
        0.38064852356910706, 
        0.3883456885814667, 
        0.27751514315605164, 
        0.9556837677955627, 
        0.1273747831583023, 
        0.32568246126174927, 
        0.8463485240936279, 
        0.10854637622833252, 
        0.9469668865203857, 
        0.29507192969322205, 
        0.19123908877372742, 
        0.09541206061840057, 
        0.14919282495975494, 
        0.1985294222831726, 
        0.9189828634262085
    ]

    r_list = np.linspace(
        0, 1, 
        n + 1
    )
    
    figsize = (10, 8)
    fig = plt.figure(
        figsize = figsize
    )
    ax = fig.add_subplot(
        111, 
        projection = "polar"
    )

    # 不显示半圆框线
    # ax.spines['polar'].set_visible(False)

    # 逆时针 [0, \pi]
    ax.set_thetamin(180)
    ax.set_thetamax(0)

    # 半径 1
    # ax.set_rmax(1)

    ax.set_rlabel_position(180)

    # 不显示半径刻度
    ax.set_rticks([])
    ax.set_yticklabels([])

    x_tick_rad_str_list = [
        r"0", 
        r"$\pi$ / 6", 
        r"$\pi$ / 3", 
        r"$\pi$ / 2", 
        r"2 $\pi$ / 3", 
        r"5 $\pi$ / 6", 
        r"$\pi$", 
    ]

    ax.set_xticklabels(
        x_tick_rad_str_list, 

        fontsize = 15
    )

    ax.set_rgrids(
        np.arange(
            0.0, 1.0, 
            1.0 / n
        )
    )

    def plot_with_eta_list(
        eta_list: List[float], 

        marker: str, 
        label: str, 

        # color_0: str, 
        color: str, 
    ):
        theta_list = [
            eta * np.pi \
                for eta in eta_list
        ]

        theta_list = [theta_list[0]] + theta_list

        ax.plot(
            theta_list, r_list, 

            marker = marker, 
            color = color, 
            label = label, 

            linewidth = 2, 
            markersize = 5
        )

        # # 逆时针 [0, \pi]
        # ax.set_thetamin(180)
        # ax.set_thetamax(0)

        # 半径 1
        ax.set_rmax(1)

        # ax.set_rlabel_position(180)

        # # 不显示半径刻度
        # ax.set_rticks([])
        # ax.set_yticklabels([])

        # x_tick_rad_str_list = [
        #     r"0", 
        #     r"$\pi$ / 6", 
        #     r"$\pi$ / 3", 
        #     r"$\pi$ / 2", 
        #     r"2 $\pi$ / 3", 
        #     r"5 $\pi$ / 6", 
        #     r"$\pi$", 
        # ]

        # ax.set_xticklabels(
        #     x_tick_rad_str_list, 

        #     fontsize = 10
        # )

        # ax.set_rgrids(
        #     np.arange(
        #         0.0, 1.0, 
        #         1.0 / n
        #     )
        # )

        # `plot_with_eta_list()` done
        pass

    
    # DDPM
    plot_with_eta_list(
        eta_list = ddpm_eta_list, 

        marker = 'o', 
        color = "#a0acfe",  # 浅蓝
        label = "DDPM"
    )

    # det. DDIM
    plot_with_eta_list(
        eta_list = deterministic_ddim_eta_list, 

        marker = 'o', 
        color = "#9be3aa",  # 浅蓝
        label = "DDIM"
    )

    # Ours 1_1
    plot_with_eta_list(
        eta_list = ours_eta_list_1_1, 

        marker = 'o', 
        color = "#fda99f",  # 浅红
        label = "Ours-1"
    )

    # Ours 1_9
    plot_with_eta_list(
        eta_list = ours_eta_list_1_9, 

        marker = 'o', 
        color = "#f2d7a6",  # 浅红
        label = "Ours-2"
    )

    ax.legend(
        loc = "upper right", 

        fontsize = 20, 

        ncol = 2
    )

    fig.tight_layout()

    save_plot_root_path = "./tmp/polar"
    save_plot_filename = f"ours_ddpm_ddim.png"

    save_plot(
        fig = fig, 

        save_plot_root_path = save_plot_root_path, 
        save_plot_filename = save_plot_filename
    )

    # `plot_polar()` done
    pass


def test_image_reward(
    cfg: DictConfig
):
    import ImageReward
    import torch

    img_root_path = Path("./tmp/run_sample_scheduled/sd_v1_4/15/fixed_eps-fixed_eta_0.5_2025-04-25_10-13-52/a_3d_illustrated_chubby/png/")

    num_img = 5
    img_pil_list = [
        load_img_path(img_root_path / f"{img_idx}.png") \
            for img_idx in range(num_img)
    ]

    prompt = "A 3D illustrated chubby room with studio lighting."

    model = ImageReward.load("/mnt/d/hytidel/model/THUDM/ImageReward/ImageReward.pt")

    with torch.no_grad():
        ranking_list, reward_list = model.inference_rank(prompt, img_pil_list)

        print(f"ranking_list: {ranking_list}")
        print(f"reward_list: {reward_list}")

        for img_idx in range(num_img):
            score = model.score(
                prompt, 
                img_pil_list[img_idx]
            )

            print(f"[Img {img_idx}] score: {score}")

            # goto `for img_idx`
            pass

    # `test_image_reward()` done
    pass


def test_pick_score(
    cfg: DictConfig
):
    from transformers import AutoProcessor, AutoModel
    import torch

    device = "cuda"

    img_root_path = Path("./tmp/run_sample_scheduled/sd_v1_4/15/fixed_eps-fixed_eta_0.5_2025-04-25_10-13-52/a_3d_illustrated_chubby/png/")

    num_img = 5
    img_pil_list = [
        load_img_path(img_root_path / f"{img_idx}.png") \
            for img_idx in range(num_img)
    ]

    prompt = "A 3D illustrated chubby room with studio lighting."

    # open_clip_model_ckpt_path = "/mnt/d/hytidel/model/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    open_clip_model_ckpt_path = "/mnt/d/hytidel/model/laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    pick_score_model_ckpt_path = "/mnt/d/hytidel/model/yuvalkirstain/PickScore_v1"

    processor = AutoProcessor.from_pretrained(open_clip_model_ckpt_path)
    model = AutoModel.from_pretrained(pick_score_model_ckpt_path) \
        .eval() \
        .to(device)
    
    img_input_list = processor(
        images = img_pil_list, 
        padding = True, 
        truncation = True, 
        max_length = 77, 
        return_tensors = "pt"
    ).to(device)

    text_input_list = processor(
        text = prompt, 
        padding = True, 
        truncation = True, 
        max_length = 77, 
        return_tensors = "pt"
    ).to(device)

    with torch.no_grad():
        img_emb_list = model.get_image_features(**img_input_list)
        img_emb_list = img_emb_list / torch.norm(
            img_emb_list, 
            dim = -1, 
            keepdim = True
        )

        text_emb_list = model.get_text_features(**text_input_list)
        text_emb_list = text_emb_list / torch.norm(
            text_emb_list, 
            dim = -1, 
            keepdim = True
        )

        score_list = model.logit_scale.exp() * (img_emb_list @ text_emb_list.T)
    
    print(f"score_list: {score_list}")

    # `test_pick_score()` done
    pass


def get_50_prompts_str(
    cfg: DictConfig
):
    # folder_root_path = "./tmp/run_optimal_control_mcts/sd_v1_4/HumanPreferenceDataset_v2/hps_v2/latent_reward/max_reward_max/pseudo_latent_as_final-5/15/2.0/12/beta-soft-3/immediate_posterior_mean/max_max_999_2025-06-06_08-13-36"
    folder_root_path = "./tmp/run_optimal_control_mcts/sd_v1_4/HumanPreferenceDataset_v2/hps_v2/disabled/sparse_reward_average/final_latent_only/15/2.0/12/uniform/immediate_posterior_mean/sparse_average_2025-06-02_07-29-16"
    folder_root_path = Path(folder_root_path)

    folder_path_list = list(
        folder_root_path.iterdir()
    )

    folder_name_list = [
        str(folder_path.stem) \
            for folder_path in folder_path_list
    ]

    # folder_name_list = [
    #     f"\'{folder_name}\'" \
    #         for folder_name in folder_name_list
    # ]
    folder_name_list = [
        f"{folder_name}" \
            for folder_name in folder_name_list
    ]

    folder_name_list_str = ",".join(folder_name_list)

    print(f"folder_name_list_str: ")
    print(f"[{folder_name_list_str}]")

    # goto `get_50_prompts_str()`
    pass


def test_csv(
    cfg: DictConfig
):
    from util.csv_util import (
        load_csv, 
        get_element
    )

    csv_path = "/mnt/d/hytidel/dataset/nlphuji/mscoco_2014_5k_test_image_text_retrieval/test_5k_mscoco_2014.csv"

    data_frame = load_csv(csv_path)

    num_row, num_col = data_frame.shape

    for row_idx in range(num_row):
        img_filename = get_element(
            data_frame = data_frame, 

            row_idx = row_idx, 
            col_name = "filename"
        )

        prompt_list = get_element(
            data_frame = data_frame, 

            row_idx = row_idx, 
            col_name = "raw"
        )
        prompt_list = eval(prompt_list)


        breakpoint()

        # goto `for row_idx`
        pass


    # `test_csv()` done
    pass


def test_mscoco(
    cfg: DictConfig
):
    # from prompt_manager.util import get_prompt_manager

    # prompt_manager_dict = {
    #     "prompt_manager_type": "MSCOCO_2014_5K_Test",
    #     "cfg_yaml_path": None, 

    #     "num_prompt_lim_per_img": None
    # }

    # prompt_manager = get_prompt_manager(
    #     prompt_manager_dict = prompt_manager_dict
    # )

    # prompt_manager.load_prompt_list()

    # breakpoint()

    # setting_root_path = Path("./tmp/run_sample_scheduled/sdxl/DrawBench/30/ddpm_3072")
    setting_root_path = Path("./tmp/run_optimal_control_mcts/sd_v1_4/HumanPreferenceDataset_v2/hps_v2/latent_reward/max_reward_max/pseudo_latent_as_final-5/15/2.0/12/beta-soft-3/immediate_posterior_mean/max_max_999_3072")

    folder_path_list = list(
        setting_root_path.iterdir()
    )

    # print(folder_path_list)
    # print()

    folder_name_list = [
        folder_path.name \
            for folder_path in folder_path_list \
                if folder_path.name != "_metric"
    ]

    folder_name_list = [
        folder_name if (',' not in folder_name) \
            else f"'{folder_name}'" \
                for folder_name in folder_name_list
    ]
    
    folder_name_list_str = ",".join(folder_name_list)
    folder_name_list_str_1 = ", ".join(folder_name_list)

    print(f"folder_name_list_str: {folder_name_list_str}")
    print()
    print(f"folder_name_list_str_1: {folder_name_list_str_1}")

    # `test_mscoco()` done
    pass


def test_pixart_alpha(
    cfg: DictConfig
):
    from util.pipeline_util import (
        load_pipeline, load_scheduler, 
        get_inference_step_minus_one, 
        get_pipeline_category_and_type, 
        img_latent_to_pil, 
        get_folder_name
    )
    from util.image_util import save_pil_as_png
    from util.torch_util import get_latent

    device = get_global_variable("device")

    # pipeline_type = "StableDiffusion3Pipeline"
    pipeline_type = "PixArtAlphaPipeline"
    # pipeline_path = "/home/skl/hytidel/model/stabilityai/stable-diffusion-3.5-medium"
    pipeline_path = "/home/skl/hytidel/model/PixArt-XL-2-1024-MS"
    pipeline_torch_dtype = "float32"
    pipeline_variant = None

    scheduler_type = "DPMSolverMultistepScheduler"
    # scheduler_type = "DDIMScheduler"

    (
        pipeline_category_name, 
        pipeline_type_name
    ) = get_pipeline_category_and_type(
        pipeline_path = pipeline_path
    )

    pipeline = load_pipeline(
        pipeline_type = pipeline_type, 
        pipeline_path = pipeline_path, 
        torch_dtype = pipeline_torch_dtype, 
        variant = pipeline_variant
    )
    
    pipeline.scheduler = load_scheduler(
        pipeline = pipeline, 
        scheduler_type = scheduler_type
    )

    pipeline = pipeline.to(device)

    pipeline.enable_model_cpu_offload()

    init_latent_seed_list = [
        0, 1, 0
    ]
    
    init_latent_list = [
        get_latent(
            shape = (4, 128, 128), 
            # shape = (16, 128, 128), 
            
            seed = seed, 

            device = device, 

            dtype = pipeline_torch_dtype
        ) \
            for seed in init_latent_seed_list
    ]
    init_latent_list = torch.stack(init_latent_list)

    prompt_list = [
        "A man on a boat crossing a hellish body of water with soul-like creatures swimming around.", 
        "A man on a boat crossing a hellish body of water with soul-like creatures swimming around.", 
        "A monster coming out of a cellphone screen.", 
        # "A monster coming out of a cellphone screen.", 
        # "An eye level counter-view shows blue tile, a faucet, dish scrubbers, bowls, a squirt bottle and similar kitchen items. ", 
        # "An eye level counter-view shows blue tile, a faucet, dish scrubbers, bowls, a squirt bottle and similar kitchen items. "
    ]
    negative_prompt_list = [
        "low quality, blurry, ugly, oversaturated"
    ] * len(prompt_list)

    # num_inference_step = 5
    num_inference_step = 15
    guidance_scale = 4.5

    max_sequence_length = 120

    # max_sequence_length = 256
    # max_sequence_length = 512

    img_pil_list = pipeline(
        height = 1024, 
        width = 1024, 

        latents = init_latent_list, 

        prompt = prompt_list, 
        negative_prompt = negative_prompt_list, 

        num_inference_steps = num_inference_step, 

        guidance_scale = guidance_scale, 

        max_sequence_length = max_sequence_length, 

        clean_caption = False
    ).images

    # save_png_root_path = Path("./tmp/test_sd_v3") / f"{max_sequence_length}"
    save_png_root_path = Path(f"./tmp/test/test_{pipeline_type_name}") / scheduler_type
    
    for img_pil_idx, img_pil in enumerate(img_pil_list):
        png_filename = f"{img_pil_idx}.png"

        save_pil_as_png(
            pil = img_pil, 

            png_root_path = save_png_root_path, 
            png_filename = png_filename
        )

    # breakpoint()

    # `test_pixart_alpha()` done
    pass


def add_optimized_prompt(
    cfg: DictConfig
):
    



    # `add_optimized_prompt()` done
    pass


# discarded
# def test_hydit(
#     cfg: DictConfig
# ):
#     import torch

#     from util.torch_util import get_latent
#     from task.util.seed_list_util import prepare_seed_list
#     from util.pipeline_util import (
#         load_pipeline, 
#         load_scheduler, 
#         get_inference_step_minus_one
#     )

#     from my_diffusers.scheduling_ddim import register_scheduling_ddim
#     # from my_diffusers.pipeline_hunyuandit import register_pipeline_hunyuandit
#     from my_diffusers.pipeline_stable_diffusion_xl import register_pipeline_stable_diffusion_xl

#     dtype = "float32"
#     device = "cuda"

#     pipeline = load_pipeline(
#         # pipeline_type = "HunyuanDiTPipeline", 
#         # pipeline_path = "/mnt/d/hytidel/model/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", 

#         pipeline_type = "StableDiffusionXLPipeline", 
#         pipeline_path = "/mnt/d/hytidel/model/stabilityai/stable-diffusion-xl-base-1.0", 

#         torch_dtype = dtype, 
#         variant = None
#     )

#     scheduler_type = "DDIMScheduler"
#     pipeline.scheduler = load_scheduler(
#         scheduler_type = scheduler_type, 
#         pipeline = pipeline
#     )

#     register_scheduling_ddim(
#         scheduler = pipeline.scheduler
#     )

#     # register_pipeline_hunyuandit(
#     register_pipeline_stable_diffusion_xl(
#         pipeline = pipeline
#     )

#     pipeline.to(device)

#     pipeline.enable_model_cpu_offload()

#     prompt = "a blue dog and a white cat"
#     nagative_prompt = ""
#     height, width = 1024, 1024
#     num_inference_step = 5
#     guidance_scale = 4.5

#     eta_list = torch.tensor(
#         [
#             [1.0] * num_inference_step
#         ], 

#         dtype = torch.float32, 
#         device = device
#     )

#     eps_seed_list = 3072
#     eps_seed_list = prepare_seed_list(
#         seed_list = eps_seed_list, 
#         target_length = num_inference_step, 

#         auto_inrement = True
#     )

#     latent_height, latent_width = 128, 128
#     eps_list = [
#         get_latent(
#             shape = (4, latent_height, latent_width), 
            
#             seed = seed, 

#             device = device, 

#             dtype = dtype
#         ) \
#             for seed in eps_seed_list
#     ]

#     save_inference_process_dict = {
#         "noise_pred": False, 
#         "latent": True, 
#         "pil": False
#     }

#     inference_step_minus_one = get_inference_step_minus_one(
#         scheduler_type = scheduler_type
#     )

#     tmp = pipeline.forward(
#         prompt = prompt, 

#         height = height, 
#         width = width, 

#         num_inference_steps = num_inference_step, 

#         guidance_scale = guidance_scale, 

#         nagative_prompt = nagative_prompt, 

#         num_images_per_prompt = 5, 

#         inference_step_minus_one = inference_step_minus_one, 

#         eta_list = eta_list, 
#         eps_list = eps_list, 

#         save_inference_process_dict = save_inference_process_dict
#     )
    
#     breakpoint()

#     # `test_hydit()` done
#     pass


def test_implement(
    cfg: DictConfig
):
    # ---------= [Global Variables] =---------
    logger(f"[Global Variables] Loading started. ")

    exp_name = get_global_variable("exp_name")
    device = get_global_variable("device")
    seed = get_global_variable("seed")

    logger(f"[Global Variables] Loading finished. ")

    # ---------= [Task] =---------
    # test_MDP(cfg)

    # test_MCTS(cfg)

    # test_controllability_gramian(cfg)

    # test_lqr(cfg)

    # test_batch_MCTS(cfg)
    
    # test_torch_MCTS(cfg)

    # test_mscoco_2014(cfg)

    # test_hpsv2_reward_model(cfg)

    # cal_avg_mpd_lpips(cfg)

    # test_polar(cfg)

    # test_hpsv2_prompt(cfg)

    # test_depth_reward(cfg)

    # test_mpd(cfg)

    # test_scaling_line_chart(cfg)

    # plot_scaling_line_chart(cfg)

    # plot_polar(cfg)

    # test_image_reward(cfg)

    # test_pick_score(cfg)

    # test_hydit(cfg)

    get_50_prompts_str(cfg)

    # test_csv(cfg)

    # test_mscoco(cfg)

    # test_pixart_alpha(cfg)

    # add_optimized_prompt(cfg)

    # `test_implement()` done
    pass

def test(
    cfg: DictConfig
):
    test_implement(cfg)

    pass
