import random
import numpy as np
import os
import torch

class ReplayMemory:
    def __init__(self, capacity, seed, store_all_data=False):
        random.seed(seed)
        self.capacity = capacity
        self.buffer = []
        self.buffer_full = []
        self.position = 0
        self.position_full = 0
        self.store_all_data = store_all_data

    def push(self, state, action, reward, next_state, mask, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, mask, done)
        self.position = (self.position + 1) % self.capacity
        if self.store_all_data:
            self._push_full_buffer(state, action, reward, next_state, mask, done)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, mask, _ = map(np.stack, zip(*batch))
        return state, action, reward, next_state, mask

    def __len__(self):
        return len(self.buffer)

    def _push_full_buffer(self, state, action, reward, next_state, mask, done):
        '''Keep all data, even if not used for training in here, if needed.'''
        self.buffer_full.append(None)
        self.buffer_full[self.position_full] = (state, action, reward, next_state, mask, done)
        self.position_full += 1
