import time
from collections import deque, namedtuple
from typing import Optional, Any, List

import numpy as np
import torch
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \
    allreduce_data
from ding.worker.collector.base_serial_collector import ISerialCollector
from torch.nn import L1Loss

from lzero.mcts.buffer.game_segment import GameSegment
from lzero.mcts.utils import prepare_observation


@SERIAL_COLLECTOR_REGISTRY.register('episode_muzero')
class MuZeroCollector(ISerialCollector):
    """
    Overview:
        The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero.
        It manages the data collection process for training these algorithms using a serial mechanism.
    Interfaces:
        ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``,
        ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close``
    Properties:
        ``envstep``
    """

    # TO be compatible with ISerialCollector
    config = dict()

    def __init__(
            self,
            collect_print_freq: int = 100,
            env: BaseEnvManager = None,
            policy: namedtuple = None,
            tb_logger: 'SummaryWriter' = None,  # noqa
            exp_name: Optional[str] = 'default_experiment',
            instance_name: Optional[str] = 'collector',
            policy_config: 'policy_config' = None,  # noqa
    ) -> None:
        """
        Overview:
            Initialize the MuZeroCollector with the given parameters.
        Arguments:
            - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information.
            - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager.
            - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API.
            - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance.
            - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes.
            - instance_name (:obj:`str`): Unique identifier for this collector instance.
            - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy.
        """
        self._exp_name = exp_name
        self._instance_name = instance_name
        self._collect_print_freq = collect_print_freq
        self._timer = EasyTimer()
        self._end_flag = False

        self._rank = get_rank()
        self._world_size = get_world_size()
        if self._rank == 0:
            if tb_logger is not None:
                self._logger, _ = build_logger(
                    path='./{}/log/{}'.format(self._exp_name, self._instance_name),
                    name=self._instance_name,
                    need_tb=False
                )
                self._tb_logger = tb_logger
            else:
                self._logger, self._tb_logger = build_logger(
                    path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
                )
        else:
            self._logger, _ = build_logger(
                path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
            )
            self._tb_logger = None

        self.policy_config = policy_config
        self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy

        self.reset(policy, env)

    def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
        """
        Overview:
            Reset or replace the environment managed by this collector.
            If _env is None, reset the old environment.
            If _env is not None, replace the old environment in the collector with the new passed \
                in environment and launch.
        Arguments:
            - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided.
        """
        if _env is not None:
            self._env = _env
            self._env.launch()
            self._env_num = self._env.env_num
        else:
            self._env.reset()

    def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
        """
        Overview:
            Reset or replace the policy used by this collector.
            If _policy is None, reset the old policy.
            If _policy is not None, replace the old policy in the collector with the new passed in policy.
        Arguments:
            - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
        """
        assert hasattr(self, '_env'), "please set env first"
        if _policy is not None:
            self._policy = _policy
            self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None)
            self._logger.debug(
                'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num)
            )
        self._policy.reset() # do nothing for non-conv nets

    def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
        """
        Overview:
            Reset the collector with the given policy and/or environment.
            If _env is None, reset the old environment.
            If _env is not None, replace the old environment in the collector with the new passed \
                in environment and launch.
            If _policy is None, reset the old policy.
            If _policy is not None, replace the old policy in the collector with the new passed in policy.
        Arguments:
            - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
            - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
                env_manager(BaseEnvManager)
        """
        if _env is not None:
            self.reset_env(_env)
        if _policy is not None:
            self.reset_policy(_policy)

        self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}

        self._episode_info = []
        self._total_envstep_count = 0
        self._total_episode_count = 0
        self._total_duration = 0
        self._last_train_iter = 0
        self._end_flag = False

        # A game_segment_pool implementation based on the deque structure.
        self.game_segment_pool = deque(maxlen=int(1e6))
        self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps

    def _reset_stat(self, env_id: int) -> None:
        """
        Overview:
            Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \
            and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
            to get more messages.
        Arguments:
            - env_id (:obj:`int`): the id where we need to reset the collector's state
        """
        self._env_info[env_id] = {'time': 0., 'step': 0}

    @property
    def envstep(self) -> int:
        """
        Overview:
            Get the total number of environment steps collected.
        Returns:
            - envstep (:obj:`int`): Total number of environment steps collected.
        """
        return self._total_envstep_count

    def close(self) -> None:
        """
        Overview:
            Close the collector. If end_flag is False, close the environment, flush the tb_logger \
            and close the tb_logger.
        """
        if self._end_flag:
            return
        self._end_flag = True
        self._env.close()
        if self._tb_logger:
            self._tb_logger.flush()
            self._tb_logger.close()

    def __del__(self) -> None:
        """
        Overview:
            Execute the close command and close the collector. __del__ is automatically called to \
            destroy the collector instance when the collector finishes its work
        """
        self.close()

    # ==============================================================
    # MCTS+RL related core code
    # ==============================================================
    def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray:
        """
        Overview:
            Compute the priorities for transitions based on prediction and search value discrepancies.
        Arguments:
            - i (:obj:`int`): Index of the values in the list to compute the priority for.
            - pred_values_lst (:obj:`List[float]`): List of predicted values.
            - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS.
        Returns:
            - priorities (:obj:`np.ndarray`): Array of computed priorities.
        """
        if self.policy_config.use_priority:
            # Calculate priorities. The priorities are the L1 losses between the predicted
            # values and the search values. We use 'none' as the reduction parameter, which
            # means the loss is calculated for each element individually, instead of being summed or averaged.
            # A small constant (1e-6) is added to the results to avoid zero priorities. This
            # is done because zero priorities could potentially cause issues in some scenarios.
            pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1)
            search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device
                                                                                ).float().view(-1)
            # print(pred_values.shape, search_values.shape) # torch.Size([11]) torch.Size([11])

            priorities = L1Loss(reduction='none'
                                )(pred_values,
                                  search_values).detach().cpu().numpy() + 1e-6
        else:
            # priorities is None -> use the max priority for all newly collected data
            priorities = None

        return priorities

    def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment],
                                     last_game_priorities: List[np.ndarray],
                                     game_segments: List[GameSegment], done: np.ndarray) -> None:
        """
        Overview:
            Save the game segment to the pool if the current game is finished, padding it if necessary.
        Arguments:
            - i (:obj:`int`): Index of the current game segment.
            - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved.
            - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments.
            - game_segments (:obj:`List[GameSegment]`): List of the current game segments.
            - done (:obj:`np.ndarray`): Array indicating whether each game is done.
        Note:
            (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True
        """
        # pad over last segment trajectory
        beg_index = self.policy_config.model.frame_stack_num
        end_index = beg_index + self.policy_config.num_unroll_steps

        # the start <frame_stack_num> obs is init zero obs, so we take the
        # [<frame_stack_num> : <frame_stack_num>+<num_unroll_steps>] obs as the pad obs
        # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs
        pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index]
        pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps]
        # EfficientZero original repo bug:
        # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index]

        beg_index = 0
        # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps
        end_index = beg_index + self.unroll_plus_td_steps - 1

        pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index]
        if self.policy_config.use_ture_chance_label_in_chance_encoder:
            chance_lst = game_segments[i].chance_segment[beg_index:end_index]

        beg_index = 0
        end_index = beg_index + self.unroll_plus_td_steps

        pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index]

        if self.policy_config.gumbel_algo:
            pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index]

        # pad over and save
        if self.policy_config.gumbel_algo:
            last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst,
                                           next_segment_improved_policy=pad_improved_policy_prob)
        else:
            if self.policy_config.use_ture_chance_label_in_chance_encoder:
                last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst,
                                               next_chances=chance_lst)
            else:
                last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst)
        """
        Note:
            game_segment element shape:
            obs: game_segment_length + stack + num_unroll_steps, 20+4 +5
            rew: game_segment_length + stack + num_unroll_steps + td_steps -1  20 +5+3-1
            action: game_segment_length -> 20
            root_values:  game_segment_length + num_unroll_steps + td_steps -> 20 +5+3
            child_visits： game_segment_length + num_unroll_steps -> 20 +5
            to_play: game_segment_length -> 20
            action_mask: game_segment_length -> 20
        """

        last_game_segments[i].game_segment_to_array()

        # put the game segment into the pool
        self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i]))

        # reset last game_segments
        last_game_segments[i] = None
        last_game_priorities[i] = None

        return None

    def collect(self,
                n_episode: Optional[int] = None,
                train_iter: int = 0,
                policy_kwargs: Optional[dict] = None,
                collect_with_pure_policy: bool = False) -> List[Any]:
        """
        Overview:
            Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations.
        Arguments:
            - n_episode (:obj:`Optional[int]`): Number of episodes to collect.
            - train_iter (:obj:`int`): Number of training iterations completed so far.
            - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy.
            - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS.
        Returns:
            - return_data (:obj:`List[Any]`): Collected data in the form of a list.
        """
        # TODO: collect_with_pure_policy as a separate collector
        if n_episode is None:
            if self._default_n_episode is None:
                raise RuntimeError("Please specify collect n_episode")
            else:
                n_episode = self._default_n_episode
        assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num)
        if policy_kwargs is None:
            policy_kwargs = {}
        temperature = policy_kwargs['temperature']
        epsilon = policy_kwargs['epsilon']

        collected_episode = 0
        collected_step = 0
        env_nums = self._env_num

        # initializations
        init_obs = self._env.ready_obs

        retry_waiting_time = 0.001
        while len(init_obs.keys()) != self._env_num:
            # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to
            # len(self._env.ready_obs), especially in tictactoe env.
            self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys()))
            self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states))
            time.sleep(retry_waiting_time)
            self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10)
            self._logger.info(
                'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states)
            )
            init_obs = self._env.ready_obs

        action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)}
        to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)}
        if self.policy_config.use_ture_chance_label_in_chance_encoder:
            chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)}
        if self.policy_config.get('offline', False):
            offline_action_dict = {i: to_ndarray(init_obs[i]['action']) for i in range(env_nums)}

        game_segments = [
            GameSegment(
                self._env.action_space,
                game_segment_length=self.policy_config.game_segment_length,
                config=self.policy_config
            ) for _ in range(env_nums)
        ]
        # stacked observation windows in reset stage for init game_segments
        observation_window_stack = [[] for _ in range(env_nums)]
        for env_id in range(env_nums):
            observation_window_stack[env_id] = deque(
                [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)],
                maxlen=self.policy_config.model.frame_stack_num
            )

            game_segments[env_id].reset(observation_window_stack[env_id])

        dones = np.array([False for _ in range(env_nums)])
        last_game_segments = [None for _ in range(env_nums)]
        last_game_priorities = [None for _ in range(env_nums)]
        # for priorities in self-play
        search_values_lst = [[] for _ in range(env_nums)]
        pred_values_lst = [[] for _ in range(env_nums)]
        # improved policy is only for the gumbel algorithm
        if self.policy_config.gumbel_algo:
            improved_policy_lst = [[] for _ in range(env_nums)]

        # some logs
        eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums)
        if self.policy_config.gumbel_algo:
            completed_value_lst = np.zeros(env_nums)
        self_play_moves = 0.
        self_play_episodes = 0.
        self_play_moves_max = 0
        self_play_visit_entropy = []
        total_transitions = 0

        ready_env_id = set()
        remain_episode = n_episode
        if collect_with_pure_policy:
            temp_visit_list = [0.0 for i in range(self._env.action_space.n)]

        while True:
            with self._timer:
                # Get current ready env obs.
                obs = self._env.ready_obs
                new_available_env_id = set(obs.keys()).difference(ready_env_id)
                ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
                remain_episode -= min(len(new_available_env_id), remain_episode)

                stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id}
                stack_obs = list(stack_obs.values())

                action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id}
                to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id}
                action_mask = [action_mask_dict[env_id] for env_id in ready_env_id]
                to_play = [to_play_dict[env_id] for env_id in ready_env_id]
                if self.policy_config.use_ture_chance_label_in_chance_encoder:
                    chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id}
                if self.policy_config.get('offline', False):
                    offline_action_dict = {env_id: offline_action_dict[env_id] for env_id in ready_env_id}
                
                stack_obs = to_ndarray(stack_obs)
                # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96]
                stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type)
                stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device)

                # ==============================================================
                # Key policy forward step
                # ==============================================================
                # print(f'ready_env_id:{ready_env_id}')
                if self.policy_config.get('offline', False):
                    policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, offline_action_dict=offline_action_dict, ready_env_id=ready_env_id)
                else:
                    policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id)

                # Extract relevant policy outputs
                actions_with_env_id = {k: v['action'] for k, v in policy_output.items()}
                value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()}
                pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()}

                if self.policy_config.sampled_algo:
                    root_sampled_actions_dict_with_env_id = {
                        k: v['root_sampled_actions'] for k, v in policy_output.items()
                    }

                if not collect_with_pure_policy:
                    distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in
                                                      policy_output.items()}
                    visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in
                                                      policy_output.items()}

                    if self.policy_config.gumbel_algo:
                        improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in
                                                            policy_output.items()}
                        completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()}

                # Initialize dictionaries to store results
                actions = {}
                value_dict = {}
                pred_value_dict = {}

                if not collect_with_pure_policy:
                    distributions_dict = {}
                    visit_entropy_dict = {}

                    if self.policy_config.sampled_algo:
                        root_sampled_actions_dict = {}

                    if self.policy_config.gumbel_algo:
                        improved_policy_dict = {}
                        completed_value_dict = {}

                # Populate the result dictionaries
                for env_id in ready_env_id:
                    actions[env_id] = actions_with_env_id.pop(env_id)
                    value_dict[env_id] = value_dict_with_env_id.pop(env_id)
                    pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id)

                    if not collect_with_pure_policy:
                        distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id)

                        if self.policy_config.sampled_algo:
                            root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id)

                        visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id)

                        if self.policy_config.gumbel_algo:
                            improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id)
                            completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id)

                # ==============================================================
                # Interact with the environment
                # ==============================================================
                timesteps = self._env.step(actions) #???

            interaction_duration = self._timer.value / len(timesteps)

            for env_id, timestep in timesteps.items():
                with self._timer:
                    if timestep.info.get('abnormal', False):
                        # If there is an abnormal timestep, reset all the related variables(including this env).
                        # suppose there is no reset param, reset this env
                        self._env.reset({env_id: None})
                        self._policy.reset([env_id])
                        self._reset_stat(env_id)
                        self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info))
                        continue
                    obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info

                    if collect_with_pure_policy:
                        game_segments[env_id].store_search_stats(temp_visit_list, 0)
                    else:
                        if self.policy_config.sampled_algo:
                            game_segments[env_id].store_search_stats(
                                distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]
                            )
                        elif self.policy_config.gumbel_algo:
                            game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id],
                                                                     improved_policy=improved_policy_dict[env_id])
                        else:
                            game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id])

                    # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}
                    # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment``
                    if self.policy_config.use_ture_chance_label_in_chance_encoder:
                        game_segments[env_id].append(
                            actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
                            to_play_dict[env_id], chance_dict[env_id]
                        )
                    else:
                        game_segments[env_id].append(
                            actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
                            to_play_dict[env_id]
                        ) #???

                    # NOTE: the position of code snippet is very important.
                    # the obs['action_mask'] and obs['to_play'] are corresponding to the next action
                    action_mask_dict[env_id] = to_ndarray(obs['action_mask'])
                    to_play_dict[env_id] = to_ndarray(obs['to_play'])
                    if self.policy_config.use_ture_chance_label_in_chance_encoder:
                        chance_dict[env_id] = to_ndarray(obs['chance'])
                    if self.policy_config.get('offline', False):
                        offline_action_dict[env_id] = to_ndarray(obs['action'])

                    if self.policy_config.ignore_done:
                        dones[env_id] = False
                    else:
                        dones[env_id] = done

                    if not collect_with_pure_policy:
                        visit_entropies_lst[env_id] += visit_entropy_dict[env_id]
                        if self.policy_config.gumbel_algo:
                            completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id]))

                    eps_steps_lst[env_id] += 1
                    if self._policy.get_attribute('cfg').type == 'unizero':
                        # only for UniZero now
                        self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False)

                    total_transitions += 1

                    if self.policy_config.use_priority:
                        pred_values_lst[env_id].append(pred_value_dict[env_id])
                        search_values_lst[env_id].append(value_dict[env_id])
                        if self.policy_config.gumbel_algo and not collect_with_pure_policy:
                            improved_policy_lst[env_id].append(improved_policy_dict[env_id])

                    # append the newest obs
                    observation_window_stack[env_id].append(to_ndarray(obs['observation']))

                    # ==============================================================
                    # we will save a game segment if it is the end of the game or the next game segment is finished.
                    # ==============================================================

                    # if game segment is full, we will save the last game segment
                    if game_segments[env_id].is_full():
                        # pad over last segment trajectory
                        if last_game_segments[env_id] is not None:
                            # TODO(pu): return the one game segment
                            self.pad_and_save_last_trajectory(
                                env_id, last_game_segments, last_game_priorities, game_segments, dones
                            )

                        # calculate priority
                        priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst)
                        pred_values_lst[env_id] = []
                        search_values_lst[env_id] = []
                        if self.policy_config.gumbel_algo and not collect_with_pure_policy:
                            improved_policy_lst[env_id] = []

                        # the current game_segments become last_game_segment
                        last_game_segments[env_id] = game_segments[env_id]
                        last_game_priorities[env_id] = priorities

                        # create new GameSegment
                        game_segments[env_id] = GameSegment(
                            self._env.action_space,
                            game_segment_length=self.policy_config.game_segment_length,
                            config=self.policy_config
                        )
                        game_segments[env_id].reset(observation_window_stack[env_id])

                    self._env_info[env_id]['step'] += 1
                    collected_step += 1

                self._env_info[env_id]['time'] += self._timer.value + interaction_duration
                if timestep.done:
                    reward = timestep.info['eval_episode_return']
                    info = {
                        'reward': reward,
                        'time': self._env_info[env_id]['time'],
                        'step': self._env_info[env_id]['step'],
                    }
                    if not collect_with_pure_policy:
                        info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id]
                        if self.policy_config.gumbel_algo:
                            info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id]

                    collected_episode += 1
                    self._episode_info.append(info)

                    # ==============================================================
                    # if it is the end of the game, we will save the game segment
                    # ==============================================================

                    # NOTE: put the penultimate game segment in one episode into the trajectory_pool
                    # pad over 2th last game_segment using the last game_segment
                    if last_game_segments[env_id] is not None:
                        self.pad_and_save_last_trajectory(
                            env_id, last_game_segments, last_game_priorities, game_segments, dones
                        )

                    # store current segment trajectory
                    priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst)

                    # NOTE: put the last game segment in one episode into the trajectory_pool
                    game_segments[env_id].game_segment_to_array()

                    # assert len(game_segments[env_id]) == len(priorities)
                    # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null
                    if len(game_segments[env_id].reward_segment) != 0:
                        self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id]))

                    # print(game_segments[env_id].reward_segment)
                    # reset the finished env and init game_segments
                    if n_episode > self._env_num:
                        # Get current ready env obs.
                        init_obs = self._env.ready_obs
                        retry_waiting_time = 0.001
                        while len(init_obs.keys()) != self._env_num:
                            # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to
                            # len(self._env.ready_obs), especially in tictactoe env.
                            self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys()))
                            self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states))
                            time.sleep(retry_waiting_time)
                            self._logger.info(
                                '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10
                            )
                            self._logger.info(
                                'After sleeping {}s, the current _env_states is {}'.format(
                                    retry_waiting_time, self._env._env_states
                                )
                            )
                            init_obs = self._env.ready_obs

                        new_available_env_id = set(init_obs.keys()).difference(ready_env_id)
                        ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
                        remain_episode -= min(len(new_available_env_id), remain_episode)

                        action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask'])
                        to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play'])
                        if self.policy_config.use_ture_chance_label_in_chance_encoder:
                            chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance'])
                        if self.policy_config.get('offline', False):
                            offline_action_dict[env_id] = to_ndarray(init_obs[env_id]['action'])

                        game_segments[env_id] = GameSegment(
                            self._env.action_space,
                            game_segment_length=self.policy_config.game_segment_length,
                            config=self.policy_config
                        )
                        observation_window_stack[env_id] = deque(
                            [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)],
                            maxlen=self.policy_config.model.frame_stack_num
                        )
                        game_segments[env_id].reset(observation_window_stack[env_id])
                        last_game_segments[env_id] = None
                        last_game_priorities[env_id] = None

                    # log
                    self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id])
                    if not collect_with_pure_policy:
                        self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id])
                    self_play_moves += eps_steps_lst[env_id]
                    self_play_episodes += 1

                    pred_values_lst[env_id] = []
                    search_values_lst[env_id] = []
                    eps_steps_lst[env_id] = 0
                    visit_entropies_lst[env_id] = 0

                    # Env reset is done by env_manager automatically
                    self._policy.reset([env_id])  # NOTE: reset the policy for the env_id. Default reset_init_data=True.
                    self._reset_stat(env_id)
                    ready_env_id.remove(env_id)

            if collected_episode >= n_episode:
                # [data, meta_data]
                return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [
                    {
                        'priorities': self.game_segment_pool[i][1],
                        'done': self.game_segment_pool[i][2],
                        'unroll_plus_td_steps': self.unroll_plus_td_steps
                    } for i in range(len(self.game_segment_pool))
                ]
                self.game_segment_pool.clear()
                break

        collected_duration = sum([d['time'] for d in self._episode_info])
        # reduce data when enables DDP
        if self._world_size > 1:
            collected_step = allreduce_data(collected_step, 'sum')
            collected_episode = allreduce_data(collected_episode, 'sum')
            collected_duration = allreduce_data(collected_duration, 'sum')
        self._total_envstep_count += collected_step
        self._total_episode_count += collected_episode
        self._total_duration += collected_duration

        # log
        self._output_log(train_iter)
        return return_data

    def _output_log(self, train_iter: int) -> None:
        """
        Overview:
            Log the collector's data and output the log information.
        Arguments:
            - train_iter (:obj:`int`): Current training iteration number for logging context.
        """
        if self._rank != 0:
            return
        if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
            self._last_train_iter = train_iter
            episode_count = len(self._episode_info)
            envstep_count = sum([d['step'] for d in self._episode_info])
            duration = sum([d['time'] for d in self._episode_info])
            episode_reward = [d['reward'] for d in self._episode_info]
            if not self.collect_with_pure_policy:
                visit_entropy = [d['visit_entropy'] for d in self._episode_info]
            else:
                visit_entropy = [0.0]
            if self.policy_config.gumbel_algo:
                completed_value = [d['completed_value'] for d in self._episode_info]
            self._total_duration += duration
            info = {
                # 'episode_count': episode_count,
                # 'envstep_count': envstep_count,
                # 'avg_envstep_per_episode': envstep_count / episode_count,
                # 'avg_envstep_per_sec': envstep_count / duration,
                # 'avg_episode_per_sec': episode_count / duration,
                # 'collect_time': duration,
                'reward_mean': np.mean(episode_reward),
                # 'reward_std': np.std(episode_reward),
                # 'reward_max': np.max(episode_reward),
                # 'reward_min': np.min(episode_reward),
                # 'total_envstep_count': self._total_envstep_count,
                # 'total_episode_count': self._total_episode_count,
                # 'total_duration': self._total_duration,
                # 'visit_entropy': np.mean(visit_entropy),
            }
            if self.policy_config.gumbel_algo:
                info['completed_value'] = np.mean(completed_value)
            self._episode_info.clear()
            self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
            for k, v in info.items():
                if k in ['each_reward']:
                    continue
                self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
                if k in ['total_envstep_count']:
                    continue
                self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)