import copy
import datetime
from functools import partial
import os
import pickle
import setproctitle
import socket
import sys
import time
from typing import Mapping
import wandb
import lz4framed
import numpy as np

import SMOS
import SMOS_utils
import zmq
import pickle
import queue

from enum import Enum
from torch.utils.tensorboard import SummaryWriter

from core.watchdog import get_watchdog_server
from core.storage_config import StorageConfig
from core.shared_storage import get_shared_storage, read_weights
from core.replay_buffer import get_replay_buffer
from core.log import _log, _log_worker
from core.utils import MappingThread, TimeTicker
from core.meta_data_manager import MetaDataSharedMemoryManager
from core.game import GameHistory

VIRTUAL_TEST=os.environ.get("VIRTUAL_TEST", False) # If True, packets would not be sent from worker node. This simulates an extremely fast trainer.

class PacketType(Enum):
    TARGET_MODEL = 9999
    MODEL = 8888
    LOG = 7777
    SIG_REMOVE_TO_FIT = 6666
    SIG_INCREASE_COUNTER = 5555
    START_TRAINING = 3333


class BatchPacket:
    def __init__(self, node_id, batch):
        self.node_id = node_id
        self.batch = batch


class PriorityPacket:
    def __init__(self, indices, new_priority, make_time):
        self.indices = indices
        self.new_priority = new_priority
        self.make_time = make_time


class SignalPacket:
    def __init__(self, packet_type: PacketType, create_steps, data):
        self.packet_type = packet_type
        self.create_steps = create_steps
        self.data = data



def batch_sender(sender_id, worker_node_id, batch_multi, config, storage_config: StorageConfig):
    """
    send batch from smos queue on worker node to trainer node using zmq
    position: worker
    multiple: yes
    """
    # initialize smos client and zmq socket
    # smos client
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    batch_storage_name = storage_config.batch_storage_name + f"{sender_id % storage_config.batch_storage_count_worker}"
    # watchdog
    watchdog_server = get_watchdog_server(storage_config=storage_config)
    ticker =  TimeTicker(3)
    # zmq socket
    zmq_context = zmq.Context()
    batch_socket = zmq_context.socket(zmq.PUSH)
    server_port = storage_config.zmq_batch_port_list_worker[sender_id % storage_config.batch_receiver_count]
    batch_queue_address = f"tcp://{storage_config.zmq_batch_ip_worker}:{server_port}"
    batch_socket.set_hwm(2)
    batch_socket.connect(batch_queue_address)
    print(f"[Batch Sender {sender_id} at node {worker_node_id}]"
          f" Batch sender {sender_id} has been initialized at process {os.getpid()}, port {server_port}.")

    shared_storage = get_shared_storage(storage_config=storage_config)
    replay_buffer = get_replay_buffer(storage_config=storage_config)

    # global obs
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
    # global_obs = meta_data_manager.get("global_mcts_obs")

    # send forever
    total_batches = 0
    total_time = 0
    total_drops = 0
    total_version_gap = 0

    # drop ratio for reanalyzed batches to control policy version
    drop_ratio = 0

    # send start when enough data is collected in the replay buffer
    if sender_id == 0:
        while not (replay_buffer.get_total_len() >= config.start_transitions 
                and shared_storage.get_num_ready_cpu_workers() == config.cpu_actor 
                and shared_storage.get_num_ready_gpu_workers() == config.gpu_actor 
                and shared_storage.get_num_ready_value_updaters() == config.value_updater
                and shared_storage.get_num_ready_priority_updaters() == config.priority_updater):
            time.sleep(1)
            pass
        print("Ready to strat training!! All workers are set up now.")

    # wait until start
    while not shared_storage.get_start_signal():
        if sender_id == 0:
            # send ready signal to trainer
            if not VIRTUAL_TEST:
                compressed_stream = lz4framed.compress(SMOS_utils.serialize(f"ready-{worker_node_id}"))
                batch_socket.send(compressed_stream)
        time.sleep(1)
        pass

    while True:
        start_time = time.time()
        # fetch data
        status, handle, batch = smos_client.pop_from_object(name=batch_storage_name)
        if not status == SMOS.SMOS_SUCCESS:
            time.sleep(0.05)
            continue

        mcts_idx = batch[-1]
        batch = batch[:-1]
        # batch[0] = global_obs[mcts_idx]

        # send
        time0 = time.time()
        data_stream = SMOS_utils.serialize(batch)
        time1 = time.time()
        compressed_stream = lz4framed.compress(data_stream)
        time2 = time.time()
        if not VIRTUAL_TEST:
            batch_socket.send(compressed_stream)
        else:
            if not os.path.exists(storage_config.compressed_batch_path):
                pickle.dump(compressed_stream, open(storage_config.compressed_batch_path, "wb"))
        time3 = time.time()

        # logging
        end_time = time.time()
        total_time += end_time - start_time
        # total_version_gap += version_gap
        total_batches += 1
        if total_batches % 20 == 0:
            print('[Batch Sender {}] Avg={:.2f}, Lst={:.2f}. '
                .format(sender_id, total_time / total_batches, end_time - start_time) + 
                'Serialize={:.2f}, Compress={:.2f}, Send={:.2f}'.format(time1-time0, time2-time1, time3-time2)
                )
        
        if ticker.tick():
            drop_ratio = watchdog_server.get_drop_ratio()

        smos_client.free_handle(object_handle=handle)


def batch_receiver(receiver_id, config, storage_config: StorageConfig):
    """
    receive batch from worker nodes and put them into smos queue on trainer node
    position: trainer
    multiple: yes
    """
    setproctitle.setproctitle(f"EfficientZero-batch_receiver{receiver_id}")
    # initialize smos client and zmq socket
    # smos client
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    batch_storage_name = storage_config.batch_storage_name + f"{receiver_id % storage_config.batch_storage_count_trainer}"
    # shared storage
    shared_storage = get_shared_storage(storage_config=storage_config)
    # zmq socket
    zmq_context = zmq.Context()
    batch_socket = zmq_context.socket(zmq.PULL)
    server_port = storage_config.zmq_batch_port_list_trainer[receiver_id]
    batch_queue_address = f"tcp://{storage_config.zmq_batch_ip_trainer}:{server_port}"
    batch_socket.set_hwm(2)
    batch_socket.bind(batch_queue_address)
    print(f"[Batch Receiver {receiver_id}] Batch receiver {receiver_id} has been initialized "
          f"at process {os.getpid()}, port {server_port}.", flush=True)

    compressed_stream = None
    # if virtual test, read compressed batch from local
    if VIRTUAL_TEST:
        compressed_stream = pickle.load(open(storage_config.compressed_batch_path, "rb"))
        data_stream = lz4framed.decompress(compressed_stream)
        batch = SMOS_utils.deserialize(data_stream)

    # receiver thread
    def receive_batch(_ignore):
        if not VIRTUAL_TEST:
            compressed_stream = batch_socket.recv()
        else:
            compressed_stream = pickle.load(open(storage_config.compressed_batch_path, "rb"))
        data_stream = lz4framed.decompress(compressed_stream)
        batch = SMOS_utils.deserialize(data_stream)
        return batch
    batch_queue = queue.Queue(4)
    thread = MappingThread(receive_batch, False, None, batch_queue)
    thread.start()

    # receive forever
    total_batches = 0
    total_time = 0
    total_drops = 0
    while True:
        start_time = time.time()

        try:
            batch = batch_queue.get_nowait()
        except queue.Empty:
            time.sleep(0.001)
            continue

        if isinstance(batch, str):
            worker_node_id = batch[6:]
            shared_storage.add_ready_worker_node(worker_node_id)
            print(f"Worker {worker_node_id} is ready. Current ready nums: {shared_storage.get_num_ready_worker_nodes()}", flush=True)
            continue

        version_gap = shared_storage.get_counter() - batch[-2].mean()

        if version_gap <= config.max_version_gap:
            # push to local batch queue until success
            while True:
                status, _ = smos_client.push_to_object(name=batch_storage_name, data=batch)
                if status == SMOS.SMOS_SUCCESS:
                    break
                time.sleep(0.01)

            # logging
            end_time = time.time()
            total_time += end_time - start_time
            total_batches += 1
            if total_batches % 20 == 0:
                print('[Batch Receiver {}] Avg={:.2f}, Lst={:.2f}. Avg. drop={:.2f}.'
                    .format(receiver_id, total_time / total_batches, end_time - start_time, total_drops / (total_drops + total_batches)), flush=True)
        else:
            total_drops += 1


def signal_publisher(storage_config: StorageConfig):
    """
    read training signals from smos and broadcast them to workers (every worker will here this)
    position: trainer
    multiple: no
    """
    # get smos client and zmq socket
    # smos client
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    signal_queue_name = storage_config.signal_queue_name
    # zmq socket
    zmq_context = zmq.Context()
    signal_socket = zmq_context.socket(zmq.PUB)
    signal_queue_address = f"tcp://{storage_config.zmq_signal_ip_trainer}:{storage_config.zmq_signal_port_trainer}"
    signal_socket.bind(signal_queue_address)
    print(f"[Signal Publisher] Signal Publisher has been initialized at process {os.getpid()},"
          f" port {storage_config.zmq_signal_port_trainer}.")

    # publish forever
    total_batches = 0
    while True:
        # fetch data
        status, handle, signal_packet = smos_client.pop_from_object(name=signal_queue_name)
        if not status == SMOS.SMOS_SUCCESS:
            time.sleep(0.05)
            continue

        # send
        if not VIRTUAL_TEST:
            signal_socket.send(SMOS_utils.serialize(signal_packet))
        else:
            pass

        # clean up
        smos_client.free_handle(object_handle=handle)

        # logging
        total_batches += 1


def signal_subscriber(exp_path, config, storage_config: StorageConfig):
    """
    listens training signals
    position: worker
    multiple: no
    """
    # get storages
    # local
    # summary_writer = SummaryWriter(exp_path, flush_secs=10)
    localtime = datetime.datetime.now().strftime("%m-%d-%H")
    wandb_run = wandb.init(
        config=config,
        project='Atari-EfficientZero',
        entity='speedyzero',
        notes=socket.gethostname(),
        name=f'log-node{config.all_args.worker_node_id}',
        group=f"{config.all_args.env}-bs{config.all_args.eff_batch_size}-{config.all_args.info}-seed{config.all_args.seed}-time{localtime}-run{config.run_id}",
        dir=exp_path,
        job_type="train",
        reinit=True,
        tags=config.all_args.wandb_tags,
        mode='online',
    )
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    shared_storage = get_shared_storage(storage_config=storage_config)
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    watchdog_server = get_watchdog_server(storage_config=storage_config)
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
    # zmq
    zmq_context = zmq.Context()
    signal_socket = zmq_context.socket(zmq.SUB)
    signal_queue_address = f"tcp://{storage_config.zmq_signal_ip_worker}:{storage_config.zmq_signal_port_worker}"
    signal_socket.setsockopt_string(zmq.SUBSCRIBE, "")
    signal_socket.connect(signal_queue_address)
    print(f"[Signal Subscriber] Signal Subscriber has been initialized at process {os.getpid()}, "
          f"port {storage_config.zmq_signal_port_worker}.")

    # set start signal
    # wait until collecting enough data to start
    while not (replay_buffer.get_total_len() >= config.start_transitions 
                and shared_storage.get_num_ready_cpu_workers() == config.cpu_actor 
                and shared_storage.get_num_ready_gpu_workers() == config.gpu_actor 
                and shared_storage.get_num_ready_value_updaters() == config.value_updater
                and shared_storage.get_num_ready_priority_updaters() == config.priority_updater):
        time.sleep(1)
        pass

    # Virtual Test
    if VIRTUAL_TEST:
        shared_storage.set_start_signal()
        time.sleep(5.)
        TRAIN_LOOP = 0.06
        ticker =  TimeTicker(TRAIN_LOOP)
        step_count = 0
        weights, _ = read_weights(meta_data_manager, smos_client, shared_storage, storage_config)
        while True:
            if ticker.tick():
                step_count += 1
                shared_storage.incr_counter()
                if step_count % 10 == 0:
                    shared_storage.set_weights(weights, step_count)
                    print(f"[Signal Subscriber] Model {step_count}")
                if step_count % 100 == 0:
                    _log_worker(config, step_count, watchdog_server, replay_buffer, shared_storage, wandb_run)
                
                if step_count > config.training_steps + config.last_steps:
                    return
        return

    while True:
        # get signal
        data_stream = signal_socket.recv()
        signal_packet = SMOS.deserialize(data_stream)
        if signal_packet.packet_type == PacketType.START_TRAINING:
            shared_storage.set_start_signal()
            print(f"[Signal Subscriber] Start training.")
            break
        else:
            raise RuntimeError(f"Expected 'START_TRAINING' packet, but got {signal_packet}")

    print(f'************* Begin Training *************')

    # subscribe forever
    while True:
        # get signal
        data_stream = signal_socket.recv()
        signal_packet = SMOS.deserialize(data_stream)
        if not (signal_packet.packet_type == PacketType.SIG_INCREASE_COUNTER
                and not signal_packet.create_steps % 20 == 0):
            print(f"[Signal Subscriber] Receiving packet {signal_packet.packet_type} created"
                  f" at {signal_packet.create_steps}.")

        # parse packet
        if signal_packet.packet_type == PacketType.TARGET_MODEL:
            weights, version = signal_packet.data
            shared_storage.set_target_weights(weights, version)
            print(f"[Signal Subscriber] Target Model {version}")
        elif signal_packet.packet_type == PacketType.MODEL:
            weights, version = signal_packet.data
            shared_storage.set_weights(weights, version)
            print(f"[Signal Subscriber] Model {version}")
        # elif signal_packet.packet_type == PacketType.SIG_REMOVE_TO_FIT:
        #     replay_buffer.remove_to_fit()
        elif signal_packet.packet_type == PacketType.SIG_INCREASE_COUNTER:
            shared_storage.incr_counter()
            replay_buffer.update_version(shared_storage.get_counter())
            # print(f"[Signal Subscriber] Increase Counter")
        elif signal_packet.packet_type == PacketType.LOG:
            config, step_count, log_data, lr, vis_result = signal_packet.data
            print(f"[Signal Subscriber] Try LOG")
            _log(config, step_count, log_data, replay_buffer, lr, shared_storage,
                 wandb_run, vis_result)
            print(f"[Signal Subscriber] Finish LOG")
        else:
            raise NotImplementedError
        
def replay_buffer_publisher(storage_config: StorageConfig, config):
    """
    broadcasts replay buffer and priority from master worker to other worker node
    position: master worker
    multiple: no
    """
    assert storage_config.is_data_worker
    # zmq socket
    zmq_context = zmq.Context()
    replay_buffer_socket = zmq_context.socket(zmq.PUB)
    replay_buffer_queue_address = f"tcp://{storage_config.zmq_replay_buffer_ip_master}:" \
                                  f"{storage_config.zmq_replay_buffer_port_master}"
    replay_buffer_socket.bind(replay_buffer_queue_address)
    print(f"[Replay Buffer Publisher] Replay Buffer Publisher has been initialized at process {os.getpid()}, "
          f"port {storage_config.zmq_replay_buffer_port_master}.")
    # storage
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
    global_priorities = meta_data_manager.get("priorities")
    global_valid_entries = meta_data_manager.get("valid_entries")
    global_trained_steps = meta_data_manager.get("trained_steps")
    global_rewards = meta_data_manager.get("rewards")
    global_values = meta_data_manager.get("values")
    global_target_values = meta_data_manager.get("target_values")
    global_death_masks = meta_data_manager.get("death_masks")
    global_game_ids = meta_data_manager.get("game_ids")

    time.sleep(20.)

    # publish forever
    last_publish_priorities_time = time.time()
    last_filter_time = time.time()
    total_new_games, total_new_game_time = 0, 0
    total_upd_priorities, total_upd_priorities_gap = 0, 0
    if config.worker_node_id == 0:
        os.makedirs(config.buffer_snapshot_path)
        last_buffer_snapshot_step = 0
    while True:
        # sync game
        new_games = replay_buffer.get_new_games()
        for game_metadata in new_games:
            start_time = time.time()
            game = GameHistory.from_metadata(max_length=config.history_length, config=config, metadata=game_metadata)
            # restore data for game
            status, handle_batch, reconstructed_batch = smos_client.batch_read_from_object(name=storage_config.replay_buffer_name,
                                                                                                entry_idx_batch=[game.entry_idx])

            # check whether replay is deleted
            if not status == SMOS.SMOS_SUCCESS:
                time.sleep(0.05)
                continue
            
            data = ["game", [game_metadata, reconstructed_batch[0]]]
            compressed_data = lz4framed.compress(SMOS_utils.serialize(data))

            # clean up batch from replay buffer
            smos_client.batch_release_entry(object_handle_batch=handle_batch)
            
            replay_buffer_socket.send(compressed_data)

            # time.sleep(0.05)

            end_time = time.time()

            total_new_games += 1
            total_new_game_time += end_time - start_time

            if total_upd_priorities % 5 == 0:
                print("[Replay Buffer Publisher] #Update Game# Avg. Upd. Pri. Gap={}, Avg. Pub New Game={}".format(total_upd_priorities_gap / max(1, total_upd_priorities), total_new_game_time / max(1, total_new_games)))

        # sync priorities and target values
        current_time = time.time()
        if current_time - last_publish_priorities_time > 0.03: # sync every 0.05 sec
            
            data_stream = lz4framed.compress(SMOS_utils.serialize(["priorities", [total_upd_priorities, global_priorities[:].copy(), global_valid_entries.copy(), global_target_values[:].copy()]]))

            replay_buffer_socket.send(data_stream)

            total_upd_priorities_gap += current_time - last_publish_priorities_time
            last_publish_priorities_time = current_time
            total_upd_priorities += 1
            
            if total_upd_priorities % 20 == 0:
                print("[Replay Buffer Publisher] #Update Pri.# Avg. Upd. Pri. Gap={}, Avg. Pub New Game={}".format(total_upd_priorities_gap / max(1, total_upd_priorities), total_new_game_time / max(1, total_new_games)), total_upd_priorities, global_priorities.sum(), replay_buffer.get_sampling_log())


def replay_buffer_subscriber(worker_node_id, storage_config: StorageConfig, config):
    """
    listens broadcast and updates Replaybuffer
    position: worker
    multiple: no
    """
    assert not storage_config.is_data_worker
    # get replay buffer and zmq socket
    # replay buffer
    replay_buffer = get_replay_buffer(storage_config)
    # zmq socket
    zmq_context = zmq.Context()
    replay_buffer_socket = zmq_context.socket(zmq.SUB)
    replay_buffer_queue_address = f"tcp://{storage_config.zmq_replay_buffer_ip_worker}:{storage_config.zmq_replay_buffer_port_worker}"
    replay_buffer_socket.setsockopt_string(zmq.SUBSCRIBE, "")
    replay_buffer_socket.connect(replay_buffer_queue_address)
    print(f"[Replay Buffer Subscriber] Replay Buffer Subscriber has been initialized at process {os.getpid()}, "
          f"port {storage_config.zmq_replay_buffer_port_worker}.")
    # storage
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    replay_buffer = get_replay_buffer(storage_config=storage_config)
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
    global_priorities = meta_data_manager.get("priorities")
    global_valid_entries = meta_data_manager.get("valid_entries")
    global_target_values = meta_data_manager.get("target_values")

    # subscribe forever
    total_new_games, total_new_game_time = 0, 0
    total_upd_priorities, total_upd_priorities_gap = 0, 0
    last_upd_priorities_time = time.time()
    while True:
        data_stream = replay_buffer_socket.recv()
        data_stream = lz4framed.decompress(data_stream)
        data = SMOS.deserialize(data_stream)

        data_type, data = data
        
        if data_type == "game":
            # push new game
            start_time = time.time()
            game_metadata, reconstructed_batch = data
            game = GameHistory.from_metadata(max_length=config.history_length, config=config, metadata=game_metadata)
            game_entry_idx = game.entry_idx
            game.restore_data(reconstructed_object=reconstructed_batch, docopy=False)

            game.game_over(smos_client, storage_config) # call game_over() to write data in shared memory
            # assert (game_entry_idx is not None and game.entry_idx == game_entry_idx), (game_entry_idx, game.entry_idx)

            replay_buffer.save_new_game(game)

            end_time = time.time()
            total_new_games += 1
            total_new_game_time += end_time - start_time
            if total_new_games % 5 == 0:
                print("[Replay Buffer Publisher] #Update Game# Avg. Upd. Pri. Gap={}, Avg. Pub New Game={}".format(total_upd_priorities_gap / max(1, total_upd_priorities), total_new_game_time / max(1, total_new_games)))

        elif data_type == "priorities":
            # priorities = meta_data_manager.get("priorities")
            # sync priorities
            upd_priorities_step, new_priorities, new_valid_entries, target_values = data
            global_priorities[:] = new_priorities[:]
            global_valid_entries[:] = new_valid_entries[:]
            global_target_values[:] = target_values[:]
            # replay_buffer.update_new_priorities(new_priorities)
            current_time = time.time()
            total_upd_priorities_gap += current_time - last_upd_priorities_time
            last_upd_priorities_time = current_time
            total_upd_priorities += 1

            if total_upd_priorities % 20 == 0:
                print("[Replay Buffer Publisher] #Update Pri.# Avg. Upd. Pri. Gap={}, Avg. Pub New Game={}".format(total_upd_priorities_gap / max(1, total_upd_priorities), total_new_game_time / max(1, total_new_games)), upd_priorities_step, global_priorities.sum(), new_priorities.sum(), replay_buffer.get_sampling_log())


        if time.time() - last_upd_priorities_time > 60:
            raise RuntimeError("[Replay Buffer Subscriber] Priorities are not updated for 60 sec!!")

def priority_publisher(storage_config: StorageConfig):
    """
    broadcasts priority to target worker node
    position: trainer
    multiple: no
    """
    # get smos client and zmq socket
    # smos client
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    priority_queue_name = storage_config.priority_queue_name
    # zmq socket
    zmq_context = zmq.Context()
    priority_socket = zmq_context.socket(zmq.PUB)
    priority_queue_address = f"tcp://{storage_config.zmq_priority_ip_trainer}:" \
                             f"{storage_config.zmq_priority_port_trainer}"
    priority_socket.bind(priority_queue_address)
    print(f"[Priority Publisher] Priority Publisher has been initialized at process {os.getpid()}, "
          f"port {storage_config.zmq_priority_port_trainer}.")

    # publish forever
    total_batches, total_time = 0, 0
    while True:
        start_time = time.time()

        # fetch data
        status, handle, data = smos_client.pop_from_object(name=priority_queue_name)
        if not status == SMOS.SMOS_SUCCESS:
            time.sleep(0.05)
            continue

        # construct data and send
        indices, new_priority, make_time, worker_node_id = data
        num_workers = len(worker_node_id)
        num_entries = len(indices)
        entries_per_worker = num_entries // num_workers
        l = 0
        for i in range(num_workers):
            r = l + entries_per_worker         
            priority_packet = PriorityPacket(indices=indices[l: r], new_priority=new_priority[l: r], make_time=make_time[l: r])
            priority_socket.send(f"Node{worker_node_id[i]}_ao~wu~ao~wu~_".encode() + SMOS_utils.serialize(priority_packet))   
            l += entries_per_worker

        # clean up
        smos_client.free_handle(object_handle=handle)

        # logging
        end_time = time.time()
        total_batches += 1
        total_time += end_time - start_time
        if total_batches % 20 == 0:
            _, priority_queue_size = smos_client.get_entry_count(name=priority_queue_name)
            # print('[Priority Publisher] Avg.={:.2f}, LST.={:.2f}, priority queue size={}'
            #       .format(total_time / total_batches, end_time - start_time, priority_queue_size))


def priority_subscriber(worker_node_id, config, storage_config: StorageConfig):
    """
    listens broadcast and updates Replaybuffer
    position: worker
    multiple: no
    """
    # get replay buffer and zmq socket
    if worker_node_id != 0:
        while True:
            time.sleep(10.)
    # replay buffer
    replay_buffer = get_replay_buffer(storage_config)
    # global priorities
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
    global_priorities = meta_data_manager.get("priorities")
    # zmq socket
    zmq_context = zmq.Context()
    priority_socket = zmq_context.socket(zmq.SUB)
    priority_queue_address = f"tcp://{storage_config.zmq_priority_ip_worker}:{storage_config.zmq_priority_port_worker}"
    priority_socket.setsockopt_string(zmq.SUBSCRIBE, "")
    priority_socket.connect(priority_queue_address)
    print(f"[Priority Subscriber] Priority Subscriber has been initialized at process {os.getpid()}, "
          f"port {storage_config.zmq_priority_port_worker}.")

    # subscribe forever
    total_time, total_batches = 0, 0
    while True:
        start_time = time.time()

        # get signal, remove header and reconstruct
        data_stream = priority_socket.recv()
        data_stream = data_stream.split(b"_ao~wu~ao~wu~_")[1]
        priority_packet = SMOS.deserialize(data_stream)

        # logging
        end_time = time.time()
        total_time += end_time - start_time
        total_batches += 1
        if total_batches % 20 == 0:
            print('[Priority Subscriber] Avg.={:.2f}, LST.={:.2f}'
                  .format(total_time / total_batches, end_time - start_time))
