import numpy as np
import threading


class ReplayBuffer:
    def __init__(self, buffer_size=1024):
        self.size = buffer_size

        # memory management
        self.current_size = 0
        # create the buffer to store info
        self.buffers = []
        # thread lock
        self.lock = threading.Lock()

        # store the episode
    def store_episode(self, episode_batch):
        with self.lock:
            self.buffers += episode_batch
            if self.current_size > self.size:
                self.buffers = self.buffers[-self.size:]

            self.current_size = len(self.buffers)


    def sample(self, batch_size):
        temp_buffer = {}

        random_integers = np.random.choice(np.arange(0, self.current_size), batch_size, replace=False)
        samples = [self.buffers[idx] for idx in random_integers]

        for key in samples[0].keys():
            temp_buffer[key] = []

        for sample in samples:
            for key, value in sample.items():
                temp_buffer[key].append(value)

        return temp_buffer


