import numpy as np

from .list_replay_buffer import SimpleReplayBuffer


class NumpyReplayBuffer(SimpleReplayBuffer):
    def __init__(self, env, *args, **kwargs):
        self.env = env  # must be put before super().__init__ as the latter calls `_buf_init`
        super().__init__(*args, **kwargs)

    def _buf_init(self, key, max_buf_size):
        shapes = {  # TODO: infer type
            'state': (np.float32, self.env.observation_space.shape),
            'action': (np.float32, self.env.action_space.shape),
            'next_state': (np.float32, self.env.observation_space.shape),
            'reward': (np.float32, ()),
            'done': (np.bool, ()),
            'timeout': (np.bool, ()),
        }
        dtype, shape = shapes[key]
        return np.empty([max_buf_size, *shape], dtype=dtype)

    def _buf_add(self, buf, idx, data):
        buf[idx] = data

    def _buf_sample(self, buf, indices):
        return buf[indices]
