"""
Data Cache Utils

Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""

import os
import SharedArray

try:
    from multiprocessing.shared_memory import ShareableList
except ImportError:
    import warnings

    warnings.warn("Please update python version >= 3.8 to enable shared_memory")
import numpy as np


def shared_array(name, var=None):
    if var is not None:
        # check exist
        if os.path.exists(f"/dev/shm/{name}"):
            return SharedArray.attach(f"shm://{name}")
        # create shared_array
        data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype)
        data[...] = var[...]
        data.flags.writeable = False
    else:
        data = SharedArray.attach(f"shm://{name}").copy()
    return data


def shared_dict(name, var=None):
    name = str(name)
    assert "." not in name  # '.' is used as sep flag
    data = {}
    if var is not None:
        assert isinstance(var, dict)
        keys = var.keys()
        # current version only cache np.array
        keys_valid = []
        for key in keys:
            if isinstance(var[key], np.ndarray):
                keys_valid.append(key)
        keys = keys_valid

        ShareableList(sequence=keys, name=name + ".keys")
        for key in keys:
            if isinstance(var[key], np.ndarray):
                data[key] = shared_array(name=f"{name}.{key}", var=var[key])
    else:
        keys = list(ShareableList(name=name + ".keys"))
        for key in keys:
            data[key] = shared_array(name=f"{name}.{key}")
    return data
