import SMOS
import SMOS_utils
import os
from core.node_config import ALL_NODE_CONFIGS

NODE_CONFIG_ID = int(os.environ["NODE_CONFIG_ID"])
node_config = ALL_NODE_CONFIGS[NODE_CONFIG_ID]
BASE_INSIDE_PORT = 5010
BASE_OUTSIDE_PORT = 5010 # 10002

class StorageConfig:
    def __init__(self, label="", worker_id=-1):
        """
        Storage config contains everything needed for accessing storages
        used in EfficientZero_smos.
        """
        # device allocation
        # Note that n DDP trainers will always occupy the first n cards!!
        self.gpu_worker_visible_devices =  [0, 1, 2, 3, 4, 5, 6, 7]
        self.num_gpu_works_per_device =  [6, 6, 6, 6, 6, 6, 6, 6]
        self.data_worker_visible_devices = [0, 1, 2]
        self.value_updater_visible_devices = [3, 4, 5, 6, 7]
        self.test_visible_device = 1
        
        # whether is worker node
        self.is_data_worker = False
        data_worker_address = None
        if worker_id >= 0:
            self.is_data_worker = (worker_id in node_config.data_worker_ids)
            if self.is_data_worker:
                self.gpu_worker_visible_devices =  [0, 1]
                self.num_gpu_works_per_device =  [0, 0]
                self.data_worker_visible_devices = [0]
                self.value_updater_visible_devices = [1]
            else:
                self.gpu_worker_visible_devices =  [0]
                self.num_gpu_works_per_device =  [6]
                reanalyze_worker_id = len([i for i in range(len(node_config.worker_ip_list)) if i <= worker_id and i not in node_config.data_worker_ids])
                data_worker_address = node_config.worker_ip_list[reanalyze_worker_id % len(node_config.data_worker_ids)]

        # connection for remote storages
        free_port_list = SMOS_utils.get_local_free_port(4, 6000, 7000)
        self.shared_storage_connection = SMOS.ConnectionDescriptor(ip="localhost", port=free_port_list[0],
                                                                   authkey=b"speedyzero")
        self.replay_buffer_connection = SMOS.ConnectionDescriptor(ip="localhost", port=free_port_list[1],
                                                                  authkey=b"speedyzero")
        self.watchdog_server_connection = SMOS.ConnectionDescriptor(ip="localhost", port=free_port_list[2],
                                                                    authkey=b"speedyzero")
        self.smos_connection = SMOS.ConnectionDescriptor(ip="localhost", port=free_port_list[3],
                                                         authkey=b"speedyzero")

        if len(label) > 0:
            label = label + "-"
        self.label=label

        # name and number of SharedMemoryObject for each storage
        # replay buffer
        # name of replay buffer
        self.replay_buffer_name = label + "replaybuffer"
        # max number of entries in replay buffer
        self.replay_buffer_capacity = 16384 
        # max mcts child count, used for determine block size
        self.max_child_visits_count = 16
        # block size list (max_len = 400)
        self.replay_buffer_block_size_list = [4 * (1024 ** 2), 4096, 4096, 4096, 4096, 4096]
        # name of zombie queue (garbage collection)
        self.zombie_queue_name = label + "zombie"
        # max number of entries in zombie queue
        self.zombie_queue_capacity = 1024 * 128
        # block size of each entry in zombie queue
        self.zombie_queue_block_size = 32

        # mcts storage (cpu -> gpu queue)
        # name of mcts storage
        self.mcts_storage_name = label + "mcts_storage"
        # number of mcts storages
        self.mcts_storage_count = 1
        # capacity of each mcts storage
        self.mcts_storage_capacity = 36
        # block size of each entry in mcts storage
        self.mcts_block_size_list = [256 * (1024 ** 2), 32 * (1024 ** 2), 256 * (1024 ** 2), 32 * (1024 ** 2),
                                     32 * (1024 ** 2), 256 * (1024 ** 2), 32 * (1024 ** 2), 32 * (1024), 32 * (1024)]

        # priority storage
        self.priority_storage_name = label + "priority"
        # number of priority storage
        self.priority_storage_count = 4
        # capacity of each priority storage
        self.priority_storage_capacity = 64
        # block size of each entry in priority storage
        self.priority_block_size_list = [32 * (1024 ** 2)]

        # batch storage (gpu -> training queue)
        # name of batch storage
        self.batch_storage_name = label + "batch"
        # number of batch storages
        self.batch_storage_count_trainer = 1
        self.batch_storage_count_worker = 8
        # capacity of each batch storage
        self.batch_storage_capacity_trainer = 20
        self.batch_storage_capacity_worker = 16
        # block size for each entry in batch storage
        self.batch_storage_block_size_list = [512 * (1024 ** 2), 512 * 1024, 512 * 1024, 32 * 1024]
        self.batch_storage_block_size_list_worker = [512 * (1024 ** 2), 512 * 1024, 512 * 1024, 32 * 1024, 1024, 1024]

        # weights storage
        self.weights_storage_name = label + "weights"
        self.weights_storage_cycle_size = 8
        
        # reanalyze control storage
        self.reanalyze_control_storage_name = label + "reanalyze_control"
        self.reanalyze_control_storage_capacity = self.mcts_storage_capacity

        # xnode storages
        self.trainer_address = node_config.trainer_address # machine 19
        # batch related
        self.per_worker_batch_sender_count = 16
        self.batch_receiver_count = 4
        self.zmq_batch_ip_trainer = "*"
        self.zmq_batch_port_list_trainer = [i for i in range(BASE_INSIDE_PORT, BASE_INSIDE_PORT + self.batch_receiver_count)]
        self.zmq_batch_ip_worker = self.trainer_address
        self.zmq_batch_port_list_worker = [i for i in range(BASE_OUTSIDE_PORT, BASE_OUTSIDE_PORT + self.batch_receiver_count)]
        # training signal related
        self.zmq_signal_ip_trainer = "*"
        self.zmq_signal_port_trainer = BASE_INSIDE_PORT + self.batch_receiver_count
        self.zmq_signal_ip_worker = self.trainer_address
        self.zmq_signal_port_worker = BASE_OUTSIDE_PORT + self.batch_receiver_count
        self.signal_queue_name = label + "signal_queue"
        self.signal_queue_capacity = 16
        self.signal_queue_block_size_list = [128 * (1024 ** 2)]
        # priority related
        self.zmq_priority_ip_trainer = "*"
        self.zmq_priority_port_trainer = BASE_INSIDE_PORT + 1 + self.batch_receiver_count # +100 as a hack to avoid standard priority
        self.zmq_priority_ip_worker = self.trainer_address
        self.zmq_priority_port_worker = BASE_OUTSIDE_PORT + 1 + self.batch_receiver_count # + 100 as a hack to avoid standard priority
        self.priority_queue_name = label + "priority_queue"
        self.priority_queue_capacity = 64
        self.priority_queue_block_size_list = [128 * (1024 ** 2), 128 * (1024 ** 2), 128 * (1024 ** 2), 128]

        # replay buffer sync
        self.zmq_replay_buffer_ip_master = "*"
        self.zmq_replay_buffer_port_master = BASE_INSIDE_PORT
        self.zmq_replay_buffer_ip_worker = data_worker_address
        self.zmq_replay_buffer_port_worker = BASE_OUTSIDE_PORT

        # reanalyze speed control
        self.batches_per_10sec = 45

        # virtual test
        self.compressed_batch_path = "/workspace/batch.pt"
