from util.logger import logger

from typing import Optional, Union, List, Dict, Tuple

import numpy as np

import torch

from tqdm import tqdm

import time

import gc

from pathlib import Path

from types import MethodType

from util.torch_util import tsfm_to_1d_array
from util.image_util import save_pil_as_png
from util.yaml_util import (
    save_yaml, 
    convert_numpy_type_to_native_type
)
from util.pipeline_util import img_latent_to_pil

from OT_MCTS.src.monte_carlo_tree_search.mcts_node import MCTSNode
from OT_MCTS.src.monte_carlo_tree_search.mcts import MonteCarloTreeSearch

from OT_MCTS.src.unimodal_beta_distribution import UnimodalBetaDistribution
# from OT_MCTS.src.optimal_control.optimal_control_solver import OptimalControlSolver

from .mdp.diffusion_mdp import DiffusionMDP


# TODO: 参数加一组最优的默认值
class DiffusionOTMCTS(MonteCarloTreeSearch):
    def __init__(
        self, 

        is_eps_action: Optional[bool] = False, 

        mdp: DiffusionMDP = None, 

        init_state_list: Union[torch.Tensor, List[torch.Tensor]] = None, 

        # ---------= [Mode] =---------
        mdp_modeling: str = "max_reward", 
        value_policy: str = "max_reward", 
        pseudo_latent_as_final: Optional[bool] = False, 
        enable_pseudo_latent_as_final_depth: Optional[int] = 1, 

        # ---------= [Upper Confidence Bound (UCB)] =---------
        exploration_coef: float = None, 
        # depth_coef: float = None, 
        # exclude_last_intermediate_reward: bool = False, 

        # ---------= [Selection Policy] =---------
        selection_depth_lim: int = None, 

        # ---------= [Expansion Policy] =---------
        expansion_action_sampling_policy: str = "uniform",  # ["uniform", "beta"]
        # expansion_enable_importance_sampling: Optional[bool] = False,  
        # expansion_importance_sampling_J_star_scaling_factor: Optional[float] = None, 
        # expansion_importance_sampling_eps: Optional[float] = 1e-8, 
        # expansion_importance_sampling_verbose: Optional[bool] = False, 
        # num_per_iteration_selection: Optional[int] = None,  # discard
        # per_iteration_expansion_lim: Optional[int] = None, 

        # ---------= [Simulation Policy] =---------
        simulation_action_sampling_policy: str = None,  
        simulation_default_action_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 

        # ---------= [NFE Limit] =---------
        nfe_cal_dynamics_lim: int = None, 
        nfe_cal_intermediate_reward_lim: int = None, 
        nfe_cal_final_reward_lim: int = None, 

        # # ---------= [Optimal Control] =---------
        # optimal_control_online_update: Optional[bool] = True, 
        # optimal_control_update_reward_threshold: Optional[float] = 1e-8, 
        # optimal_control_omega_z: float = None, 
        # optimal_control_omega_eta: float = None, 
        # optimal_control_finite_difference_accuracy_order: Optional[str] = "SECOND",  # ["SECOND", "FOURTH", "SIXTH", "EIGHTH"]
        # optimal_control_finite_difference_eps: Optional[float] = 1e-8, 
        # optimal_control_force_positive_semi_definite_max_tolerance: Optional[float] = 1e-8, 
        # optimal_control_force_positive_definite_max_tolerance: Optional[float] = 1e-8,

        # ---------= [Beta Distribution Parameterization] =---------
        beta_online_update: Optional[bool] = True, 
        beta_update_policy: Optional[str] = "hard",  # ["hard", "soft", "value_gradient"]
        beta_value_gradient_update_time: Optional[str] = "best_trajectory_updated",  # ["best_trajectory_updated", "back_propagation"]
        beta_action_bias: Optional[float] = 1e-2, 
        beta_update_step_size: Optional[float] = 1e-4, 
        beta_max_update_bias: Optional[float] = 1e-3, 
        beta_zeta_list: Optional[Union[float, List[float]]] = None, 
        beta_default_zeta: Optional[float] = 10.0, 
        beta_update_reward_threshold: Optional[float] = 1e-8, 
        beta_clamp_eps: Optional[float] = 1e-8, 
        beta_direction_length_eps: Optional[float] = 1e-8, 

        display_beta_mode_update: Optional[bool] = False, 

        # ---------= [LRU Cache] =---------
        lru_cache: "LRUCache" = None, 

        # ---------= [Dtype] =---------
        dtype: Optional[str] = "float32", 

        # ---------= [Save Root Path] =---------
        expansion_policy_root_path: Optional[Union[str, Path]] = None, 
        folder_name_list: List[str] = None, 
        cfg_dict: Dict = None, 

        **arg_dict: Optional[Dict]
    ):
        # ---------= [Beta Distribution Parameterization 1] =---------
        # self._beta_parameterization = beta_parameterization

        if expansion_action_sampling_policy == "beta":
            self._unimodal_beta_distribution_list_list = [
                [] \
                    for sample_idx in range(len(init_state_list))
            ]

        # ---------= [LRU Cache] =---------
        self.lru_cache = lru_cache

        super().__init__(
            is_eps_action = is_eps_action, 

            mdp = mdp, 

            init_state_list = init_state_list, 

            # ---------= [Mode] =---------
            mdp_modeling = mdp_modeling, 
            value_policy = value_policy, 
            pseudo_latent_as_final = pseudo_latent_as_final, 
            enable_pseudo_latent_as_final_depth = enable_pseudo_latent_as_final_depth, 

            # ---------= [Upper Confidence Bound (UCB)] =---------
            exploration_coef = exploration_coef, 
            # depth_coef = depth_coef, 
            # exclude_last_intermediate_reward = exclude_last_intermediate_reward, 

            # ---------= [Selection Policy] =---------
            selection_depth_lim = selection_depth_lim, 

            # ---------= [Expansion Policy] =---------
            expansion_action_sampling_policy = expansion_action_sampling_policy, 
            # expansion_enable_importance_sampling = expansion_enable_importance_sampling, 
            # expansion_importance_sampling_J_star_scaling_factor = expansion_importance_sampling_J_star_scaling_factor, 
            # expansion_importance_sampling_eps = expansion_importance_sampling_eps, 
            # expansion_importance_sampling_verbose = expansion_importance_sampling_verbose, 
            # num_per_iteration_selection = num_per_iteration_selection, 
            # per_iteration_expansion_lim = per_iteration_expansion_lim, 

            # ---------= [Simulation Policy] =---------
            simulation_action_sampling_policy = simulation_action_sampling_policy, 
            simulation_default_action_list = simulation_default_action_list, 

            dtype = dtype, 

            **arg_dict
        )

        # ---------= [Best Trajectory Updated MCTS Loop Index List List] =---------
        self.best_trajectory_updated_mcts_loop_idx_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self.best_trajectory_updated_wall_clock_time_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self.best_trajectory_updated_nfe_cal_dynamics_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]
        self.best_trajectory_updated_nfe_cal_intermediate_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]
        self.best_trajectory_updated_nfe_cal_final_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        # ---------= [NFEs] =---------
        self._nfe_cal_dynamics_lim = nfe_cal_dynamics_lim
        self._nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim
        self._nfe_cal_final_reward_lim = nfe_cal_final_reward_lim

        self._nfe_cal_dynamics_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self._nfe_cal_intermediate_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self._nfe_cal_final_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        # # ---------= [Optimal Control] =---------
        # self.optimal_control_online_update = optimal_control_online_update

        # self.optimal_control_best_tot_reward = torch.tensor(
        #     float("-inf"), 

        #     dtype = self.dtype
        # )
        
        # self.optimal_control_best_tot_reward \
        #     = self.optimal_control_best_tot_reward.reshape(self.reward_shape)

        # # ---------= [Optimal Control Solver] =---------
        # self.optimal_control_omega_z = optimal_control_omega_z
        # self.optimal_control_omega_eta = optimal_control_omega_eta
        # self.optimal_control_finite_difference_accuracy_order = optimal_control_finite_difference_accuracy_order
        # self.optimal_control_finite_difference_eps = optimal_control_finite_difference_eps
        # self.optimal_control_force_positive_semi_definite_max_tolerance = optimal_control_force_positive_semi_definite_max_tolerance
        # self.optimal_control_force_positive_definite_max_tolerance = optimal_control_force_positive_definite_max_tolerance

        # self.optimal_control_solver = None

        # ---------= [Beta Distribution Parameterization 2] =---------
        self._beta_online_update = None
        self._beta_update_policy = None
        self._beta_value_gradient_update_time = None
        self._beta_action_bias = None
        self._beta_update_step_size = None
        self._beta_max_update_bias = None
        self._beta_default_zeta = None
        self._beta_update_reward_threshold = None
        self._beta_clamp_eps = None
        self._beta_direction_length_eps = None
        self._display_beta_mode_update = None
        self._beta_zeta_list = None

        if self.expansion_action_sampling_policy == "beta":
            self._beta_online_update = beta_online_update
            self._beta_update_policy = beta_update_policy
            if self._beta_update_policy == "value_gradient":
                self._beta_value_gradient_update_time = beta_value_gradient_update_time
            self._beta_action_bias = beta_action_bias
            self._beta_update_step_size = beta_update_step_size
            self._beta_max_update_bias = beta_max_update_bias
            self._beta_default_zeta = beta_default_zeta
            self._beta_update_reward_threshold = beta_update_reward_threshold
            self._beta_clamp_eps = beta_clamp_eps
            self._beta_direction_length_eps = beta_direction_length_eps
            self._display_beta_mode_update = display_beta_mode_update

            self._beta_zeta_list = self._prepare_beta_zeta_list(
                beta_zeta_list = beta_zeta_list
            )

            # if self.expansion_action_sampling_policy not in [
            #     "uniform", 
            #     "beta", 
            #     "optimal_control_beta"
            # ]:
            #     raise ValueError(
            #         f"Only support `self.expansion_action_sampling_policy = \"uniform\"` "
            #         f"or `self.expansion_action_sampling_policy = \"optimal_control_beta\"` "
            #         # f"or `self.expansion_action_sampling_policy = \"optimal_control_beta\"` "
            #         f"when using Beta distribution parameterization. "
            #     )
        # elif self.expansion_action_sampling_policy == "optimal_control_beta":
        #     raise ValueError(
        #         f"Must use Beta distribution parameterization "
        #         f"when `self.expansion_action_sampling_policy = \"optimal_control_beta\"`. "
        #     )

        # ---------= [Folder Root Path] =---------
        if isinstance(expansion_policy_root_path, str):
            expansion_policy_root_path = Path(expansion_policy_root_path)
        self._expansion_policy_root_path = expansion_policy_root_path

        self._folder_root_path_list = [
            expansion_policy_root_path / folder_name
                for folder_name in folder_name_list
        ]
        
        num_prompt = len(self._folder_root_path_list)
        self._num_sample_per_prompt = self.num_sample // num_prompt

        self._save_png_root_path_list = []
        self._save_action_list_root_path_list = []
        self._save_result_root_path_list = []

        for folder_root_path in self._folder_root_path_list:
            self._save_png_root_path_list.append(folder_root_path / "png")
            self._save_action_list_root_path_list.append(folder_root_path / "action_list")
            self._save_result_root_path_list.append(folder_root_path / "result")

            # goto `for folder_root_path`
            pass

        # ---------= [Save Task Cfg] =---------
        if self.mdp.optimized_prompt_list is None:
            self.mdp.optimized_prompt_list = [None] * num_prompt

        for (prompt, optimized_prompt, folder_root_path) in zip(
            self.mdp.prompt_list, 
            self.mdp.optimized_prompt_list, 

            self._folder_root_path_list
        ):
            cfg_dict["sample"]["prompt"] = prompt

            if optimized_prompt is not None:
                cfg_dict["sample"]["optimized_prompt"] = optimized_prompt

            save_yaml(
                cfg_dict, 

                yaml_root_path = folder_root_path, 
                yaml_filename = "cfg.yaml"
            )

            # goto `for prompt`
            pass

        # `__init__()` done
        pass

    
    def _get_nfe(
        self, 

        sample_idx: int
    ) -> Tuple[int, int, int]:
        """
        Func: 
            Get the NFEs of Sample `sample_idx`. 

        Ret: 
            The NFEs of Sample `sample_idx`. 
                `nfe_cal_dynamics` (`int`). 
                `nfe_cal_intermediate_reward` (`int`). 
                `nfe_cal_final_reward` (`int`). 
        """

        nfe_cal_dynamics = self.mdp.reward_model.nfe_cal_dynamics_list[sample_idx]
        nfe_cal_intermediate_reward = self.mdp.reward_model.nfe_cal_intermediate_reward_list[sample_idx]
        nfe_cal_final_reward = self.mdp.reward_model.nfe_cal_final_reward_list[sample_idx]

        # `_get_nfe()` done
        return (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        )


    def _get_is_time_to_stop(
        self, 

        sample_idx: int
    ) -> bool:
        """
        Func:
            Simulate the trajectory from the expanded node to terminal state. 

        Ret:
            `is_time_to_stop` (`bool`): Whether it is time to stop. 
        """

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        is_time_to_stop = (nfe_cal_dynamics >= self._nfe_cal_dynamics_lim) \
            or (nfe_cal_intermediate_reward >= self._nfe_cal_intermediate_reward_lim) \
            or (nfe_cal_final_reward >= self._nfe_cal_final_reward_lim)

        # `get_is_time_to_stop()` done
        return is_time_to_stop
    

    def _get_is_cost_legal(
        self, 

        sample_idx: int
    ) -> bool:
        """
        Func:
            Determine whether the cost of MCTS is legal. 

        Ret:
            `is_cost_legal` (`bool`): Whether the cost of MCTS is legal. 
        """

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        # `get_is_time_to_stop()` done
        is_cost_legal = (nfe_cal_dynamics <= self._nfe_cal_dynamics_lim) \
            and (nfe_cal_intermediate_reward <= self._nfe_cal_intermediate_reward_lim) \
            and (nfe_cal_final_reward <= self._nfe_cal_final_reward_lim)
        
        # `_get_is_cost_legal()` done
        return is_cost_legal


    @torch.no_grad()
    def _prepare_beta_zeta_list(
        self, 

        beta_zeta_list: Union[float, List[float]] = None
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Check the input `beta_zeta_list`, and ensure its length equal to `self.mdp.time_horizon`. 
            If `beta_zeta_list = None`, set it to a list of `self._beta_default_zeta`. 

        Ret:
            `beta_zeta_list` (`Union[torch.Tensor, np.ndarray]`): The list of the ζs within the beta distribution
                when using beta distribution parameterization. 
        """

        if beta_zeta_list is None:
            beta_zeta_list = torch.tensor(
                self._beta_default_zeta, 
                
                dtype = self.dtype, 
                device = self.device
            )

        beta_zeta_list = tsfm_to_1d_array(
            array = beta_zeta_list, 
            target_length = self.mdp.time_horizon, 

            dtype = self.dtype, 
            device = self.device
        )

        # `_prepare_beta_zeta_list()` done
        return beta_zeta_list


    # discarded
    # @torch.no_grad()
    # def _init_optimal_control_sovler(
    #     self, 

    #     sample_idx: int, 

    #     # ---------= [Cost Gamma] =---------
    #     omega_z: float, 
    #     omega_eta: float, 

    #     # ---------= [Finite Difference] =---------
    #     finite_difference_accuracy_order: Optional[str] = "SECOND",  # ["SECOND", "FOURTH", "SIXTH", "EIGHTH"]
    #     finite_difference_eps: Optional[float] = 1e-8, 

    #     # ---------= [Force Positive (Semi-)definite] =---------
    #     force_positive_semi_definite_max_tolerance: Optional[float] = 1e-8, 
    #     force_positive_definite_max_tolerance: Optional[float] = 1e-8
    # ):
    #     """
    #     Func:
    #         Initialize `self.optimal_control_solver`. 
    #     """

    #     if self.optimal_control_solver is not None:
    #         # ---------= [Clean Up] =---------
    #         del self.optimal_control_solver
    #         gc.collect()
    #         if self.ver == "torch":
    #             torch.cuda.empty_cache()

    #     self.optimal_control_solver = OptimalControlSolver(
    #         mdp = self.mdp, 

    #         # ---------= [Cost Gamma] =---------
    #         omega_z = omega_z, 
    #         omega_eta = omega_eta, 

    #         # ---------= [Finite Difference] =---------
    #         finite_difference_accuracy_order = finite_difference_accuracy_order, 
    #         finite_difference_eps = finite_difference_eps,

    #         # ---------= [Force Positive (Semi-)definite] =---------
    #         force_positive_semi_definite_max_tolerance = force_positive_semi_definite_max_tolerance, 
    #         force_positive_definite_max_tolerance = force_positive_definite_max_tolerance
    #     )

    #     # ---------= [Initialization] =---------
    #     action_list = [
    #         self.mdp.action_space.sample_uniform_element() \
    #             for _ in range(self.mdp.time_horizon)
    #     ]

    #     init_state = self.root.get_state(
    #         sample_idx_list = sample_idx
    #     )[0]
        
    #     (
    #         state_list, 
    #         reward_list
    #     ) = self.optimal_control_solver.init_everything_around_a_trajectory(
    #         init_state = init_state, 
    #         action_list = action_list
    #     )
    #     self.optimal_control_solver.update_cost_gamma()
    #     self.optimal_control_solver.update_dare_P()

    #     cated_reward_list = torch.cat(reward_list)
    #     tot_reward = torch.sum(cated_reward_list)
    #     self.optimal_control_best_tot_reward = tot_reward

    #     logger(f"Optimal control solver initialized. ")
        
    #     # clean up
    #     del state_list, reward_list
    #     del cated_reward_list
    #     gc.collect()
    #     torch.cuda.empty_cache()

    #     # `_init_optimal_control_sovler()` done
    #     pass


    # discarded
    # @torch.no_grad()
    # def _update_optimal_control_solver(
    #     self, 

    #     sample_idx: int
    # ):
    #     """
    #     Func:
    #         Update `self.optimal_control_solver`. 
    #     """
        
    #     init_state = self.root.get_state(
    #         sample_idx_list = sample_idx
    #     )[0]
    #     action_list = self.best_trajectory_list[sample_idx].action_list
        
    #     self.optimal_control_solver.update_everything_around_a_trajectory(
    #         init_state = init_state, 
    #         action_list = action_list
    #     ) 
    #     self.optimal_control_solver.update_cost_gamma()
    #     self.optimal_control_solver.update_dare_P()
        
    #     logger(f"Optimal control solver updated. ")

    #     # `_update_optimal_control_solver()` done
    #     pass


    # (discarded) serial ver
    # @torch.no_grad()
    # def _cal_value_gradient_list(
    #     self, 

    #     node: MCTSNode, 
    #     sample_idx: int, 

    #     best_trajectory: Optional["Trajectory"] = None, 

    #     state: Optional[torch.Tensor] = None, 
    #     action_list: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 
    #     merged_reward: Optional[torch.Tensor] = None
    # ) -> Union[
    #     List[torch.Tensor], 
    #     List[np.ndarray]
    # ]:
    #     """
    #     Func: 
    #         Compute value gradients from root to `node`. 

    #     Ret:
    #         `value_gradient_list` (`Union[List[torch.Tensor], List[np.ndarray]]`): The list of value gradients, from root to `node`. 
    #             value_gradient_list.shape = (node.depth, )
    #     """

    #     time_horizon = self.mdp.time_horizon
    #     beta_action_bias = self._beta_action_bias

    #     if state is None:
    #         state = best_trajectory.state_list[0]

    #     if action_list is None:
    #         action_list = best_trajectory.action_list
    #     num_action = len(action_list)

    #     if merged_reward is None:
    #         merged_reward = best_trajectory.merged_reward
    #     if not isinstance(merged_reward, float):
    #         merged_reward = merged_reward.item()

    #     def implement_single(
    #         bias_idx: int
    #     ) -> float:
    #         if self.ver == "torch":
    #             last_state = state.clone()
    #         else:
    #             last_state = state.copy()

    #         if self._mdp_modeling == "max_reward":
    #             new_merged_reward = float("-inf")
    #         elif self._mdp_modeling in [
    #             "sparse_reward", 
    #             "cumulative_reward"
    #         ]:
    #             new_merged_reward = 0.0
            
    #         for timestep_idx in range(num_action):
    #             action = action_list[timestep_idx]
    #             if timestep_idx == bias_idx:
    #                 action += beta_action_bias
                
    #             cur_state = self.mdp.batch_cal_dynamics(
    #                 sample_idx_list = [sample_idx], 

    #                 state_list = [last_state], 
    #                 action_list = [action], 

    #                 timestep_idx_list = [timestep_idx]
    #             )[0]

    #             intermediate_reward = None
    #             final_reward = None
                
    #             if timestep_idx < time_horizon - 1:
    #                 (
    #                     intermediate_reward_list, 
    #                     pseudo_final_latent_list
    #                 ) = self.mdp.batch_cal_intermediate_reward(
    #                     state_list = [cur_state], 
    #                     prev_action_list = [action], 

    #                     timestep_idx_list = [timestep_idx + 1], 

    #                     prev_latent_list = [last_state], 

    #                     sample_idx_list = [sample_idx]
    #                 )

    #                 intermediate_reward = intermediate_reward_list[0]

    #                 # ---------= [Clean Up] =---------
    #                 del intermediate_reward_list, pseudo_final_latent_list
    #             else:
    #                 final_reward = self.mdp.batch_cal_final_reward(
    #                     sample_idx_list = [sample_idx], 

    #                     state_list = [cur_state]
    #                 )[0]
                
    #             if self._mdp_modeling == "max_reward":
    #                 if timestep_idx < time_horizon - 1:
    #                     new_merged_reward = max(new_merged_reward, intermediate_reward)
    #                 else:
    #                     new_merged_reward = max(new_merged_reward, final_reward)
    #             elif self._mdp_modeling in [
    #                 "sparse_reward", 
    #                 "cumulative_reward"
    #             ]:
    #                 if timestep_idx < time_horizon - 1:
    #                     new_merged_reward = new_merged_reward + intermediate_reward
    #                 else:
    #                     new_merged_reward = new_merged_reward + final_reward

    #             last_state = cur_state

    #             # ---------= [Clean Up] =---------
    #             del action
    #             if timestep_idx < time_horizon - 1:
    #                 del intermediate_reward
    #             else:
    #                 del final_reward
    #             gc.collect()
    #             if self.ver == "torch":
    #                 torch.cuda.empty_cache()

    #             # goto `for timestep_idx`
    #             pass

    #         value_gradient = (new_merged_reward - merged_reward) / beta_action_bias
    #         # if not isinstance(value_gradient, float):
    #         #     value_gradient = value_gradient.item()

    #         # ---------= [Clean Up] =---------
    #         del last_state
    #         del cur_state
    #         gc.collect()
    #         if self.ver == "torch":
    #             torch.cuda.empty_cache()

    #         # `implement_single()` done
    #         return value_gradient

    #     value_gradient_list = [
    #         implement_single(bias_idx = depth) \
    #             for depth in range(node.depth)
    #     ]
        
    #     # ---------= [Clean Up] =---------
    #     del action_list
    #     gc.collect()
    #     if self.ver == "torch":
    #         torch.cuda.empty_cache()

    #     # `_cal_value_gradient_list()` done
    #     return value_gradient_list


    # discarded
    # @torch.no_grad()
    # def _batch_cal_value_gradient_list(
    #     self, 

    #     node: MCTSNode, 
    #     sample_idx: int, 

    #     best_trajectory: Optional["Trajectory"] = None, 

    #     state: Optional[torch.Tensor] = None, 
    #     action_list: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 
    #     merged_reward: Optional[torch.Tensor] = None
    # ) -> Union[
    #     List[torch.Tensor], 
    #     List[np.ndarray]
    # ]:
    #     """
    #     Func: 
    #         Compute value gradients from root to `node`. 

    #     Ret:
    #         `value_gradient_list` (`Union[List[torch.Tensor], List[np.ndarray]]`): The list of value gradients, from root to `node`. 
    #             value_gradient_list.shape = (node.depth, )
    #     """

    #     time_horizon = self.mdp.time_horizon
    #     beta_action_bias = self._beta_action_bias

    #     depth = node.depth

    #     if state is None:
    #         state = best_trajectory.state_list[0]

    #     if action_list is None:
    #         action_list = best_trajectory.action_list

    #     if not isinstance(action_list, torch.Tensor):
    #         action_list = torch.tensor(
    #             action_list, 

    #             dtype = self.dtype, 
    #             device = self.device
    #         )
        
    #     num_action = len(action_list)

    #     if merged_reward is None:
    #         merged_reward = best_trajectory.merged_reward
    #     if not isinstance(merged_reward, float):
    #         merged_reward = merged_reward.item()
        
    #     # last_state_list.shape = (depth, *state_shape)
    #     last_state_list = state.unsqueeze(0) \
    #         .repeat(depth, 1, 1, 1)
        
    #     sample_idx_list = [sample_idx] * depth

    #     if self._mdp_modeling == "max_reward":
    #         init_merged_reward = float("-inf")
    #     elif self._mdp_modeling in [
    #         "sparse_reward", 
    #         "cumulative_reward"
    #     ]:
    #         init_merged_reward = 0.0

    #     new_merged_reward_list = torch.tensor(
    #         [init_merged_reward] * depth, 

    #         dtype = self.dtype, 
    #         device = self.device
    #     )

    #     # new_merged_reward_list.shape = (depth, *reward_shape)
    #     new_merged_reward_list = new_merged_reward_list.reshape(
    #         (depth, *self.reward_shape)
    #     )

    #     action_list_list = action_list.unsqueeze(0) \
    #         .repeat(depth, 1)
        
    #     action_list_list += torch.eye(
    #         n = depth, m = num_action, 

    #         dtype = action_list_list.dtype, 
    #         device = action_list_list.device
    #     ) * beta_action_bias

    #     for timestep_idx in range(num_action):
    #         tmp_action_list = action_list_list[:, timestep_idx]
    #         timestep_idx_list = torch.tensor(
    #             [timestep_idx] * depth, 

    #             dtype = torch.int32, 
    #             device = "cpu"
    #         )

    #         cur_state_list = self.mdp.batch_cal_dynamics(
    #             sample_idx_list = sample_idx_list, 

    #             state_list = last_state_list, 
    #             action_list = tmp_action_list, 

    #             timestep_idx_list = timestep_idx_list
    #         )

    #         next_timestep_idx_list = timestep_idx_list + 1

    #         intermediate_reward_list = None
    #         final_reward_list = None

    #         if timestep_idx < time_horizon - 1:
    #             (
    #                 intermediate_reward_list, 
    #                 pseudo_final_latent_list
    #             ) = self.mdp.batch_cal_intermediate_reward(
    #                 state_list = cur_state_list, 
    #                 prev_action_list = tmp_action_list, 

    #                 timestep_idx_list = next_timestep_idx_list, 

    #                 prev_latent_list = last_state_list, 

    #                 sample_idx_list = sample_idx_list
    #             )

    #             # ---------= [Clean Up] =---------
    #             del pseudo_final_latent_list
    #         else:
    #             final_reward_list = self.mdp.batch_cal_final_reward(
    #                 sample_idx_list = sample_idx_list, 

    #                 state_list = cur_state_list
    #             )
            
    #         if self._mdp_modeling == "max_reward":
    #             if timestep_idx < time_horizon - 1:
    #                 new_merged_reward_list = torch.max(new_merged_reward_list, intermediate_reward_list)
    #             else:
    #                 new_merged_reward_list = torch.max(new_merged_reward_list, final_reward_list)
    #         elif self._mdp_modeling in [
    #             "sparse_reward", 
    #             "cumulative_reward"
    #         ]:
    #             if timestep_idx < time_horizon - 1:
    #                 new_merged_reward_list = new_merged_reward_list + intermediate_reward_list
    #             else:
    #                 new_merged_reward_list = new_merged_reward_list + final_reward_list

    #         last_state_list = cur_state_list

    #         # ---------= [Clean Up] =---------
    #         del tmp_action_list
    #         del timestep_idx_list, next_timestep_idx_list
    #         if timestep_idx < time_horizon - 1:
    #             del intermediate_reward_list
    #         else:
    #             del final_reward_list
    #         gc.collect()
    #         torch.cuda.empty_cache()

    #         # goto `for timestep_idx`
    #         pass
        
    #     value_gradient_list = (new_merged_reward_list - merged_reward) / beta_action_bias
        
    #     # ---------= [Clean Up] =---------
    #     del action_list
    #     del last_state_list
    #     del sample_idx_list
    #     del action_list_list
    #     del new_merged_reward_list
    #     gc.collect()
    #     torch.cuda.empty_cache()

    #     # `_batch_cal_value_gradient_list()` done
    #     return value_gradient_list

    # TODO
    # 去掉 VG
    def _backpropagate(
        self, 

        node: MCTSNode, 

        sample_idx: int, 

        reward: float
    ):
        """
        Func:
            Backpropagate the reward from the simulation to update the path. 
            Update the Beta policies of tree portion if `beta_update_policy = "value_gradient"`
                and `beta_value_gradient_update_time = "back_propagation"`. 
        """

        vg_update = False
        if (self._beta_update_policy == "value_gradient") \
            and (self._beta_value_gradient_update_time == "back_propagation"):
            vg_update = True
        
        action_list = None

        if vg_update:
            depth_lim = node.depth

            beta_max_update_bias = self._beta_max_update_bias
            beta_update_step_size = self._beta_update_step_size

            init_state = self.root.get_state(
                sample_idx_list = [sample_idx]
            )[0]

            action_list = []
            
            tmp_node = node
            while tmp_node.parent:
                action_list.append(
                    tmp_node.info_list[sample_idx].prev_action
                )

                tmp_node = tmp_node.parent

                # goto `while tmp_node`
                
            action_list.reverse()

            merged_reward = node.info_list[sample_idx].merged_reward_to_root

            value_gradient_list = self._batch_cal_value_gradient_list(
                node = node, 
                sample_idx = sample_idx, 

                state = init_state, 
                action_list = action_list, 
                merged_reward = merged_reward
            )

            value_gradient_idx = -1

            # --------= [Clean Up] =---------
            del init_state
            gc.collect()
            torch.cuda.empty_cache()

        while node:
            info = node.info_list[sample_idx]
            info.num_vis += 1

            last_value = info.value
            
            if self._value_policy == "max":
                info.value = max(last_value, reward)
            elif self._value_policy == "average":
                info.value = last_value + (reward - info.value) / info.num_vis

            # ---------= [update Beta Policy] =---------
            if vg_update and (node.depth < depth_lim):
                node_idx = node.node_idx
                timestep_idx = node.depth

                beta_distribution = self._unimodal_beta_distribution_list_list[sample_idx][node_idx]

                init_mode = beta_distribution.init_mode
                last_mode = beta_distribution.get_mode()

                value_gradient = value_gradient_list[value_gradient_idx]
                value_gradient_idx -= 1

                value_gradient = torch.clip(
                    value_gradient, 
                    -beta_update_step_size, beta_update_step_size
                )
                
                if last_mode is not None:
                    target_mode = last_mode + value_gradient
                else:
                    target_mode = action_list[timestep_idx] + value_gradient

                if init_mode is not None:
                    target_mode = torch.clip(
                        target_mode, 
                        init_mode - beta_max_update_bias, init_mode + beta_max_update_bias
                    )

                beta_distribution.update_with_mode(
                    mode = target_mode,  
                    zeta = self._beta_zeta_list[timestep_idx], 

                    clamp_eps = self._beta_clamp_eps
                )

                beta_distribution.initialized = True

                if self._display_beta_mode_update:
                    target_mode_scalar = target_mode.item()

                    if last_mode is None:
                        logger(
                            f"[Sample {sample_idx}, Node {node_idx}] "
                            f"Beta mode initialized to [{target_mode_scalar:.4f}]"
                        )
                    else:
                        logger(
                            f"[Sample {sample_idx}, Node {node_idx}] "
                            f"Beta mode: [{last_mode:.4f}] -> [{target_mode_scalar:.4f}]"
                        )

            node = node.parent

            # goto `while node`
            pass

        # ---------= [Clean Up] =---------
        del action_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_backpropagate()` done
        pass


    @torch.no_grad()
    def _update_beta_to_root(
        self, 

        node: MCTSNode, 
        sample_idx: int, 

        best_trajectory: "Trajectory"
    ):
        """
        Func:
            The callback function after the best trajectory is updated. 
        """
        
        beta_max_update_bias = self._beta_max_update_bias

        vg_update = False
        if self._beta_update_policy == "value_gradient":
            if self._beta_value_gradient_update_time == "best_trajectory_updated":
                vg_update = True
            else:  # back_propagation
                return

        if vg_update:
            value_gradient_list = self._batch_cal_value_gradient_list(
                node = node, 
                sample_idx = sample_idx, 

                best_trajectory = best_trajectory
            )

            value_gradient_idx = -1
        
        while node.parent:
            node = node.parent

            node_idx = node.node_idx
            timestep_idx = node.depth

            beta_distribution = self._unimodal_beta_distribution_list_list[sample_idx][node_idx]

            init_mode = beta_distribution.init_mode
            last_mode = beta_distribution.get_mode()

            best_action = best_trajectory.action_list[timestep_idx]
            beta_update_step_size = self._beta_update_step_size

            if (last_mode is None) or (self._beta_update_policy == "hard"):
                target_mode = best_action
            elif self._beta_update_policy == "soft":
                direction = best_action - last_mode
                direction_length = abs(direction)

                if direction_length < self._beta_direction_length_eps:
                    target_mode = last_mode
                else:
                    direction = direction / direction_length

                    target_mode = last_mode + direction * beta_update_step_size
                
                # dbg
                # best_action_scalar = best_action.item()
                # print(f"best_action: {best_action_scalar:.4f}, last_mode: {last_mode:.4f}, target_mode: {target_mode}")
                # breakpoint()
                # last_mode_str = f"{last_mode:.4f}"
                # target_mode_str = f"{target_mode.item():.4f}"
                # if last_mode_str != target_mode_str:
                #     print(f"last_mode: {last_mode_str}, target_mode: {target_mode_str}")
                #     breakpoint()
            elif vg_update:
                value_gradient = value_gradient_list[value_gradient_idx]
                value_gradient_idx -= 1

                value_gradient = torch.clip(
                    value_gradient, 
                    -beta_update_step_size, beta_update_step_size
                )

                target_mode = last_mode + value_gradient

            # if (last_mode is not None) \
            #     and (self._beta_max_update_bias is not None):

            #     if self.ver == "torch":
            #         target_mode = torch.clip(
            #             target_mode, 
            #             last_mode - beta_max_update_bias, last_mode + beta_max_update_bias
            #         )
            #     elif self.ver == "numpy":
            #         target_mode = np.clip(
            #             target_mode, 
            #             last_mode - beta_max_update_bias, last_mode + beta_max_update_bias
            #         )
            
            if not isinstance(target_mode, torch.Tensor):
                target_mode = torch.tensor(
                    target_mode, 

                    dtype = self.dtype, 
                    device = self.device
                )

            if init_mode is not None:
                target_mode = torch.clip(
                    target_mode, 
                    init_mode - beta_max_update_bias, init_mode + beta_max_update_bias
                )
            
            beta_distribution.update_with_mode(
                mode = target_mode,  
                zeta = self._beta_zeta_list[timestep_idx], 

                clamp_eps = self._beta_clamp_eps
            )

            beta_distribution.initialized = True

            if self._display_beta_mode_update:
                best_action_scalar = best_action.item()
                target_mode_scalar = target_mode.item()

                if last_mode is None:
                    logger(
                        f"[Sample {sample_idx}, Node {node_idx}] "
                        f"best_action: [{best_action_scalar:.4f}], Beta mode initialized to [{target_mode_scalar:.4f}]"
                    )
                else:
                    logger(
                        f"[Sample {sample_idx}, Node {node_idx}] "
                        f"best_action: [{best_action_scalar:.4f}], Beta mode: [{last_mode:.4f}] -> [{target_mode_scalar:.4f}]"
                    )
        
            # goto `while node.parent`
            pass

        # ---------= [Clean Up] =---------
        if vg_update:
            del value_gradient_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_update_beta_to_root()` done
        pass
    

    @torch.no_grad()
    def _best_trajectory_updated_callback(
        self, 

        caller: str, 
        pseudo: bool, 

        node: MCTSNode, 

        sample_idx: int, 

        timestep_idx: int
    ):
        """
        Func:
            The callback function after the best trajectory is updated. 
        """
        
        best_merged_reward_list = self.history_best_merged_reward_list_list[sample_idx]

        second_last_merged_reward = best_merged_reward_list[-2]
        if not isinstance(second_last_merged_reward, float):
            second_last_merged_reward = second_last_merged_reward.item()
            
        last_merged_reward = best_merged_reward_list[-1]
        if not isinstance(last_merged_reward, float):
            last_merged_reward = last_merged_reward.item()

        last_final_reward_list = self.history_last_final_reward_list_list[sample_idx]

        second_last_final_reward = last_final_reward_list[-2]
        if not isinstance(second_last_final_reward, float):
            second_last_final_reward = second_last_final_reward.item()

        last_final_reward = last_final_reward_list[-1]
        if not isinstance(last_final_reward, float):
            last_final_reward = last_final_reward.item()

        logger(
            f"[Sample {sample_idx}] Best trajectory updated by `{caller}` at node {node.node_idx} for sample {sample_idx}. "
        )
        logger(
            f"    merged_reward: [{second_last_merged_reward:.4f}] -> [{last_merged_reward:.4f}], "
            f"final_reward: [{second_last_final_reward:.4f}] -> [{last_final_reward:.4f}]"
        )

        self.best_trajectory_list[sample_idx].update_accumulated_reward_list(
            mdp_modeling = self._mdp_modeling
        )

        # ---------= [Record Best Trajectory Updated] =---------
        mcts_loop_idx = self._mcts_loop_idx_list[sample_idx]

        self.best_trajectory_updated_mcts_loop_idx_list_list[sample_idx].append(mcts_loop_idx)

        time_ed = time.time()
        time_cost = time_ed - self._time_st

        self.best_trajectory_updated_wall_clock_time_list_list[sample_idx].append(
            time_cost
        )

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        self.best_trajectory_updated_nfe_cal_dynamics_list_list[sample_idx].append(nfe_cal_dynamics)
        self.best_trajectory_updated_nfe_cal_intermediate_reward_list_list[sample_idx].append(nfe_cal_intermediate_reward)
        self.best_trajectory_updated_nfe_cal_final_reward_list_list[sample_idx].append(nfe_cal_final_reward)

        # dbg
        # print(f"[Sample {sample_idx}]")
        # print(f"    reward_list: {best_trajectory.reward_list}")
        # print(f"    accumulated_reward_list: {best_trajectory.accumulated_reward_list}")

        # ---------= [Update Beta Distribution Parameterization] =---------
        best_trajectory = self.best_trajectory_list[sample_idx]

        if (self.expansion_action_sampling_policy == "beta") \
            and (self._beta_online_update) \
            and (last_merged_reward - second_last_merged_reward >= self._beta_update_reward_threshold):

            # for unimodal_beta_distribution in self._unimodal_beta_distribution_list_list[sample_idx]:
            #     unimodal_beta_distribution.initialized = False

            #     # goto `for unimodal_beta_distribution`
            #     pass

            self._update_beta_to_root(
                node = node, 
                sample_idx = sample_idx, 

                best_trajectory = best_trajectory
            )

        # # ---------= [Update Optimal Control Solver] =---------
        # if (self.expansion_action_sampling_policy == "optimal_control") \
        #     and self.optimal_control_online_update \
        #     and (last_tot_reward - second_last_tot_reward >= self.optimal_control_update_reward_threshold):
            
        #     self._update_optimal_control_solver(
        #         sample_idx = sample_idx
        #     )

        # ---------= [Save Trajectory] =---------
        action_list = best_trajectory.action_list
        action_list_dict = {
            "mcts_loop_idx": mcts_loop_idx, 

            "time_cost": time_cost, 

            "nfe_cal_dynamics": nfe_cal_dynamics, 
            "nfe_cal_intermediate_reward": nfe_cal_intermediate_reward, 
            "nfe_cal_final_reward": nfe_cal_final_reward, 
            
            "caller": caller, 
            "pseudo": pseudo, 

            "timestep_idx": timestep_idx, 

            "action_list": action_list, 

            "merged_reward": last_merged_reward, 
            "final_reward": last_final_reward
        }
        action_list_dict = convert_numpy_type_to_native_type(action_list_dict)
        
        prompt_idx = sample_idx // self._num_sample_per_prompt
        true_sample_idx = sample_idx % self._num_sample_per_prompt
        img_pil_idx = len(self.history_best_merged_reward_list_list[sample_idx]) - 2

        # dbg
        # breakpoint()
        # print(prompt_idx, true_sample_idx, img_pil_idx, f"{last_final_reward:.4f}")

        save_yaml(
            action_list_dict, 

            yaml_root_path = self._save_action_list_root_path_list[prompt_idx] / f"{true_sample_idx}", 
            yaml_filename = f"{true_sample_idx}_{img_pil_idx}.yaml"
        )

        # ---------= [Save Sample] =---------
        # img_latent_list.shape = (1, latent_num_channel, latent_height, latent_width)
        img_latent_list = best_trajectory.state_list[-1]
        if len(img_latent_list.shape) < 4:
            img_latent_list = img_latent_list.unsqueeze(0)
        
        img_pil = img_latent_to_pil(
            img_latent_list = img_latent_list, 
            pipeline = self.mdp.pipeline
        )[0]

        save_pil_as_png(
            pil = img_pil, 

            png_root_path = self._save_png_root_path_list[prompt_idx] / f"{true_sample_idx}", 
            png_filename = f"{true_sample_idx}_{img_pil_idx}.png"
        )

        # `_best_trajectory_updated_callback()` done
        pass


    # @torch.no_grad()
    # def _get_root(
    #     self, 

    #     init_state_list: Union[
    #         torch.Tensor, List[torch.Tensor], 
    #         np.ndarray, List[np.ndarray]
    #     ]
    # ) -> MCTSNode:
    #     """
    #     Func:
    #         Get the root node. 

    #     Ret:
    #         `root` (`MCTSNode`): The root node. 
    #     """

    #     sample_idx_list = list(
    #         range(self.num_sample)
    #     )

    #     root = self._get_new_node(
    #         state_list = init_state_list, 
    #         sample_idx_list = sample_idx_list
    #     )

    #     for sample_idx in sample_idx_list:
    #         root.info_list[sample_idx].potential = 0
            
    #         # goto `for sample_idx`
    #         pass

    #     # `_get_root()` done
    #     return root


    @torch.no_grad()
    def _get_new_node(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ], 

        prev_action_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        parent: MCTSNode = None
    ) -> MCTSNode:
        """
        Func:
            Get a new node and compute its intermediate reward. 
            If terminal, compute its final reward, too. 

        Ret:
            `new_node` (`MCTSNode`): The created node. 
        """

        # ---------= [Prepare `state_list`] =---------
        if isinstance(state_list, list):
            state_list = torch.stack(state_list)

        state_list = state_list.clone()

        # ---------= [New Node] =---------
        new_node = MCTSNode(
            mcts_instance = self, 

            sample_idx_list = sample_idx_list, 

            state_list = state_list, 

            prev_action_list = prev_action_list, 

            parent = parent, 

            device = self.device
        )

        # # ---------= [Cal Reward] =---------
        # num_sample = len(sample_idx_list)

        # # reward_sum_to_root_list.shape = (num_sample, )
        # reward_sum_to_root_list = torch.zeros(
        #     (num_sample, ), 

        #     dtype = self.dtype, 
        #     device = "cpu"
        # )

        # if parent:
        #     for i, sample_idx in enumerate(sample_idx_list):
        #         parent_info = parent.info_list[sample_idx]

        #         if parent_info is not None:
        #             reward_sum_to_root_list[i] = parent_info.reward_sum_to_root

        #         # goto `for i, sample_idx`
        #         pass

        # depth = new_node.depth

        # timestep_idx_list = tsfm_to_1d_array(
        #     array = depth, 
        #     target_length = num_sample, 

        #     dtype = torch.int32, 
        #     device = "cpu"
        # )
        
        # if depth > 0:
        #     intermediate_reward_list = self._cal_intermediate_reward(
        #         node = new_node, 

        #         sample_idx_list = sample_idx, 

        #         timestep_idx_list = timestep_idx_list
        #     )[0]

        #     reward_sum_to_root_list.append(
        #         intermediate_reward_list.item()
        #     )

        #     # ---------= [Clean Up] =---------
        #     del intermediate_reward_list
        #     gc.collect()
        #     torch.cuda.empty_cache()

        # for i, sample_idx in enumerate(sample_idx_list):
        #     new_node.info_list[sample_idx].reward_sum_to_root = reward_sum_to_root_list[i].item()

        #     # goto `for i, sample_idx`
        #     pass

        # if depth == self.mdp.time_horizon:
        #     final_reward_list = self._cal_final_reward(
        #         node = new_node, 

        #         sample_idx_list = sample_idx
        #     )

        #     for i, sample_idx in enumerate(sample_idx_list):
        #         new_node.info_list[sample_idx].final_reward = final_reward_list[i].item()

        #         # goto `for i, sample_idx`
        #         pass
            
        #     # ---------= [Clean Up] =---------
        #     del final_reward_list
        #     gc.collect()
        #     torch.cuda.empty_cache()
        
        # ---------= [Beta Distribution Parameterization] =---------
        if self.expansion_action_sampling_policy == "beta":
            for sample_idx in range(self.num_sample):
                unimodal_beta_distribution = UnimodalBetaDistribution(
                    dtype = self.dtype, 
                    device = self.device, 

                    ver = self.ver
                )

                self._unimodal_beta_distribution_list_list[sample_idx].append(unimodal_beta_distribution)

                # goto `for sample_idx`
                pass
        
        # ---------= [LRU] =---------
        def _get_state(
            self, 

            sample_idx_list: Union[int, List[int]] = None
        ) -> Union[List[torch.Tensor], List[np.ndarray]]:
            """
            Func:
                Get the `state` of the node. 
                Ensure on GPU if it is `torch.Tensor`. 

            Ret:
                `state_list` (`List[torch.Tensor]` or `List[np.ndarray]`): The list of `state`s of the node.
                    (`sample_idx_list` is provided) state_list.shape = (len(sample_idx_list), ). 
                    (`sample_idx_list` is not provided) state_list.shape = (num_sample, ). 
            """

            # ---------= [Prepare `sample_idx_list`] =---------
            if sample_idx_list is None:
                sample_idx_list = list(
                    range(self.mcts_instance.num_sample)
                )
            
            if not isinstance(sample_idx_list, list):
                sample_idx_list = [sample_idx_list]

            # ---------= [Get State] =---------
            state_list = []
            
            for sample_idx in sample_idx_list:
                info = self.info_list[sample_idx]

                if info is None:
                    state_list.append(None)
                    
                    continue

                self.mcts_instance.lru_cache.access(
                    self, 
                    sample_idx = sample_idx
                )
    
                state = info.get_state()
                state_list.append(state)
                
                # goto `for sample_idx`
                pass

            # `state()` done
            return state_list


        new_node.get_state = MethodType(
            _get_state, 
            new_node
        )
        
        # `_get_new_node()` done
        return new_node


    # @torch.no_grad()
    # def _run_discarded(
    #     self, 
    
    #     sample_idx: int, 

    #     display_state_value_list: Optional[bool] = False, 

    #     display_reward_sum_to_leaf: Optional[bool] = False, 

    #     **arg_dict: Optional[Dict]
    # ):
    #     """
    #     Func:
    #         Run MCTS for a sample until stopped. 
    #     """

    #     display_cal_state_value = arg_dict.pop("display_cal_state_value", False)
    #     display_selected_node_depth = arg_dict.pop("display_selected_node_depth", False)

    #     nfe_cal_dynamics_lim = self._nfe_cal_dynamics_lim
    #     nfe_cal_intermediate_reward_lim = self._nfe_cal_intermediate_reward_lim
    #     nfe_cal_final_reward_lim = self._nfe_cal_final_reward_lim

    #     with tqdm(total = nfe_cal_dynamics_lim, desc = "[NFE Cal Dynamics]", position = 0) as bar_1, \
    #         tqdm(total = nfe_cal_intermediate_reward_lim, desc = "[NFE Cal Intermediate Reward]", position = 1) as bar_2, \
    #         tqdm(total = nfe_cal_final_reward_lim, desc = "[NFE Cal Final Reward]", position = 2) as bar_3:
    
    #         # ---------= [Last NFEs] =---------
    #         last_nfe_cal_dynamics = self.mdp.get_nfe_cal_dynamics()
    #         last_nfe_cal_intermediate_reward = self.mdp.get_nfe_cal_intermediate_reward()
    #         last_nfe_cal_final_reward = self.mdp.get_nfe_cal_final_reward()

    #         while not self._get_is_time_to_stop():
    #             # ---------= [Run MCTS] =---------
    #             node_list = self._select(
    #                 sample_idx = sample_idx, 

    #                 display_state_value_list = display_state_value_list, 

    #                 display_cal_state_value = display_cal_state_value
    #             )

    #             if display_selected_node_depth:
    #                 mcts_loop_idx = self._mcts_loop_idx_list[sample_idx]

    #                 timestep_idx_list = [
    #                     node.depth \
    #                         for node in node_list
    #                 ]

    #                 logger(
    #                     f"[MCTS Loop {mcts_loop_idx}, Sample {sample_idx}] "
    #                     f"selected node timestep_idx_list: {timestep_idx_list}"
    #                 )

    #                 # ---------= [Clean Up] =---------
    #                 del timestep_idx_list
    #                 gc.collect()
                
    #             expanded_node_list = self._expand(
    #                 node_list = node_list, 

    #                 sample_idx = sample_idx
    #             )

    #             if expanded_node_list is None:
    #                 continue
                
    #             reward_sum_to_leaf_list = self._simulate(
    #                 node_list = expanded_node_list, 

    #                 sample_idx = sample_idx
    #             )

    #             if display_reward_sum_to_leaf:
    #                 logger(f"reward_sum_to_leaf_list: {reward_sum_to_leaf_list}")

    #             for expanded_node, reward_sum_to_leaf in zip(
    #                 expanded_node_list, 
    #                 reward_sum_to_leaf_list
    #             ):
    #                 self._backpropagate(
    #                     node = expanded_node, 
                        
    #                     sample_idx = sample_idx, 

    #                     reward = reward_sum_to_leaf
    #                 )

    #                 # goto `for expanded_node, reward_sum_to_leaf`
    #                 pass
            
    #             # ---------= [Save Current NFEs] =---------
    #             nfe_cal_dynamics = self.mdp.get_nfe_cal_dynamics()
    #             nfe_cal_intermediate_reward = self.mdp.get_nfe_cal_intermediate_reward()
    #             nfe_cal_final_reward = self.mdp.get_nfe_cal_final_reward()

    #             self._nfe_cal_dynamics_list_list[sample_idx].append(nfe_cal_dynamics)
    #             self._nfe_cal_intermediate_reward_list_list[sample_idx].append(nfe_cal_intermediate_reward)
    #             self._nfe_cal_final_reward_list_list[sample_idx].append(nfe_cal_final_reward)

    #             # ---------= [Update MCTS Loop Index] =---------
    #             self._mcts_loop_idx_list[sample_idx] += 1

    #             # ---------= [Cal Delta] =---------
    #             delta_nfe_cal_dynamics = nfe_cal_dynamics - last_nfe_cal_dynamics
    #             delta_nfe_cal_intermediate_reward = nfe_cal_intermediate_reward - last_nfe_cal_intermediate_reward
    #             delta_nfe_cal_final_reward = nfe_cal_final_reward - last_nfe_cal_final_reward

    #             # ---------= [Update Last NFEs] =---------
    #             last_nfe_cal_dynamics = nfe_cal_dynamics
    #             last_nfe_cal_intermediate_reward = nfe_cal_intermediate_reward
    #             last_nfe_cal_final_reward = nfe_cal_final_reward

    #             # ---------= [Update Tqdm] =---------
    #             bar_1.update(delta_nfe_cal_dynamics)
    #             bar_2.update(delta_nfe_cal_intermediate_reward)
    #             bar_3.update(delta_nfe_cal_final_reward)

    #             # ---------= [Display `num_gpu_resident`] =---------
    #             num_gpu_resident = self.lru_cache.num_gpu_resident

    #             logger(f"num_gpu_resident: {num_gpu_resident}")

    #             # goto `while not self._get_is_time_to_stop()`
    #             pass

    #     # `_run()` done
    #     pass


    @torch.no_grad()
    def _run_sample_pre_process(
        self, 

        **arg_dict
    ):
        """
        Func:
            The function called before starting a run of MCTS for a single sample. 
        """

        self._time_st = time.time()

        # # ---------= [Initialize `self.optimal_control_solver`] =---------
        # if self.expansion_action_sampling_policy == "optimal_control":
        #     self._init_optimal_control_sovler(
        #         sample_idx = sample_idx, 

        #         # ---------= [Cost Gamma] =---------
        #         omega_z = self.optimal_control_omega_z, 
        #         omega_eta = self.optimal_control_omega_eta, 

        #         # ---------= [Finite Difference] =---------
        #         finite_difference_accuracy_order = self.optimal_control_finite_difference_accuracy_order, 
        #         finite_difference_eps = self.optimal_control_finite_difference_eps, 

        #         # ---------= [Force Positive (Semi-)definite] =---------
        #         force_positive_semi_definite_max_tolerance = self.optimal_control_force_positive_definite_max_tolerance, 
        #         force_positive_definite_max_tolerance = self.optimal_control_force_positive_semi_definite_max_tolerance
        #     )

        # `_run_sample_pre_process()` done
        pass


    @torch.no_grad()
    def _run_sample_post_process(
        self, 

        sample_idx: int, 

        **arg_dict
    ):
        """
        Func:
            The function called after finishing a run of MCTS for a single sample. 
        """

        # ---------= [Compute Wall-clock Time Cost] =---------
        time_ed = time.time()

        time_cost = time_ed - self._time_st
        self._wall_clock_time_cost_list[sample_idx] = time_cost

        logger(
            f"[Sample {sample_idx}] Finished, wall-clock time cost: {round(time_cost)} second(s). "
        )

        # ---------= [Total Cost] =---------
        mcts_loop_idx = self._mcts_loop_idx_list[sample_idx]
        if not self._get_is_cost_legal(sample_idx = sample_idx):
            mcts_loop_idx -= 1

        nfe_cal_dynamics_list = self._nfe_cal_dynamics_list_list[sample_idx]
        nfe_cal_intermediate_reward_list = self._nfe_cal_intermediate_reward_list_list[sample_idx]
        nfe_cal_final_reward_list = self._nfe_cal_final_reward_list_list[sample_idx]

        # ---------= [Best Trajectory Updated] =---------
        best_trajectory_updated_mcts_loop_idx_list \
            = self.best_trajectory_updated_mcts_loop_idx_list_list[sample_idx]
        
        best_trajectory_updated_wall_clock_time_list \
            = self.best_trajectory_updated_wall_clock_time_list_list[sample_idx]

        best_trajectory_updated_nfe_cal_dynamics_list \
            = self.best_trajectory_updated_nfe_cal_dynamics_list_list[sample_idx]
        best_trajectory_updated_nfe_cal_intermediate_reward_list \
            = self.best_trajectory_updated_nfe_cal_intermediate_reward_list_list[sample_idx]
        best_trajectory_updated_nfe_cal_final_reward_list \
            = self.best_trajectory_updated_nfe_cal_final_reward_list_list[sample_idx]

        # ---------= [Merged Reward] =---------
        best_merged_reward_list = self.history_best_merged_reward_list_list[sample_idx][1: ]
        best_merged_reward_list = [
            best_merged_reward if isinstance(best_merged_reward, float) \
                else best_merged_reward.item() \
                    for best_merged_reward in best_merged_reward_list
        ]

        # ---------= [Last Final Reward] =---------
        last_final_reward_list = self.history_last_final_reward_list_list[sample_idx][1: ]
        last_final_reward_list = [
            last_final_reward if isinstance(last_final_reward, float) \
                else last_final_reward.item() \
                    for last_final_reward in last_final_reward_list
        ]

        # ---------= [Best Final Reward] =---------
        best_final_reward = np.max(last_final_reward_list)

        # `best_final_reward_updated_idx` is the index in `last_final_reward_list`, i.e., `num_updated`
        best_final_reward_updated_idx = np.argmax(last_final_reward_list)

        # best_final_reward_merged_reward = best_merged_reward_list[best_final_reward_updated_idx]

        # best_final_reward_mcts_loop_idx \
        #     = best_trajectory_updated_mcts_loop_idx_list[best_final_reward_updated_idx]

        # ---------= [Misc] =---------
        mcts_loop_selected_node_tuple_list = self.mcts_loop_selected_node_tuple_list_list[sample_idx]

        mcts_loop_wall_time_cost_list = self._mcts_loop_wall_time_cost_list

        # ---------= [Save Result] =---------
        result_dict = {
            "mcts_loop_idx": mcts_loop_idx, 

            "wall_clock_time_cost": time_cost, 

            # ---------= [NFE] =---------
            "nfe_cal_dynamics_list": nfe_cal_dynamics_list, 
            "nfe_cal_intermediate_reward_list": nfe_cal_intermediate_reward_list, 
            "nfe_cal_final_reward_list": nfe_cal_final_reward_list, 

            # ---------= [Best Trajectory Updated] =---------
            "best_trajectory_updated_mcts_loop_idx_list": best_trajectory_updated_mcts_loop_idx_list, 
            "best_trajectory_updated_wall_clock_time_list": best_trajectory_updated_wall_clock_time_list, 

            "best_trajectory_updated_nfe_cal_dynamics_list": best_trajectory_updated_nfe_cal_dynamics_list, 
            "best_trajectory_updated_nfe_cal_intermediate_reward_list": best_trajectory_updated_nfe_cal_intermediate_reward_list, 
            "best_trajectory_updated_nfe_cal_final_reward_list": best_trajectory_updated_nfe_cal_final_reward_list, 
            
            # ---------= [Merged Reward] =---------
            "best_merged_reward_list": best_merged_reward_list, 

            # ---------= [Last Final Reward] =---------
            "last_final_reward_list": last_final_reward_list, 

            # ---------= [Best Final Reward] =---------
            "best_final_reward": best_final_reward, 
            "best_final_reward_updated_idx": best_final_reward_updated_idx, 

            # ---------= [Misc] =---------
            "mcts_loop_selected_node_tuple_list": mcts_loop_selected_node_tuple_list, 

            "mcts_loop_wall_time_cost_list": mcts_loop_wall_time_cost_list
        }
        result_dict = convert_numpy_type_to_native_type(result_dict)

        prompt_idx = sample_idx // self._num_sample_per_prompt
        true_sample_idx = sample_idx % self._num_sample_per_prompt

        save_yaml(
            result_dict, 

            yaml_root_path = self._save_result_root_path_list[prompt_idx], 
            yaml_filename = f"{true_sample_idx}.yaml"
        )

        # ---------= [Clean Up] =---------
        del nfe_cal_dynamics_list, nfe_cal_intermediate_reward_list, nfe_cal_final_reward_list
        del best_trajectory_updated_mcts_loop_idx_list
        del best_trajectory_updated_wall_clock_time_list
        del best_trajectory_updated_nfe_cal_dynamics_list, best_trajectory_updated_nfe_cal_intermediate_reward_list, best_trajectory_updated_nfe_cal_final_reward_list
        del best_merged_reward_list
        del last_final_reward_list
        del mcts_loop_selected_node_tuple_list
        del mcts_loop_wall_time_cost_list
        gc.collect()

        # `_run_sample_post_process()` done
        pass


    @torch.no_grad()
    def _mcts_loop_callback(
        self, 

        local_var_dict: Dict, 

        **arg_dict
    ):
        """
        Func:
            The callback function after each MCTS loop. 
        """
        
        time_ed = time.time()

        time_cost = time_ed - self._time_st

        self._mcts_loop_wall_time_cost_list.append(time_cost)

        # ---------= [NFEs] =---------
        need_cal_sample_idx_list = local_var_dict["need_cal_sample_idx_list"]

        for sample_idx in need_cal_sample_idx_list:
            (
                nfe_cal_dynamics, 
                nfe_cal_intermediate_reward, 
                nfe_cal_final_reward
            ) = self._get_nfe(sample_idx = sample_idx)

            self._nfe_cal_dynamics_list_list[sample_idx].append(nfe_cal_dynamics)
            self._nfe_cal_intermediate_reward_list_list[sample_idx].append(nfe_cal_intermediate_reward)
            self._nfe_cal_final_reward_list_list[sample_idx].append(nfe_cal_final_reward)

            logger(
                f"[Sample {sample_idx}] "
                f"nfe_cal_dynamics: {nfe_cal_dynamics}, "
                f"nfe_cal_intermediate_reward: {nfe_cal_intermediate_reward}, "
                f"nfe_cal_final_reward: {nfe_cal_final_reward}"
            )

            # goto `for sample_idx`
            pass

        # ---------= [Clean Up] =---------
        del need_cal_sample_idx_list
        gc.collect()

        # `_mcts_loop_callback()` done
        pass
    

    def display_sample_result(
        self, 

        sample_idx: int, 

        display_trajectory: bool = False, 
        display_state: bool = True, 
        display_action: bool = True, 
        display_reward: bool = True
    ):
        """
        Func:
            Display the search results. 
        """

        best_merged_reward = self.history_best_merged_reward_list_list[sample_idx][-1]
        if not isinstance(best_merged_reward, float):
            best_merged_reward = best_merged_reward.item()

        last_final_reward_list = self.history_last_final_reward_list_list[sample_idx]
        best_final_reward = max(last_final_reward_list)
        if not isinstance(best_final_reward, float):
            best_final_reward = best_final_reward.item()

        logger(f"    [Main Results]")
        logger(f"        sample_idx: {sample_idx}")
        logger(f"        best_merged_reward: {best_merged_reward:.4f}, best_final_reward: {best_final_reward:.4f}")

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        logger(
            f"        nfe_cal_dynamics: {nfe_cal_dynamics}, "
            f"nfe_cal_intermediate_reward: {nfe_cal_intermediate_reward}, "
            f"nfe_cal_final_reward: {nfe_cal_final_reward}"
        )
        
        if display_trajectory:
            logger(f"[Best Trajectory]")
            
            self.best_trajectory_list[sample_idx].display_trajectory(
                display_state = display_state, 
                display_action = display_action, 
                display_reward = display_reward
            )
    
        # `display_sample_result()` done
        pass
    
    
    @torch.no_grad()
    def _sample_expansion_action(
        self, 
        
        node: MCTSNode, 

        sample_idx: int
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Sample an action from the action space for `self._expand()`. 

        Ret:
            `action` (`Union[torch.Tensor, np.ndarray]`): The sampled action. 
        """

        action_sampling_policy = self.expansion_action_sampling_policy

        node_idx = node.node_idx
        # state = node.get_state(
        #     sample_idx_list = sample_idx
        # )[0]
        # timestep_idx = node.depth

        if action_sampling_policy == "uniform":
            action = self.mdp.action_space.sample_uniform_element()
        elif action_sampling_policy == "beta":
            unimodal_beta_distribution = self._unimodal_beta_distribution_list_list[sample_idx][node_idx]

            if unimodal_beta_distribution.initialized:
                action = unimodal_beta_distribution.sample(
                    shape = self.mdp.get_default_action().shape
                )
            else:
                action = self.mdp.action_space.sample_uniform_element()

                unimodal_beta_distribution.init_mode = action

                # unimodal_beta_distribution.update_with_mode(
                #     mode = action, 
                #     zeta = self._beta_zeta_list[timestep_idx], 

                #     clamp_eps = self._beta_clamp_eps
                # )

                # unimodal_beta_distribution.initialized = True
                
        # elif action_sampling_policy == "optimal_control_beta":
        #     action = self.optimal_control_solver.cal_optimal_action(
        #         state = state, 
        #         t = timestep_idx
        #     )

        #     if not unimodal_beta_distribution.initialized:
        #         unimodal_beta_distribution.update_with_mode(
        #             mode = action, 
        #             zeta = self._beta_zeta_list[timestep_idx], 

        #             clamp_eps = self._beta_clamp_eps
        #         )

        #         unimodal_beta_distribution.initialized = True
            
        #     action = unimodal_beta_distribution.sample(
        #         shape = self.mdp.get_default_action().shape
        #     )

        else:
            raise ValueError(
                f"Unsupported `self.expansion_action_sampling_policy`, got `{action_sampling_policy}`. "
            )

        # if self._beta_parameterization:
        #     self._unimodal_beta_distribution_list_list[sample_idx][node_idx] = unimodal_beta_distribution
        
        # `_sample_expansion_action()` done
        return action


    @torch.no_grad()
    def _batch_sample_expansion_action(
        self, 
        
        node: MCTSNode, 

        sample_idx_list: List[int]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Batch sample actions from the action space for `self._expand()`. 

        Ret:
            `action_list` (`Union[torch.Tensor, np.ndarray]`): The batch of the sampled actions. 
        """

        action_list = [
            self._sample_expansion_action(
                node = node, 

                sample_idx = sample_idx
            ) \
                for sample_idx in sample_idx_list
        ]

        action_list = torch.stack(
            action_list, 
            dim = 0
        )

        # `_batch_sample_expansion_action()` done
        return action_list


    @torch.no_grad()
    def _sample_simulation_action(
        self, 
        
        timestep_idx: int

        # `action_sampling_policy` in [
        #     "deterministic", 
        #     "uniform"
        # ]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Sample an action from the action space for `self._simulate()`. 

        Ret:
            `action` (`Union[torch.Tensor, np.ndarray]`): The sampled action. 
        """

        action_sampling_policy = self.simulation_action_sampling_policy

        if action_sampling_policy == "deterministic":
            action = self.simulation_default_action_list[timestep_idx - 1]
        elif action_sampling_policy == "uniform":
            action = self.mdp.action_space.sample_uniform_element()

        else:
            raise ValueError(
                f"Unsupported `self.simulation_action_sampling_policy`, "
                f"got `{action_sampling_policy}`. "
            )
        
        # `_sample_simulation_action()` done
        return action


    @torch.no_grad()
    def _batch_sample_simulation_action(
        self, 
        
        timestep_idx_list: List[int]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Batch sample actions from the action space for `self._simulate()`. 

        Ret:
            `action_list` (`Union[torch.Tensor, np.ndarray]`): The batch of the sampled actions. 
        """
        
        action_list = [
            self._sample_simulation_action(timestep_idx = timestep_idx) \
                for timestep_idx in timestep_idx_list
        ]
        action_list = torch.stack(action_list)

        # `_batch_sample_simulation_action()` done
        return action_list
