
import random

class ReplayBuffer():
    def __init__(self, size=1000):
        super(ReplayBuffer, self).__init__()
        self.data = []
        self.max_size = size
        self.position = -1
            
    def append(self, obs, actions, next_obs, rewards, dones):
        if len(self.data) < self.max_size:
            self.data.append((obs, actions, next_obs, rewards, dones))
        else:
            self.data[self.position] = (obs, actions, next_obs, rewards, dones)
            self.position = (self.position + 1) % self.max_size
    
    # def append(self, buffer):
    #     if len(self.data) + len(buffer) < self.max_size:
    #         self.data.append(buffer)
    #     else:
    #         self.data[self.position] = buffer
    #         self.position = (self.position + 1) % self.max_size
        
    def sample(self, batch_size):
        return random.sample(self.data, batch_size)
    
    def clear(self):
        self.data = []
        self.position = 0
        
    def size(self):
        return len(self.data)
        
    