from util.logger import logger

from typing import Optional, Union, List, Dict, Tuple

from abc import ABC, abstractmethod

from tqdm.auto import tqdm

import time

import numpy as np

import torch

import gc

from util.basic_util import get_attr

# from ..importance_sampling import (
#     cal_q, 
#     get_num_split
# )
from .mcts_node import MCTSNode
from .info import Info
from ..markov_decision_process.trajectory import Trajectory


class MonteCarloTreeSearch(ABC):
    def __init__(
        self, 

        is_eps_action: Optional[bool] = False, 

        mdp: "MarkovDecisionProcess" = None, 

        init_state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        # ---------= [Mode] =---------
        mdp_modeling: str = "max_reward", 
        value_policy: str = "max", 
        pseudo_latent_as_final: Optional[bool] = False, 
        enable_pseudo_latent_as_final_depth: Optional[int] = 1, 

        # ---------= [Upper Confidence Bound (UCB)] =---------
        exploration_coef: float = 3.0, 
        # depth_coef: float = 0.2, 
        # exclude_last_intermediate_reward: bool = False, 

        # ---------= [Selection Policy] =---------
        selection_depth_lim: int = None, 

        # ---------= [Expansion Policy] =---------
        # (discarded) `action_sampling_policy` in [
        #     "deterministic", 
        #     "optimal_control", 
        #     "spectral_expansion", 
        #     "uniform", 
        #     "optimal_control_beta"
        # ]
        expansion_action_sampling_policy: str = "uniform",   # ["uniform", "beta"]
        # only used for `expansion_action_sampling_policy` in ["uniform", "optimal_control_beta"]
        # expansion_enable_importance_sampling: Optional[bool] = False, 
        # expansion_importance_sampling_J_star_scaling_factor: Optional[float] = None, 
        # expansion_importance_sampling_eps: Optional[float] = 1e-8, 
        # expansion_importance_sampling_verbose: Optional[bool] = False, 
        # per_iteration_expansion_lim: Optional[int] = 1, 

        # ---------= [Simulation Policy] =---------
        simulation_action_sampling_policy: str = "deterministic",  
        simulation_default_action_list: Optional[
            Union[
                Union[torch.Tensor, np.ndarray], 
                Union[List[torch.Tensor], List[np.ndarray]]
            ]
        ] = None, 

        dtype: Optional[str] = "float32", 

        **arg_dict: Optional[Dict]
    ):
        self._is_eps_action = is_eps_action

        self.mdp = mdp

        # ---------= [Mode] =---------
        if mdp_modeling not in [
            "max_reward", 
            "sparse_reward", 
            "cumulative_reward"
        ]:
            raise NotImplementedError(
                f"Unsupported `mdp_modeling`, got `{mdp_modeling}`. "
            )
        
        if value_policy not in [
            "max", 
            "average"
        ]:
            raise NotImplementedError(
                f"Unsupported `value_policy`, got `{value_policy}`. "
            )

        self._mdp_modeling = mdp_modeling
        self._value_policy = value_policy
        self._pseudo_latent_as_final = pseudo_latent_as_final
        
        if enable_pseudo_latent_as_final_depth is None:
            enable_pseudo_latent_as_final_depth = 1
        self._enable_pseudo_latent_as_final_depth = enable_pseudo_latent_as_final_depth

        if self._mdp_modeling == "cumulative_reward":
            if self._pseudo_latent_as_final:
                raise ValueError(
                   f"`pseudo_latent_as_final` can only be used when `mdp_modeling = \"cumulative_reward\"`. "
                )

        # ---------= [Expansion Policy 1] =---------
        if expansion_action_sampling_policy not in [
            "uniform", 
            "beta"
        ]:
            raise NotImplementedError(
                f"Unsupported `expansion_action_sampling_policy`, got `{expansion_action_sampling_policy}`. "
            )
        
        self.expansion_action_sampling_policy = expansion_action_sampling_policy

        # ---------= [MDP] =---------
        self.ver = self.mdp.ver

        self.dtype = dtype
        if isinstance(self.dtype, str):
            self.dtype = get_attr(self.ver, self.dtype)

        self.device = self.mdp.device
        
        self.reward_shape = self.mdp.reward_shape

        self.num_sample = len(init_state_list)

        self.root = self._get_root(
            init_state_list = init_state_list
        )
        
        # ---------= [Prepare `self.`] =---------
        self.mcts_loop_selected_node_tuple_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        # ---------= [Prepare `self.history_best_merged_reward_list_list` & `best_final_reward_list_list`] =---------
        # init_reward_list.shape = (num_sample, 1)
        if self.ver == "torch":
            init_reward_list = torch.tensor(
                [float("-inf") for _ in range(self.num_sample)], 

                dtype = self.dtype, 
                device = self.device
            ).reshape(self.num_sample, *self.reward_shape)
        elif self.ver == "numpy":
            init_reward_list = np.array(
                [float("-inf") for _ in range(self.num_sample)], 

                dtype = self.dtype
            ).reshape(self.num_sample, *self.reward_shape)
        
        # history_best_merged_reward_list_list.shape = (num_sample, num_reward, 1)
        self.history_best_merged_reward_list_list = [
            [init_reward_list[sample_idx]] \
                for sample_idx in range(self.num_sample)
        ]

        # history_last_final_reward_list_list.shape = (num_sample, num_reward, 1)
        self.history_last_final_reward_list_list = [
            [init_reward_list[sample_idx]] \
                for sample_idx in range(self.num_sample)
        ]

        # ---------= [MCTS Loop Index List] =---------
        self._mcts_loop_idx_list = [
            0 \
                for _ in range(self.num_sample)
        ]

        # ---------= [Prepare `self.best_trajectory`] =---------
        self.best_trajectory_list = [
            Trajectory(
                reward_shape = self.reward_shape, 

                dtype = self.dtype, 
                device = self.device, 

                ver = self.ver
            ) \
                for sample_idx in range(self.num_sample)
        ]

        # ---------= [UCB] =---------
        self._exploration_coef = exploration_coef
        # self._depth_coef = depth_coef
        # self._exclude_last_intermediate_reward = exclude_last_intermediate_reward
        
        # ---------= [Selection Policy] =---------
        _selection_depth_lim = max(self.mdp.time_horizon - 1, 0)

        if selection_depth_lim is None:
            logger(
                f"`selection_depth_lim` is not provided, "
                f"set default to `max(time_horizon - 1, 0)` ({_selection_depth_lim}). "
            )

            selection_depth_lim = _selection_depth_lim

        if selection_depth_lim > _selection_depth_lim:
            raise ValueError(
                f"`selection_depth_lim` is larger than the maximum allowed value "
                f"`max(time_horizon - 1, 0)` ({_selection_depth_lim}). "
            )
        
        self._selection_depth_lim = selection_depth_lim
        
        # ---------= [Expansion Policy 2] =---------
        # self.expansion_enable_importance_sampling = expansion_enable_importance_sampling
        # self.expansion_importance_sampling_J_star_scaling_factor = expansion_importance_sampling_J_star_scaling_factor
        # self.expansion_importance_sampling_eps = expansion_importance_sampling_eps
        # self.expansion_importance_sampling_verbose = expansion_importance_sampling_verbose
        # self.per_iteration_expansion_lim = per_iteration_expansion_lim

        # ---------= [Simulation Policy] =---------
        self.simulation_action_sampling_policy = simulation_action_sampling_policy
        self.simulation_default_action_list = self._prepare_default_action_list(
            default_action_list = simulation_default_action_list
        )

        # ---------= [Timer] =---------
        self._time_st = None
        self._mcts_loop_wall_time_cost_list = []
        self._wall_clock_time_cost_list = [0] * self.num_sample  # unit: second

        # `__init__()` done
        pass


    def _get_root(
        self, 

        init_state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ]
    ) -> MCTSNode:
        """
        Func:
            Get the root node. 

        Ret:
            `root` (`MCTSNode`): The root node. 
        """

        sample_idx_list = list(
            range(self.num_sample)
        )

        root = self._get_new_node(
            state_list = init_state_list, 
            sample_idx_list = sample_idx_list
        )

        # for sample_idx in sample_idx_list:
        #     root.info_list[sample_idx].potential = 0
            
        #     # goto `for sample_idx`
        #     pass

        # `_get_root()` done
        return root

    
    def _cal_node_reward(
        self, 

        node: MCTSNode, 

        parent: Optional[MCTSNode] = None, 

        sample_idx_list: List[int] = None, 

        timestep_idx_list: List[int] = None
    ):
        """
        Func: 
            Compute the `reward_sum_to_root` and `final_reward` (if terminal) for `node`. 
        """

        num_sample = len(sample_idx_list)

        if self._mdp_modeling == "max_reward":
            init_merged_reward = float("-inf")
        elif self._mdp_modeling in [
            "sparse_reward", 
            "cumulative_reward"
        ]:
            init_merged_reward = 0.0
        
        # merged_reward_to_root_list.shape = (num_sample, )
        merged_reward_to_root_list = torch.tensor(
            [init_merged_reward] * num_sample, 

            dtype = self.dtype, 
            device = "cpu"
        )

        if parent:
            for i, sample_idx in enumerate(sample_idx_list):
                parent_info = parent.info_list[sample_idx]

                if parent_info is not None:
                    merged_reward_to_root_list[i] = parent_info.merged_reward_to_root

                # goto `for i, sample_idx`
                pass

        depth = node.depth

        time_horizon = self.mdp.time_horizon

        pseudo_final_latent_list = None
        
        if 0 < depth < time_horizon:
            # intermediate_reward_list.shape = (num_sample, )
            # pseudo_final_latent_list.shape = (num_sample, *self.state_shape)
            # potential_list.shape = (num_sample, )
            (
                intermediate_reward_list, 
                pseudo_final_latent_list
                # potential_list
            ) = self._cal_intermediate_reward(
                node = node, 

                sample_idx_list = sample_idx_list, 
                
                timestep_idx_list = timestep_idx_list
            )
        
            # for sample_idx_idx, sample_idx in enumerate(sample_idx_list):
            #     node.info_list[sample_idx].potential = potential_list[sample_idx_idx]

            #     # goto `for sample_idx`
            #     pass
            
            if not isinstance(intermediate_reward_list, torch.Tensor):
                intermediate_reward_list = torch.vstack(intermediate_reward_list)
            
            pseudo_final_latent_list_is_none = False
            if (pseudo_final_latent_list is None) \
                or (pseudo_final_latent_list == [None] * num_sample):
                pseudo_final_latent_list_is_none = True

            if (not pseudo_final_latent_list_is_none) \
                and (not isinstance(pseudo_final_latent_list, torch.Tensor)):
                
                if pseudo_final_latent_list[0].ndim == 3:
                    pseudo_final_latent_list = torch.stack(pseudo_final_latent_list)
                elif pseudo_final_latent_list[0].ndim == 4:
                    pseudo_final_latent_list = torch.vstack(pseudo_final_latent_list)

            intermediate_reward_list = torch.tensor(
                [
                    intermediate_reward.item() \
                        for intermediate_reward in intermediate_reward_list
                ], 

                dtype = self.dtype, 
                device = "cpu"
            )

            if self._mdp_modeling == "max_reward":
                merged_reward_to_root_list = torch.max(merged_reward_to_root_list, intermediate_reward_list)
            elif self._mdp_modeling in [
                "sparse_reward", 
                "cumulative_reward"
            ]:
                merged_reward_to_root_list = merged_reward_to_root_list + intermediate_reward_list

            # if self._mdp_modeling == "max_reward":
            if self._pseudo_latent_as_final and (depth >= self._enable_pseudo_latent_as_final_depth):
                # ---------= [Update Best Trajectory] =---------
                for (intermediate_reward, merged_reward_to_root, pseudo_final_latent, sample_idx) in zip(
                    intermediate_reward_list, 
                    merged_reward_to_root_list, 
                    pseudo_final_latent_list, 
                    sample_idx_list
                ):
                    merged_reward = merged_reward_to_root
                    
                    best_trajectory = self.best_trajectory_list[sample_idx]

                    if self._get_is_cost_legal(sample_idx = sample_idx) \
                        and (merged_reward > best_trajectory.merged_reward):
                        
                        self.history_last_final_reward_list_list[sample_idx].append(intermediate_reward)
                        self.history_best_merged_reward_list_list[sample_idx].append(merged_reward)

                        best_trajectory.final_reward = intermediate_reward
                        best_trajectory.merged_reward = merged_reward
                        best_trajectory.incomplete = True

                        node.info_list[sample_idx].final_reward = intermediate_reward
                        
                        (
                            state_list, 
                            action_list, 
                            reward_list
                        ) = best_trajectory.get_trajectory_to_root(
                            node = node, 
                            sample_idx = sample_idx, 

                            include_final_reward = True
                        )

                        state_list.append(pseudo_final_latent)

                        best_trajectory.state_list = state_list
                        best_trajectory.action_list = action_list
                        best_trajectory.reward_list = reward_list

                        # ---------= [Clean Up] =---------
                        trash_best_trajectory = self.best_trajectory_list[sample_idx]
                        self.best_trajectory_list[sample_idx] = None
                        del trash_best_trajectory
                        gc.collect()
                        if self.ver == "torch":
                            torch.cuda.empty_cache()

                        # ---------= [Update] =---------
                        self.best_trajectory_list[sample_idx] = best_trajectory

                        timestep_idx = node.depth

                        self._best_trajectory_updated_callback(
                            caller = "_cal_node_reward()", 
                            pseudo = True, 

                            node = node, 

                            sample_idx = sample_idx, 

                            timestep_idx = timestep_idx
                        )

                # goto `for (merged_reward_to_root, sample_idx)`
                pass
            
            if (pseudo_final_latent_list is not None) \
                and (not pseudo_final_latent_list_is_none):

                for i, sample_idx in enumerate(sample_idx_list):
                    node.info_list[sample_idx].pseudo_final_latent = pseudo_final_latent_list[i].cpu()

                    # goto `for i, sample_idx`
                    pass

            # ---------= [Clean Up] =---------
            del intermediate_reward_list
            if pseudo_final_latent_list is not None:
                del pseudo_final_latent_list
            # if potential_list is not None:
            #     del potential_list
            gc.collect()
            if self.ver == "torch":
                torch.cuda.empty_cache()
        
        for i, sample_idx in enumerate(sample_idx_list):
            node.info_list[sample_idx].merged_reward_to_root = merged_reward_to_root_list[i].item()

            # goto `for i, sample_idx`
            pass

        if depth == time_horizon:
            final_reward_list = self._cal_final_reward(
                node = node, 

                sample_idx_list = sample_idx_list
            )

            final_reward_list = torch.tensor(
                [
                    final_reward.item() \
                        for final_reward in final_reward_list
                ], 

                dtype = self.dtype, 
                device = "cpu"
            )

            for i, (sample_idx, final_reward) in enumerate(
                zip(
                    sample_idx_list, 
                    final_reward_list
                )
            ):
                node.info_list[sample_idx].final_reward = final_reward

                # goto `for i, (sample_idx, final_reward)`
                pass

            # if self._mdp_modeling == "max_reward":
            #     merged_reward_to_root_list = torch.max(merged_reward_to_root_list, final_reward_list)
            # elif self._mdp_modeling in [
            #     "sparse_reward", 
            #     "cumulative_reward"
            # ]:
            #     merged_reward_to_root_list = merged_reward_to_root_list + final_reward_list

            # for i, (sample_idx, final_reward, merged_reward_to_root) in enumerate(
            #     zip(
            #         sample_idx_list, 
            #         final_reward_list, 
            #         merged_reward_to_root_list
            #     )
            # ):
            #     node.info_list[sample_idx].final_reward = final_reward
            #     # node.info_list[sample_idx].merged_reward_to_root = merged_reward_to_root

            #     # goto `for i, sample_idx`
            #     pass
            
            # ---------= [Clean Up] =---------
            del final_reward_list
            gc.collect()
            if self.ver == "torch":
                torch.cuda.empty_cache()

        # `_cal_node_reward()` done
        pass


    def _get_new_node(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ], 

        prev_action_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        parent: MCTSNode = None, 

        # timestep_idx_list: List[int] = None
    ) -> MCTSNode:
        """
        Func:
            Get a new node and compute its intermediate reward. 
            If terminal, compute its final reward, too. 

        Ret:
            `new_node` (`MCTSNode`): The created node. 
        """

        # ---------= [Prepare `state_list`] =---------
        if isinstance(state_list, list):
            if self.ver == "torch":
                state_list = torch.stack(state_list)
            elif self.ver == "numpy":
                state_list = np.stack(state_list)

        if self.ver == "torch":
            state_list = state_list.clone()
        elif self.ver == "numpy":
            state_list = state_list.copy()

        # ---------= [New Node] =---------
        new_node = MCTSNode(
            mcts_instance = self, 

            sample_idx_list = sample_idx_list, 

            state_list = state_list, 

            prev_action_list = prev_action_list, 

            parent = parent, 

            device = self.device
        )

        # # ---------= [Cal Reward] =---------
        # self._cal_node_reward(
        #     node = new_node, 
            
        #     parent = parent, 
            
        #     sample_idx_list = sample_idx_list, 

        #     timestep_idx_list = timestep_idx_list
        # )

        # num_sample = len(sample_idx_list)

        # # reward_sum_to_root_list.shape = (num_sample, )
        # reward_sum_to_root_list = torch.zeros(
        #     (num_sample, ), 

        #     dtype = self.dtype, 
        #     device = "cpu"
        # )

        # if parent:
        #     for i, sample_idx in enumerate(sample_idx_list):
        #         parent_info = parent.info_list[sample_idx]

        #         if parent_info is not None:
        #             reward_sum_to_root_list[i] = parent_info.reward_sum_to_root

        #         # goto `for i, sample_idx`
        #         pass

        # depth = new_node.depth

        # if depth > 0:
        #     intermediate_reward_list = self._cal_intermediate_reward(
        #         node = new_node, 

        #         sample_idx_list = sample_idx, 
                
        #         timestep_idx_list = timestep_idx_list
        #     )[0]

        #     reward_sum_to_root_list = reward_sum_to_root_list + intermediate_reward_list.item()

        #     del intermediate_reward_list
        #     gc.collect()
        #     if self.ver == "torch":
        #         torch.cuda.empty_cache()

        # for i, sample_idx in enumerate(sample_idx_list):
        #     new_node.info_list[sample_idx].reward_sum_to_root = reward_sum_to_root_list[i].item()

        #     # goto `for i, sample_idx`
        #     pass

        # if depth == self.mdp.time_horizon:
        #     final_reward_list = self._cal_final_reward(
        #         node = new_node, 

        #         sample_idx_list = sample_idx
        #     )

        #     for i, sample_idx in enumerate(sample_idx_list):
        #         new_node.info_list[sample_idx].final_reward = final_reward_list[i].item()

        #         # goto `for i, sample_idx`
        #         pass
            
        #     # ---------= [Clean Up] =---------
        #     del final_reward_list
        #     gc.collect()
        #     torch.cuda.empty_cache()

        # # ---------= [Clean Up] =---------
        # del reward_sum_to_root_list
        # gc.collect()
        # torch.cuda.empty_cache()

        # `_get_new_node()` done
        return new_node


    def _prepare_default_action_list(
        self, 

        default_action_list: Union[
            Union[torch.Tensor, np.ndarray], 
            Union[List[torch.Tensor], List[np.ndarray]]
        ] = None
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Check the input `default_action_list`, and ensure its length equal to `self.mdp.time_horizon`. 
            If `default_action_list = None`, set it to a list of default actions. 

        Ret:
            `default_action_list` (`torch.Tensor` or `np.ndarray`): The list of the default actions. 
        """
        
        if default_action_list is None:
            default_action = self.mdp.get_default_action()

            default_action_list = default_action

        if self.ver == "torch":
            from util.torch_util import tsfm_to_1d_array
            
            if not isinstance(default_action_list, torch.Tensor):
                default_action_list = torch.tensor(
                    default_action_list, 

                    dtype = self.dtype, 
                    device = self.device
                )

                # default_action_list = default_action_list.reshape(
                #     default_action_list.shape[0], *action_space
                # )

            default_action_list = tsfm_to_1d_array(
                array = default_action_list, 
                target_length = self.mdp.time_horizon, 

                dtype = self.dtype, 
                device = self.device
            )
        elif self.ver == "numpy":
            from util.numpy_util import tsfm_to_1d_array

            if not isinstance(default_action_list, np.ndarray):
                default_action_list = np.asarray(
                    default_action_list, 

                    dtype = self.dtype, 
                    device = self.device
                )

                # default_action_list = default_action_list.reshape(
                #     default_action_list.shape[0], action_space
                # )
        
            default_action_list = tsfm_to_1d_array(
                array = default_action_list, 
                target_length = self.mdp.time_horizon, 

                dtype = self.dtype
            )
        
        # action_space = self.mdp.action_space.shape
        default_action_list = default_action_list.reshape(
            (self.mdp.time_horizon, *self.mdp.action_space.shape)
        )

        # `_prepare_default_action_list()` done
        return default_action_list


    def _cal_intermediate_reward(
        self, 

        node: MCTSNode, 

        sample_idx_list: Union[int, List[int]] = None, 

        timestep_idx_list: List[int] = None
    ) -> Union[
        Tuple[List[torch.Tensor], List[torch.Tensor]], 
        Tuple[List[np.ndarray], List[np.ndarray]]
    ]:
        """
        Func:
            Compute the intermediate reward list for transitioning from the parent node to `node`. 

        Ret:
            `intermediate_reward_list` (`List[torch.Tensor]` or `List[np.ndarray]`): 
                The list of the intermediate rewards. 
                (`sample_idx_list` is provided) intermediate_reward_list.shape = (len(sample_idx_list), *reward_shape). 
                (`sample_idx_list` is not provided) intermediate_reward_list.shape = (num_sample, *reward_shape). 
            
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
        """

        if node.depth == 0:
            raise ValueError(
                f"Can not compute the intermediate reward for the root node. "
            )

        # ---------= [Prepare `sample_idx_list`] =---------
        if sample_idx_list is None:
            sample_idx_list = list(
                range(self.num_sample)
            )
        
        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        # ---------= [Compute Intermediate Reward] =---------
        intermediate_reward_list = []
        pseudo_final_latent_list = []
        # potential_list = []

        need_cal_idx_list = []

        for i, sample_idx in enumerate(sample_idx_list):
            info = node.info_list[sample_idx]

            if info is not None:
                intermediate_reward = info.intermediate_reward
                pseudo_final_latent = info.pseudo_final_latent
                # potential = info.potential

                if intermediate_reward is not None:
                    intermediate_reward = torch.tensor(
                        intermediate_reward, 

                        dtype = self.dtype, 
                        device = self.device
                    )
                    intermediate_reward = intermediate_reward.reshape(self.reward_shape)

                    intermediate_reward_list.append(intermediate_reward)
                    pseudo_final_latent_list.append(pseudo_final_latent)

                    # potential_list.append(potential)
                else:
                    intermediate_reward_list.append(None)
                    pseudo_final_latent_list.append(None)

                    # potential_list.append(None)

                    need_cal_idx_list.append(i)
            else:
                intermediate_reward_list.append(None)
                pseudo_final_latent_list.append(None)
                # potential_list.append(None)

            # goto `for i, sample_idx`
            pass
        
        need_cal_sample_idx_list = None

        if len(need_cal_idx_list) > 0:
            need_cal_sample_idx_list = [
                sample_idx_list[i] \
                    for i in need_cal_idx_list
            ]

            # state_list.shape = (num_need_cal, *state_shape)
            state_list = [
                node.get_state(
                    sample_idx_list = sample_idx
                )[0] \
                    for sample_idx in need_cal_sample_idx_list
            ]
            state_list = torch.stack(state_list)

            # prev_action_list.shape = (num_need_cal, *action_shape)
            prev_action_list = [
                node.info_list[sample_idx].prev_action \
                    for sample_idx in need_cal_sample_idx_list
            ]

            # need_timestep_idx_list.shape = (num_need_cal, )
            need_timestep_idx_list = [
                timestep_idx_list[i] \
                    for i in need_cal_idx_list
            ]

            # prev_latent_list = None
            # prev_potential_list = None

            # if self.mdp.reward_model.reward_shaping_policy == "skipping":
            #     # prev_latent_list: `z_{t_{i - 1}}`
            #     # prev_latent_list.shape = (batch_size, 4, latent_height, latent_width)
            #     prev_latent_list = [
            #         node.parent.get_state(
            #             sample_idx_list = sample_idx
            #         )[0] \
            #             for sample_idx in sample_idx_list
            #     ]
            #     prev_latent_list = torch.stack(prev_latent_list)
            # elif self.mdp.reward_model.reward_shaping_policy == "potential_based":
            #     prev_potential_list = torch.tensor(
            #         [
            #             node.parent.info_list[sample_idx].potential \
            #                 for sample_idx in sample_idx_list
            #         ]
            #     )
            
            # need_cal_intermediate_reward_list.shape = (num_need_cal, )
            # need_cal_potential_list.shape = (num_need_cal, )
            # need_cal_pseudo_final_latent_list.shape = (num_need_cal, *state_shape)
            (
                need_cal_intermediate_reward_list, 
                # need_cal_potential_list
                need_cal_pseudo_final_latent_list
            ) = self.mdp.batch_cal_intermediate_reward(
                state_list = state_list, 
                prev_action_list = prev_action_list, 

                timestep_idx_list = need_timestep_idx_list, 

                # prev_latent_list = prev_latent_list, 
                
                # prev_potential_list = prev_potential_list, 

                sample_idx_list = need_cal_sample_idx_list
            )
            
            for i, (need_cal_idx, need_cal_sample_idx) in enumerate(
                zip(
                    need_cal_idx_list, 
                    need_cal_sample_idx_list
                )
            ):
                intermediate_reward = need_cal_intermediate_reward_list[i]
                
                pseudo_final_latent = None
                if need_cal_pseudo_final_latent_list is not None:
                    pseudo_final_latent = need_cal_pseudo_final_latent_list[i]

                intermediate_reward_list[need_cal_idx] = intermediate_reward
                pseudo_final_latent_list[need_cal_idx] = pseudo_final_latent

                node.info_list[need_cal_sample_idx].intermediate_reward = intermediate_reward
                node.info_list[need_cal_sample_idx].pseudo_final_latent = pseudo_final_latent

                # goto `for i, (need_cal_idx, need_cal_sample_idx)`
                pass
            
            # ---------= [Clean Up] =---------
            del need_cal_sample_idx_list
            del state_list, prev_action_list
            del need_cal_intermediate_reward_list
            if need_cal_pseudo_final_latent_list is not None:
                del need_cal_pseudo_final_latent_list
        
        # ---------= [Clean Up] =---------
        del need_cal_idx_list
        gc.collect()
        if self.ver == "torch":
            torch.cuda.empty_cache()

        # `_cal_intermediate_reward()` done
        return (
            intermediate_reward_list, 
            # potential_list
            pseudo_final_latent_list
        )


    def _cal_final_reward(
        self, 

        node: MCTSNode, 

        sample_idx_list: Union[int, List[int]] = None
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Compute the final reward for the terminal node `node`. 

        Ret:
            `final_reward` (`torch.Tensor` or `np.ndarray`): 
                The list of the final rewards.
                (`sample_idx_list` is provided) final_reward_list.shape = (len(sample_idx_list), *reward_shape). 
                (`sample_idx_list` is not provided) final_reward_list.shape = (num_sample, *reward_shape).  
        """

        if node.depth != self.mdp.time_horizon:
            raise ValueError(
                f"Can not compute the final reward for a non-terminal node. "
            )
        
        # ---------= [Prepare `sample_idx_list`] =---------
        if sample_idx_list is None:
            sample_idx_list = list(
                range(self.num_sample)
            )
        
        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        need_cal_idx_list = []

        for i, sample_idx in enumerate(sample_idx_list):
            info = node.info_list[sample_idx]
            
            if info is not None:
                final_reward = info.final_reward

                if final_reward is not None:
                    final_reward = torch.tensor(
                        final_reward, 

                        dtype = self.dtype, 
                        device = self.device
                    )
                    final_reward = final_reward.reshape(self.reward_shape)

                    final_reward_list.append(final_reward)
                else:
                    final_reward_list.append(None)

                    need_cal_idx_list.append(i)
            else:
                final_reward_list.append(None)

            # goto `for i, sample_idx`
            pass

        if len(need_cal_idx_list) > 0:
            need_cal_sample_idx_list = [
                sample_idx_list[i] \
                    for i in need_cal_idx_list
            ]

            # state_list.shape = (num_need_cal, *state_shape)
            state_list = [
                node.get_state(
                    sample_idx_list = sample_idx
                )[0] \
                    for sample_idx in need_cal_sample_idx_list
            ]
            state_list = torch.stack(state_list)

            # need_cal_final_reward_list.shape = (num_need_cal, )
            need_cal_final_reward_list = self.mdp.batch_cal_final_reward(
                sample_idx_list = need_cal_sample_idx_list, 

                state_list = state_list
            )

            for i, (need_cal_idx, need_cal_sample_idx) in enumerate(
                zip(
                    need_cal_idx_list, 
                    need_cal_sample_idx_list
                )
            ):
                final_reward = need_cal_final_reward_list[i]

                final_reward_list[need_cal_idx] = final_reward

                node.info_list[need_cal_sample_idx].final_reward = final_reward

                # goto `for i, (need_cal_idx, need_cal_sample_idx)`
                pass

            # ---------= [Clean Up] =---------
            del need_cal_sample_idx_list
            del state_list
            del need_cal_final_reward_list
        
        # ---------= [Clean Up] =---------
        del need_cal_idx_list
        gc.collect()
        if self.ver == "torch":
            torch.cuda.empty_cache()

        # `_cal_final_reward()` done
        return final_reward_list


    def _cal_state_value(
        self, 

        node: MCTSNode, 

        sample_idx_list: Union[int, List[int]] = None, 

        display_cal_state_value: Optional[bool] = False
    ) -> Union[
        List[torch.Tensor], 
        List[np.ndarray]
    ]:
        """
        Func:
            Calculate the list of values of state `node`, used for MCTS selection. 
            Default implemented with upper confidence bound (UCB). 

        Ret:
            `state_value_list` (`List[torch.Tensor]` or `List[np.ndarray]`): 
                The list of state values of `node`. 
                (`sample_idx_list` is provided) state_value_list.shape = (len(sample_idx_list), ). 
                (`sample_idx_list` is not provided) state_value_list.shape = (num_sample, ). 
        """

        # ---------= [Prepare `sample_idx_list`] =---------
        if sample_idx_list is None:
            sample_idx_list = list(
                range(self.num_sample)
            )
        
        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        # ---------= [Compute State Value List] =---------
        # state_value_list.shape = (num_sample, )
        state_value_list = []

        for sample_idx in sample_idx_list:
            info = node.info_list[sample_idx]

            if info is None:
                state_value_list.append(None)

                continue
        
            # exploitation = info.value / (info.num_vis + 1)
            exploitation = info.value

            parent_num_vis = 0
            if node.parent is not None:
                parent_num_vis = node.parent.info_list[sample_idx].num_vis
            
            if self.ver == "torch":
                exploration = torch.sqrt(
                    torch.log(
                        torch.tensor(
                            parent_num_vis + 1, 

                            dtype = self.dtype, 
                            device = self.device
                        )
                    ) / (info.num_vis + 1)
                )
            elif self.ver == "numpy":
                exploration = np.sqrt(
                    np.log(
                        np.array(
                            parent_num_vis + 1, 
                            dtype = self.dtype
                        )
                    ) / (info.num_vis + 1)
                )

            # n = self.mdp.time_horizon

            # depth_reward = node.depth / n

            exploitation_coef = 1.0

            # accumulated_reward_list \
            #     = self.best_trajectory_list[sample_idx].accumulated_reward_list
            
            # if len(accumulated_reward_list) > 0:
            #     if node.depth == 0:
            #         reward_sum_to_leaf_upper_bound = accumulated_reward_list[n - 1]
            #     else:
            #         reward_sum_to_leaf_upper_bound \
            #             = accumulated_reward_list[n - 1] - accumulated_reward_list[node.depth - 1]
                
            #     exploitation_coef = 1.0 / reward_sum_to_leaf_upper_bound

            #     # logger(f"[node {node.node_idx}] reward_sum_to_leaf_upper_bound: {reward_sum_to_leaf_upper_bound}")

            exploitation_part = exploitation_coef * exploitation
            
            # exploration_part = self._exploration_coef * torch.sqrt(
            #     torch.log(
            #         torch.tensor(
            #             n, 

            #             dtype = self.dtype, 
            #             device = self.device
            #         )
            #     ) / n
            # ) * exploration

            # depth_part = self._depth_coef * depth_reward

            exploration_part = self._exploration_coef * exploration

            # state_value = exploitation_part + exploration_part + depth_part
            state_value = exploitation_part + exploration_part
            
            # dbg
            # state_value += node.depth * 10

            if display_cal_state_value:
                node_idx = node.node_idx

                logger(
                    f"[Node {node_idx}, Sample {sample_idx}] state_value: {state_value}"
                )

                logger(
                    f"    exploitation_part: {exploitation_part}, "
                    f"exploration_part: {exploration_part}, "
                    f"num_vis: {info.num_vis}, "
                    # f"depth_part: {depth_part}"
                )
        
            if self.ver == "torch":
                state_value = state_value.to(
                    dtype = self.dtype, 
                    device = self.device
                )
            elif self.ver == "numpy":
                state_value = state_value.astype(self.dtype)

            state_value_list.append(state_value)

            # goto `for sample_idx`
            pass

        # `_cal_state_value()` done
        return state_value_list
    

    def _select(
        self, 

        sample_idx_list: List[int], 

        display_cal_state_value: Optional[bool] = False, 

        display_state_value_list: Optional[bool] = False
    ) -> List[MCTSNode]:
        """
        Func:
            Select candidate nodes to expand. 

        Ret:
            `node_list` (`List[MCTSNode]`): A list of nodes to expand for each sample. 
        """

        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        
        def implement_sample(
            sample_idx: int
        ) -> MCTSNode:
            # ---------= [Prepare `candidate_node_list`] =---------
            candidate_node_list = []

            def dfs(
                node
            ):
                candidate_node_list.append(node)

                if node.depth > self._selection_depth_lim - 1:
                    return
                
                for child in node.children_list:
                    if child.info_list[sample_idx]:
                        dfs(child)

                    # goto `for child`
                    pass

                # `dfs()` done
                pass
            
            dfs(self.root)

            # ---------= [Prepare `state_value_list`] =---------
            num_node = len(candidate_node_list)
            
            # state_value_list.shape = (num_node, )
            state_value_list = [
                self._cal_state_value(
                    node = node, 

                    sample_idx_list = [sample_idx], 

                    display_cal_state_value = display_cal_state_value
                )[0] \
                    for node in candidate_node_list
            ]

            if display_state_value_list:
                node_idx_list = [
                    node.node_idx \
                        for node in candidate_node_list
                ]

                state_value_list_str = ", ".join(
                    [
                        f"{state_value:.4f}" \
                            for state_value in state_value_list
                    ]
                )
                
                logger(f"node_idx_list: {node_idx_list}, state_value_list: {state_value_list_str}")

                # ---------= [Clean Up] =---------
                del node_idx_list
                gc.collect()

            # ---------= [Select Node] =---------
            selected_node_idx = None
            max_state_value = float("-inf")

            for i in range(num_node):
                state_value = float("-inf")
                if state_value_list[i] is not None:
                    state_value = state_value_list[i]

                if state_value > max_state_value:
                    max_state_value = state_value
                    selected_node_idx = i

                # goto `for i`
                pass

            if selected_node_idx is None:
                selected_node = None
            else:
                selected_node = candidate_node_list[selected_node_idx]
            
            # ---------= [Clean Up] =---------
            del candidate_node_list
            del state_value_list
            gc.collect()
            if self.ver == "torch":
                torch.cuda.empty_cache()

            # `implement_sample()` done
            return selected_node


        selected_node_list = [
            implement_sample(sample_idx = sample_idx) \
                for sample_idx in sample_idx_list
        ]

        # `_select()` done
        return selected_node_list


    def _best_trajectory_updated_callback(
        self, 

        caller: str, 
        pseudo: bool, 

        node: MCTSNode, 

        sample_idx: int, 

        timestep_idx: int
    ):
        """
        Func:
            The callback function after the best trajectory is updated. 
        """

        best_merged_reward_list = self.history_best_merged_reward_list_list[sample_idx]

        second_last_merged_reward = best_merged_reward_list[-2]
        if not isinstance(second_last_merged_reward, float):
            second_last_merged_reward = second_last_merged_reward.item()
            
        last_merged_reward = best_merged_reward_list[-1]
        if not isinstance(last_merged_reward, float):
            last_merged_reward = last_merged_reward.item()

        last_final_reward_list = self.history_last_final_reward_list_list[sample_idx]

        second_last_final_reward = last_final_reward_list[-2]
        if not isinstance(second_last_final_reward, float):
            second_last_final_reward = second_last_final_reward.item()

        last_final_reward = last_final_reward_list[-1]
        if not isinstance(last_final_reward, float):
            last_final_reward = last_final_reward.item()

        logger(
            f"[Sample {sample_idx}] Best trajectory updated by `{caller}` at node {node.node_idx} for sample {sample_idx}. "
        )
        logger(
            f"    merged_reward: [{second_last_merged_reward:.4f}] -> [{last_merged_reward:.4f}], "
            f"final_reward: [{second_last_final_reward:.4f}] -> [{last_final_reward:.4f}]"
        )

        self.best_trajectory_list[sample_idx].update_accumulated_reward_list(
            mdp_modeling = self._mdp_modeling
        )

        # `_best_trajectory_updated_callback()` done
        pass


    def _get_node_empty_child(
        self, 

        node: MCTSNode,

        sample_idx: int
    ) -> MCTSNode:
        """
        Func:
            Return a child of `node` with empty state at sample `sample_idx`. 
            If not exists, return `None`. 

        Ret:
            `child_node` (`MCTSNode`): A child of `node` with empty state at sample `sample_idx`. 
        """
    
        for child_idx, child in enumerate(node.children_list):
            if child.info_list[sample_idx] is None:
                child.info_list[sample_idx] = -1  # occupy

                child_node = child

                num_occupied_list = [
                    sum(
                        [
                            child.info_list[idx] is not None \
                                for idx in range(self.num_sample)
                        ]
                    ) for child in node.children_list
                ]

                num_son = len(num_occupied_list)

                for son_idx in range(0, num_son - 1):
                    if num_occupied_list[son_idx] > num_occupied_list[son_idx + 1]:
                        num_occupied_list[son_idx], num_occupied_list[son_idx + 1] \
                            = num_occupied_list[son_idx + 1], num_occupied_list[son_idx]
                    
                        node.children_list[son_idx], node.children_list[son_idx + 1] \
                            = node.children_list[son_idx + 1], node.children_list[son_idx]

                    # goto `for son_idx`
                    pass

                return child_node

            # goto `for child_idx, child`
            pass

        # `_get_node_empty_child()` done
        return None


    def _expand(
        self, 

        node_list: List[MCTSNode], 

        sample_idx_list: List[int]
    ) -> Tuple[
        List[MCTSNode], 
        List[int]
    ]:
        """
        Func:
            Sample actions to expand the selected nodes, and create new child nodes. 

        Ret:
            `expanded_node_list` (`List[MCTSNode]`): The list of expanded nodes. 
            `_sample_idx_list` (`List[int]`): The list of sample idx. 
                Maybe different from the provided `sample_idx_list`. 
        """

        # ---------= [Sort By Node Index] =---------
        num_node = len(node_list)

        ord_list = sorted(
            list(
                range(num_node)
            ), 
            key = lambda ord_idx: node_list[ord_idx].node_idx
        )

        _node_list = []
        _sample_idx_list = []

        for ord_idx in ord_list:
            _node_list.append(node_list[ord_idx])
            _sample_idx_list.append(sample_idx_list[ord_idx])

            # goto `for ord_idx`
            pass
    
        node_list = _node_list
        sample_idx_list = _sample_idx_list

        # dbg
        # print("[Expand]")
        # print(f"    [selected] ord_list: {ord_list}, sample_idx_list: {sample_idx_list}")

        # timestep_idx_list.shape = (num_node, )
        timestep_idx_list = torch.tensor(
            [
                node.depth \
                    for node in node_list
            ], 

            dtype = torch.int32, 
            device = "cpu"
        )

        # terminal node
        if max(timestep_idx_list) >= self.mdp.time_horizon:
            raise ValueError(
                f"Can not expand a terminal node. "
            )

        # ---------= [Current State] =---------
        state_list = []

        # state_list = node.get_state(
        #     sample_idx_list = sample_idx_list
        # )

        node_i = 0
        
        while node_i < num_node:
            node_j = node_i

            while (node_j + 1 < num_node) \
                and (node_list[node_j + 1].node_idx == node_list[node_i].node_idx):
                node_j += 1
            
            node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]
            node_state_list = node_list[node_i].get_state(
                sample_idx_list = node_sample_idx_list
            )

            state_list += node_state_list

            node_i = node_j + 1

            # goto `while i < num_node`
            pass

        # state_list.shape = (num_node, *state_shape)
        state_list = torch.stack(state_list)

        num_state = state_list.shape[0]

        # # ---------= [Importance Sampling] =---------
        # num_split_list = [1] * num_node
        # if self.expansion_enable_importance_sampling:
        #     if self.expansion_action_sampling_policy not in [
        #         "uniform", 
        #         "optimal_control_beta"
        #     ]:
        #         raise ValueError(
        #             f"Importance Sampling is only supported when "
        #             f"`self.expansion_action_sampling_policy` is `uniform` or `optimal_control_beta`. "
        #         )

        #     per_iteration_expansion_lim = self.per_iteration_expansion_lim

        #     if per_iteration_expansion_lim <= 1:
        #         raise ValueError(
        #             f"`self.per_iteration_expansion_lim` should be larger than `1` to "
        #             f"perform importance sampling. "
        #         )

        #     J_list = [
        #         node.info_list[sample_idx].reward_sum_to_root \
        #             for node in node_list
        #     ]

        #     accumulated_reward_list = self.best_trajectory_list[sample_idx].accumulated_reward_list
            
        #     if (accumulated_reward_list is not None) \
        #         and (len(accumulated_reward_list) > 0):
                
        #         J_star_list = [
        #             accumulated_reward_list[timestep_idx] \
        #                 for timestep_idx in timestep_idx_list
        #         ]

        #         q_list = [
        #             cal_q(
        #                 J = J, J_star = J_star, 

        #                 J_star_scaling_factor \
        #                     = self.expansion_importance_sampling_J_star_scaling_factor, 

        #                 eps = self.expansion_importance_sampling_eps
        #             ) \
        #                 for (J, J_star) in zip(J_list, J_star_list)
        #         ]

        #         if num_split_list is not None:
        #             # clean up
        #             del num_split_list
        #             gc.collect()
        #             if self.ver == "torch":
        #                 torch.cuda.empty_cache()
                
        #         num_split_list = [
        #             min(
        #                 get_num_split(q = q), 
        #                 self.per_iteration_expansion_lim
        #             ) \
        #                 for q in q_list
        #         ]

        #         max_num_split = max(num_split_list)
        #         if self.expansion_importance_sampling_verbose \
        #             and (max_num_split > 1):

        #             logger(f"[Importance Sampling]")
        #             logger(f"    q_list: {q_list}")
        #             logger(f"    num_split_list: {num_split_list}")

        #         # dbg
        #         # min_num_split = min(num_split_list)
        #         # if min_num_split == 0:
        #         #     logger(f"q_list: {q_list}")
        #         #     logger(f"num_split_list: {num_split_list}")
        #         #     breakpoint()

        #         # clean up
        #         del J_star_list
        #         gc.collect()
        #         if self.ver == "torch":
        #             torch.cuda.empty_cache()

        #     # clean up
        #     del J_list
        #     gc.collect()
        #     if self.ver == "torch":
        #         torch.cuda.empty_cache()
        
        # # nothing to expand
        # tot_num_split = sum(num_split_list)
        # if tot_num_split == 0:
        #     return None
        
        # state_list = self._duplicate_state_list(
        #     state_list = state_list, 
        #     num_duplicate_list = num_split_list
        # )
        
        # node_list = self._duplicate_node_list(
        #     node_list = node_list, 
        #     num_duplicate_list = num_split_list
        # )

        # # duplicate sample_idx_list

        # num_state = state_list.shape[0]

        # ---------= [Expansion] =---------
        action_list = []
        
        node_i = 0
        
        while node_i < num_node:
            node_j = node_i

            while (node_j + 1 < num_node) \
                and (node_list[node_j + 1].node_idx == node_list[node_i].node_idx):
                node_j += 1
            
            node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]

            node_action_list = self._batch_sample_expansion_action(
                node = node_list[node_i], 
                sample_idx_list = node_sample_idx_list
            )

            action_list += node_action_list

            node_i = node_j + 1

            # goto `while i < num_node`
            pass

        # action_list.shape = (num_state, *action_shape)
        action_list = torch.stack(action_list)
        
        # next_state_list.shape = (num_state, *state_shape)
        next_state_list = self.mdp.batch_cal_dynamics(
            sample_idx_list = sample_idx_list, 

            state_list = state_list, 
            action_list = action_list, 

            timestep_idx_list = timestep_idx_list
        )
        
        # ---------= [Fill In] =---------
        expanded_node_list = []
            
        for i, (node, sample_idx) in enumerate(
            zip(
                node_list, 
                sample_idx_list
            )
        ):
            state = next_state_list[i]
            action = action_list[i]

            expanded_node = self._get_node_empty_child(
                node = node, 
                sample_idx = sample_idx
            )

            if expanded_node is not None:
                expanded_node.info_list[sample_idx] = Info(
                    state = state, 
                    prev_action = action
                )
            else:
                expanded_node = self._get_new_node(
                    sample_idx_list = [sample_idx], 

                    state_list = [state], 
                    prev_action_list = [action], 

                    parent = node
                )

                node.children_list.append(expanded_node)

            expanded_node_list.append(expanded_node)

            # goto `for sample_idx_idx, (node, sample_idx)`
            pass

        # ---------= [Cal Node Reward] =---------
        num_expanded_node = len(expanded_node_list)

        ord_list = sorted(
            list(
                range(num_expanded_node)
            ), 
            key = lambda ord_idx: expanded_node_list[ord_idx].node_idx
        )

        _node_list = []
        _expanded_node_list = []
        _sample_idx_list = []

        for ord_idx in ord_list:
            _node_list.append(node_list[ord_idx])
            _expanded_node_list.append(expanded_node_list[ord_idx])
            _sample_idx_list.append(sample_idx_list[ord_idx])

            # goto `for ord_idx`
            pass

        node_list = _node_list
        expanded_node_list = _expanded_node_list
        sample_idx_list = _sample_idx_list

        # dbg
        # print(f"    [expanded] ord_list: {ord_list}, sample_idx_list: {sample_idx_list}")

        node_i = 0
        
        while node_i < num_expanded_node:
            node_j = node_i

            while (node_j + 1 < num_expanded_node) \
                and (expanded_node_list[node_j + 1].node_idx == expanded_node_list[node_i].node_idx):
                node_j += 1
            
            node = node_list[node_i]
            expanded_node = expanded_node_list[node_i]
            expanded_node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]

            expanded_node_timestep_idx_list = torch.tensor(
                [expanded_node.depth] * (node_j - node_i + 1), 

                dtype = torch.int32, 
                device = "cpu"
            )

            node_i = node_j + 1
            
            self._cal_node_reward(
                node = expanded_node, 

                parent = node, 

                sample_idx_list = expanded_node_sample_idx_list,
                
                timestep_idx_list = expanded_node_timestep_idx_list
            )
            
            # ---------= [Cal Final Reward] =---------
            timestep_idx = expanded_node.depth

            if timestep_idx == self.mdp.time_horizon:
                final_reward_list = []

                for sample_idx in expanded_node_sample_idx_list:
                    final_reward = expanded_node.info_list[sample_idx].final_reward

                    if isinstance(final_reward, (np.ndarray, torch.Tensor)):
                        final_reward = final_reward.item()

                    final_reward_list.append(final_reward)

                    # goto `for (node, sample_idx)`
                    pass

                # final_reward_list = self._cal_final_reward(
                #     node = expanded_node, 
                    
                #     sample_idx_list = expanded_node_sample_idx_list
                # )

                # ---------= [Update Best Trajectory] =---------
                for (final_reward, sample_idx) in zip(
                    final_reward_list, 
                    expanded_node_sample_idx_list
                ):
                    merged_reward_to_root = expanded_node.info_list[sample_idx].merged_reward_to_root

                    if self._mdp_modeling == "max_reward":
                        merged_reward = max(merged_reward_to_root, final_reward)
                    elif self._mdp_modeling in [
                        "sparse_reward", 
                        "cumulative_reward"
                    ]:
                        merged_reward = merged_reward_to_root + final_reward
                    
                    # merged_reward = merged_reward.reshape(self.reward_shape)
                    
                    best_trajectory = self.best_trajectory_list[sample_idx]
                    
                    if self._get_is_cost_legal(sample_idx = sample_idx) \
                        and (merged_reward > best_trajectory.merged_reward):
                        
                        self.history_best_merged_reward_list_list[sample_idx].append(merged_reward)
                        self.history_last_final_reward_list_list[sample_idx].append(final_reward)

                        best_trajectory.final_reward = final_reward
                        best_trajectory.merged_reward = merged_reward
                        best_trajectory.incomplete = False
                        
                        (
                            state_list, 
                            action_list, 
                            reward_list
                        ) = best_trajectory.get_trajectory_to_root(
                            node = expanded_node, 
                            sample_idx = sample_idx, 

                            include_final_reward = True
                        )

                        best_trajectory.state_list = state_list
                        best_trajectory.action_list = action_list
                        best_trajectory.reward_list = reward_list

                        # ---------= [Clean Up] =---------
                        trash_best_trajectory = self.best_trajectory_list[sample_idx]
                        self.best_trajectory_list[sample_idx] = None
                        del trash_best_trajectory
                        gc.collect()
                        if self.ver == "torch":
                            torch.cuda.empty_cache()

                        # ---------= [Update] =---------
                        self.best_trajectory_list[sample_idx] = best_trajectory

                        self._best_trajectory_updated_callback(
                            caller = "_expand()", 
                            pseudo = False, 

                            node = expanded_node, 

                            sample_idx = sample_idx, 

                            timestep_idx = timestep_idx
                        )

                    # goto `for (final_reward, sample_idx)`
                    pass

                # ---------= [Clean Up] =---------
                del final_reward_list
                gc.collect()
                if self.ver == "torch":
                    torch.cuda.empty_cache()

            # ---------= [Clean Up] =---------
            del expanded_node_sample_idx_list
            del expanded_node_timestep_idx_list
            gc.collect()
            if self.ver == "torch":
                torch.cuda.empty_cache()

            # goto `while node_i < num_expanded_node`
            pass

        # ---------= [Clean Up] =---------
        del state_list, action_list, next_state_list
        del ord_list
        gc.collect()
        if self.ver == "torch":
            torch.cuda.empty_cache()

        # `_expand()` done
        return (
            expanded_node_list, 
            sample_idx_list
        )
    

    def _simulate(
        self, 

        node_list: MCTSNode, 
        sample_idx_list: List[int]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Simulate the trajectory from the expanded node to terminal node. 
        
        Ret: 
            `merged_reward_to_leaf_list` (`torch.Tensor` or `np.ndarray`): 
                The list of the merged rewards from `node` to a terminal node. 
        """

        time_horizon = self.mdp.time_horizon

        num_node = len(node_list)

        if not isinstance(sample_idx_list, (np.ndarray, torch.Tensor)):
            if self.ver == "torch":
                sample_idx_list = torch.tensor(
                    sample_idx_list, 
                    
                    dtype = torch.int32, 
                    device = "cpu"
                )
            elif self.ver == "numpy":
                sample_idx_list = np.array(
                    sample_idx_list, 

                    dtype = np.int32
                )

        # ---------= [Prepare Current Timestep List] =---------
        cur_timestep_idx_list = [
            node.depth \
                for node in node_list
        ]

        # ---------= [Prepare Everything] =---------
        # merged_reward_to_leaf_list.shape = (num_node, )
        merged_reward_to_leaf_list = [None] * num_node
        
        # _need_cal_sample_idx_list.shape = (num_need_cal, )
        _need_cal_sample_idx_list = []
        # need_cal_idx_to_abs_idx_list.shape = (num_need_cal, )
        need_cal_idx_to_abs_idx_list = []

        for i, sample_idx in enumerate(sample_idx_list):
            if cur_timestep_idx_list[i] < time_horizon:
                _need_cal_sample_idx_list.append(sample_idx)
                need_cal_idx_to_abs_idx_list.append(i)
            else:
                merged_reward_to_leaf_list[i] \
                    = torch.tensor(
                        node_list[i].info_list[sample_idx].merged_reward_to_root, 
                        
                        dtype = self.dtype, 
                        device = self.device
                    )

            # goto `for sample_idx`
            pass

        if self.ver == "torch":
            _need_cal_sample_idx_list = torch.tensor(
                _need_cal_sample_idx_list, 
                
                dtype = torch.int32, 
                device = self.device
            )
        elif self.ver == "numpy":
            _need_cal_sample_idx_list = np.ndarray(
                _need_cal_sample_idx_list, 

                dtype = np.int32
            )

        # dbg
        # print(f"_need_cal_sample_idx_list: {_need_cal_sample_idx_list}, need_cal_idx_to_abs_idx_list: {need_cal_idx_to_abs_idx_list}")

        if len(_need_cal_sample_idx_list) <= 0:
            return merged_reward_to_leaf_list

        num_need_cal = len(need_cal_idx_to_abs_idx_list)

        # progressively ‌extinguish‌
        need_cal_idx_list = list(
            range(num_need_cal)
        )
        
        # last_state_list.shape = (num_need_cal, *self.state_shape)
        last_state_list = [
            node_list[need_cal_idx_to_abs_idx].get_state(
                sample_idx_list = need_cal_sample_idx
            )[0] \
                for need_cal_idx_to_abs_idx, need_cal_sample_idx in zip(
                    need_cal_idx_to_abs_idx_list, 
                    _need_cal_sample_idx_list
                )
        ]

        assert (None not in last_state_list)

        if self.ver == "torch":
            last_state_list = torch.stack(
                last_state_list, 
                dim = 0
            )
        elif self.ver == "numpy":
            last_state_list = np.stack(
                last_state_list, 
                axis = 0
            )

        # cur_timestep_idx_list.shape = (num_need_cal, )
        cur_timestep_idx_list = [
            cur_timestep_idx_list[need_cal_idx_to_abs_idx] \
                for need_cal_idx_to_abs_idx in need_cal_idx_to_abs_idx_list
        ]

        if self.ver == "torch":
            from util.torch_util import tsfm_to_1d_array

            cur_timestep_idx_list = tsfm_to_1d_array(
                array = cur_timestep_idx_list, 
                target_length = num_need_cal, 

                dtype = torch.int32, 
                device = "cpu"
            )
        elif self.ver == "numpy":
            from util.numpy_util import tsfm_to_1d_array

            cur_timestep_idx_list = tsfm_to_1d_array(
                array = cur_timestep_idx_list, 
                target_length = num_need_cal, 

                dtype = np.int32
            )

        action_list_to_leaf_list_list = [
            [] \
                for _ in range(num_need_cal)
        ]
        reward_list_to_leaf_list_list = [
            [] \
                for _ in range(num_need_cal)
        ]
        state_list_to_leaf_list_list = [
            [] \
                for _ in range(num_need_cal)
        ]

        last_min_timestep_idx = 0

        # ---------= [Do Simulation] =---------
        # dbg
        # print("[Simulate]")

        with tqdm(
            total = self.mdp.time_horizon, 
            desc = "[Simulation]"
        ) as bar:
            while len(need_cal_idx_list) > 0:
                # need_cal_sample_idx_list.shape = (num_need_cal, )
                need_cal_sample_idx_list = _need_cal_sample_idx_list[need_cal_idx_list]

                # need_cal_last_state_list.shape = (num_need_cal, *self.state_shape)
                need_cal_last_state_list = last_state_list[need_cal_idx_list]

                # need_cal_timestep_idx_list.shape = (num_need_cal, )
                need_cal_timestep_idx_list = cur_timestep_idx_list[need_cal_idx_list]

                # need_cal_action_list.shape = (num_need_cal, *action_shape)
                need_cal_action_list = self._batch_sample_simulation_action(
                    timestep_idx_list = need_cal_timestep_idx_list
                )

                # need_cal_cur_state_list.shape = (num_need_cal, *state_shape)
                need_cal_cur_state_list = self.mdp.batch_cal_dynamics(
                    sample_idx_list = need_cal_sample_idx_list, 

                    state_list = need_cal_last_state_list, 
                    action_list = need_cal_action_list, 

                    timestep_idx_list = need_cal_timestep_idx_list
                )

                # need_cal_next_timestep_idx_list.shape = (num_need_cal, )
                need_cal_next_timestep_idx_list = need_cal_timestep_idx_list + 1

                # non_terminal_idx_list.shape = (num_non_terminal, )
                non_terminal_idx_list = []
                # non_terminal_sample_idx_list.shape = (num_non_terminal, )
                non_terminal_sample_idx_list = []
                # terminal_idx_list.shape = (num_terminal, )
                terminal_idx_list = []
                # terminal_sample_idx_list.shape = (num_terminal, )
                terminal_sample_idx_list = []
                
                for i, (sample_idx, next_timestep_idx) in enumerate(
                    zip(
                        need_cal_sample_idx_list, 
                        need_cal_next_timestep_idx_list
                    )
                ):
                    if next_timestep_idx < time_horizon:
                        non_terminal_idx_list.append(i)
                        non_terminal_sample_idx_list.append(sample_idx)
                    else:
                        terminal_idx_list.append(i)
                        terminal_sample_idx_list.append(sample_idx)
                    
                    # goto `for i, timestep_idx`
                    pass

                # if self.ver == "torch":
                #     non_terminal_idx_list = torch.tensor(
                #         non_terminal_idx_list, 

                #         dtype = torch.int32, 
                #         device = "cpu"
                #     )
                #     terminal_idx_list = torch.tensor(
                #         terminal_idx_list, 

                #         dtype = torch.int32, 
                #         device = "cpu"
                #     )
                # elif self.ver == "numpy":
                #     non_terminal_idx_list = np.asarray(
                #         non_terminal_idx_list, 

                #         dtype = np.int32
                #     )
                #     terminal_idx_list = np.asarray(
                #         terminal_idx_list, 

                #         dtype = np.int32
                #     )

                # dbg
                # print(f"    [non-terminal] non_terminal_idx_list: {non_terminal_idx_list}, non_terminal_sample_idx_list: {non_terminal_sample_idx_list}")
                # print(f"    [terminal] terminal_idx_list: {terminal_idx_list}, terminal_sample_idx_list: {terminal_sample_idx_list}")

                # ---------= [Cal Intermediate Reward] =---------
                if len(non_terminal_idx_list) > 0:
                    non_terminal_cur_state_list = need_cal_cur_state_list[non_terminal_idx_list]
                    non_terminal_action_list = need_cal_action_list[non_terminal_idx_list]
                    non_terminal_timestep_idx_list = need_cal_next_timestep_idx_list[non_terminal_idx_list]
                    non_terminal_last_state_list = need_cal_last_state_list[non_terminal_idx_list]
                    
                    (
                        # non_terminal_intermediate_reward_list.shape = (num_non_terminal, )
                        non_terminal_intermediate_reward_list, 
                        # non_terminal_pseudo_final_latent_list.shape = (num_non_terminal, *self.state_shape)
                        non_terminal_pseudo_final_latent_list
                    ) = self.mdp.batch_cal_intermediate_reward(
                        state_list = non_terminal_cur_state_list, 
                        prev_action_list = non_terminal_action_list, 

                        timestep_idx_list = non_terminal_timestep_idx_list, 

                        prev_latent_list = non_terminal_last_state_list, 

                        sample_idx_list = non_terminal_sample_idx_list
                    )

                    for i, (
                        non_terminal_idx, 
                        # non_terminal_sample_idx, 
                        state, action, reward
                    ) in enumerate(
                        zip(
                            non_terminal_idx_list, 
                            # non_terminal_sample_idx_list, 

                            non_terminal_cur_state_list, 
                            non_terminal_action_list, 
                            non_terminal_intermediate_reward_list
                        )
                    ):
                        need_cal_idx = need_cal_idx_list[non_terminal_idx]

                        state_list_to_leaf_list_list[need_cal_idx].append(state)
                        action_list_to_leaf_list_list[need_cal_idx].append(action)
                        reward_list_to_leaf_list_list[need_cal_idx].append(reward)

                        # goto `for i, ()`
                        pass

                    # ---------= [Clean Up] =---------
                    del non_terminal_idx_list, non_terminal_sample_idx_list
                    del non_terminal_cur_state_list
                    del non_terminal_action_list
                    del non_terminal_timestep_idx_list
                    del non_terminal_last_state_list
                    del non_terminal_intermediate_reward_list
                    del non_terminal_pseudo_final_latent_list
                    gc.collect()
                    if self.ver == "torch":
                        torch.cuda.empty_cache()

                # ---------= [Cal Final Reward] =---------
                if len(terminal_idx_list) > 0:
                    terminal_state_list = need_cal_cur_state_list[terminal_idx_list]
                    terminal_action_list = need_cal_action_list[terminal_idx_list]
                    
                    # terminal_final_reward_list.shape = (num_terminal_node, *reward_shape)
                    terminal_final_reward_list = self.mdp.batch_cal_final_reward(
                        sample_idx_list = terminal_sample_idx_list, 

                        state_list = terminal_state_list
                    )

                    for i, (
                        terminal_idx, 
                        state, action, reward
                    ) in enumerate(
                        zip(
                            terminal_idx_list, 

                            terminal_state_list, 
                            terminal_action_list, 
                            terminal_final_reward_list
                        )
                    ):
                        need_cal_idx = need_cal_idx_list[terminal_idx]

                        state_list_to_leaf_list_list[need_cal_idx].append(state)
                        action_list_to_leaf_list_list[need_cal_idx].append(action)
                        reward_list_to_leaf_list_list[need_cal_idx].append(reward)

                        # goto `for i, ()`
                        pass
                    
                    # ---------= [Update Best Trajectory] =---------
                    for i, (terminal_idx, sample_idx, final_reward) in enumerate(
                        zip(
                            terminal_idx_list, 
                            terminal_sample_idx_list, 
                            terminal_final_reward_list
                        )
                    ):
                        need_cal_idx = need_cal_idx_list[terminal_idx]
                        abs_idx = need_cal_idx_to_abs_idx_list[need_cal_idx]
                        node = node_list[abs_idx]

                        if self.ver == "torch":
                            cated_reward_list_to_leaf = torch.cat(reward_list_to_leaf_list_list[need_cal_idx])

                            if self._mdp_modeling == "max_reward":
                                merged_reward_to_leaf = torch.max(cated_reward_list_to_leaf)
                            elif self._mdp_modeling in [
                                "sparse_reward", 
                                "cumulative_reward"
                            ]:
                                merged_reward_to_leaf = torch.sum(cated_reward_list_to_leaf)
                        elif self.ver == "numpy":
                            if self._mdp_modeling == "max_reward":
                                merged_reward_to_leaf = np.max(reward_list_to_leaf_list_list[need_cal_idx])
                            elif self._mdp_modeling in [
                                "sparse_reward", 
                                "cumulative_reward"
                            ]:
                                merged_reward_to_leaf = np.sum(reward_list_to_leaf_list_list[need_cal_idx])

                        # merged_reward_to_leaf = merged_reward_to_leaf.reshape(self.reward_shape)

                        merged_reward_to_leaf_list[abs_idx] = merged_reward_to_leaf

                        merged_reward_to_root = node.info_list[sample_idx].merged_reward_to_root

                        if self._mdp_modeling == "max_reward":
                            merged_reward = max(merged_reward_to_root, merged_reward_to_leaf)
                        elif self._mdp_modeling in [
                            "sparse_reward", 
                            "cumulative_reward"
                        ]:
                            merged_reward = merged_reward_to_root + merged_reward_to_leaf
                        
                        if isinstance(merged_reward, torch.Tensor):
                            merged_reward = merged_reward.cpu()

                        best_trajectory = self.best_trajectory_list[sample_idx]
                        
                        if self._get_is_cost_legal(sample_idx = sample_idx) \
                            and (merged_reward > best_trajectory.merged_reward):

                            self.history_best_merged_reward_list_list[sample_idx].append(merged_reward)
                            self.history_last_final_reward_list_list[sample_idx].append(final_reward)

                            best_trajectory.merged_reward = merged_reward
                            best_trajectory.final_reward = final_reward
                            best_trajectory.incomplete = False
                            
                            (
                                state_list, 
                                action_list, 
                                reward_list
                            ) = best_trajectory.get_trajectory_to_root(
                                node = node, 
                                
                                sample_idx = sample_idx, 

                                include_final_reward = False
                            )
                            
                            state_list_to_leaf = state_list_to_leaf_list_list[need_cal_idx]
                            action_list_to_leaf = action_list_to_leaf_list_list[need_cal_idx]
                            reward_list_to_leaf = reward_list_to_leaf_list_list[need_cal_idx]
                            
                            state_list += state_list_to_leaf
                            action_list += action_list_to_leaf
                            reward_list += reward_list_to_leaf
                            
                            best_trajectory.state_list = state_list
                            best_trajectory.action_list = action_list
                            best_trajectory.reward_list = reward_list

                            # ---------= [Clean Up] =---------
                            trash_best_trajectory = self.best_trajectory_list[sample_idx]
                            self.best_trajectory_list[sample_idx] = None
                            del trash_best_trajectory
                            gc.collect()
                            if self.ver == "torch":
                                torch.cuda.empty_cache()

                            # ---------= [Update] =---------
                            self.best_trajectory_list[sample_idx] = best_trajectory

                            timestep_idx = node.depth

                            self._best_trajectory_updated_callback(
                                caller = "_simulate()", 
                                pseudo = False, 

                                node = node, 

                                sample_idx = sample_idx, 

                                timestep_idx = timestep_idx
                            )

                        # goto `for i, ()`
                        pass
                
                    # ---------= [Clean Up] =---------
                    del terminal_idx_list, terminal_sample_idx_list
                    del terminal_state_list
                    del terminal_action_list
                    del terminal_final_reward_list
                    gc.collect()
                    if self.ver == "torch":
                        torch.cuda.empty_cache()

                # ---------= [Update Timestep] =---------
                min_timestep_idx = time_horizon

                finished_need_cal_idx_list = []

                for need_cal_idx in need_cal_idx_list:
                    cur_timestep_idx_list[need_cal_idx] += 1

                    cur_timestep_idx = cur_timestep_idx_list[need_cal_idx]
                    if not isinstance(cur_timestep_idx, int):
                        cur_timestep_idx = cur_timestep_idx.item()

                    min_timestep_idx = min(min_timestep_idx, cur_timestep_idx)

                    if cur_timestep_idx >= time_horizon:
                        finished_need_cal_idx_list.append(need_cal_idx)
                    
                    # goto `for need_cal_idx`
                    pass

                # ---------= [Update Tqdm] =---------
                del_min_timestep_idx \
                    = min_timestep_idx - last_min_timestep_idx
                bar.update(del_min_timestep_idx)

                # ---------= [Update Last] =---------
                last_min_timestep_idx = min_timestep_idx

                for i, need_cal_idx in enumerate(need_cal_idx_list):
                    last_state_list[need_cal_idx] = need_cal_cur_state_list[i]

                    # goto `for need_cal_idx`
                    pass

                for need_cal_idx in finished_need_cal_idx_list:
                    need_cal_idx_list.remove(need_cal_idx)

                    # goto `for need_cal_idx`
                    pass

                # ---------= [Clean Up] =---------
                del need_cal_last_state_list, need_cal_timestep_idx_list, need_cal_action_list
                del finished_need_cal_idx_list
                gc.collect()
                if self.ver == "torch":
                    torch.cuda.empty_cache()
                    
                # goto `while len(need_cal_idx_list)`
                pass

            # gpto `with tqdm`
            pass

        if self.ver == "torch":
            merged_reward_to_leaf_list = torch.stack(
                merged_reward_to_leaf_list, 
                dim = 0
            )
        elif self.ver == "numpy":
            merged_reward_to_leaf_list = np.stack(
                merged_reward_to_leaf_list, 
                axis = 0
            )

        merged_reward_to_leaf_list = merged_reward_to_leaf_list.reshape(
            (num_node, *self.reward_shape)
        )

        # ---------= [Clean Up] =---------
        del cur_timestep_idx_list
        del last_state_list
        del action_list_to_leaf_list_list, reward_list_to_leaf_list_list, state_list_to_leaf_list_list
        gc.collect()
        if self.ver == "torch":
            torch.cuda.empty_cache()

        # `_simulate()` done
        return merged_reward_to_leaf_list


    # (discarded)
    # def _simulate_(
    #     self, 

    #     node_list: MCTSNode, 

    #     # sample_idx_list: List[int]
    # ) -> Union[torch.Tensor, np.ndarray]:
    #     """
    #     Func:
    #         Simulate the trajectory from the expanded node to terminal node. 
        
    #     Ret: 
    #         `merged_reward_to_leaf_list` (`torch.Tensor` or `np.ndarray`): 
    #             The list of the merged rewards from `node` to a terminal node. 
    #         # `batch_reward_sum_to_leaf` (`torch.Tensor` or `np.ndarray`): 
    #         #       The list of accumulated rewards from `node` to a terminal node. 
    #         # `final_reward_list` (`torch.Tensor` or `np.ndarray`): 
    #         #     The list of the final rewards from `node` to a terminal node. 
    #     """

    #     time_horizon = self.mdp.time_horizon

    #     num_node = len(node_list)

    #     if not isinstance(sample_idx_list, torch.Tensor):
    #         sample_idx_list = torch.tensor(
    #             sample_idx_list, 

    #             dtype = torch.int32, 
    #             device = "cpu"
    #         )

    #     # ---------= [Prepare Current Timestep List] =---------
    #     cur_timestep_idx_list = [
    #         node.depth \
    #             for node in node_list
    #     ]
        
    #     # flag = False
    #     # if cur_timestep_idx_list != [cur_timestep_idx_list[0]] * num_node:
    #     #     flag = True

    #     # if flag:
    #     #     breakpoint()

    #     if self.ver == "torch":
    #         from util.torch_util import tsfm_to_1d_array

    #         cur_timestep_idx_list = tsfm_to_1d_array(
    #             array = cur_timestep_idx_list, 
    #             target_length = num_node, 

    #             dtype = torch.int32, 
    #             device = "cpu"
    #         )
    #     elif self.ver == "numpy":
    #         from util.numpy_util import tsfm_to_1d_array

    #         cur_timestep_idx_list = tsfm_to_1d_array(
    #             array = cur_timestep_idx_list, 
    #             target_length = num_node, 

    #             dtype = np.int32
    #         )

    #     # ---------= [Determine Finished] =---------
    #     simulation_finished_list = [False] * num_node
    #     num_simulation_finished = 0

    #     # ---------= [Prepare `merged_reward_to_leaf_list`] =---------
    #     # if self.ver == "torch":
    #     #     init_reward = torch.tensor(
    #     #         0.0, 

    #     #         dtype = self.dtype, 
    #     #         # device = self.device
    #     #         device = "cpu"
    #     #     )
    #     # elif self.ver == "numpy":
    #     #     init_reward = np.array(
    #     #         0.0, 

    #     #         dtype = self.dtype
    #     #     )

    #     # merged_reward_to_leaf_list.shape = (num_node, )
    #     # merged_reward_to_leaf_list = [init_reward] * num_node
    #     merged_reward_to_leaf_list = [None] * num_node

    #     # # ---------= [Prepare `final_reward_list`] =---------
    #     # # final_reward_list.shape = (num_node, )
    #     # final_reward_list = [None] * num_node

    #     # for i, (node, sample_idx) in enumerate(
    #     #     zip(
    #     #         node_list, 
    #     #         sample_idx_list
    #     #     )
    #     # ):
    #     #     if node.depth == self.mdp.time_horizon:
    #     #         if self.ver == "torch":
    #     #             final_reward = torch.tensor(
    #     #                 node.info_list[sample_idx].final_reward, 

    #     #                 dtype = self.dtype, 
    #     #                 device = "cpu"
    #     #             )
    #     #         elif self.ver == "numpy":
    #     #             final_reward = np.array(
    #     #                 node.info_list[sample_idx].final_reward, 

    #     #                 dtype = self.dtype
    #     #             )

    #     #         final_reward = final_reward.reshape(self.reward_shape) \
    #     #             .to(self.device)

    #     #         merged_reward_to_leaf_list[i] = final_reward

    #     #         # final_reward_list[i] = final_reward

    #     #         simulation_finished_list[i] = True
    #     #         num_simulation_finished += 1

    #     #     # goto `for node`
    #     #     pass

    #     # ---------= [Prepare Params] =---------
    #     # last_state_list.shape = (num_need_cal, *state_shape)
    #     # last_state_list = [
    #     #     node_list[need_cal_idx].get_state(
    #     #         sample_idx_list = sample_idx_list[need_cal_idx]
    #     #     )[0] \
    #     #         for need_cal_idx in need_cal_idx_list
    #     # ]

    #     last_state_list = [
    #         node.get_state(
    #             sample_idx_list = [sample_idx]
    #         )[0] \
    #             for (node, sample_idx) in zip(
    #                 node_list, 
    #                 sample_idx_list
    #             )
    #     ]

    #     if self.ver == "torch":
    #         last_state_list = torch.stack(
    #             last_state_list, 
    #             dim = 0
    #         )
    #     elif self.ver == "numpy":
    #         last_state_list = np.stack(
    #             last_state_list, 
    #             axis = 0
    #         )

    #     action_list_to_leaf_list = [
    #         [] \
    #             for _ in range(num_node)
    #     ]
    #     reward_list_to_leaf_list = [
    #         [] \
    #             for _ in range(num_node)
    #     ]
    #     state_list_to_leaf_list = [
    #         [] \
    #             for _ in range(num_node)
    #     ]

    #     last_min_timestep_idx = 0

    #     # prev_potential_list = None
    #     # if self.mdp.reward_model.reward_shaping_policy == "potential_based":
    #     #     prev_potential_list = [
    #     #         node.info_list[sample_idx].potential \
    #     #             for (node, sample_idx) in zip(
    #     #                 node_list, 
    #     #                 sample_idx_list
    #     #             )
    #     #     ]

    #     #     if self.ver == "numpy":
    #     #         prev_potential_list = np.array(prev_potential_list)
    #     #     elif self.ver == "torch":
    #     #         prev_potential_list = torch.tensor(prev_potential_list)
                
    #     #     prev_potential_list = prev_potential_list.reshape(
    #     #         (num_node, *self.reward_shape)
    #     #     )

    #     with tqdm(
    #         total = self.mdp.time_horizon, 
    #         desc = "[Simulation]"
    #     ) as bar:
    #         while num_simulation_finished < num_node:
    #             need_cal_idx_list = [
    #                 i \
    #                     for i, simulation_finished in enumerate(simulation_finished_list) \
    #                         if (not simulation_finished)
    #             ]

    #             # need_cal_node_list = [
    #             #     node_list[need_cal_idx] \
    #             #         for need_cal_idx in need_cal_idx_list
    #             # ]

    #             need_cal_sample_idx_list = sample_idx_list[need_cal_idx_list]

    #             # need_cal_sample_idx_list = torch.tensor(
    #             #     [
    #             #         sample_idx_list[need_cal_idx] \
    #             #             for need_cal_idx in need_cal_idx_list
    #             #     ], 

    #             #     dtype = torch.int32, 
    #             #     device = "cpu"
    #             # )

    #             need_cal_last_state_list = last_state_list[need_cal_idx_list]

    #             # need_cal_last_state_list = torch.stack(
    #             #     [
    #             #         last_state_list[need_cal_idx] \
    #             #             for need_cal_idx in need_cal_idx_list
    #             #     ]
    #             # )

    #             # need_cal_timestep_idx_list.shape = (num_need_cal, )
    #             need_cal_timestep_idx_list = cur_timestep_idx_list[need_cal_idx_list]

    #             # need_cal_timestep_idx_list = torch.tensor(
    #             #     [
    #             #         cur_timestep_idx_list[need_cal_idx] \
    #             #             for need_cal_idx in need_cal_idx_list
    #             #     ]
    #             # )

    #             # need_cal_action_list.shape = (num_need_cal, *action_shape)
    #             need_cal_action_list = self._batch_sample_simulation_action(
    #                 timestep_idx_list = need_cal_timestep_idx_list
    #             )
                
    #             # need_cal_cur_state_list.shape = (num_need_cal, *state_shape)
    #             need_cal_cur_state_list = self.mdp.batch_cal_dynamics(
    #                 sample_idx_list = need_cal_sample_idx_list, 

    #                 state_list = need_cal_last_state_list, 
    #                 action_list = need_cal_action_list, 

    #                 timestep_idx_list = need_cal_timestep_idx_list
    #             )

    #             need_cal_next_timestep_idx_list = need_cal_timestep_idx_list + 1

    #             # need_cal_prev_potential_list = None
    #             # if self.mdp.reward_model.reward_shaping_policy == "potential_based":
    #             #     need_cal_prev_potential_list = torch.stack(
    #             #         [
    #             #             prev_potential_list[need_cal_idx] \
    #             #                 for need_cal_idx in need_cal_idx_list
    #             #         ]
    #             #     )

    #             non_terminal_idx_list = []
    #             terminal_idx_list = []
                
    #             for i, (need_cal_idx, next_timestep_idx) in enumerate(
    #                 zip(
    #                     need_cal_idx_list, 
    #                     need_cal_next_timestep_idx_list
    #                 )
    #             ):
    #                 if next_timestep_idx < time_horizon:
    #                     # non_terminal_idx_list.append(need_cal_idx)
    #                     non_terminal_idx_list.append(i)
    #                 else:
    #                     # terminal_idx_list.append(need_cal_idx)
    #                     terminal_idx_list.append(i)
                    
    #                 # goto `for i, timestep_idx`
    #                 pass

    #             non_terminal_idx_list = torch.tensor(
    #                 non_terminal_idx_list, 

    #                 dtype = torch.int32, 
    #                 device = "cpu"
    #             )
    #             terminal_idx_list = torch.tensor(
    #                 terminal_idx_list, 

    #                 dtype = torch.int32, 
    #                 device = "cpu"
    #             )

    #             # ---------= [Cal Intermediate Reward] =---------
    #             if len(non_terminal_idx_list) > 0:
    #                 non_terminal_cur_state_list = need_cal_cur_state_list[non_terminal_idx_list]
    #                 non_terminal_action_list = need_cal_action_list[non_terminal_idx_list]
    #                 non_terminal_timestep_idx_list = need_cal_next_timestep_idx_list[non_terminal_idx_list]
    #                 non_terminal_last_state_list = need_cal_last_state_list[non_terminal_idx_list]
    #                 non_terminal_sample_idx_list = need_cal_sample_idx_list[non_terminal_idx_list]
                    
    #                 (
    #                     # non_terminal_intermediate_reward_list.shape = (num_non_terminal, )
    #                     non_terminal_intermediate_reward_list, 
    #                     # non_terminal_pseudo_final_latent_list.shape = (num_non_terminal, *self.state_shape)
    #                     non_terminal_pseudo_final_latent_list
    #                     # # non_terminal_potential_list.shape = (num_non_terminal, )
    #                     # non_terminal_potential_list
    #                 ) = self.mdp.batch_cal_intermediate_reward(
    #                     state_list = non_terminal_cur_state_list, 
    #                     prev_action_list = non_terminal_action_list, 

    #                     timestep_idx_list = non_terminal_timestep_idx_list, 

    #                     prev_latent_list = non_terminal_last_state_list, 

    #                     # prev_potential_list = need_cal_prev_potential_list, 

    #                     # node = node, 

    #                     sample_idx_list = non_terminal_sample_idx_list
    #                 )

    #                 for i, (
    #                     non_terminal_idx, 
    #                     # non_terminal_sample_idx, 
    #                     action, 
    #                     reward, 
    #                     state
    #                 ) in enumerate(
    #                     zip(
    #                         non_terminal_idx_list, 
    #                         # non_terminal_sample_idx_list, 
    #                         non_terminal_action_list, 
    #                         non_terminal_intermediate_reward_list, 
    #                         non_terminal_cur_state_list
    #                     )
    #                 ):
    #                     need_cal_idx = need_cal_idx_list[non_terminal_idx]
                        
    #                     action_list_to_leaf_list[need_cal_idx].append(action)
    #                     reward_list_to_leaf_list[need_cal_idx].append(reward)
    #                     state_list_to_leaf_list[need_cal_idx].append(state)

    #                     # non_terminal_sample_idx = need_cal_sample_idx_list[non_terminal_idx]
                        
    #                     # action_list_to_leaf_list[non_terminal_sample_idx].append(action)
    #                     # reward_list_to_leaf_list[non_terminal_sample_idx].append(reward)
    #                     # state_list_to_leaf_list[non_terminal_sample_idx].append(state)

    #                     # goto `for i, ()`
    #                     pass
                
    #                 # ---------= [Clean Up] =---------
    #                 del non_terminal_idx_list
    #                 del non_terminal_cur_state_list
    #                 del non_terminal_action_list
    #                 del non_terminal_timestep_idx_list
    #                 del non_terminal_last_state_list
    #                 del non_terminal_sample_idx_list
    #                 del non_terminal_intermediate_reward_list
    #                 del non_terminal_pseudo_final_latent_list
    #                 gc.collect()
    #                 if self.ver == "torch":
    #                     torch.cuda.empty_cache()
                
    #             # print(f"intermediate_reward_list: {intermediate_reward_list}")
    #             # print(f"potential_list: {potential_list}")

    #             # tmp_potential_list = [
    #             #     node.info_list[sample_idx].potential \
    #             #         for sample_idx in sample_idx_list
    #             # ]

    #             # print(f"tmp_potential_list: {tmp_potential_list}")
                
    #             # ---------= [Cal Final Reward] =---------
    #             if len(terminal_idx_list) > 0:
    #                 # terminal_state_list.shape = (num_terminal_node, *state_shape)
    #                 terminal_state_list = need_cal_cur_state_list[terminal_idx_list]

    #                 # terminal_action_list.shape = (num_terminal_node, *self.reward_shape)
    #                 terminal_action_list = need_cal_action_list[terminal_idx_list]
                    
    #                 # terminal_final_reward_list.shape = (num_terminal_node, *reward_shape)
    #                 terminal_final_reward_list = self.mdp.batch_cal_final_reward(
    #                     sample_idx_list = terminal_sample_idx_list, 

    #                     state_list = terminal_state_list
    #                 )

    #                 for i, (
    #                     terminal_idx, 
    #                     # terminal_sample_idx,  
    #                     action, 
    #                     reward, 
    #                     state
    #                 ) in enumerate(
    #                     zip(
    #                         terminal_idx_list, 
    #                         # terminal_sample_idx_list, 
    #                         terminal_action_list, 
    #                         terminal_final_reward_list, 
    #                         terminal_state_list
    #                     )
    #                 ):
    #                     need_cal_idx = need_cal_idx_list[terminal_idx]

    #                     action_list_to_leaf_list[need_cal_idx].append(action)
    #                     reward_list_to_leaf_list[need_cal_idx].append(reward)
    #                     state_list_to_leaf_list[need_cal_idx].append(state)

    #                     # terminal_sample_idx = need_cal_sample_idx_list[terminal_idx]

    #                     # action_list_to_leaf_list[terminal_sample_idx].append(action)
    #                     # reward_list_to_leaf_list[terminal_sample_idx].append(reward)
    #                     # state_list_to_leaf_list[terminal_sample_idx].append(state)

    #                     # goto `for i, ()`
    #                     pass
                    
    #                 for i, (terminal_idx, sample_idx, final_reward) in enumerate(
    #                     zip(
    #                         terminal_idx_list, 
    #                         terminal_sample_idx_list, 
    #                         terminal_final_reward_list
    #                     )
    #                 ):
    #                     # ---------= [Update Best Trajectory] =---------
    #                     # terminal_sample_idx = need_cal_sample_idx_list[terminal_idx]
    #                     # node = node_list[terminal_sample_idx]

    #                     # node = node_list[sample_idx]

    #                     need_cal_idx = need_cal_idx_list[terminal_idx]
    #                     node = node_list[need_cal_idx]
    #                     sample_idx = sample_idx_list[need_cal_idx]
                        
    #                     if self.ver == "torch":
    #                         cated_reward_list_to_leaf = torch.cat(reward_list_to_leaf_list[sample_idx])

    #                         if self._mdp_modeling == "max_reward":
    #                             merged_reward_to_leaf = torch.max(cated_reward_list_to_leaf)
    #                         elif self._mdp_modeling in [
    #                             "sparse_reward", 
    #                             "cumulative_reward"
    #                         ]:
    #                             merged_reward_to_leaf = torch.sum(cated_reward_list_to_leaf)
    #                     elif self.ver == "numpy":
    #                         if self._mdp_modeling == "max_reward":
    #                             merged_reward_to_leaf = np.max(cated_reward_list_to_leaf)
    #                         elif self._mdp_modeling in [
    #                             "sparse_reward", 
    #                             "cumulative_reward"
    #                         ]:
    #                             merged_reward_to_leaf = np.sum(reward_list_to_leaf_list[sample_idx])

    #                     merged_reward_to_leaf = merged_reward_to_leaf.reshape(self.reward_shape)

    #                     # merged_reward_to_leaf_list[sample_idx] = merged_reward_to_leaf
    #                     merged_reward_to_leaf_list[need_cal_idx] = merged_reward_to_leaf

    #                     # final_reward_list[terminal_idx] = final_reward

    #                     merged_reward_to_root = node.info_list[sample_idx].merged_reward_to_root
                        
    #                     # breakpoint()
    #                     # print(f"node.depth: {node.depth}, merged_reward_to_root: {merged_reward_to_root}")
    #                     # # print(f"node.depth: {node.depth}, cated_reward_list_to_leaf: {cated_reward_list_to_leaf}")
    #                     # reward_list_to_root = []
    #                     # tmp_node = node
    #                     # while tmp_node.parent:
    #                     #     reward_list_to_root.append(
    #                     #         tmp_node.info_list[sample_idx].intermediate_reward
    #                     #     )

    #                     #     tmp_node = tmp_node.parent
    #                     # reward_list_to_root.reverse()
    #                     # reward_list_to_root = torch.tensor(
    #                     #     reward_list_to_root, 
    #                     #     device = cated_reward_list_to_leaf.device
    #                     # )

    #                     # tmp_reward_list = torch.cat(
    #                     #     (reward_list_to_root, cated_reward_list_to_leaf)
    #                     # )
    #                     # print(f"    tmp_reward_list: {tmp_reward_list}")

    #                     # print(f"    node.depth: {node.depth}")

    #                     if self._mdp_modeling == "max_reward":
    #                         merged_reward = max(merged_reward_to_root, merged_reward_to_leaf)
    #                     elif self._mdp_modeling in [
    #                         "sparse_reward", 
    #                         "cumulative_reward"
    #                     ]:
    #                         merged_reward = merged_reward_to_root + merged_reward_to_leaf
                        
    #                     best_trajectory = self.best_trajectory_list[sample_idx]
                        
    #                     if self._get_is_cost_legal(sample_idx = sample_idx) \
    #                         and (merged_reward > best_trajectory.merged_reward):

    #                         self.history_best_merged_reward_list_list[sample_idx].append(merged_reward)
    #                         self.history_last_final_reward_list_list[sample_idx].append(final_reward)

    #                         best_trajectory.merged_reward = merged_reward
    #                         best_trajectory.final_reward = final_reward
    #                         best_trajectory.incomplete = False
                            
    #                         (
    #                             state_list, 
    #                             action_list, 
    #                             reward_list
    #                         ) = best_trajectory.get_trajectory_to_root(
    #                             node = node, 
                                
    #                             sample_idx = sample_idx, 

    #                             include_final_reward = False
    #                         )

    #                         state_list_to_leaf = state_list_to_leaf_list[terminal_idx]
    #                         action_list_to_leaf = action_list_to_leaf_list[terminal_idx]
    #                         reward_list_to_leaf = reward_list_to_leaf_list[terminal_idx]

    #                         state_list += state_list_to_leaf
    #                         action_list += action_list_to_leaf
    #                         reward_list += reward_list_to_leaf

    #                         # best_trajectory.concat_trajectory_by_list(
    #                         #     state_list = state_list_to_leaf, 
    #                         #     action_list = action_list_to_leaf, 
    #                         #     reward_list = reward_list_to_leaf
    #                         # )
                            
    #                         best_trajectory.state_list = state_list
    #                         best_trajectory.action_list = action_list
    #                         best_trajectory.reward_list = reward_list

    #                         # ---------= [Clean Up] =---------
    #                         trash_best_trajectory = self.best_trajectory_list[sample_idx]
    #                         self.best_trajectory_list[sample_idx] = None
    #                         del trash_best_trajectory
    #                         gc.collect()
    #                         if self.ver == "torch":
    #                             torch.cuda.empty_cache()

    #                         # ---------= [Update] =---------
    #                         self.best_trajectory_list[sample_idx] = best_trajectory

    #                         timestep_idx = node_list[terminal_idx].depth

    #                         self._best_trajectory_updated_callback(
    #                             caller = "_simulate()", 
    #                             pseudo = False, 

    #                             node = node, 

    #                             sample_idx = sample_idx, 

    #                             timestep_idx = timestep_idx
    #                         )
                
    #                 # ---------= [Clean Up] =---------
    #                 del terminal_idx_list
    #                 del terminal_state_list
    #                 del terminal_action_list
    #                 del terminal_sample_idx_list
    #                 del terminal_final_reward_list
    #                 gc.collect()
    #                 if self.ver == "torch":
    #                     torch.cuda.empty_cache()

    #             # ---------= [Update Timestep] =---------
    #             min_timestep_idx = time_horizon

    #             for need_cal_idx in need_cal_idx_list:
    #                 cur_timestep_idx_list[need_cal_idx] += 1

    #                 cur_timestep_idx = cur_timestep_idx_list[need_cal_idx]
    #                 if not isinstance(cur_timestep_idx, int):
    #                     cur_timestep_idx = cur_timestep_idx.item()

    #                 min_timestep_idx = min(min_timestep_idx, cur_timestep_idx)

    #                 if cur_timestep_idx >= time_horizon:
    #                     simulation_finished_list[need_cal_idx] = True
    #                     num_simulation_finished += 1
                    
    #                 # goto `for need_cal_idx`
    #                 pass
                    
    #             # ---------= [Update Tqdm] =---------
    #             del_min_timestep_idx \
    #                 = min_timestep_idx - last_min_timestep_idx
    #             bar.update(del_min_timestep_idx)

    #             # ---------= [Update Last] =---------
    #             last_min_timestep_idx = min_timestep_idx

    #             for i, need_cal_idx in enumerate(need_cal_idx_list):
    #                 last_state_list[need_cal_idx] = need_cal_cur_state_list[i]

    #                 # if self.mdp.reward_model.reward_shaping_policy == "potential_based":
    #                 #     prev_potential_list[need_cal_idx] = need_cal_potential_list[i]

    #                 # goto `for need_cal_idx`
    #                 pass

    #             # ---------= [Clean Up] =---------
    #             del need_cal_idx_list, need_cal_action_list, need_cal_timestep_idx_list, need_cal_last_state_list, need_cal_sample_idx_list
    #             # if self.mdp.reward_model.reward_shaping_policy == "potential_based":
    #             #     del need_cal_prev_potential_list
    #             # del need_cal_potential_list
    #             gc.collect()
    #             if self.ver == "torch":
    #                 torch.cuda.empty_cache()

    #             # goto `while num_simulation_finished < num_node`
    #             pass

    #     if self.ver == "torch":
    #         merged_reward_to_leaf_list = torch.stack(
    #             merged_reward_to_leaf_list, 
    #             dim = 0
    #         )

    #         # final_reward_list = torch.stack(
    #         #     final_reward_list, 
    #         #     dim = 0
    #         # )
    #     elif self.ver == "numpy":
    #         merged_reward_to_leaf_list = np.stack(
    #             merged_reward_to_leaf_list, 
    #             axis = 0
    #         )

    #         # final_reward_list = np.stack(
    #         #     final_reward_list, 
    #         #     axis = 0
    #         # )

    #     # clean up
    #     del cur_timestep_idx_list, simulation_finished_list
    #     del last_state_list
    #     del action_list_to_leaf_list, reward_list_to_leaf_list, state_list_to_leaf_list
    #     # if prev_potential_list is not None:
    #     #     del prev_potential_list
    #     gc.collect()
    #     if self.ver == "torch":
    #         torch.cuda.empty_cache()

    #     # `_simulate()` done
    #     return merged_reward_to_leaf_list
    #     # return final_reward_list


    def _backpropagate(
        self, 

        node: MCTSNode, 

        sample_idx: int, 

        reward: float
    ):
        """
        Func:
            Backpropagate the reward from the simulation to update the path.
        """
        
        while node:
            info = node.info_list[sample_idx]
            info.num_vis += 1

            last_value = info.value
            
            if self._value_policy == "max":
                info.value = max(last_value, reward)
            elif self._value_policy == "average":
                info.value = last_value + (reward - info.value) / info.num_vis

            # dbg
            # if last_value > 0.0:
            #     print(f"last_value: {last_value}, info.value: {info.value}")

            node = node.parent

            # goto `while node`
            pass

        # `_backpropagate()` done
        pass


    # def _run_discarded(
    #     self, 

    #     sample_idx: int, 

    #     display_state_value_list: bool = False, 

    #     display_reward_sum_to_leaf: Optional[bool] = False, 

    #     **arg_dict: Optional[Dict]
    # ):
    #     """
    #     Func:
    #         Run MCTS for a sample until stopped. 
    #     """
        
    #     while not self._get_is_time_to_stop():
    #         node_list = self._select(
    #             sample_idx = sample_idx, 

    #             display_state_value_list = display_state_value_list
    #         )

    #         expanded_node_list = self._expand(
    #             node_list = node_list, 

    #             sample_idx = sample_idx
    #         )

    #         if expanded_node_list is None:
    #             continue
            
    #         reward_sum_to_leaf_list = self._simulate(
    #             node_list = expanded_node_list, 

    #             sample_idx = sample_idx
    #         )

    #         if display_reward_sum_to_leaf:
    #             logger(f"reward_sum_to_leaf_list: {reward_sum_to_leaf_list}")

    #         for expanded_node, reward_sum_to_leaf in zip(
    #             expanded_node_list, 
    #             reward_sum_to_leaf_list
    #         ):
    #             self._backpropagate(
    #                 node = expanded_node, 
                    
    #                 sample_idx = sample_idx, 

    #                 reward = reward_sum_to_leaf
    #             )

    #             # goto `for expanded_node, reward_sum_to_leaf`
    #             pass

    #         # goto `while not self._get_is_time_to_stop()`
    #         pass

    #     # `_run()` done
    #     pass


    def _step(
        self, 

        sample_idx_list: List[int], 

        display_state_value_list: Optional[bool] = False, 

        display_reward_sum_to_leaf: Optional[bool] = False, 

        **arg_dict: Optional[Dict]
    ):
        """
        Func:
            Run MCTS for a sample until stopped. 
        """

        display_cal_state_value = arg_dict.pop("display_cal_state_value", False)
        display_selected_node_depth = arg_dict.pop("display_selected_node_depth", False)
        
        node_list = self._select(
            sample_idx_list = sample_idx_list, 

            display_state_value_list = display_state_value_list, 

            display_cal_state_value = display_cal_state_value
        )
        
        node_idx_list = [
            node.node_idx \
                for node in node_list
        ]

        timestep_idx_list = [
            node.depth \
                for node in node_list
        ]

        for i, sample_idx in enumerate(sample_idx_list):
            self.mcts_loop_selected_node_tuple_list_list[sample_idx].append(
                (
                    node_idx_list[i], 
                    timestep_idx_list[i]
                )
            )

            # goto `for sample_idx`
            pass

        if display_selected_node_depth:
            logger(
                f"[Selected] node_idx_list: {node_idx_list}, "
                f"timestep_idx_list: {timestep_idx_list}"
            )
        
        (
            expanded_node_list, 
            _sample_idx_list
        ) = self._expand(
            node_list = node_list, 

            sample_idx_list = sample_idx_list
        )

        del sample_idx_list
        sample_idx_list = _sample_idx_list

        # dbg
        # print(f"sample_idx_list: {sample_idx_list}")

        # if (expanded_node_list is None) \
        #     or (len(expanded_node_list) <= 0) \
        #     or (min(timestep_idx_list) >= self.mdp.time_horizon - 1):

        #     return

        # print(f"len(expanded_node_list): {len(expanded_node_list)}, sample_idx_list: {sample_idx_list}")

        # if min(timestep_idx_list) >= self.mdp.time_horizon - 1:
        #     merged_reward_to_leaf_list = [
        #         expanded_node.info_list[sample_idx].merged_reward_to_root \
        #             for (sample_idx, expanded_node) in zip(
        #                 sample_idx_list, 
        #                 expanded_node_list
        #             )
        #     ]
        # else:
        #     merged_reward_to_leaf_list = self._simulate(
        #     # final_reward_list = self._simulate(
        #         node_list = expanded_node_list, 

        #         # sample_idx_list = sample_idx_list
        #     )
        
        merged_reward_to_leaf_list = self._simulate(
            node_list = expanded_node_list, 
            sample_idx_list = sample_idx_list
        )
        
        if display_reward_sum_to_leaf:
            logger(f"merged_reward_to_leaf_list: {merged_reward_to_leaf_list}")
            # logger(f"final_reward_list: {final_reward_list}")
        
        for (expanded_node, merged_reward_to_leaf, sample_idx) in zip(
            expanded_node_list, 
            merged_reward_to_leaf_list, 
            sample_idx_list
        ):
            self._backpropagate(
                node = expanded_node, 
                
                sample_idx = sample_idx, 

                reward = merged_reward_to_leaf
            )

            # goto `for (expanded_node, merged_reward_to_leaf, sample_idx)`
            pass

        # for (expanded_node, final_reward, sample_idx) in zip(
        #     expanded_node_list, 
        #     final_reward_list, 
        #     sample_idx_list
        # ):
        #     self._backpropagate(
        #         node = expanded_node, 
                
        #         sample_idx = sample_idx, 

        #         reward = final_reward
        #     )

        #     # goto `for expanded_node, reward_sum_to_leaf`
        #     pass

        # `_step()` done
        pass


    def _run_sample_pre_process(
        self, 

        **arg_dict
    ):
        """
        Func:
            The function called before starting a run of MCTS for a single sample. 
        """

        self._time_st = time.time()

        # `_run_sample_pre_process()` done
        pass


    def _run_sample_post_process(
        self, 

        sample_idx: int, 

        **arg_dict
    ):
        """
        Func:
            The function called after finishing a run of MCTS for a single sample. 
        """

        # ---------= [Compute Wall-clock Time Cost] =---------
        time_ed = time.time()

        time_cost = time_ed - self._time_st
        self._wall_clock_time_cost_list[sample_idx] = time_cost

        logger(
            f"[Sample {sample_idx}] Finished, wall-clock time cost: {round(time_cost)} second(s). "
        )
        
        # `_run_sample_post_process()` done
        pass

    
    def _mcts_loop_callback(
        self, 

        local_var_dict: Dict, 

        **arg_dict
    ):
        """
        Func:
            The callback function after each MCTS loop. 
        """

        time_ed = time.time()

        time_cost = time_ed - self._time_st

        self._mcts_loop_wall_time_cost_list.append(time_cost)

        # `_mcts_loop_callback()` done
        pass


    def run(
        self, 

        display_result: bool = False, 

        display_trajectory: bool = False, 
        display_state: bool = True, 
        display_action: bool = True, 
        display_reward: bool = True, 

        **arg_dict: Optional[Dict]
    ):
        # ---------= [Pre-process] =---------
        self._run_sample_pre_process()

        num_sample = self.num_sample

        need_cal_sample_idx_list = list(
            range(num_sample)
        )

        while len(need_cal_sample_idx_list) > 0:
            for sample_idx in need_cal_sample_idx_list:
                self._mcts_loop_idx_list[sample_idx] += 1

            self._step(
                sample_idx_list = need_cal_sample_idx_list, 
                
                **arg_dict
            )
    
            trash_need_cal_sample_idx_list = []

            for need_cal_sample_idx in need_cal_sample_idx_list:
                if self._get_is_time_to_stop(sample_idx = need_cal_sample_idx):
                    trash_need_cal_sample_idx_list.append(need_cal_sample_idx)

                    self._run_sample_post_process(sample_idx = need_cal_sample_idx)

                    if display_result:
                        self.display_sample_result(
                            sample_idx = need_cal_sample_idx, 

                            display_trajectory = display_trajectory, 

                            display_state = display_state, 
                            display_action = display_action, 
                            display_reward = display_reward
                        )

                # goto `for need_cal_sample_idx`
                pass

            for trash_need_cal_sample_idx in trash_need_cal_sample_idx_list:
                need_cal_sample_idx_list.remove(trash_need_cal_sample_idx)

                # goto `for trash_need_cal_sample_idx`
                pass

            # ---------= [Clean Up] =---------
            del trash_need_cal_sample_idx_list
            gc.collect()
            
            local_var_dict = locals()
            self._mcts_loop_callback(local_var_dict)
            
            # goto `while`
            pass

        if display_result:
            self.display_result(
                display_trajectory = display_trajectory, 

                display_state = display_state, 
                display_action = display_action, 
                display_reward = display_reward
            )

        # `run()` done
        pass


    @abstractmethod
    def _get_is_time_to_stop(
        self, 

        **arg_dict
    ) -> bool:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Simulate the trajectory from the expanded node to terminal state. 

        Ret:
            `is_time_to_stop` (`bool`): Whether it is time to stop. 
        """

        # `_get_is_time_to_stop()` done
        pass


    def _get_is_cost_legal(
        self, 

        **arg_dict
    ) -> bool:
        """
        Func:
            Determine whether the cost of MCTS is legal. 

        Ret:
            `is_cost_legal` (`bool`): Whether the cost of MCTS is legal. 
        """

        # `_get_is_cost_legal()` done
        pass


    @abstractmethod
    def _sample_expansion_action(
        self, 
        
        node: MCTSNode
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Sample an action from the action space for `self._expand()`. 

        Ret:
            `action` (`torch.Tensor` or `np.ndarray`): The sampled action. 
        """

        # `_sample_expansion_action()` done
        pass


    @abstractmethod
    def _batch_sample_expansion_action(
        self, 
        
        node_list: List[MCTSNode]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Batch sample actions from the action space for `self._expand()`. 

        Ret:
            `action_list` (`torch.Tensor` or `np.ndarray`): The list of the sampled actions. 
        """

        # `_batch_sample_expansion_action()` done
        pass


    @abstractmethod
    def _sample_simulation_action(
        self, 
        
        timestep_idx: int
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Sample an action from the action space for `self._simulate()`. 

        Ret:
            `action` (`torch.Tensor` or `np.ndarray`): The sampled action. 
        """

        # `_sample_simulation_action()` done
        pass


    @abstractmethod
    def _batch_sample_simulation_action(
        self, 
        
        timestep_idx_list: List[int]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Batch sample actions from the action space for `self._simulate()`. 

        Ret:
            `batch_action` (`torch.Tensor` or `np.ndarray`): The batch of the sampled actions. 
        """

        # `_batch_sample_simulation_action()` done
        pass


    def display_tree(
        self
    ):
        """
        Func:
            Display the search tree structure. 
        """

        def dfs(
            node
        ):
            for child in node.children_list:
                node_idx = node.node_idx
                child_idx = child.node_idx

                logger(f"[Node {node_idx}] -> [Node {child_idx}]")

                child_info_dict = child.info_dict

                for sample_idx, info in child_info_dict.items():
                    prev_action = info.prev_action

                    logger(f"    [Sample {sample_idx}] action: {prev_action}")

                    # goto `for sample_idx, info`
                    pass

                dfs(child)

                # goto `for child`
                pass

            # `dfs()` done
            pass

        dfs(self.root)

        # `display_tree()` done
        pass


    def display_info(
        self, 

        display_state: bool = True, 
        display_prev_action: bool = True, 
        display_num_vis: bool = True, 
        display_value: bool = True, 
        display_merged_reward_to_root: bool = True, 
        display_intermediate_reward: bool = True, 
        display_pseudo_final_latent: bool = True, 
        display_final_reward: bool = True
    ):
        """
        Func:
            Display the infomation of every node, ordered by DFS order. 
        """

        def dfs(
            node
        ):
            node.display_info(
                display_state = display_state, 
                display_prev_action = display_prev_action, 
                display_num_vis = display_num_vis, 
                display_value = display_value, 
                display_merged_reward_to_root = display_merged_reward_to_root, 
                display_intermediate_reward = display_intermediate_reward, 
                display_pseudo_final_latent = display_pseudo_final_latent, 
                display_final_reward = display_final_reward
            )

            for child in node.children_list:
                dfs(child)

            # `dfs()` done
            pass

        dfs(self.root)

        # `display_info()` done
        pass


    def display_sample_result(
        self, 

        sample_idx: int, 

        display_trajectory: bool = False, 
        display_state: bool = True, 
        display_action: bool = True, 
        display_reward: bool = True
    ):
        """
        Func:
            Display the search results. 
        """

        best_merged_reward = self.history_best_merged_reward_list_list[sample_idx][-1]
        if not isinstance(best_merged_reward, float):
            best_merged_reward = best_merged_reward.item()

        last_final_reward_list = self.history_last_final_reward_list_list[sample_idx]
        best_final_reward = max(last_final_reward_list)
        if not isinstance(best_final_reward, float):
            best_final_reward = best_final_reward.item()

        logger(f"[Main Results]")
        logger(f"    sample_idx: {sample_idx}")
        logger(f"    best_merged_reward: {best_merged_reward}, best_final_reward: {best_final_reward}")

        if display_trajectory:
            logger(f"[Best Trajectory]")

            self.best_trajectory_list[sample_idx].display_trajectory(
                display_state = display_state, 
                display_action = display_action, 
                display_reward = display_reward
            )

        # `display_sample_result()` done
        pass


    def display_result(
        self, 

        display_trajectory: bool = False, 
        display_state: bool = True, 
        display_action: bool = True, 
        display_reward: bool = True
    ):
        best_merged_reward_list = []
        last_final_reward_list = []

        for sample_idx in range(self.num_sample):
            best_merged_reward = self.history_best_merged_reward_list_list[sample_idx][-1]
            last_final_reward = max(self.history_last_final_reward_list_list[sample_idx])

            best_merged_reward_list.append(
                best_merged_reward if isinstance(best_merged_reward, float) \
                    else best_merged_reward.item()
            )
            last_final_reward_list.append(
                last_final_reward if isinstance(last_final_reward, float) \
                    else last_final_reward.item()
            )

            # goto `for sample_idx`
            pass

        best_merged_reward_list_str = ", ".join(
            [
                f"{best_merged_reward:.4f}" \
                    for best_merged_reward in best_merged_reward_list
            ]
        )
        last_final_reward_list_str = ", ".join(
            [
                f"{last_final_reward:.4f}" \
                    for last_final_reward in last_final_reward_list
            ]
        )

        logger(f"[Main Results]")
        logger(f"    best_merged_reward_list: [{best_merged_reward_list_str}]")
        logger(f"    last_final_reward_list: [{last_final_reward_list_str}]")

        if display_trajectory:
            for sample_idx in range(self.num_sample):
                logger(f"[Best Trajectory] sample_idx: {sample_idx}")

                self.best_trajectory_list[sample_idx].display_trajectory(
                    display_state = display_state, 
                    display_action = display_action, 
                    display_reward = display_reward
                )

                logger("")

                # goto `for sample_idx`
                pass

        # `display_result()` done
        pass
