import numpy as np
import random


class ReplayBuffer:
    
    def __init__(self, n_samples=1000000, n_batch=32):
        self.n_samples = n_samples
        self.n_batch = n_batch
    
    def reset(self):
        self.buffer = np.empty(self.n_samples, dtype=object)
        self.index = 0
        self.size = 0
    
    def replay(self):
        if self.size < self.n_batch: return None
        indices = np.random.randint(low=0, high=self.size, size=(self.n_batch,))
        states, actions, rewards, next_states, gammas, _ = zip(*self.buffer[indices])
        states = np.vstack(states)
        actions = np.array(actions)
        rewards = np.vstack(rewards)
        next_states = np.vstack(next_states)
        gammas = np.array(gammas)
        return states, actions, rewards, next_states, gammas
    
    def append(self, state, action, reward, next_state, gamma, other=None):
        self.buffer[self.index] = (state, action, reward, next_state, gamma, other)
        self.size = min(self.size + 1, self.n_samples)
        self.index = (self.index + 1) % self.n_samples
    
    def all(self):
        return self.buffer[:self.size]
