from re import T
import time
import os
import SMOS
import setproctitle
import torch
import copy
import numpy as np
import core.ctree.cytree as cytree

from torch.cuda.amp import autocast as autocast
from core.game import GameHistory
from core.mcts import MCTS
from core.model import concat_output, concat_output_value
from core.utils import prepare_observation_lst, LinearSchedule, get_gpu_memory, str_to_arr

from core.storage_config import StorageConfig
from core.replay_buffer import get_replay_buffer
from core.shared_storage import get_shared_storage, read_weights
from core.watchdog import get_watchdog_server
from core.meta_data_manager import MetaDataSharedMemoryManager
from core.utils import profile


class BatchWorker_CPU(object):
    def __init__(self, worker_id, replay_buffer, storage, smos_client, config, storage_config: StorageConfig):
        """CPU Batch Worker for reanalyzing targets, see Appendix.
        Prepare the context concerning CPU overhead
        Parameters
        ----------
        worker_id: int
            id of the worker
        replay_buffer: Any
            Replay buffer
        storage: Any
            The model storage
        smos_client: Any
            Used as queue between cpu worker and gpu worker
        """
        self.worker_id = worker_id
        self.replay_buffer = replay_buffer
        self.storage = storage
        self.smos_client = smos_client
        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
        self.global_target_values = self.meta_data_manager.get("target_values")
        self.mcts_storage_name = storage_config.mcts_storage_name + f"{worker_id % storage_config.mcts_storage_count}"

        self.config = config
        self.storage_config = storage_config

        self.last_model_index = -1
        self.batch_max_num = 20
        self.beta_schedule = LinearSchedule(config.training_steps + config.last_steps,
                                            initial_p=config.priority_prob_beta, final_p=1.0)
        
        self.trained_steps = 0

    def _prepare_reward_value_context(self, indices, games, state_index_lst, total_transitions):
        """prepare the context of rewards and values for reanalyzing part
        Parameters
        ----------
        indices: list
            transition index in replay buffer
        games: list
            list of game histories
        state_index_lst: list
            transition index in game
        total_transitions: int
            number of collected transitions
        """
        zero_obs = games[0].zero_obs()
        config = self.config
        assert (len(zero_obs) == config.stacked_observations), (zero_obs.shape)
        value_obs_lst = []
        # the value is valid or not (out of trajectory)
        value_mask = []
        rewards_lst = []
        traj_lens = []
        virtual_bootstrap_values = []
        virtual_mask = []

        td_steps_lst = []
        for game, state_index, idx in zip(games, state_index_lst, indices):
            # traj_len = len(game)
            traj_len = len(game.actions) # this is a fake length, 200 + num_unroll_steps + td_steps
            traj_lens.append(traj_len)

            # off-policy correction: shorter horizon of td steps
            delta_td = (total_transitions - idx) // config.auto_td_steps
            td_steps = config.td_steps - delta_td
            td_steps = np.clip(td_steps, 1, 5).astype(np.int)

            # prepare the corresponding observations for bootstrapped values o_{t+k}
            game_obs = game.obs(state_index + td_steps, config.num_unroll_steps + 1)
            rewards_lst.append(game.rewards)
            for current_index in range(state_index, state_index + config.num_unroll_steps + 1 + 1):
                td_steps_lst.append(td_steps)
                bootstrap_index = current_index + td_steps

                if bootstrap_index < traj_len:
                    value_mask.append(1)
                    beg_index = bootstrap_index - (state_index + td_steps)
                    end_index = beg_index + config.stacked_observations
                    obs = game_obs[beg_index:end_index]
                    assert (len(obs) == (end_index - beg_index)), (len(game.obs_history), len(game.rewards), len(game.actions), len(game.root_values), len(game.child_visits), len(game))
                else:
                    value_mask.append(0)
                    obs = zero_obs
                    assert (len(obs) == config.stacked_observations), (obs.shape)

                # virtual value
                virtual_mask.append(0)
                virtual_bootstrap_values.append(0)

                value_obs_lst.append(obs)

        value_obs_lst = prepare_observation_lst(value_obs_lst)

        reward_value_context = [value_obs_lst, value_mask, state_index_lst, rewards_lst, traj_lens, td_steps_lst, [virtual_mask, virtual_bootstrap_values]]
        return reward_value_context

    def _prepare_policy_non_re_context(self, indices, games, state_index_lst):
        """prepare the context of policies for non-reanalyzing part, just return the policy in self-play
        Parameters
        ----------
        indices: list
            transition index in replay buffer
        games: list
            list of game histories
        state_index_lst: list
            transition index in game
        """
        child_visits = []
        traj_lens = []

        for game, state_index, idx in zip(games, state_index_lst, indices):
            # traj_len = len(game)
            traj_len = len(game.actions) # this is a fake length, 200 + num_unroll_steps + td_steps
            traj_lens.append(traj_len)

            child_visits.append(game.child_visits)

        policy_non_re_context = [state_index_lst, child_visits, traj_lens]
        return policy_non_re_context

    def _prepare_policy_re_context(self, indices, games, state_index_lst):
        """prepare the context of policies for reanalyzing part
        Parameters
        ----------
        indices: list
            transition index in replay buffer
        games: list
            list of game histories
        state_index_lst: list
            transition index in game
        """
        zero_obs = games[0].zero_obs()
        config = self.config

        with torch.no_grad():
            # for policy
            policy_obs_lst = []
            policy_mask = []  # 0 -> out of traj, 1 -> new policy
            actions, rewards, child_visits, traj_lens = [], [], [], []
            for game, state_index in zip(games, state_index_lst):
                # traj_len = len(game)
                traj_len = len(game.actions) # this is a fake length, 200 + num_unroll_steps + td_steps
                traj_lens.append(traj_len)
                rewards.append(game.rewards)
                child_visits.append(game.child_visits)
                # prepare the corresponding observations
                game_obs = game.obs(state_index, config.num_unroll_steps)
                for current_index in range(state_index, state_index + config.num_unroll_steps + 1):

                    if current_index < traj_len:
                        policy_mask.append(1)
                        beg_index = current_index - state_index
                        end_index = beg_index + config.stacked_observations
                        obs = game_obs[beg_index:end_index]
                    else:
                        policy_mask.append(0)
                        obs = zero_obs
                    policy_obs_lst.append(obs)

        policy_obs_lst = prepare_observation_lst(policy_obs_lst)
        policy_re_context = [policy_obs_lst, policy_mask, state_index_lst, indices, actions, rewards, child_visits, traj_lens]
        return policy_re_context

    def make_batch(self, batch_context, ratio, weights=None, version=0, mcts_idx=-1):
        """prepare the context of a batch
        reward_value_context:        the context of reanalyzed value targets
        policy_re_context:           the context of reanalyzed policy targets
        policy_non_re_context:       the context of non-reanalyzed policy targets
        inputs_batch:                the inputs of batch
        weights:                     the target model weights
        Parameters
        ----------
        batch_context: Any
            batch context from replay buffer
        ratio: float
            ratio of reanalyzed policy (value is 100% reanalyzed)
        weights: Any
            the target model weights
        """
        assert mcts_idx != -1
        time0 = time.time()
        # obtain the batch context from replay buffer
        game_metadata_lst, game_pos_lst, indices_lst, weights_lst = batch_context
        game_lst = [GameHistory.from_metadata(max_length=self.config.history_length, config=self.config, metadata=metadata) for metadata in game_metadata_lst]
        make_time_lst = [time.time() for _ in range(self.config.reanalyze_batch_size)]

        indices_lst = np.asarray(indices_lst)

        # restore data for each game in game_lst
        idx_list = [game.entry_idx for game in game_lst]
        status, handle_batch, reconstructed_batch = self.smos_client.batch_read_from_object(name=self.storage_config.replay_buffer_name,
                                                                                            entry_idx_batch=idx_list)

        # check whether replay is deleted
        if not status == SMOS.SMOS_SUCCESS:
            print(f"[CPU Worker] Error: Data deleted from replay buffer, batch discarded.")
            status, _ = self.smos_client.push_to_object(name=self.storage_config.reanalyze_control_storage_name,
                                                    data=[mcts_idx])
            assert status == SMOS.SMOS_SUCCESS
            return

        for game, reconstructed_object in zip(game_lst, reconstructed_batch):
            game.restore_data(reconstructed_object=reconstructed_object)
        
        time1 = time.time()
        restore_time = time1 - time0

        batch_size = len(indices_lst)

        # target_values at current time
        target_values_old = np.zeros((batch_size, self.config.num_unroll_steps + 1))
        for i in range(self.config.num_unroll_steps + 1):
            target_values_old[:, i] = self.global_target_values[indices_lst + i]

        obs_lst, action_lst, mask_lst = [], [], []
        # prepare the inputs of a batch
        for i in range(batch_size):
            game = game_lst[i]
            game_pos = game_pos_lst[i]

            _actions = game.actions[game_pos:game_pos + self.config.num_unroll_steps + 1].tolist()
            # add mask for invalid actions (out of trajectory)
            _mask = [1. for i in range(len(_actions))]
            _mask += [0. for _ in range(self.config.num_unroll_steps + 1 - len(_mask))]

            _actions += [np.random.randint(0, game.action_space_size) for _ in range(self.config.num_unroll_steps + 1 - len(_actions))]

            assert (len(_actions) == 6)

            # obtain the input observations
            obs_lst.append(game_lst[i].obs(game_pos_lst[i], extra_len=self.config.num_unroll_steps, padding=True))
            action_lst.append(_actions)
            mask_lst.append(_mask)

        re_num = batch_size # int(batch_size * ratio)
        # formalize the input observations
        obs_lst = prepare_observation_lst(obs_lst)

        # formalize the inputs of a batch
        inputs_batch = [obs_lst, action_lst, mask_lst, indices_lst, weights_lst, make_time_lst]
        for i in range(len(inputs_batch)):
            inputs_batch[i] = np.asarray(inputs_batch[i])

        total_transitions = self.replay_buffer.get_total_len()

        # obtain the context of value targets
        time0 = time.time()
        reward_value_context = self._prepare_reward_value_context(indices_lst, game_lst, game_pos_lst, total_transitions)
        time1 = time.time()
        reward_value_time = time1 - time0

        # 0:re_num -> reanalyzed policy, re_num:end -> non reanalyzed policy
        # reanalyzed policy
        time0 = time.time()
        if re_num > 0:
            # obtain the context of reanalyzed policy targets
            policy_re_context = self._prepare_policy_re_context(indices_lst[:re_num], game_lst[:re_num], game_pos_lst[:re_num])
        else:
            policy_re_context = None
        time1 = time.time()
        policy_time = time1 - time0

        item1 = reward_value_context[0] # np.array(reward_value_context[0])
        item2 = reward_value_context[1:] + [target_values_old]
        item3 = policy_re_context[0] # np.array(policy_re_context[0])
        item4 = policy_re_context[1:]
        item5 = -1
        item6 = inputs_batch[0]
        item7 = inputs_batch[1:]
        item8 = mcts_idx
        item9 = version

        # push to mcts storage
        wait_time = 0
        while True:
            status, _ = self.smos_client.push_to_object(name=self.mcts_storage_name,
                                                        data=[item1, item2, item3, item4, item5, item6, item7, item8, item9])
            if not status == SMOS.SMOS_SUCCESS:
                wait_time += 0.1
                time.sleep(0.1)
            else:
                break

        # clean up batch from replay buffer
        self.smos_client.batch_release_entry(object_handle_batch=handle_batch)

        return wait_time, restore_time, reward_value_time, policy_time

    def run(self):
        self.storage.add_ready_cpu_worker(self.worker_id)
        # start making mcts contexts to feed the GPU batch maker
        start = False
        total_time = 0
        batch_count = 0
        total_wait_time = 0
        total_prepare_batch_time = 0
        total_restore_time = 0
        total_reward_value_time = 0
        total_policy_time = 0
        while True:
            # wait for starting
            if not start:
                start = self.storage.get_start_signal()
                time.sleep(1)
                continue

            self.trained_steps = trained_steps = self.meta_data_manager.get("trained_steps")[0]

            # get mcts idx first to avoid getting blocked when writing data into shm
            time0 = time.time()
            mcts_idx = -1
            while mcts_idx == -1:
                status, handle, mcts_idx = self.smos_client.pop_from_object(name=self.storage_config.reanalyze_control_storage_name)
                if status == SMOS.SMOS_SUCCESS:
                    mcts_idx = copy.deepcopy(mcts_idx)
                    self.smos_client.free_handle(object_handle=handle)
                    break
                mcts_idx = -1
                time.sleep(0.1)
            time1 = time.time()

            # print(f"[CPU worker {self.worker_id}] get mcts_idx", mcts_idx)

            start_time = time.time()

            beta = self.beta_schedule.value(trained_steps)
            # obtain the batch context from replay buffer
            time0 = time.time()
            batch_context = self.replay_buffer.prepare_batch_context(self.config.reanalyze_batch_size, beta)
            time1 = time.time()
            prepare_batch_time = time1 - time0
            # break
            if trained_steps >= self.config.training_steps + self.config.last_steps:
                time.sleep(30)
                break

            target_weights, target_weights_version = None, 0

            # make batch
            wait_time, restore_time, reward_value_time, policy_time = self.make_batch(batch_context, self.config.revisit_policy_search_rate, weights=target_weights, version=target_weights_version, mcts_idx=mcts_idx)
            end_time = time.time()
            total_time += end_time - start_time
            total_wait_time += wait_time
            total_prepare_batch_time += prepare_batch_time
            total_restore_time += restore_time
            total_reward_value_time += reward_value_time
            total_policy_time += policy_time
            batch_count += 1
            if batch_count % 20 == 0:
                _msg = '[CPU worker] Avg. CPU={:.3f}, Lst. CPU={:.3f}. Avg. Wait={:.3f}. Avg. Prepare={:.3f}. Avg. Restore={:.3f}. Avg. Val.={:.3f}. Avg. Policy={:.3f}'\
                    .format(total_time / batch_count, end_time - start_time, total_wait_time / batch_count, total_prepare_batch_time / batch_count, total_restore_time / batch_count, total_reward_value_time / batch_count, total_policy_time / batch_count)
                print(_msg)


def start_batch_worker_cpu(worker_id, config, storage_config: StorageConfig):
    """
    Start a CPU batch worker. Call this method remotely.
    """

    # get storages
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    shared_storage = get_shared_storage(storage_config=storage_config)
    smos_client = SMOS.Client(connection=storage_config.smos_connection)

    # set process name
    setproctitle.setproctitle(f"EfficientZero-cpu_worker{worker_id}")

    # start CPU worker
    cpu_worker = BatchWorker_CPU(worker_id=worker_id, replay_buffer=replay_buffer, storage=shared_storage,
                                 smos_client=smos_client, config=config, storage_config=storage_config)
    print(f"[Batch Worker CPU] Starting batch worker CPU {worker_id} at process {os.getpid()}.")
    cpu_worker.run()


class BatchWorker_GPU(object):
    def __init__(self, worker_id, replay_buffer, storage, smos_client, watchdog_server,
                 config, storage_config: StorageConfig):
        """GPU Batch Worker for reanalyzing targets, see Appendix.
        receive the context from CPU maker and deal with GPU overheads
        Parameters
        ----------
        worker_id: int
            id of the worker
        replay_buffer: Any
            Replay buffer
        storage: Any
            The model storage
        batch_storage: Any
            The batch storage (batch queue)
        mcts_storage: Ant
            The mcts-related contexts storage
        """
        self.replay_buffer = replay_buffer
        self.config = config
        self.storage_config = storage_config
        self.worker_id = worker_id

        self.model = config.get_uniform_network(is_reanalyze_worker=True)
        self.version = 0
        # self.model.to(config.device)
        self.model.eval()

        self.storage = storage
        self.smos_client = smos_client
        self.watchdog_server = watchdog_server
        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
        self.global_priorities = self.meta_data_manager.get("priorities")
        self.mcts_storage_name = storage_config.mcts_storage_name + f"{worker_id % storage_config.mcts_storage_count}"
        self.batch_storage_name = storage_config.batch_storage_name + f"{worker_id % storage_config.batch_storage_count_worker}"
        self.priority_storage_name = storage_config.priority_storage_name + f"{worker_id % storage_config.priority_storage_count}"

        self.last_model_index = 0
        self.empty_loop_count = 0
        self.trained_steps = 0

        self.config.set_transforms()
    
    def set_device(self, device):
        self.config.device = device
        self.device = device
        self.model.to(device)
        self.model.set_device(device)

    def _prepare_reward_value(self, reward_value_context, indices_lst):
        """prepare reward and value targets from the context of rewards and values
        """
        value_obs_lst, value_mask, state_index_lst, rewards_lst, traj_lens, td_steps_lst, [virtual_mask, virtual_bootstrap_values], target_values_old = reward_value_context
        device = self.device
        batch_size = len(value_obs_lst)

        batch_values, batch_values_mask, batch_value_prefixs = [],[], []
        bootstrap_values = []
        with torch.no_grad():
            # value_obs_lst = prepare_observation_lst(value_obs_lst)
            # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
            m_batch = batch_size
            slices = np.ceil(batch_size / m_batch).astype(np.int_)
            all_value_lsts = []
            for _ in range(1):
                network_output = []
                for i in range(slices):
                    beg_index = m_batch * i
                    end_index = m_batch * (i + 1)
                    m_obs = torch.from_numpy(value_obs_lst[beg_index:end_index]).to(device).float() / 255.0
                    if self.config.amp_type == 'torch_amp':
                        with autocast():
                            m_output = self.model.initial_inference(m_obs)#self.config.transform(m_obs))
                    else:
                        m_output = self.model.initial_inference(m_obs)#self.config.transform(m_obs))
                    network_output.append(m_output)
                # use the predicted values
                value_lst = concat_output_value(network_output)
                all_value_lsts.append(value_lst.reshape(-1))
            value_lst = np.min(all_value_lsts, axis=0)

            # get last state value
            bootstrap_values = (value_lst.reshape(-1) * np.array(value_mask)).reshape(self.config.reanalyze_batch_size, self.config.num_unroll_steps + 1 + 1)
            value_lst = value_lst.reshape(-1)
            value_lst = value_lst * np.array(value_mask)
            value_lst = value_lst * (np.array([self.config.discount for _ in range(batch_size)]) ** td_steps_lst)

            value_index = 0
            for batch_idx, traj_len_non_re, reward_lst, state_index in zip(range(self.config.reanalyze_batch_size), traj_lens, rewards_lst, state_index_lst):
                # traj_len = len(game)
                target_values = []
                target_values_mask = []
                target_value_prefixs = []

                horizon_id = 0
                value_prefix = 0.0
                base_index = state_index
                for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
                    td_steps = td_steps_lst[value_index]
                    value_lst[value_index] = bootstrap_values[batch_idx, current_index - state_index] * (self.config.discount ** td_steps)
                    bootstrap_index = current_index + td_steps
                    for i, reward in enumerate(reward_lst[current_index:bootstrap_index]):
                        value_lst[value_index] += reward * (self.config.discount ** i)

                    # reset every lstm_horizon_len
                    if horizon_id % self.config.lstm_horizon_len == 0:
                        value_prefix = 0.0
                        base_index = current_index
                    horizon_id += 1

                    if current_index < traj_len_non_re:
                        target_values.append(value_lst[value_index])
                        target_values_mask.append(1)
                        # Since the horizon is small and the discount is close to 1.
                        # Compute the reward sum to approximate the value prefix for simplification
                        value_prefix += reward_lst[current_index]  # * config.discount ** (current_index - base_index)
                        target_value_prefixs.append(value_prefix)
                    else:
                        target_values.append(0)
                        target_values_mask.append(0)
                        target_value_prefixs.append(value_prefix)
                    value_index += 1

                batch_value_prefixs.append(target_value_prefixs)
                batch_values.append(target_values)
                batch_values_mask.append(target_values_mask)

        batch_value_prefixs = np.asarray(batch_value_prefixs)
        batch_values = np.asarray(batch_values)
        batch_values_mask = np.asarray(batch_values_mask)

        return batch_value_prefixs, batch_values, bootstrap_values

    # @profile
    def _prepare_policy_re(self, policy_re_context):
        """prepare policy targets from the reanalyzed context of policies
        """
        batch_policies_re = []
        if policy_re_context is None:
            return batch_policies_re

        policy_obs_lst, policy_mask, state_index_lst, indices, actions, rewards, child_visits, traj_lens = policy_re_context
        batch_size = len(policy_obs_lst)
        device = self.device

        with torch.no_grad():
            time0 = time.time()
            # policy_obs_lst = prepare_observation_lst(policy_obs_lst)
            # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
            network_output = []
            m_obs = torch.from_numpy(policy_obs_lst).to(device).float() / 255.0
            if self.config.amp_type == 'torch_amp':
                with autocast():
                    m_output = self.model.initial_inference(m_obs, to_numpy=False)
            else:
                m_output = self.model.initial_inference(m_obs, to_numpy=False)
            network_output.append(m_output)
            time1 = time.time()

            root_values, value_prefix_pool, policy_logits_pool, hidden_state_roots, reward_hidden_roots = concat_output(network_output, state_gpu=True, reward_hidden_gpu=True)
            root_values = root_values.reshape(-1, self.config.num_unroll_steps + 1)
            value_prefix_pool = value_prefix_pool.squeeze().tolist()
            policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
            policy_logits_pool = np.exp(policy_logits_pool)
            policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
            target_pred_policies = policy_logits_pool.reshape(self.config.reanalyze_batch_size, self.config.num_unroll_steps + 1, -1)
            policy_logits_pool = policy_logits_pool.tolist()

            roots = cytree.Roots(batch_size, self.config.action_space_size, self.config.reanalyze_num_simulations)
            noises = np.random.dirichlet(self.config.root_dirichlet_alpha * np.ones((self.config.action_space_size,)), (batch_size,)).tolist()
            roots.prepare(self.config.root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool)
            # do MCTS for a new policy with the recent target model
            cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time = MCTS(self.config, self.config.reanalyze_num_simulations).gpu_search(roots, self.model, hidden_state_roots, reward_hidden_roots, to_numpy=False)
            model_inference_time += time1 - time0

            roots_distributions = roots.get_distributions()

        policy_index = 0
        for state_index, game_idx in zip(state_index_lst, indices):
            target_policies = []

            for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
                distributions = roots_distributions[policy_index]

                if policy_mask[policy_index] == 0:
                    target_policies.append([0 for _ in range(self.config.action_space_size)])
                else:
                    # game.store_search_stats(distributions, value, current_index)
                    sum_visits = sum(distributions)
                    policy = [visit_count / sum_visits for visit_count in distributions]
                    target_policies.append(policy)

                policy_index += 1

            batch_policies_re.append(target_policies)

        batch_policies_re = np.asarray(batch_policies_re)

        return root_values, target_pred_policies, batch_policies_re, cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time

    def _prepare_policy_non_re(self, policy_non_re_context):
        """prepare policy targets from the non-reanalyzed context of policies
        """
        batch_policies_non_re = []
        if policy_non_re_context is None:
            return batch_policies_non_re

        state_index_lst, child_visits, traj_lens = policy_non_re_context
        with torch.no_grad():
            # for policy
            policy_mask = []  # 0 -> out of traj, 1 -> old policy
            # for game, state_index in zip(games, state_index_lst):
            for traj_len, child_visit, state_index in zip(traj_lens, child_visits, state_index_lst):
                # traj_len = len(game)
                target_policies = []

                for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
                    if current_index < traj_len:
                        target_policies.append(child_visit[current_index])
                        policy_mask.append(1)
                    else:
                        target_policies.append([0 for _ in range(self.config.action_space_size)])
                        policy_mask.append(0)

                batch_policies_non_re.append(target_policies)
        batch_policies_non_re = np.asarray(batch_policies_non_re)
        return batch_policies_non_re

    # @profile
    def _prepare_target_gpu(self):
        # get batch from cpu
        # input_context = self.mcts_storage.pop()
        # print(f"[GPU Worker {self.worker_id}] Try get batch from cpu worker")
        status, handle, data = self.smos_client.pop_from_object(name=self.mcts_storage_name)
        # print(f"[GPU Worker {self.worker_id}] Finish get batch from cpu worker")

        wait_time = 0
        # check status
        if not status == SMOS.SMOS_SUCCESS:
            time0 = time.time()
            time.sleep(0.05)
            self.empty_loop_count += 1
            if self.empty_loop_count % 100 == 0:
                print(f"[GPU Worker {self.worker_id}] Warning: Waiting for CPU for 100 cycles!!")
            '''if self.empty_loop_count >= 100:
                raise RuntimeError(f"[GPU Worker {self.worker_id}] Waiting for CPU for over 100 cycles!!!")'''
            return False, time.time() - time0, 0., 0., 0., 0., 0. ,0., 0., 0., 0., 0.
        else:
            mcts_idx = copy.deepcopy(data[7])

            input_context = [[data[0]] + data[1], [data[2]] + data[3], data[4], [data[5]] + data[6], data[7], data[8]]

            # un-package input context
            reward_value_context, policy_re_context, policy_non_re_context, inputs_batch, mcts_idx, target_version = input_context
            obs_lst, action_lst, mask_lst, indices_lst, weights_lst, make_time_lst = inputs_batch
            value_obs_lst, value_mask, state_index_lst, rewards_lst, traj_lens, td_steps_lst, [virtual_mask, virtual_target_values], target_values_old = reward_value_context
            
            time0 = time.time()
            self.trained_steps = self.meta_data_manager.get("trained_steps")[0]
            self.maybe_update_model()
            time1 = time.time()
            wait_time += time1 - time0 
            model_update_time = time1 - time0

            time0 = time.time()
            # target policy
            root_values, target_pred_policies, batch_policies_re, cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time = self._prepare_policy_re(policy_re_context)
            batch_policies = batch_policies_re
            assert root_values.shape[0] == self.config.reanalyze_batch_size
            reanalyze_policy_version = self.version
            time1 = time.time()
            re_policy_time = time1 - time0

            # target reward, value
            time0 = time.time()
            batch_value_prefixes, batch_values, bootstrap_values = self._prepare_reward_value(reward_value_context, indices_lst)
            reanalyze_value_version = self.version
            time1 = time.time()
            re_value_time = time1 - time0

            self.watchdog_server.increase_reanalyze_batch_count()

            time0 = time.time()
            if self.config.reanalyze_update_priority:
                # push to priority storage
                while True:
                    status, _ = self.smos_client.push_to_object(name=self.priority_storage_name,
                                                                data=[[self.version, indices_lst, make_time_lst, batch_policies, root_values, 
                                                                        batch_values.reshape(self.config.reanalyze_batch_size, self.config.num_unroll_steps + 1), 
                                                                        bootstrap_values, 
                                                                        np.array(value_mask).reshape(self.config.reanalyze_batch_size, 1 + self.config.num_unroll_steps), 
                                                                        np.array(td_steps_lst).reshape(self.config.reanalyze_batch_size, 1 + self.config.num_unroll_steps)]])
                    if not status == SMOS.SMOS_SUCCESS:
                        time.sleep(0.1)
                    else:
                        break
            time1 = time.time()
            buffer_update_time = time1 - time0

            # priority correction
            total_transitions = self.replay_buffer.get_total_len()
            current_probs = self.global_priorities[:total_transitions] ** self.config.priority_prob_alpha
            current_probs /= current_probs.sum()
            current_weights_lst = current_probs[indices_lst]
            weights_lst = current_weights_lst / np.asarray(weights_lst)
            fix_pri_clip_ratio = (abs(weights_lst - 1) > self.config.fix_pri_eps).astype(np.int32).sum() / self.config.reanalyze_batch_size
            weights_lst = np.clip(weights_lst, 1. - self.config.fix_pri_eps, 1. + self.config.fix_pri_eps)
            inputs_batch[4] = weights_lst

            # package into batch
            # a batch contains the inputs and the targets; inputs is prepared in CPU workers
            # input_batch[0] is obs stored as numpy, single out for zero copy transmission
            targets_batch = [batch_value_prefixes, batch_values, batch_policies, target_pred_policies, None]
            # targets_batch = [batch_value_prefixes, batch_values, transformed_target_value_prefix, transformed_target_value, batch_policies]
            batch = [inputs_batch[0], inputs_batch[1:], targets_batch, np.array([reanalyze_policy_version, reanalyze_value_version])]

            batch = [batch[0], batch[1], batch[2], np.array([batch[3]]), np.array([self.config.worker_node_id]), mcts_idx]

            # push into batch storage
            fail_count = 0
            time0 = time.time()
            while True:
                status, _ = self.smos_client.push_to_object(name=self.batch_storage_name,
                                                            data=batch)
                if not status == SMOS.SMOS_SUCCESS:
                    time.sleep(0.05)
                    fail_count += 1
                    print(f"[GPU Worker {self.worker_id}] Fail {fail_count} cycles")
                    if fail_count >= 5:
                        print(f"[GPU Worker {self.worker_id}] Fail to push batch for 5 cycles !!!")
                        push_batch_time = time.time() - time0
                        self.release_control(mcts_idx)
                        return False, wait_time + push_batch_time, buffer_update_time, re_policy_time, re_value_time, model_update_time, cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time, push_batch_time, time.time() - np.mean(make_time_lst), fix_pri_clip_ratio
                else:
                    push_batch_time = time.time() - time0
                    wait_time += push_batch_time
                    break
            # print(f"[GPU Worker {self.worker_id}] Finish push reanalyzed batch")
            # cleanup
            self.release_control(mcts_idx)
            self.smos_client.free_handle(object_handle=handle)
            # print(f"[GPU Worker {self.worker_id}] Done iteration", self.version)
            return True, wait_time, buffer_update_time, re_policy_time, re_value_time, model_update_time, cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time, push_batch_time, time.time() - np.mean(make_time_lst), fix_pri_clip_ratio

    def run(self):
        self.storage.add_ready_gpu_worker(self.worker_id)
        start = False
        total_time = 0
        batch_count = 0
        total_wait_time = 0
        total_buffer_update_time = 0
        total_re_policy_time = 0
        total_re_value_time = 0
        total_model_update_time = 0
        total_push_batch_time = 0
        total_cpu_to_gpu_time = 0
        total_gpu_to_cpu_time = 0
        total_model_inference_time = 0
        total_make_time_gap = 0
        total_fix_pri_clip_ratio = 0
        while True:
            # waiting for start signal
            if not start:
                start = self.storage.get_start_signal()
                time.sleep(5)
                continue

            self.trained_steps = trained_steps = self.meta_data_manager.get("trained_steps")[0]
            if trained_steps >= self.config.training_steps + self.config.last_steps:
                time.sleep(30)
                break

            start_time = time.time()
            status, wait_time, buffer_update_time, re_policy_time, re_value_time, model_update_time, cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time, push_batch_time, make_time_gap, fix_pri_clip_ratio = self._prepare_target_gpu()
            end_time = time.time()
            total_wait_time += wait_time
            total_model_update_time += model_update_time
            total_push_batch_time += push_batch_time
            if status:
                total_time += end_time - start_time
                total_buffer_update_time += buffer_update_time
                total_re_policy_time += re_policy_time
                total_re_value_time = re_value_time
                total_cpu_to_gpu_time += cpu_to_gpu_time
                total_gpu_to_cpu_time += cpu_to_gpu_time
                total_model_inference_time += model_inference_time
                total_make_time_gap += make_time_gap
                total_fix_pri_clip_ratio += fix_pri_clip_ratio
                batch_count += 1
                if batch_count % 5 == 0:
                    _, mcts_storage_size = self.smos_client.get_entry_count(name=self.mcts_storage_name)
                    _, batch_storage_size = self.smos_client.get_entry_count(name=self.batch_storage_name)
                    print('[GPU worker {}] Avg. GPU={:.2f}, Lst. GPU={:.2f}, Avg. Wait={:.2f}, Avg. Buf. Update={:.2f}, Avg. Model Upd.={:.2f}, Avg. Push Batch={:.2f}, Avg. Re-Policy={:.2f}, Avg.Re-Value={:.2f}'
                          .format(self.worker_id, total_time / batch_count, end_time - start_time, total_wait_time / batch_count, total_buffer_update_time / batch_count, 
                          total_model_update_time / batch_count, total_push_batch_time / batch_count, total_re_policy_time / batch_count, total_re_value_time / batch_count))
                    print('[GPU worker {}] Avg. Cpu2Gpu={:.2f}, Avg. Gpu2Cpu={:.2f}, Avg. Model Inference={:.2f}'.format(self.worker_id, total_cpu_to_gpu_time / batch_count, total_gpu_to_cpu_time / batch_count, total_model_inference_time / batch_count))
                    print('[GPU worker {}] Avg. Make Time Gap={:.2f}, Avg. Clip Ratio={:.4f}'.format(self.worker_id, total_make_time_gap / batch_count, total_fix_pri_clip_ratio / batch_count))
                    print('[GPU worker {}] MCTS storage size={}, Batch storage size={}'
                          .format(self.worker_id, mcts_storage_size, batch_storage_size))
    
    def maybe_update_model(self):
        self.trained_steps = trained_steps = self.meta_data_manager.get("trained_steps")[0]
        new_model_index = (self.trained_steps - 1 -3) // self.config.target_model_interval
        if new_model_index > self.last_model_index: # and self.storage.get_version() > self.version:
            # print(f"[GPU worker {self.worker_id}] try update model at {self.trained_steps}.")
            target_weights, target_weights_version = read_weights(self.meta_data_manager, self.smos_client, self.storage, self.storage_config) # self.storage.get_weights()
            fail_count = 0
            while target_weights_version < self.config.target_model_interval * new_model_index:
                # print("update fail", target_weights_version, self.config.target_model_interval * new_model_index)
                time.sleep(0.01)
                fail_count += 1
                target_weights, target_weights_version = read_weights(self.meta_data_manager, self.smos_client, self.storage, self.storage_config)
                if fail_count >= 10:
                    print("Update model failed for 10 times", target_weights_version, self.config.target_model_interval * new_model_index)
                    break
            self.last_model_index = new_model_index
            if target_weights_version > self.version and target_weights_version % self.config.target_model_interval == 0:
                self.version = target_weights_version
                self.model.set_weights(target_weights)
                self.model.to(self.device)
                self.model.eval()
                # print(f"[GPU worker {self.worker_id}] Update to target model of "
                #         f"{target_weights_version}.")

    def release_control(self, mcts_idx):
        status, _ = self.smos_client.push_to_object(name=self.storage_config.reanalyze_control_storage_name,
                                                data=[mcts_idx])
        assert status == SMOS.SMOS_SUCCESS

def start_batch_worker_gpu(worker_id, config, storage_config: StorageConfig):
    """
    Start a GPU batch worker. Call this method remotely.
    """
    # set the gpu it resides on
    time.sleep(1 * worker_id + 15)
    available_memory_list = get_gpu_memory()
    for i in range(len(available_memory_list)):
        if i not in storage_config.gpu_worker_visible_devices:
            available_memory_list[i] = -1
    max_index = available_memory_list.index(max(available_memory_list))
    if available_memory_list[max_index] < 2000:
        print(f"[Batch worker GPU]******************* Warning: Low video ram (max remaining "
              f"{available_memory_list[max_index]}) *******************")
    num_gpu_workers = 0
    used_gpu_idx = 0
    for gpu_idx, num_workers in zip(storage_config.gpu_worker_visible_devices, storage_config.num_gpu_works_per_device):
        num_gpu_workers += num_workers
        if num_gpu_workers >= worker_id + 1:
            used_gpu_idx = gpu_idx
            break
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(used_gpu_idx)
    print(f"[Batch worker GPU] Batch worker GPU {worker_id} at process {os.getpid()}"
          f" will use GPU {used_gpu_idx}. Remaining memory before allocation {available_memory_list}")

    # set process name
    setproctitle.setproctitle(f"EfficientZero-gpu_worker{worker_id}")

    # get storages
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    shared_storage = get_shared_storage(storage_config=storage_config)
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    watchdog_server = get_watchdog_server(storage_config=storage_config)

    # start GPU worker
    gpu_worker = BatchWorker_GPU(worker_id=worker_id, replay_buffer=replay_buffer, storage=shared_storage,
                                 smos_client=smos_client, watchdog_server=watchdog_server, config=config,
                                 storage_config=storage_config)
    gpu_worker.set_device(torch.device(f'cuda:{used_gpu_idx}'))
    print(f"[Batch worker GPU] Starting batch worker GPU {worker_id} at process {os.getpid()}.")
    gpu_worker.run()

class ValueUpdater(object):
    def __init__(self, worker_id, config, storage_config, smos_client, replay_buffer, storage):
        self.config = config
        self.storage_config = storage_config

        self.smos_client = smos_client
        self.replay_buffer = replay_buffer
        self.storage = storage
        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)

        self.games = []
        self.num_value_updaters = config.value_updater

        self.current_id = 0
        self.worker_id = worker_id

        self.device = config.device
        
        self.transitions_per_step = config.rev_transitions_per_step // config.value_updater
        self.total_reanalyzed_transitions = 0

        self.model = config.get_uniform_network(is_reanalyze_worker=True)
        try:
            self.model.to(self.device)
        except RuntimeError as e:
            print("[Value Updater]", self.device, torch.cuda.device_count())
            raise e
        self.model.set_device(self.device)
        self.model.eval()
        self.target_model = config.get_uniform_network(is_reanalyze_worker=True)
        self.target_model.to(self.device)
        self.target_model.set_device(self.device)
        self.target_model.eval()
        
        self.last_model_index = 0
        self.version = 0
        self.last_target_model_index = 0
        self.target_version = 0
        self.trained_steps = 0
        
        self.update_rounds = 0

        self.global_priorities = self.meta_data_manager.get("priorities")
        self.global_values = self.meta_data_manager.get("values")
        self.global_target_values = self.meta_data_manager.get("target_values")
        
        self.preload_obs_into_gpu = True
        self.num_batch_games = config.num_batch_games
        self.num_preload_games = config.num_preload_games
        self.games_obs = []
        self.refresh_version = []
        self.fresh = []

        self.config.set_transforms()
    
    def maybe_update_model(self):
        self.trained_steps = self.meta_data_manager.get("trained_steps")[0]
        new_model_index = self.trained_steps // self.config.checkpoint_interval
        if new_model_index > self.last_model_index:
            self.last_model_index = new_model_index
            target_weights, target_weights_version = read_weights(self.meta_data_manager, self.smos_client, self.storage, self.storage_config)
            if target_weights_version > self.version:
                self.version = target_weights_version
                self.model.set_weights(target_weights)
                self.model.to(self.device)
                self.model.eval()
                del target_weights
        
        self.trained_steps = self.meta_data_manager.get("trained_steps")[0]
        new_model_index = self.trained_steps // self.config.target_model_interval
        if new_model_index > self.last_target_model_index:
            target_weights, target_weights_version = read_weights(self.meta_data_manager, self.smos_client, self.storage, self.storage_config)
            if target_weights_version > self.target_version and target_weights_version % self.config.target_model_interval == 0:
                self.last_target_model_index = new_model_index
                self.target_version = target_weights_version
                self.target_model.set_weights(target_weights)
                self.target_model.to(self.device)
                self.target_model.eval()
                del target_weights

    def get_one_game(self):
        game = None
        while game is None:
            game_id = self.current_id * self.num_value_updaters + self.worker_id
            if self.current_id < len(self.games):
                game = self.games[self.current_id]
            else:
                metadata = self.replay_buffer.get_game(game_id)
                if metadata is None:
                    self.current_id = 0
                    self.update_rounds += 1
                    continue
                game = GameHistory.from_metadata(max_length=self.config.history_length, config=self.config, metadata=metadata)
                # restore data for game
                status, handle_batch, reconstructed_batch = self.smos_client.batch_read_from_object(name=self.storage_config.replay_buffer_name,
                                                                                                    entry_idx_batch=[game.entry_idx])

                # check whether replay is deleted
                if not status == SMOS.SMOS_SUCCESS:
                    time.sleep(0.05)
                    continue

                game.restore_data(reconstructed_object=reconstructed_batch[0], docopy=True)

                # clean up batch from replay buffer
                self.smos_client.batch_release_entry(object_handle_batch=handle_batch)

                self.games.append(game)
                self.refresh_version.append(-1)
                self.fresh.append(1)

                if self.preload_obs_into_gpu:
                    try:   
                        if self.config.cvt_string:
                            obs_history = np.asarray([str_to_arr(frame) for frame in game.obs_history]).astype(np.uint8)
                        else:
                            obs_history = np.asarray([frame for frame in game.obs_history]).astype(np.uint8)
                        self.games_obs.append(torch.from_numpy(obs_history).to(self.device))
                    except ValueError as e:
                        print(type(game.obs_history))
                        print(game.obs_history.dtype)
                        raise e
        # time.sleep(0.1)
        # time.sleep(2. / len(self.games))
        self.current_id += 1
        return self.current_id - 1, game
    
    def get_games(self, num=1):
        games = [self.get_one_game() for _ in range(num)]
        return games

    def move_game_obs(self, start_idx, num_preload_games):
        cur_idx = start_idx
        for i in range(num_preload_games):
            if self.games_obs[cur_idx].get_device() == -1:
                self.games_obs[cur_idx] = self.games_obs[cur_idx].to(self.device, non_blocking=True)
            cur_idx = (cur_idx + 1) % len(self.games)
            if cur_idx == start_idx:
                break
        while cur_idx != start_idx:
            if self.games_obs[cur_idx].get_device() != -1:
                self.games_obs[cur_idx] = self.games_obs[cur_idx].to(torch.device("cpu"), non_blocking=True)
            cur_idx = (cur_idx + 1) % len(self.games)

    def step(self):
        games = self.get_games(self.num_batch_games)

        # update model
        self.trained_steps = self.meta_data_manager.get("trained_steps")[0]
        self.maybe_update_model()

        # game related
        traj_lens = [len(game.actions) for idx, game in games]

        # compute value for current game
        if self.preload_obs_into_gpu:
            # load unnecessary ones into cpu, preload useful ones with nonblocking=True
            start_idx = games[0][0]
            self.move_game_obs(start_idx, self.num_preload_games)
            game_intervals = []

            value_obs_lst = []
            for (idx, game), traj_len in zip(games, traj_lens):
                interval_l = len(value_obs_lst)
                for t in range(traj_len):
                    value_obs_lst.append(self.games_obs[idx][t: t+self.config.stacked_observations])
                interval_r = len(value_obs_lst) + 1
                game_intervals.append((interval_l, interval_r))
            value_obs_lst = torch.stack(value_obs_lst, dim=0)
            value_obs_lst = torch.movedim(value_obs_lst, -1, 2)
            shape = value_obs_lst.shape
            value_obs_lst = value_obs_lst.reshape(shape[0], -1, shape[-2], shape[-1])
            m_obs = value_obs_lst.float() / 255.0
        else:
            value_obs_lst = []
            for (idx, game), traj_len in zip(games, traj_lens):
                for t in range(traj_len):
                    value_obs_lst.append(game.obs(t))
            value_obs_lst = prepare_observation_lst(value_obs_lst)
            m_obs = torch.from_numpy(value_obs_lst).to(self.device).float() / 255.0
        with torch.no_grad():
            network_output = []
            with autocast():
                m_output = self.model.initial_inference(m_obs, to_numpy=False)#self.config.transform(m_obs))
            network_output.append(m_output)
            values, value_prefix_pool, policy_logits_pool, hidden_state_roots, reward_hidden_roots = concat_output(network_output, state_gpu=True, reward_hidden_gpu=True)
            values = values.reshape(-1)
            target_values = values

        start = 0
        total_transitions = self.replay_buffer.get_total_len()
        game_ids = []
        game_idx_intervals = []
        game_priorities = []
        game_values = []
        for (idx, game), (interval_l, interval_r), traj_len in zip(games, game_intervals, traj_lens):
            # td steps
            try:
                delta_td = (total_transitions - game.idx_interval[0]) // self.config.auto_td_steps
                td_steps = self.config.td_steps - delta_td
                td_steps = np.clip(td_steps, 1, 5).astype(np.int)
            except TypeError as e:
                print("Game:", game.metadata())
                raise e

            # recompute-priority
            assert (len(game.root_values) == traj_len), (len(game.root_values), traj_len)
            game.root_values = target_values[start: start + traj_len]
            new_priorities = game.compute_priorities(self.config.discount, td_steps, eps=self.config.prioritized_replay_eps, values=values[start:start + game.length])
            
            exp_moving_avg_factor = 1.

            if self.version > self.refresh_version[idx]:
                self.refresh_version[idx] = self.version
                exp_moving_avg_factor = 0.1 # 0.01
                pos_entries = (game.target_values[:game.length] >= game.root_values[:game.length])
                neg_entries = (game.target_values[:game.length] < game.root_values[:game.length])

                self.global_values[game.idx_interval[0]: game.idx_interval[1]] = game.root_values[:game.length]
                self.global_target_values[game.idx_interval[0]: game.idx_interval[1]] = game.target_values[:game.length]
                self.global_priorities[game.idx_interval[0]: game.idx_interval[1]] = new_priorities

            start += traj_len
        
        self.total_reanalyzed_transitions += len(values)
        
        return True

    def run(self):
        self.storage.add_ready_value_updater(self.worker_id)

        start = False
        total_time = 0.
        update_count = 0
        while True:
            # waiting for start signal
            if not start:
                start = self.storage.get_start_signal()
                time.sleep(5)
                continue

            self.trained_steps = trained_steps = self.meta_data_manager.get("trained_steps")[0]
            if trained_steps >= self.config.training_steps + self.config.last_steps:
                time.sleep(30)
                break
        
            while self.total_reanalyzed_transitions > self.trained_steps * self.transitions_per_step:
                time.sleep(1.)
                self.trained_steps = trained_steps = self.meta_data_manager.get("trained_steps")[0]

            start_time = time.time()
            status = self.step()
            end_time = time.time()

            total_time += end_time - start_time
            if status:
                update_count += 1
                if update_count % 20 == 0 and self.worker_id == 0:
                    print('[Value Updater] Avg. Val. Upd.={:.2f}, Lst. Val. Upd.={:.2f}. Avg. Round.={:.2f}'
                            .format(total_time / update_count, end_time - start_time, total_time / (self.update_rounds + 1)))

def start_value_updater(worker_id, config, storage_config: StorageConfig):
    """
    Start a value updater. Call this method remotely.
    """
    # set the gpu it resides on
    time.sleep(1 * worker_id + 15)
    available_memory_list = get_gpu_memory()
    for i in range(len(available_memory_list)):
        if i not in storage_config.value_updater_visible_devices:
            available_memory_list[i] = -1
    max_index = available_memory_list.index(max(available_memory_list))
    if available_memory_list[max_index] < 2000:
        print(f"[Value Updater]******************* Warning: Low video ram (max remaining "
              f"{available_memory_list[max_index]}) *******************")
    used_gpu_idx = storage_config.value_updater_visible_devices[worker_id % len(storage_config.value_updater_visible_devices)]
    config.device = torch.device(f'cuda:{used_gpu_idx}')
    print(f"[Value Updater] Value updater {worker_id} at process {os.getpid()}"
          f" will use GPU {used_gpu_idx}. Remaining memory before allocation {available_memory_list}")

    # set process name
    setproctitle.setproctitle(f"EfficientZero-value_updater{worker_id}")

    # get storages
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    shared_storage = get_shared_storage(storage_config=storage_config)
    smos_client = SMOS.Client(connection=storage_config.smos_connection)

    # start value updater
    value_updater = ValueUpdater(worker_id=worker_id, replay_buffer=replay_buffer, storage=shared_storage,
                                 smos_client=smos_client, config=config, storage_config=storage_config)
    print(f"[Value Updater] Starting value updater {worker_id} at process {os.getpid()}.")
    value_updater.run()

class PriorityUpdater(object):
    def __init__(self, worker_id, config, storage_config, smos_client, replay_buffer, storage):
        self.config = config
        self.storage_config = storage_config
        self.priority_storage_name = storage_config.priority_storage_name + f"{worker_id % storage_config.priority_storage_count}"

        self.smos_client = smos_client
        self.replay_buffer = replay_buffer
        self.storage = storage
        self.meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)

        self.worker_id = worker_id

    def step(self):
        # pop from priority storage
        while True:
            status, handle, batch = self.smos_client.pop_from_object(name=self.priority_storage_name)
            if not status == SMOS.SMOS_SUCCESS:
                time.sleep(0.5)
            else:
                break
        
        version, indices_lst, make_time_lst, batch_policies, root_values, batch_values, bootstrap_values, value_mask, td_steps_lst = batch
        
        self.replay_buffer.update_policies(version, indices_lst, make_time_lst, batch_policies, root_values)
        self.replay_buffer.update_values(version, indices_lst, make_time_lst, batch_values, bootstrap_values, value_mask, td_steps_lst)

        # cleanup
        self.smos_client.free_handle(object_handle=handle)
    
    def run(self):
        self.storage.add_ready_priority_updater(self.worker_id)
        
        start = False
        total_time = 0
        batch_count = 0
        while True:
            # waiting for start signal
            if not start:
                start = self.storage.get_start_signal()
                time.sleep(5)
                continue

            self.trained_steps = trained_steps = self.meta_data_manager.get("trained_steps")[0]
            if trained_steps >= self.config.training_steps + self.config.last_steps:
                time.sleep(30)
                break

            start_time = time.time()
            self.step()
            end_time = time.time()

            total_time += end_time - start_time
            batch_count += 1
            if batch_count % 20 == 0:
                print('[Priority Updater] Avg. Buf. Upd.={:.2f}, Lst. Buf. Upd.={:.2f}'
                        .format(total_time / batch_count, end_time - start_time))

def start_priority_updater(worker_id, config, storage_config: StorageConfig):
    """
    Start a worker to update priority. Call this method remotely.
    """
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    shared_storage = get_shared_storage(storage_config=storage_config)
    smos_client = SMOS.Client(connection=storage_config.smos_connection)

    priority_updater = PriorityUpdater(worker_id=worker_id, config=config, storage_config=storage_config, smos_client=smos_client, replay_buffer=replay_buffer, storage=shared_storage)
    print(f"[Priority Updater] Starting priority updater {worker_id} at process {os.getpid()}.")
    priority_updater.run()