import jax
import jax.numpy as jnp


class ReplayBuffer:
    def __init__(self, buffer_size, batch_size, sample):
        self.buffer_size = buffer_size
        self.buffer = []
        self.batch_size = batch_size
        self.idx = 0
        self.len_observation_space = len(sample)

    def add(self, state, action, next_state, reward, absorbing):
        experience = jnp.concatenate(
            [
                state,
                jnp.expand_dims(jnp.array(action), 0),
                next_state,
                jnp.expand_dims(reward, 0),
                jnp.expand_dims(absorbing, 0),
            ],
            axis=0,
        )
        if len(self.buffer) < self.buffer_size:
            self.buffer.append(experience)
        else:
            self.buffer[self.idx] = experience
            self.idx = (self.idx + 1) % self.buffer_size

    def sample(self, key, batch_size=None):
        batch_size = self.batch_size if batch_size is None else batch_size
        indices = jax.random.randint(
            key,
            shape=(batch_size,),
            minval=0,
            maxval=len(self.buffer),
        )
        batch = jnp.array([self.buffer[i] for i in indices])
        states = batch[:, : self.len_observation_space]
        actions = batch[:, self.len_observation_space].astype(int)
        next_states = batch[
            :, self.len_observation_space + 1 : 2 * self.len_observation_space + 1
        ]
        rewards = batch[:, -2]
        absorbings = batch[:, -1].astype(int)
        return states, actions, next_states, rewards, absorbings

    def __len__(self):
        return len(self.buffer)
