from typing import Any, List, Tuple

import numpy as np
import torch
from ding.utils import BUFFER_REGISTRY

from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree
from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree
from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
from .game_buffer_efficientzero import EfficientZeroGameBuffer


@BUFFER_REGISTRY.register('game_buffer_sampled_efficientzero')
class SampledEfficientZeroGameBuffer(EfficientZeroGameBuffer):
    """
    Overview:
        The specific game buffer for Sampled EfficientZero policy.
    """

    def __init__(self, cfg: dict):
        super().__init__(cfg)
        """
        Overview:
            Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
            in the default configuration, the user-provided value will override the default configuration. Otherwise,
            the default configuration will be used.
        """
        default_config = self.default_config()
        default_config.update(cfg)
        self._cfg = default_config
        assert self._cfg.env_type in ['not_board_games', 'board_games']
        assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
        self.replay_buffer_size = self._cfg.replay_buffer_size
        self.batch_size = self._cfg.batch_size
        self._alpha = self._cfg.priority_prob_alpha
        self._beta = self._cfg.priority_prob_beta

        self.game_segment_buffer = []
        self.game_pos_priorities = []
        self.game_segment_game_pos_look_up = []

        self.keep_ratio = 1
        self.num_of_collected_episodes = 0
        self.base_idx = 0
        self.clear_time = 0

    def sample(self, batch_size: int, policy: Any) -> List[Any]:
        """
        Overview:
            sample data from ``GameBuffer`` and prepare the current and target batch for training
        Arguments:
            - batch_size (:obj:`int`): batch size
            - policy (:obj:`torch.tensor`): model of policy
        Returns:
            - train_data (:obj:`List`): List of train data
        """

        policy._target_model.to(self._cfg.device)
        policy._target_model.eval()

        reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
            batch_size, self._cfg.reanalyze_ratio
        )

        # target reward, target value
        batch_value_prefixs, batch_target_values = self._compute_target_reward_value(
            reward_value_context, policy._target_model
        )

        batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
            policy_non_re_context, self._cfg.model.num_of_sampled_actions
        )

        if self._cfg.reanalyze_ratio > 0:
            # target policy
            batch_target_policies_re, root_sampled_actions = self._compute_target_policy_reanalyzed(
                policy_re_context, policy._target_model
            )
            # ==============================================================
            # fix reanalyze in sez:
            # use the latest root_sampled_actions after the reanalyze process,
            # because the batch_target_policies_re is corresponding to the latest root_sampled_actions
            # ==============================================================

            assert (self._cfg.reanalyze_ratio > 0 and self._cfg.reanalyze_outdated is True), \
                "in sampled effiicientzero, if self._cfg.reanalyze_ratio>0, you must set self._cfg.reanalyze_outdated=True"
            # current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list]
            if self._cfg.model.continuous_action_space:
                current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape(
                    int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1,
                    self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size
                )
            else:
                current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape(
                    int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1,
                    self._cfg.model.num_of_sampled_actions, 1
                )

        if 0 < self._cfg.reanalyze_ratio < 1:
            try:
                batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
            except Exception as error:
                print(error)
        elif self._cfg.reanalyze_ratio == 1:
            batch_target_policies = batch_target_policies_re
        elif self._cfg.reanalyze_ratio == 0:
            batch_target_policies = batch_target_policies_non_re

        target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies]
        # a batch contains the current_batch and the target_batch
        train_data = [current_batch, target_batch]
        return train_data


    def _prepare_reward_value_context(
            self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any],
            total_transitions: int
    ) -> List[Any]:
        """
        Overview:
            prepare the context of rewards and values for calculating TD value target in reanalyzing part.
        Arguments:
            - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer
            - game_segment_list (:obj:`list`): list of game segments
            - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment
            - total_transitions (:obj:`int`): number of collected transitions
        Returns:
            - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
              td_steps_list, action_mask_segment, to_play_segment
        """
        zero_obs = game_segment_list[0].zero_obs()
        value_obs_list = []
        # the value is valid or not (out of trajectory)
        value_mask = []
        root_value_list = []
        rewards_list = []
        game_segment_lens = []
        # for two_player board games
        action_mask_segment, to_play_segment = [], []

        td_steps_list = []
        for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list):
            game_segment_len = len(game_segment)
            game_segment_lens.append(game_segment_len)

            # ==============================================================
            # EfficientZero related core code
            # ==============================================================
            # TODO(pu):
            # for atari, off-policy correction: shorter horizon of td steps
            # delta_td = (total_transitions - idx) // config.auto_td_steps
            # td_steps = config.td_steps - delta_td
            # td_steps = np.clip(td_steps, 1, 5).astype(np.int)
            td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32)

            # prepare the corresponding observations for bootstrapped values o_{t+k}
            # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps]
            # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14]
            game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps)

            rewards_list.append(game_segment.reward_segment)

            # for two_player board games
            action_mask_segment.append(game_segment.action_mask_segment)
            to_play_segment.append(game_segment.to_play_segment)

            for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
                # get the <num_unroll_steps+1>  bootstrapped target obs
                td_steps_list.append(td_steps)
                # index of bootstrapped obs o_{t+td_steps}
                bootstrap_index = current_index + td_steps

                if bootstrap_index < game_segment_len:
                    value_mask.append(1)
                    # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps
                    beg_index = current_index - state_index
                    end_index = beg_index + self._cfg.model.frame_stack_num
                    # the stacked obs in time t
                    obs = game_obs[beg_index:end_index]
                    root_value = [game_segment.root_value_segment[bootstrap_index]]
                else:
                    value_mask.append(0)
                    obs = zero_obs
                    root_value = [0.0]

                value_obs_list.append(obs)
                root_value_list.append(root_value)

        reward_value_context = [
            value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
            action_mask_segment, to_play_segment, root_value_list
        ]
        return reward_value_context


    def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
        """
        Overview:
            first sample orig_data through ``_sample_orig_data()``,
            then prepare the context of a batch:
                reward_value_context:        the context of reanalyzed value targets
                policy_re_context:           the context of reanalyzed policy targets
                policy_non_re_context:       the context of non-reanalyzed policy targets
                current_batch:                the inputs of batch
        Arguments:
            - batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
            - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
        Returns:
            - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
        """
        # obtain the batch context from replay buffer
        orig_data = self._sample_orig_data(batch_size)
        game_lst, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
        batch_size = len(batch_index_list)
        obs_list, action_list, mask_list = [], [], []
        root_sampled_actions_list = []
        # prepare the inputs of a batch
        for i in range(batch_size):
            game = game_lst[i]
            pos_in_game_segment = pos_in_game_segment_list[i]
            # ==============================================================
            # sampled related core code
            # ==============================================================
            actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
                                              self._cfg.num_unroll_steps].tolist()

            # NOTE: self._cfg.num_unroll_steps + 1
            root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment +
                                                                 self._cfg.num_unroll_steps + 1]

            # print(actions_tmp, root_sampled_actions_tmp) # (5, 3) (5, 20, 3)

            # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
            mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))]
            mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]

            # pad random action
            if self._cfg.model.continuous_action_space:
                actions_tmp += [
                    np.random.randn(self._cfg.model.action_space_size)
                    for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
                ]
                root_sampled_actions_tmp += [
                    np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size)
                    for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
                ]
            else:
                # generate random `padded actions_tmp`
                actions_tmp += generate_random_actions_discrete(
                    self._cfg.num_unroll_steps - len(actions_tmp),
                    self._cfg.model.action_space_size,
                    1  # Number of sampled actions for actions_tmp is 1
                )

                # generate random padded `root_sampled_actions_tmp`
                # root_sampled_action have different shape in mcts_ctree and mcts_ptree, thus we need to pad differently
                reshape = True if self._cfg.mcts_ctree else False
                root_sampled_actions_tmp += generate_random_actions_discrete(
                    self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp),
                    self._cfg.model.action_space_size,
                    self._cfg.model.num_of_sampled_actions,
                    reshape=reshape
                )

            # obtain the input observations
            # stack+num_unroll_steps = 4+5
            # pad if length of obs in game_segment is less than stack+num_unroll_steps
            obs_list.append(
                game_lst[i].get_unroll_obs(
                    pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
                )
            )
            action_list.append(actions_tmp)
            root_sampled_actions_list.append(root_sampled_actions_tmp)

            mask_list.append(mask_tmp)

        # formalize the input observations
        obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
        # ==============================================================
        # sampled related core code
        # ==============================================================
        # formalize the inputs of a batch
        current_batch = [
            obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list
        ]

        for i in range(len(current_batch)):
            current_batch[i] = np.asarray(current_batch[i])

        total_transitions = self.get_num_of_transitions()

        # obtain the context of value targets
        reward_value_context = self._prepare_reward_value_context(
            batch_index_list, game_lst, pos_in_game_segment_list, total_transitions
        )
        """
        only reanalyze recent reanalyze_ratio (e.g. 50%) data
        if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
        0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
        """
        reanalyze_num = int(batch_size * reanalyze_ratio)
        # reanalyzed policy
        if reanalyze_num > 0:
            # obtain the context of reanalyzed policy targets
            policy_re_context = self._prepare_policy_reanalyzed_context(
                batch_index_list[:reanalyze_num], game_lst[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num]
            )
        else:
            policy_re_context = None

        # non reanalyzed policy
        if reanalyze_num < batch_size:
            # obtain the context of non-reanalyzed policy targets
            policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
                batch_index_list[reanalyze_num:], game_lst[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:]
            )
        else:
            policy_non_re_context = None

        context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
        return context
    
    def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]:
        """
        Overview:
            prepare reward and value targets from the context of rewards and values.
        Arguments:
            - reward_value_context (:obj:'list'): the reward value context
            - model (:obj:'torch.tensor'):model of the target model
        Returns:
            - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix
            - batch_target_values (:obj:'np.ndarray): batch of value estimation
        """
        value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \
        to_play_segment, root_value_list = reward_value_context  # noqa

        # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
        transition_batch_size = len(value_obs_list)
        game_segment_batch_size = len(pos_in_game_segment_list)

        to_play, action_mask = self._preprocess_to_play_and_action_mask(
            game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
        )
        if self._cfg.model.continuous_action_space is True:
            # when the action space of the environment is continuous, action_mask[:] is None.
            action_mask = [
                list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
            ]
            # NOTE: in continuous action space env: we set all legal_actions as -1
            legal_actions = [
                [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
            ]
        else:
            legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]

        batch_target_values, batch_value_prefixs = [], []
        with torch.no_grad():
            if self._cfg.use_root_value:
                value_list = np.array(root_value_list)
            else:
                value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
                # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
                slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
                network_output = []
                for i in range(slices):
                    beg_index = self._cfg.mini_infer_size * i
                    end_index = self._cfg.mini_infer_size * (i + 1)
                    m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()

                    # calculate the target value
                    m_output = model.initial_inference(m_obs)

                    # TODO(pu)
                    if not model.training:
                        # if not in training, obtain the scalars of the value/reward
                        [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
                            [
                                m_output.latent_state,
                                inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
                                m_output.policy_logits
                            ]
                        )
                        m_output.reward_hidden_state = (
                            m_output.reward_hidden_state[0].detach().cpu().numpy(),
                            m_output.reward_hidden_state[1].detach().cpu().numpy()
                        )

                    network_output.append(m_output)

            # # concat the output slices after model inference
            # if self._cfg.use_root_value:
            #     # use the root values from MCTS
            #     # the root values have limited improvement but require much more GPU actors;
            #     _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
            #         network_output, data_type='efficientzero'
            #     )
            #     value_prefix_pool = value_prefix_pool.squeeze().tolist()
            #     policy_logits_pool = policy_logits_pool.tolist()
            #     # generate the noises for the root nodes
            #     noises = [
            #         np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions
            #                             ).astype(np.float32).tolist() for _ in range(transition_batch_size)
            #     ]

            #     if self._cfg.mcts_ctree:
            #         # cpp mcts_tree
            #         # prepare the root nodes for MCTS
            #         roots = MCTSCtree.roots(
            #             transition_batch_size, legal_actions, self._cfg.model.action_space_size,
            #             self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
            #         )
            #         if self._cfg.reanalyze_noise:
            #             roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
            #         else:
            #             roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
            #         # do MCTS for a new policy with the recent target model
            #         MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
            #     else:
            #         # python mcts_tree
            #         roots = MCTSPtree.roots(
            #             transition_batch_size, legal_actions, self._cfg.model.action_space_size,
            #             self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
            #         )
            #         if self._cfg.reanalyze_noise:
            #             roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
            #         else:
            #             roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
            #         # do MCTS for a new policy with the recent target model
            #         MCTSPtree.roots(self._cfg
            #                         ).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)

            #     roots_values = roots.get_values()
            #     value_list = np.array(roots_values)
            # else:
                # use the predicted values
                value_list = concat_output_value(network_output)

            # get last state value
            if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
                # TODO(pu): for board_games, very important, to check
                value_list = value_list.reshape(-1) * np.array(
                    [
                        self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) %
                        2 == 0 else -self._cfg.discount_factor ** td_steps_list[i]
                        for i in range(transition_batch_size)
                    ]
                )
            else:
                value_list = value_list.reshape(-1) * (
                    np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
                )

            value_list = value_list * np.array(value_mask)
            value_list = value_list.tolist()

            horizon_id, value_index = 0, 0
            for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list,
                                                                                       pos_in_game_segment_list,
                                                                                       to_play_segment):
                target_values = []
                target_value_prefixs = []

                value_prefix = 0.0
                base_index = state_index
                for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
                    bootstrap_index = current_index + td_steps_list[value_index]
                    # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
                    for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
                        if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
                            # TODO(pu): for board_games, very important, to check
                            if to_play_list[base_index] == to_play_list[i]:
                                value_list[value_index] += reward * self._cfg.discount_factor ** i
                            else:
                                value_list[value_index] += -reward * self._cfg.discount_factor ** i
                        else:
                            value_list[value_index] += reward * self._cfg.discount_factor ** i
                            # TODO(pu): why value don't use discount_factor factor

                    # reset every lstm_horizon_len
                    if horizon_id % self._cfg.lstm_horizon_len == 0:
                        value_prefix = 0.0
                        base_index = current_index
                    horizon_id += 1

                    if current_index < game_segment_len_non_re:
                        target_values.append(value_list[value_index])
                        # Since the horizon is small and the discount_factor is close to 1.
                        # Compute the reward sum to approximate the value prefix for simplification
                        value_prefix += reward_list[current_index
                                                    ]  # * config.discount_factor ** (current_index - base_index)
                        target_value_prefixs.append(value_prefix)
                    else:
                        target_values.append(0)
                        target_value_prefixs.append(value_prefix)

                    value_index += 1

                batch_value_prefixs.append(target_value_prefixs)
                batch_target_values.append(target_values)

        batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object)
        batch_target_values = np.asarray(batch_target_values, dtype=object)

        return batch_value_prefixs, batch_target_values

    def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
        """
        Overview:
            prepare policy targets from the reanalyzed context of policies
        Arguments:
            - policy_re_context (:obj:`List`): List of policy context to reanalyzed
        Returns:
            - batch_target_policies_re
        """
        if policy_re_context is None:
            return []
        batch_target_policies_re = []

        policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, \
        to_play_segment = policy_re_context  # noqa
        # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
        transition_batch_size = len(policy_obs_list)
        game_segment_batch_size = len(pos_in_game_segment_list)

        to_play, action_mask = self._preprocess_to_play_and_action_mask(
            game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
        )
        if self._cfg.model.continuous_action_space is True:
            # when the action space of the environment is continuous, action_mask[:] is None.
            action_mask = [
                list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
            ]
            # NOTE: in continuous action space env, we set all legal_actions as -1
            legal_actions = [
                [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
            ]
        else:
            legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]

        with torch.no_grad():
            policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
            # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
            self._cfg.mini_infer_size = self._cfg.mini_infer_size
            slices = np.ceil(transition_batch_size / self._cfg.mini_infer_size).astype(np.int_)
            network_output = []
            for i in range(slices):
                beg_index = self._cfg.mini_infer_size * i
                end_index = self._cfg.mini_infer_size * (i + 1)
                m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float()

                m_output = model.initial_inference(m_obs)

                if not model.training:
                    # if not in training, obtain the scalars of the value/reward
                    [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
                        [
                            m_output.latent_state,
                            inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
                            m_output.policy_logits
                        ]
                    )
                    m_output.reward_hidden_state = (
                        m_output.reward_hidden_state[0].detach().cpu().numpy(),
                        m_output.reward_hidden_state[1].detach().cpu().numpy()
                    )

                network_output.append(m_output)

            _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
                network_output, data_type='efficientzero'
            )
            # (150,) (150, 6) (150, 256), 150 may be too large for batch tree search
            # print(value_prefix_pool.shape, policy_logits_pool.shape, latent_state_roots.shape)

            value_prefix_pool = value_prefix_pool.squeeze().tolist()
            policy_logits_pool = policy_logits_pool.tolist()
            # noises are not necessary for reanalyze
            noises = [
                np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions
                                    ).astype(np.float32).tolist() for _ in range(transition_batch_size)
            ]
            if self._cfg.mcts_ctree:
                # ==============================================================
                # sampled related core code
                # ==============================================================
                # cpp mcts_tree
                roots = MCTSCtree.roots(
                    transition_batch_size, legal_actions, self._cfg.model.action_space_size,
                    self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
                )
                if self._cfg.reanalyze_noise:
                    roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
                else:
                    roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
                # do MCTS for a new policy with the recent target model
                MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
            else:
                # python mcts_tree
                roots = MCTSPtree.roots(
                    transition_batch_size, legal_actions, self._cfg.model.action_space_size,
                    self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
                )
                if self._cfg.reanalyze_noise:
                    roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
                else:
                    roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
                # do MCTS for a new policy with the recent target model
                MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)

            roots_legal_actions_list = legal_actions
            roots_distributions = roots.get_distributions()
            roots_values = roots.get_values()

            # ==============================================================
            # fix reanalyze in sez
            # ==============================================================
            roots_sampled_actions = roots.get_sampled_actions()
            try:
                root_sampled_actions = np.array([action.value for action in roots_sampled_actions])
            except Exception:
                root_sampled_actions = np.array([action for action in roots_sampled_actions])
            
            policy_index = 0
            for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values):
                target_policies = []
                for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
                    distributions = roots_distributions[policy_index]
                    searched_value = roots_values[policy_index]
                    # ==============================================================
                    # sampled related core code
                    # ==============================================================
                    if policy_mask[policy_index] == 0:
                        # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
                        target_policies.append([0 for _ in range(self._cfg.model.num_of_sampled_actions)])
                    else:
                        if distributions is None:
                            # if at some obs, the legal_action is None, then add the fake target_policy
                            target_policies.append(
                                list(
                                    np.ones(self._cfg.model.num_of_sampled_actions) /
                                    self._cfg.model.num_of_sampled_actions
                                )
                            )
                        else:
                            # Update the data in game segment:
                            # after the reanalyze search, new target policies and root values are obtained
                            # the target policies and root values are stored in the gamesegment, specifically, ``child_visit_segment`` and ``root_value_segment``
                            # we replace the data at the corresponding location with the latest search results to keep the most up-to-date targets
                            sim_num = sum(distributions)
                            child_visit[current_index] = [visit_count/sim_num for visit_count in distributions]
                            root_value[current_index] = searched_value
                            if self._cfg.action_type == 'fixed_action_space':
                                sum_visits = sum(distributions)
                                policy = [visit_count / sum_visits for visit_count in distributions]
                                target_policies.append(policy)
                            else:
                                # for two_player board games
                                policy_tmp = [0 for _ in range(self._cfg.model.num_of_sampled_actions)]
                                # to make sure target_policies have the same dimension
                                sum_visits = sum(distributions)
                                policy = [visit_count / sum_visits for visit_count in distributions]
                                for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
                                    policy_tmp[legal_action] = policy[index]
                                target_policies.append(policy_tmp)

                    policy_index += 1

                batch_target_policies_re.append(target_policies)

        batch_target_policies_re = np.array(batch_target_policies_re)

        return batch_target_policies_re, root_sampled_actions

    def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
        """
        Overview:
            Update the priority of training data.
        Arguments:
            - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
            - batch_priorities (:obj:`batch_priorities`): priorities to update to.
        NOTE:
            train_data = [current_batch, target_batch]
            current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list]
        """

        batch_index_list = train_data[0][4]
        metas = {'make_time': train_data[0][6], 'batch_priorities': batch_priorities}
        # only update the priorities for data still in replay buffer
        for i in range(len(batch_index_list)):
            if metas['make_time'][i] > self.clear_time:
                idx, prio = batch_index_list[i], metas['batch_priorities'][i]
                self.game_pos_priorities[idx] = prio
