from util.logger import logger

from typing import Optional, Union, List, Dict, Tuple

from abc import ABC, abstractmethod

from tqdm.auto import tqdm

import time

import numpy as np

import torch

import gc

from util.basic_util import get_attr
from util.torch_util import tsfm_to_1d_array

from .bs_node import BSNode
from .info import Info


class BeamSearch(ABC):
    def __init__(
        self, 

        is_eps_action: Optional[bool] = True, 

        # ---------= [Beam Search] =---------
        num_beam: Optional[int] = 4, 
        num_candidate_per_beam: Optional[int] = 2,

        mdp: "MarkovDecisionProcess" = None, 

        init_state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        # ---------= [Mode] =---------
        mdp_modeling: str = "sparse_reward",  # ["max_reward", "sparse_reward"]

        # ---------= [Expansion Policy] =---------
        expansion_action_sampling_policy: str = "uniform", 

        dtype: Optional[str] = "float32", 

        **arg_dict: Optional[Dict]
    ):
        self._is_eps_action = is_eps_action

        self._num_beam = num_beam
        self._num_candidate_per_beam = num_candidate_per_beam

        self.mdp = mdp

        # ---------= [Mode] =---------
        self._mdp_modeling = mdp_modeling

        if mdp_modeling not in [
            "max_reward", 
            "sparse_reward"
        ]:
            raise NotImplementedError(
                f"Unsupported `mdp_modeling`, got `{mdp_modeling}`. "
            )

        # ---------= [Expansion Policy 1] =---------
        if expansion_action_sampling_policy != "uniform":
            raise NotImplementedError(
                f"Unsupported `expansion_action_sampling_policy`, got `{expansion_action_sampling_policy}`. "
            )
        
        self.expansion_action_sampling_policy = expansion_action_sampling_policy

        # ---------= [MDP] =---------
        self.ver = self.mdp.ver

        self.dtype = dtype
        if isinstance(self.dtype, str):
            self.dtype = get_attr(self.ver, self.dtype)

        self.device = self.mdp.device
        
        self.reward_shape = self.mdp.reward_shape

        self.num_sample = len(init_state_list)

        self.root = self._get_root(
            init_state_list = init_state_list
        )

        self._last_depth_node_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        # ---------= [Merged Reward List] =---------
        init_reward = float("-inf")

        self.best_merged_reward_list = [
            init_reward \
                for sample_idx in range(self.num_sample)
        ]

        # history_best_merged_reward_list_list.shape = (num_sample, num_reward)
        self.history_best_merged_reward_list_list = [
            [init_reward] \
                for sample_idx in range(self.num_sample)
        ]

        # ---------= [BS Loop Index] =---------
        self._bs_loop_idx = 0

        # ---------= [Timer] =---------
        self._time_st = None
        self._bs_loop_wall_time_cost_list = []
        self._wall_clock_time_cost_list = [0] * self.num_sample  # unit: second

        # `__init__()` done
        pass


    def _get_root(
        self, 

        init_state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ]
    ) -> BSNode:
        """
        Func:
            Get the root node. 

        Ret:
            `root` (`BSNode`): The root node. 
        """

        sample_idx_list = list(
            range(self.num_sample)
        )

        root = self._get_new_node(
            state_list = init_state_list, 
            sample_idx_list = sample_idx_list
        )

        # `_get_root()` done
        return root


    def _best_trajectory_updated_callback(
        self, 

        caller: str, 
        pseudo: bool, 

        node: BSNode, 

        sample_idx: int, 

        timestep_idx: int
    ):
        """
        Func:
            The callback function after the best trajectory is updated. 
        """

        best_merged_reward_list = self.history_best_merged_reward_list_list[sample_idx]

        second_last_merged_reward = best_merged_reward_list[-2]
        last_merged_reward = best_merged_reward_list[-1]
            
        logger(
            f"[Sample {sample_idx}] Best trajectory updated by `{caller}` at node {node.node_idx} for sample {sample_idx}. "
        )
        logger(
            f"    merged_reward: [{second_last_merged_reward:.4f}] -> [{last_merged_reward:.4f}]"
        )

        # `_best_trajectory_updated_callback()` done
        pass

    def _cal_node_reward(
        self, 

        node: BSNode, 

        parent: Optional[BSNode] = None, 

        sample_idx_list: List[int] = None, 

        timestep_idx_list: List[int] = None
    ):
        """
        Func: 
            Compute the `intermediate_reward` and `final_reward` (if terminal) for `node`. 
        """
        
        num_sample = len(sample_idx_list)

        depth = node.depth

        time_horizon = self.mdp.time_horizon

        pseudo_final_latent_list = None

        # ---------= [Cal Intermediate Reward] =---------
        if 0 < depth < time_horizon:
            # intermediate_reward_list.shape = (num_sample, )
            # pseudo_final_latent_list.shape = (num_sample, *self.state_shape)
            (
                intermediate_reward_list, 
                pseudo_final_latent_list
            ) = self._cal_intermediate_reward(
                node = node, 

                sample_idx_list = sample_idx_list, 
                
                timestep_idx_list = timestep_idx_list
            )

            if not isinstance(intermediate_reward_list, torch.Tensor):
                intermediate_reward_list = torch.vstack(intermediate_reward_list)
            
            pseudo_final_latent_list_is_none = False
            if (pseudo_final_latent_list is None) \
                or (pseudo_final_latent_list == [None] * num_sample):
                pseudo_final_latent_list_is_none = True

            if (not pseudo_final_latent_list_is_none) \
                and (not isinstance(pseudo_final_latent_list, torch.Tensor)):
                
                if pseudo_final_latent_list[0].ndim == 3:
                    pseudo_final_latent_list = torch.stack(pseudo_final_latent_list)
                elif pseudo_final_latent_list[0].ndim == 4:
                    pseudo_final_latent_list = torch.vstack(pseudo_final_latent_list)

            intermediate_reward_list = [
                intermediate_reward.item() \
                    for intermediate_reward in intermediate_reward_list
            ]

            for i, (
                sample_idx, 
                intermediate_reward   
            ) in enumerate(
                zip(
                    sample_idx_list, 
                    intermediate_reward_list
                )
            ):
                node.info_list[sample_idx].intermediate_reward = intermediate_reward
                
                # ---------= [Update Best Trajectory] =---------
                if (self._mdp_modeling == "max_reward") \
                    and self._get_is_cost_legal(sample_idx = sample_idx) \
                    and (intermediate_reward > self.best_merged_reward_list[sample_idx]):

                    self.best_merged_reward_list[sample_idx] = intermediate_reward
                    self.history_best_merged_reward_list_list[sample_idx].append(intermediate_reward)

                    self._best_trajectory_updated_callback(
                        caller = "_cal_node_reward()", 
                        pseudo = True, 

                        node = node, 

                        sample_idx = sample_idx, 

                        timestep_idx = depth
                    )

                # goto `for i, (sample_idx, intermediate_reward)`
                pass

            if (pseudo_final_latent_list is not None) \
                and (not pseudo_final_latent_list_is_none):

                for i, sample_idx in enumerate(sample_idx_list):
                    node.info_list[sample_idx].pseudo_final_latent = pseudo_final_latent_list[i].cpu()

                    # goto `for i, sample_idx`
                    pass

            # ---------= [Clean Up] =---------
            del intermediate_reward_list
            if pseudo_final_latent_list is not None:
                del pseudo_final_latent_list
            gc.collect()
            torch.cuda.empty_cache()

        # ---------= [Cal Final Reward] =---------
        if depth == time_horizon:
            final_reward_list = self._cal_final_reward(
                node = node, 

                sample_idx_list = sample_idx_list
            )

            final_reward_list = [
                final_reward.item() \
                    for final_reward in final_reward_list
            ]

            for i, (sample_idx, final_reward) in enumerate(
                zip(
                    sample_idx_list, 
                    final_reward_list
                )
            ):
                node.info_list[sample_idx].final_reward = final_reward

                # ---------= [Update Best Trajectory] =---------
                if self._get_is_cost_legal(sample_idx = sample_idx) \
                    and (final_reward > self.best_merged_reward_list[sample_idx]):

                    self.best_merged_reward_list[sample_idx] = final_reward
                    self.history_best_merged_reward_list_list[sample_idx].append(final_reward)

                    self._best_trajectory_updated_callback(
                        caller = "_cal_node_reward()", 
                        pseudo = False, 

                        node = node, 

                        sample_idx = sample_idx, 

                        timestep_idx = depth
                    )

                # goto `for i, (sample_idx, final_reward)`
                pass
            
            # ---------= [Clean Up] =---------
            del final_reward_list
            gc.collect()
            torch.cuda.empty_cache()

        # `_cal_node_reward()` done
        pass


    def _get_new_node(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ], 

        prev_action_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        parent: BSNode = None
    ) -> BSNode:
        """
        Func:
            Get a new node and compute its intermediate reward. 
            If terminal, compute its final reward, too. 

        Ret:
            `new_node` (`BSNode`): The created node. 
        """

        # ---------= [Prepare `state_list`] =---------
        if isinstance(state_list, list):
            state_list = torch.stack(state_list)

        state_list = state_list.clone()

        # ---------= [New Node] =---------
        new_node = BSNode(
            bs_instance = self, 

            sample_idx_list = sample_idx_list, 

            state_list = state_list, 

            prev_action_list = prev_action_list, 

            parent = parent, 

            device = self.device
        )

        # `_get_new_node()` done
        return new_node


    def _prepare_default_action_list(
        self, 

        default_action_list: Union[
            Union[torch.Tensor, np.ndarray], 
            Union[List[torch.Tensor], List[np.ndarray]]
        ] = None
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Check the input `default_action_list`, and ensure its length equal to `self.mdp.time_horizon`. 
            If `default_action_list = None`, set it to a list of default actions. 

        Ret:
            `default_action_list` (`torch.Tensor` or `np.ndarray`): The list of the default actions. 
        """
        
        if default_action_list is None:
            default_action = self.mdp.get_default_action()

            default_action_list = default_action
        
        if not isinstance(default_action_list, torch.Tensor):
            default_action_list = torch.tensor(
                default_action_list, 

                dtype = self.dtype, 
                device = self.device
            )

        default_action_list = tsfm_to_1d_array(
            array = default_action_list, 
            target_length = self.mdp.time_horizon, 

            dtype = self.dtype, 
            device = self.device
        )

        # action_space = self.mdp.action_space.shape
        default_action_list = default_action_list.reshape(
            (self.mdp.time_horizon, *self.mdp.action_space.shape)
        )

        # `_prepare_default_action_list()` done
        return default_action_list


    def _cal_intermediate_reward(
        self, 

        node: BSNode, 

        sample_idx_list: Union[int, List[int]] = None, 

        timestep_idx_list: List[int] = None
    ) -> Union[
        Tuple[List[torch.Tensor], List[torch.Tensor]], 
        Tuple[List[np.ndarray], List[np.ndarray]]
    ]:
        """
        Func:
            Compute the intermediate reward list for transitioning from the parent node to `node`. 

        Ret:
            `intermediate_reward_list` (`List[torch.Tensor]` or `List[np.ndarray]`): 
                The list of the intermediate rewards. 
                (`sample_idx_list` is provided) intermediate_reward_list.shape = (len(sample_idx_list), *reward_shape). 
                (`sample_idx_list` is not provided) intermediate_reward_list.shape = (num_sample, *reward_shape). 
            
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
        """

        if node.depth == 0:
            raise ValueError(
                f"Can not compute the intermediate reward for the root node. "
            )

        # ---------= [Prepare `sample_idx_list`] =---------
        if sample_idx_list is None:
            sample_idx_list = list(
                range(self.num_sample)
            )
        
        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        # ---------= [Compute Intermediate Reward] =---------
        intermediate_reward_list = []
        pseudo_final_latent_list = []

        need_cal_idx_list = []

        for i, sample_idx in enumerate(sample_idx_list):
            info = node.info_list[sample_idx]

            if info is not None:
                intermediate_reward = info.intermediate_reward
                pseudo_final_latent = info.pseudo_final_latent

                if intermediate_reward is not None:
                    intermediate_reward = torch.tensor(
                        intermediate_reward, 

                        dtype = self.dtype, 
                        device = self.device
                    )
                    intermediate_reward = intermediate_reward.reshape(self.reward_shape)

                    intermediate_reward_list.append(intermediate_reward)
                    pseudo_final_latent_list.append(pseudo_final_latent)
                else:
                    intermediate_reward_list.append(None)
                    pseudo_final_latent_list.append(None)

                    need_cal_idx_list.append(i)
            else:
                intermediate_reward_list.append(None)
                pseudo_final_latent_list.append(None)

            # goto `for i, sample_idx`
            pass
        
        need_cal_sample_idx_list = None

        if len(need_cal_idx_list) > 0:
            need_cal_sample_idx_list = [
                sample_idx_list[i] \
                    for i in need_cal_idx_list
            ]

            # state_list.shape = (num_need_cal, *state_shape)
            state_list = [
                node.get_state(
                    sample_idx_list = sample_idx
                )[0] \
                    for sample_idx in need_cal_sample_idx_list
            ]
            state_list = torch.stack(state_list)

            # prev_action_list.shape = (num_need_cal, *action_shape)
            prev_action_list = [
                node.info_list[sample_idx].prev_action \
                    for sample_idx in need_cal_sample_idx_list
            ]

            # need_timestep_idx_list.shape = (num_need_cal, )
            need_timestep_idx_list = [
                timestep_idx_list[i] \
                    for i in need_cal_idx_list
            ]
            
            # need_cal_intermediate_reward_list.shape = (num_need_cal, )
            # need_cal_pseudo_final_latent_list.shape = (num_need_cal, *state_shape)
            (
                need_cal_intermediate_reward_list, 
                need_cal_pseudo_final_latent_list
            ) = self.mdp.batch_cal_intermediate_reward(
                state_list = state_list, 
                prev_action_list = prev_action_list, 

                timestep_idx_list = need_timestep_idx_list, 

                sample_idx_list = need_cal_sample_idx_list
            )
            
            for i, (need_cal_idx, need_cal_sample_idx) in enumerate(
                zip(
                    need_cal_idx_list, 
                    need_cal_sample_idx_list
                )
            ):
                intermediate_reward = need_cal_intermediate_reward_list[i]
                
                pseudo_final_latent = None
                if need_cal_pseudo_final_latent_list is not None:
                    pseudo_final_latent = need_cal_pseudo_final_latent_list[i]

                intermediate_reward_list[need_cal_idx] = intermediate_reward
                pseudo_final_latent_list[need_cal_idx] = pseudo_final_latent

                node.info_list[need_cal_sample_idx].intermediate_reward = intermediate_reward
                node.info_list[need_cal_sample_idx].pseudo_final_latent = pseudo_final_latent

                # goto `for i, (need_cal_idx, need_cal_sample_idx)`
                pass
            
            # ---------= [Clean Up] =---------
            del need_cal_sample_idx_list
            del state_list, prev_action_list
            del need_cal_intermediate_reward_list
            if need_cal_pseudo_final_latent_list is not None:
                del need_cal_pseudo_final_latent_list
        
        # ---------= [Clean Up] =---------
        del need_cal_idx_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_cal_intermediate_reward()` done
        return (
            intermediate_reward_list, 
            pseudo_final_latent_list
        )


    def _cal_final_reward(
        self, 

        node: BSNode, 

        sample_idx_list: Union[int, List[int]] = None
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Compute the final reward for the terminal node `node`. 

        Ret:
            `final_reward` (`torch.Tensor` or `np.ndarray`): 
                The list of the final rewards.
                (`sample_idx_list` is provided) final_reward_list.shape = (len(sample_idx_list), *reward_shape). 
                (`sample_idx_list` is not provided) final_reward_list.shape = (num_sample, *reward_shape).  
        """

        if node.depth != self.mdp.time_horizon:
            raise ValueError(
                f"Can not compute the final reward for a non-terminal node. "
            )
        
        # ---------= [Prepare `sample_idx_list`] =---------
        if sample_idx_list is None:
            sample_idx_list = list(
                range(self.num_sample)
            )
        
        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        need_cal_idx_list = []

        for i, sample_idx in enumerate(sample_idx_list):
            info = node.info_list[sample_idx]
            
            if info is not None:
                final_reward = info.final_reward

                if final_reward is not None:
                    final_reward = torch.tensor(
                        final_reward, 

                        dtype = self.dtype, 
                        device = self.device
                    )
                    final_reward = final_reward.reshape(self.reward_shape)

                    final_reward_list.append(final_reward)
                else:
                    final_reward_list.append(None)

                    need_cal_idx_list.append(i)
            else:
                final_reward_list.append(None)

            # goto `for i, sample_idx`
            pass

        if len(need_cal_idx_list) > 0:
            need_cal_sample_idx_list = [
                sample_idx_list[i] \
                    for i in need_cal_idx_list
            ]

            # state_list.shape = (num_need_cal, *state_shape)
            state_list = [
                node.get_state(
                    sample_idx_list = sample_idx
                )[0] \
                    for sample_idx in need_cal_sample_idx_list
            ]
            state_list = torch.stack(state_list)

            # need_cal_final_reward_list.shape = (num_need_cal, )
            need_cal_final_reward_list = self.mdp.batch_cal_final_reward(
                sample_idx_list = need_cal_sample_idx_list, 

                state_list = state_list
            )

            for i, (need_cal_idx, need_cal_sample_idx) in enumerate(
                zip(
                    need_cal_idx_list, 
                    need_cal_sample_idx_list
                )
            ):
                final_reward = need_cal_final_reward_list[i]

                final_reward_list[need_cal_idx] = final_reward

                node.info_list[need_cal_sample_idx].final_reward = final_reward

                # goto `for i, (need_cal_idx, need_cal_sample_idx)`
                pass

            # ---------= [Clean Up] =---------
            del need_cal_sample_idx_list
            del state_list
            del need_cal_final_reward_list
        
        # ---------= [Clean Up] =---------
        del need_cal_idx_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_cal_final_reward()` done
        return final_reward_list


    def _select(
        self, 

        sample_idx_list: List[int], 

        display_state_value_list: Optional[bool] = False
    ) -> List[List[BSNode]]:
        """
        Func:
            Select candidate nodes to expand. 

        Ret:
            `selected_node_list_list` (`List[List[BSNode]]`): A list of lists that contains nodes to expand for each sample. 
        """

        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        
        def implement_sample(
            sample_idx: int
        ) -> BSNode:
            # ---------= [Prepare `state_value_list`] =---------
            candidate_node_list = self._last_depth_node_list_list[sample_idx]

            # state_value_list.shape = (num_node, )
            state_value_list = [
                node.info_list[sample_idx].intermediate_reward
                    for node in candidate_node_list
            ]

            num_node = len(state_value_list)

            if display_state_value_list:
                node_idx_list = [
                    node.node_idx \
                        for node in candidate_node_list
                ]

                state_value_list_str = ", ".join(
                    [
                        f"{state_value:.4f}" \
                            for state_value in state_value_list
                    ]
                )
                
                logger(f"node_idx_list: {node_idx_list}, state_value_list: {state_value_list_str}")

                # ---------= [Clean Up] =---------
                del node_idx_list
                gc.collect()

            # ---------= [Select Node] =---------
            ord_list = list(range(num_node))
            ord_list.sort(
                key = lambda i: -state_value_list[i]
            )
            
            selected_node_list = []
            for i in range(self._num_beam):
                selected_node_list.append(
                    candidate_node_list[ord_list[i]]
                )
            
            # ---------= [Clean Up] =---------
            del state_value_list
            del ord_list
            gc.collect()
            torch.cuda.empty_cache()

            # `implement_sample()` done
            return selected_node_list
        

        selected_node_list_list = [
            implement_sample(sample_idx = sample_idx) \
                for sample_idx in sample_idx_list
        ]

        # `_select()` done
        return selected_node_list_list


    def _duplicate_node_list(
        self, 

        node_list: List[BSNode], 
        sample_idx_list: List[int], 

        num_duplicate_list:  List[int]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Duplicate each `node_list[i]` / `sample_idx_list[i]` in `node_list` / `sample_idx_list` 
                for `num_duplicate_list[i]` times. 

        Ret:
            `node_list` (`List[BSNode]`): The duplicated list. 
            `sample_idx_list` (`List[int]`): The duplicated list. 
        """
        
        tmp_node_list = []
        tmp_sample_idx_list = []
        
        for (node, sample_idx, num_duplicate) in zip(
            node_list, 
            sample_idx_list, 
            num_duplicate_list
        ):
            tmp_node_list += [node] * num_duplicate
            tmp_sample_idx_list += [sample_idx] * num_duplicate

            # goto `for (node, num_duplicate)`
            pass
        
        # ---------= [Clean Up] =---------
        del node_list
        del sample_idx_list
        gc.collect()
        torch.cuda.empty_cache()
        
        node_list = tmp_node_list
        sample_idx_list = tmp_sample_idx_list

        # `_duplicate_node_list()` done
        return (
            node_list, 
            sample_idx_list
        )


    def _get_node_empty_child(
        self, 

        node: BSNode,

        sample_idx: int
    ) -> BSNode:
        """
        Func:
            Return a child of `node` with empty state at sample `sample_idx`. 
            If not exists, return `None`. 

        Ret:
            `child_node` (`BSNode`): A child of `node` with empty state at sample `sample_idx`. 
        """
    
        for child_idx, child in enumerate(node.children_list):
            if child.info_list[sample_idx] is None:
                child.info_list[sample_idx] = -1  # occupy

                child_node = child

                num_occupied_list = [
                    sum(
                        [
                            child.info_list[idx] is not None \
                                for idx in range(self.num_sample)
                        ]
                    ) for child in node.children_list
                ]

                num_son = len(num_occupied_list)

                for son_idx in range(0, num_son - 1):
                    if num_occupied_list[son_idx] > num_occupied_list[son_idx + 1]:
                        num_occupied_list[son_idx], num_occupied_list[son_idx + 1] \
                            = num_occupied_list[son_idx + 1], num_occupied_list[son_idx]
                    
                        node.children_list[son_idx], node.children_list[son_idx + 1] \
                            = node.children_list[son_idx + 1], node.children_list[son_idx]

                    # goto `for son_idx`
                    pass

                return child_node

            # goto `for child_idx, child`
            pass

        # `_get_node_empty_child()` done
        return None


    # def _expand(
    #     self, 

    #     node_list: List[BSNode], 

    #     sample_idx: int
    # ) -> List[BSNode]:
    #     """
    #     Func:
    #         Sample actions to expand the selected nodes, and create new child nodes. 
        
    #     Ret: 
    #         `expanded_node_list` (`List[BSNode]`): The list of expanded nodes. 
    #     """
        
    #     num_node = len(node_list)
    #     num_duplicate_list = [self._num_candidate_per_beam] * num_node

    #     node_list = self._duplicate_node_list(
    #         node_list = node_list, 
    #         num_duplicate_list = num_duplicate_list
    #     )

    #     num_node = len(node_list)

    #     sample_idx_list = [sample_idx] * num_node
        
    #     # ---------= [Timestep] =---------
    #     # timestep_idx_list.shape = (num_node, )
    #     timestep_idx_list = torch.tensor(
    #         [
    #             node.depth \
    #                 for node in node_list
    #         ], 

    #         dtype = torch.int32, 
    #         device = "cpu"
    #     )

    #     # terminal node
    #     if max(timestep_idx_list) >= self.mdp.time_horizon:
    #         raise ValueError(
    #             f"Can not expand a terminal node. "
    #         )

    #     # ---------= [Current State] =---------
    #     state_list = [
    #         node.get_state(
    #             sample_idx_list = [sample_idx]
    #         )[0] \
    #             for node in node_list
    #     ]

    #     # state_list.shape = (num_node, *state_shape)
    #     state_list = torch.stack(state_list)
        
    #     # ---------= [Expansion Action] =---------
    #     action_list = []
        
    #     node_i = 0
        
    #     while node_i < num_node:
    #         node_j = node_i

    #         while (node_j + 1 < num_node) \
    #             and (node_list[node_j + 1].node_idx == node_list[node_i].node_idx):
    #             node_j += 1

    #         node_sample_idx_list = [sample_idx] * (node_j - node_i + 1)

    #         node_action_list = self._batch_sample_expansion_action(
    #             node = node_list[node_i], 
    #             sample_idx_list = node_sample_idx_list
    #         )

    #         action_list += node_action_list

    #         node_i = node_j + 1

    #         # goto `while i < num_node`
    #         pass

    #     # action_list.shape = (num_state, *action_shape)
    #     action_list = torch.stack(action_list)
        
    #     # ---------= [Expansion] =---------
    #     # next_state_list.shape = (num_state, *state_shape)
    #     next_state_list = self.mdp.batch_cal_dynamics(
    #         sample_idx_list = sample_idx_list, 

    #         state_list = state_list, 
    #         action_list = action_list, 

    #         timestep_idx_list = timestep_idx_list
    #     )
        
    #     # ---------= [Fill In] =---------
    #     expanded_node_list = []
            
    #     for i, (node, sample_idx) in enumerate(
    #         zip(
    #             node_list, 
    #             sample_idx_list
    #         )
    #     ):
    #         state = next_state_list[i]
    #         action = action_list[i]

    #         expanded_node = self._get_node_empty_child(
    #             node = node, 
    #             sample_idx = sample_idx
    #         )

    #         if expanded_node is not None:
    #             expanded_node.info_list[sample_idx] = Info(
    #                 state = state, 
    #                 prev_action = action
    #             )
    #         else:
    #             expanded_node = self._get_new_node(
    #                 sample_idx_list = [sample_idx], 

    #                 state_list = [state], 
    #                 prev_action_list = [action], 

    #                 parent = node
    #             )

    #             node.children_list.append(expanded_node)

    #         expanded_node_list.append(expanded_node)

    #         # goto `for sample_idx_idx, (node, sample_idx)`
    #         pass
        
    #     # ---------= [Cal Node Reward] =---------
    #     num_expanded_node = len(expanded_node_list)

    #     ord_list = sorted(
    #         list(
    #             range(num_expanded_node)
    #         ), 
    #         key = lambda ord_idx: expanded_node_list[ord_idx].node_idx
    #     )

    #     _node_list = []
    #     _expanded_node_list = []

    #     for ord_idx in ord_list:
    #         _node_list.append(node_list[ord_idx])
    #         _expanded_node_list.append(expanded_node_list[ord_idx])

    #         # goto `for ord_idx`
    #         pass

    #     node_list = _node_list
    #     expanded_node_list = _expanded_node_list

    #     node_i = 0
        
    #     while node_i < num_expanded_node:
    #         node_j = node_i

    #         while (node_j + 1 < num_expanded_node) \
    #             and (expanded_node_list[node_j + 1].node_idx == expanded_node_list[node_i].node_idx):
    #             node_j += 1
            
    #         node = node_list[node_i]
    #         expanded_node = expanded_node_list[node_i]
    #         expanded_node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]

    #         expanded_node_timestep_idx_list = torch.tensor(
    #             [expanded_node.depth] * (node_j - node_i + 1), 

    #             dtype = torch.int32, 
    #             device = "cpu"
    #         )

    #         node_i = node_j + 1
            
    #         self._cal_node_reward(
    #             node = expanded_node, 

    #             parent = node, 

    #             sample_idx_list = expanded_node_sample_idx_list,
                
    #             timestep_idx_list = expanded_node_timestep_idx_list
    #         )

    #         # ---------= [Clean Up] =---------
    #         del expanded_node_timestep_idx_list
    #         gc.collect()
    #         torch.cuda.empty_cache()

    #         # goto `while node_i < num_expanded_node`
    #         pass

    #     # ---------= [Clean Up] =---------
    #     del timestep_idx_list
    #     del state_list, action_list, next_state_list
    #     del ord_list
    #     gc.collect()
    #     torch.cuda.empty_cache()

    #     # `_expand()` done
    #     return expanded_node_list


    def _expand(
        self, 

        node_list_list: List[List[BSNode]], 

        sample_idx_list: List[int]
        # sample_idx_list_list: List[List[int]]
    ) -> List[List[BSNode]]:
        """
        Func:
            Sample actions to expand the selected nodes, and create new child nodes. 
        
        Ret: 
            `expanded_node_list_list` (`List[List[BSNode]]`): The list of expanded nodes. 
        """
        
        node_list = sum(node_list_list, [])

        num_node = len(node_list)
        init_num_sample_idx = len(sample_idx_list)

        
        sample_idx_list = sum(
            [
                [sample_idx] * len(node_list_list[i]) \
                    for i, sample_idx in enumerate(sample_idx_list)
            ], 
            []
        )
        
        num_duplicate_list = [self._num_candidate_per_beam] * num_node
        
        (
            node_list, 
            sample_idx_list
        ) = self._duplicate_node_list(
            node_list = node_list, 
            sample_idx_list = sample_idx_list, 

            num_duplicate_list = num_duplicate_list
        )
        
        # ---------= [Sort By Node Index] =---------
        num_node = len(node_list)

        ord_list = sorted(
            list(
                range(num_node)
            ), 
            key = lambda ord_idx: node_list[ord_idx].node_idx
        )

        _node_list = []
        _sample_idx_list = []

        for ord_idx in ord_list:
            _node_list.append(node_list[ord_idx])
            _sample_idx_list.append(sample_idx_list[ord_idx])

            # goto `for ord_idx`
            pass
    
        node_list = _node_list
        sample_idx_list = _sample_idx_list
        
        # ---------= [Timestep] =---------
        # timestep_idx_list.shape = (num_node, )
        timestep_idx_list = torch.tensor(
            [
                node.depth \
                    for node in node_list
            ], 

            dtype = torch.int32, 
            device = "cpu"
        )

        # terminal node
        if max(timestep_idx_list) >= self.mdp.time_horizon:
            raise ValueError(
                f"Can not expand a terminal node. "
            )

        # ---------= [Current State] =---------
        state_list = []

        node_i = 0
        
        while node_i < num_node:
            node_j = node_i

            while (node_j + 1 < num_node) \
                and (node_list[node_j + 1].node_idx == node_list[node_i].node_idx):
                node_j += 1
            
            node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]
            node_state_list = node_list[node_i].get_state(
                sample_idx_list = node_sample_idx_list
            )

            state_list += node_state_list

            node_i = node_j + 1

            # goto `while i < num_node`
            pass

        # state_list.shape = (num_node, *state_shape)
        state_list = torch.stack(state_list)
        
        # ---------= [Expansion Action] =---------
        action_list = []
        
        node_i = 0
        
        while node_i < num_node:
            node_j = node_i

            while (node_j + 1 < num_node) \
                and (node_list[node_j + 1].node_idx == node_list[node_i].node_idx):
                node_j += 1

            node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]

            node_action_list = self._batch_sample_expansion_action(
                node = node_list[node_i], 
                sample_idx_list = node_sample_idx_list
            )

            action_list += node_action_list

            node_i = node_j + 1

            # goto `while i < num_node`
            pass

        # action_list.shape = (num_state, *action_shape)
        action_list = torch.stack(action_list)
        
        # ---------= [Expansion] =---------
        # next_state_list.shape = (num_state, *state_shape)
        next_state_list = self.mdp.batch_cal_dynamics(
            sample_idx_list = sample_idx_list, 

            state_list = state_list, 
            action_list = action_list, 

            timestep_idx_list = timestep_idx_list
        )
        
        # ---------= [Fill In] =---------
        expanded_node_list = []
            
        for i, (node, sample_idx) in enumerate(
            zip(
                node_list, 
                sample_idx_list
            )
        ):
            state = next_state_list[i]
            action = action_list[i]

            expanded_node = self._get_node_empty_child(
                node = node, 
                sample_idx = sample_idx
            )

            if expanded_node is not None:
                expanded_node.info_list[sample_idx] = Info(
                    state = state, 
                    prev_action = action
                )
            else:
                expanded_node = self._get_new_node(
                    sample_idx_list = [sample_idx], 

                    state_list = [state], 
                    prev_action_list = [action], 

                    parent = node
                )

                node.children_list.append(expanded_node)

            expanded_node_list.append(expanded_node)

            # goto `for sample_idx_idx, (node, sample_idx)`
            pass
        
        # ---------= [Cal Node Reward] =---------
        num_expanded_node = len(expanded_node_list)

        ord_list = sorted(
            list(
                range(num_expanded_node)
            ), 
            key = lambda ord_idx: expanded_node_list[ord_idx].node_idx
        )

        _node_list = []
        _expanded_node_list = []
        _sample_idx_list = []

        for ord_idx in ord_list:
            _node_list.append(node_list[ord_idx])
            _expanded_node_list.append(expanded_node_list[ord_idx])
            _sample_idx_list.append(sample_idx_list[ord_idx])

            # goto `for ord_idx`
            pass

        node_list = _node_list
        expanded_node_list = _expanded_node_list
        sample_idx_list = _sample_idx_list

        node_i = 0

        expanded_node_list_list = [
            [] for _ in range(init_num_sample_idx)
        ]
        
        while node_i < num_expanded_node:
            node_j = node_i

            while (node_j + 1 < num_expanded_node) \
                and (expanded_node_list[node_j + 1].node_idx == expanded_node_list[node_i].node_idx):
                node_j += 1
            
            node = node_list[node_i]
            expanded_node = expanded_node_list[node_i]
            expanded_node_sample_idx_list = sample_idx_list[node_i: (node_j + 1)]

            expanded_node_timestep_idx_list = torch.tensor(
                [expanded_node.depth] * (node_j - node_i + 1), 

                dtype = torch.int32, 
                device = "cpu"
            )

            node_i = node_j + 1
            
            self._cal_node_reward(
                node = expanded_node, 

                parent = node, 

                sample_idx_list = expanded_node_sample_idx_list,
                
                timestep_idx_list = expanded_node_timestep_idx_list
            )

            for expanded_node_sample_idx in expanded_node_sample_idx_list:
                expanded_node_list_list[expanded_node_sample_idx].append(
                    expanded_node
                )

                # goto `for expanded_node_sample_idx`
                pass

            # ---------= [Clean Up] =---------
            del expanded_node_sample_idx_list
            del expanded_node_timestep_idx_list
            gc.collect()
            torch.cuda.empty_cache()

            # goto `while node_i < num_expanded_node`
            pass

        # ---------= [Clean Up] =---------
        del timestep_idx_list
        del state_list, action_list, next_state_list
        del ord_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_expand()` done
        return expanded_node_list_list


    def _step(
        self, 

        sample_idx_list: List[int], 

        display_state_value_list: Optional[bool] = False, 

        **arg_dict: Optional[Dict]
    ):
        """
        Func:
            Run BS for a sample until stopped. 
        """
        
        if self._bs_loop_idx == 0:
            node_list_list = [
                [self.root for _ in range(self._num_beam)] \
                    for sample_idx in sample_idx_list
            ]
        else:
            node_list_list = self._select(
                sample_idx_list = sample_idx_list, 

                display_state_value_list = display_state_value_list
            )

        # for (
        #     sample_idx, 
        #     node_list
        # ) in zip(
        #     sample_idx_list, 
        #     node_list_list
        # ):
        #     expanded_node_list = self._expand(
        #         node_list = node_list, 

        #         sample_idx = sample_idx
        #     )

        #     self._last_depth_node_list_list[sample_idx] = expanded_node_list

        #     # goto `for (sample_idx, node_list)`
        #     pass
        
        expanded_node_list_list = self._expand(
            node_list_list = node_list_list, 

            sample_idx_list = sample_idx_list.copy()
            # sample_idx_list_list = sample_idx_list_list
        )

        for i, expanded_node_list in enumerate(expanded_node_list_list):
            sample_idx = sample_idx_list[i]
            self._last_depth_node_list_list[sample_idx] = expanded_node_list

            # goto `for (expanded_node_list, sample_idx)`
            pass

        # ---------= [Clean Up] =---------
        del node_list_list
        gc.collect()
        torch.cuda.empty_cache()

        # `_step()` done
        pass


    def _run_sample_pre_process(
        self, 

        **arg_dict
    ):
        """
        Func:
            The function called before starting a run of BS for a single sample. 
        """

        self._time_st = time.time()

        # `_run_sample_pre_process()` done
        pass


    def _run_sample_post_process(
        self, 

        sample_idx: int, 

        **arg_dict
    ):
        """
        Func:
            The function called after finishing a run of BS for a single sample. 
        """

        # ---------= [Compute Wall-clock Time Cost] =---------
        time_ed = time.time()

        time_cost = time_ed - self._time_st
        self._wall_clock_time_cost_list[sample_idx] = time_cost

        logger(
            f"[Sample {sample_idx}] Finished, wall-clock time cost: {round(time_cost)} second(s). "
        )
        
        # `_run_sample_post_process()` done
        pass


    def _bs_loop_callback(
        self, 

        local_var_dict: Dict, 

        **arg_dict
    ):
        """
        Func:
            The callback function after each BS loop. 
        """

        time_ed = time.time()

        time_cost = time_ed - self._time_st

        self._bs_loop_wall_time_cost_list.append(time_cost)

        self._bs_loop_idx += 1

        # `_bs_loop_callback()` done
        pass


    def run(
        self, 

        display_result: bool = False, 

        **arg_dict: Optional[Dict]
    ):
        # ---------= [Pre-process] =---------
        self._run_sample_pre_process()

        time_horizon = self.mdp.time_horizon

        need_cal_sample_idx_list = list(
            range(self.num_sample)
        )
        
        for bs_loop_idx in tqdm(
            range(time_horizon), 
            desc = "[Beam Search]"
        ):
            self._step(
                sample_idx_list = need_cal_sample_idx_list, 
                
                **arg_dict
            )

            # ---------= [Clean Up] =---------
            gc.collect()
            torch.cuda.empty_cache()
            
            local_var_dict = locals()
            self._bs_loop_callback(local_var_dict)
            
            # goto `for bs_loop_idx`
            pass

        for sample_idx in range(self.num_sample):
            self._run_sample_post_process(sample_idx = sample_idx)

            if display_result:
                self.display_sample_result(
                    sample_idx = sample_idx
                )

            # goto `for sample_idx`
            pass

        if display_result:
            self.display_result()

        # `run()` done
        pass


    @abstractmethod
    def _get_is_time_to_stop(
        self, 

        **arg_dict
    ) -> bool:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Simulate the trajectory from the expanded node to terminal state. 

        Ret:
            `is_time_to_stop` (`bool`): Whether it is time to stop. 
        """

        # `_get_is_time_to_stop()` done
        pass


    def _get_is_cost_legal(
        self, 

        **arg_dict
    ) -> bool:
        """
        Func:
            Determine whether the cost of BS is legal. 

        Ret:
            `is_cost_legal` (`bool`): Whether the cost of BS is legal. 
        """

        # `_get_is_cost_legal()` done
        pass


    @abstractmethod
    def _sample_expansion_action(
        self, 
        
        node: BSNode
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Sample an action from the action space for `self._expand()`. 

        Ret:
            `action` (`torch.Tensor` or `np.ndarray`): The sampled action. 
        """

        # `_sample_expansion_action()` done
        pass


    @abstractmethod
    def _batch_sample_expansion_action(
        self, 
        
        node_list: List[BSNode]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Batch sample actions from the action space for `self._expand()`. 

        Ret:
            `action_list` (`torch.Tensor` or `np.ndarray`): The list of the sampled actions. 
        """

        # `_batch_sample_expansion_action()` done
        pass


    # @abstractmethod
    # def _sample_simulation_action(
    #     self, 
        
    #     timestep_idx: int
    # ) -> Union[torch.Tensor, np.ndarray]:
    #     """
    #     NB: 
    #         Abstract, should be implemented. 

    #     Func:
    #         Sample an action from the action space for `self._simulate()`. 

    #     Ret:
    #         `action` (`torch.Tensor` or `np.ndarray`): The sampled action. 
    #     """

    #     # `_sample_simulation_action()` done
    #     pass


    # @abstractmethod
    # def _batch_sample_simulation_action(
    #     self, 
        
    #     timestep_idx_list: List[int]
    # ) -> Union[torch.Tensor, np.ndarray]:
    #     """
    #     NB: 
    #         Abstract, should be implemented. 

    #     Func:
    #         Batch sample actions from the action space for `self._simulate()`. 

    #     Ret:
    #         `batch_action` (`torch.Tensor` or `np.ndarray`): The batch of the sampled actions. 
    #     """

    #     # `_batch_sample_simulation_action()` done
    #     pass


    def display_tree(
        self
    ):
        """
        Func:
            Display the search tree structure. 
        """

        def dfs(
            node
        ):
            for child in node.children_list:
                node_idx = node.node_idx
                child_idx = child.node_idx

                logger(f"[Node {node_idx}] -> [Node {child_idx}]")

                child_info_dict = child.info_dict

                for sample_idx, info in child_info_dict.items():
                    prev_action = info.prev_action

                    logger(f"    [Sample {sample_idx}] action: {prev_action}")

                    # goto `for sample_idx, info`
                    pass

                dfs(child)

                # goto `for child`
                pass

            # `dfs()` done
            pass

        dfs(self.root)

        # `display_tree()` done
        pass


    def display_info(
        self, 

        display_state: bool = True, 
        display_prev_action: bool = True, 
        display_value: bool = True
    ):
        """
        Func:
            Display the infomation of every node, ordered by DFS order. 
        """

        def dfs(
            node
        ):
            node.display_info(
                display_state = display_state, 
                display_prev_action = display_prev_action, 
                display_value = display_value
            )

            for child in node.children_list:
                dfs(child)

            # `dfs()` done
            pass

        dfs(self.root)

        # `display_info()` done
        pass


    def display_sample_result(
        self, 

        sample_idx: int
    ):
        """
        Func:
            Display the search results. 
        """

        best_merged_reward = self.history_best_merged_reward_list_list[sample_idx][-1]

        logger(f"[Main Results]")
        logger(f"    sample_idx: {sample_idx}")
        logger(f"    best_merged_reward: {best_merged_reward}")

        # `display_sample_result()` done
        pass


    def display_result(
        self
    ):
        best_merged_reward_list = []

        for sample_idx in range(self.num_sample):
            best_merged_reward = self.history_best_merged_reward_list_list[sample_idx][-1]
            best_merged_reward_list.append(best_merged_reward)

            # goto `for sample_idx`
            pass

        best_merged_reward_list_str = ", ".join(
            [
                f"{best_merged_reward:.4f}" \
                    for best_merged_reward in best_merged_reward_list
            ]
        )

        logger(f"[Main Results]")
        logger(f"    best_merged_reward_list: [{best_merged_reward_list_str}]")

        # `display_result()` done
        pass
