import numpy as np
from multiprocessing import shared_memory

MAX_LENGTH = 400000

class MetaDataSharedMemoryManager:
    def __init__(self, config, storage_config, create=False):
        self.shm = dict()
        self.arr = dict()

        self.__create = create

        self.register_item("weights_location", 1, np.int64)
        self.register_item("trained_steps", 1, np.int64)
        
        # replay buffer
        self.register_item("priorities", MAX_LENGTH, np.float32)
        self.register_item("rewards", MAX_LENGTH, np.float32)
        self.register_item("values", MAX_LENGTH, np.float32)
        self.register_item("target_values", MAX_LENGTH, np.float32)
        self.register_item("death_masks", MAX_LENGTH, np.float32)
        self.register_item("game_ids", MAX_LENGTH, np.float32)
        self.register_item("valid_entries", MAX_LENGTH, np.float32)

    def register_item(self, item_name, size, dtype):
        assert item_name not in self.shm
        if isinstance(size, int):
            size = (size,)
        a = np.zeros(size, dtype=dtype)
        self.shm[item_name] = shared_memory.SharedMemory(create=self.__create, name=item_name, size=a.nbytes)
        self.arr[item_name] = np.ndarray(size, dtype=dtype, buffer=self.shm[item_name].buf)
    
    def get(self, item_name):
        return self.arr.get(item_name, None)

    def __close__(self):
        for _, shm in self.shm.items():
            shm.close()
            if self.__create:
                shm.unlink()