import numpy as np

from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer

import torch.multiprocessing as mp
import ctypes

class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer):
    """
    Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed
    by multiprocessing arrays. The replay buffer size is also shared. The
    intended use case is for if one wants obs/next_obs to be shared between
    processes. Accesses are synchronized internally by locks (mp takes care
    of that). Technically, putting such large arrays in shared memory/requiring
    synchronized access can be extremely slow, but it seems ok empirically.

    This code also breaks a lot of functionality for the subprocess. For example,
    random_batch is incorrect as actions and _idx_to_future_obs_idx are not
    shared. If the subprocess needs all of the functionality, a mp.Array
    must be used for all numpy arrays in the replay buffer.

    """

    def __init__(
        self,
        *args,
        **kwargs
    ):
        self._shared_size = mp.Value(ctypes.c_long, 0)
        ObsDictRelabelingBuffer.__init__(self, *args, **kwargs)

        self._mp_array_info = {}
        self._shared_obs_info = {}
        self._shared_next_obs_info = {}

        for obs_key, obs_arr in self._obs.items():
            ctype = ctypes.c_float
            if obs_arr.dtype == np.uint8:
                ctype = ctypes.c_uint8

            self._shared_obs_info[obs_key] = (
                mp.Array(ctype, obs_arr.size),
                obs_arr.dtype,
                obs_arr.shape,
            )
            self._shared_next_obs_info[obs_key] = (
                mp.Array(ctype, obs_arr.size),
                obs_arr.dtype,
                obs_arr.shape,
            )

            self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
            self._next_obs[obs_key] = to_np(*self._shared_next_obs_info[obs_key])
        self._register_mp_array("_actions")
        self._register_mp_array("_terminals")

    def _register_mp_array(self, arr_instance_var_name):
        """
        Use this function to register an array to be shared. This will wipe arr.
        """
        assert hasattr(self, arr_instance_var_name), arr_instance_var_name
        arr = getattr(self, arr_instance_var_name)

        ctype = ctypes.c_float
        if arr.dtype == np.uint8:
            ctype = ctypes.c_uint8

        self._mp_array_info[arr_instance_var_name] = (
            mp.Array(ctype, arr.size), arr.dtype, arr.shape,
        )
        setattr(
            self,
            arr_instance_var_name,
            to_np(*self._mp_array_info[arr_instance_var_name])
        )

    def init_from_mp_info(
        self,
        mp_info,
    ):
        """
        The intended use is to have a subprocess serialize/copy a
        SharedObsDictRelabelingBuffer instance and call init_from on the
        instance's shared variables. This can't be done during serialization
        since multiprocessing shared objects can't be serialized and must be
        passed directly to the subprocess as an argument to the fork call.
        """
        shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info

        self._shared_obs_info = shared_obs_info
        self._shared_next_obs_info = shared_next_obs_info
        self._mp_array_info = mp_array_info
        for obs_key in self._shared_obs_info.keys():
            self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
            self._next_obs[obs_key] = to_np(*self._shared_next_obs_info[obs_key])

        for arr_instance_var_name in self._mp_array_info.keys():
            setattr(
                self,
                arr_instance_var_name,
                to_np(*self._mp_array_info[arr_instance_var_name])
            )
        self._shared_size = shared_size

    def get_mp_info(self):
        return (
            self._shared_obs_info,
            self._shared_next_obs_info,
            self._mp_array_info,
            self._shared_size,
        )

    @property
    def _size(self):
        return self._shared_size.value

    @_size.setter
    def _size(self, size):
        self._shared_size.value = size


def to_np(shared_arr, np_dtype, shape):
    return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape)

