import numpy as np

from bgp.rlkit.data_management.replay_buffer import ReplayBuffer


class SimpleReplayBuffer(ReplayBuffer):
    def __init__(
            self, max_replay_buffer_size, observation_dim, action_dim,
    ):
        self._observation_dim = observation_dim
        self._action_dim = action_dim
        self._max_replay_buffer_size = max_replay_buffer_size
        if isinstance(observation_dim, tuple):
            self._observations = np.zeros([max_replay_buffer_size] + [d for d in observation_dim])
            self._next_obs = np.zeros([max_replay_buffer_size] + [d for d in observation_dim])
        else:
            self._observations = np.zeros((max_replay_buffer_size, observation_dim))
            # It's a bit memory inefficient to save the observations twice,
            # but it makes the code *much* easier since you no longer have to
            # worry about termination conditions.
            self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
        if isinstance(action_dim, tuple):
            self._actions = np.zeros([max_replay_buffer_size] + [d for d in action_dim])
        else:
            self._actions = np.zeros((max_replay_buffer_size, action_dim))

        self._observations = np.zeros((max_replay_buffer_size, 96))
        self._next_obs = np.zeros((max_replay_buffer_size, 96))
        self._actions = np.zeros((max_replay_buffer_size, 1)) 

        # Make everything a 2D np array to make it easier for other code to
        # reason about the shape of the data
        self._rewards = np.zeros((max_replay_buffer_size, 1))
        # self._terminals[i] = a terminal was received at time i
        self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
        self._top = 0
        self._size = 0

    def add_sample(self, observation, action, reward, terminal,
                   next_observation, **kwargs):
        for o, a, r, t, n in zip(observation, action, reward, terminal, next_observation):
            self._observations[self._top] = o.flatten()
            self._actions[self._top] = a
            self._rewards[self._top] = r
            self._terminals[self._top] = t
            self._next_obs[self._top] = n.flatten()
            self._advance()

    def terminate_episode(self):
        pass

    def _advance(self):
        self._top = (self._top + 1) % self._max_replay_buffer_size
        if self._size < self._max_replay_buffer_size:
            self._size += 1

    def random_batch(self, batch_size):
        indices = np.random.randint(0, self._size, batch_size)
        return dict(
            observations=self._observations[indices],
            actions=self._actions[indices],
            rewards=self._rewards[indices],
            terminals=self._terminals[indices],
            next_observations=self._next_obs[indices],
        )

    def num_steps_can_sample(self):
        return self._size
