import torch
import numpy as np


class ReplayMem:
    def __init__(self, do, da, capacity=int(1e6), device='cpu'):
        self.o_buffer = np.zeros((capacity, do), np.float32)
        self.op_buffer = np.zeros((capacity, do), np.float32)
        self.a_buffer = np.zeros((capacity, da), np.float32)
        self.r_buffer = np.zeros((capacity,), np.float32)
        self.d_buffer = np.zeros((capacity,), np.float32)
        self.device = device
        self.size = 0
        self.p = 0
        self.capacity = capacity

    def add(self, o, a, r, op, d):
        self.o_buffer[self.p] = o
        self.a_buffer[self.p] = a
        self.r_buffer[self.p] = r
        self.op_buffer[self.p] = op
        self.d_buffer[self.p] = float(d)
        self.p = (self.p + 1) % self.capacity
        if self.size < self.capacity:
            self.size += 1

    def sample(self, n):
        idxes = np.random.randint(0, self.size, size=n)
        data = (
            torch.as_tensor(self.o_buffer[idxes], device=self.device),
            torch.as_tensor(self.a_buffer[idxes], device=self.device),
            torch.as_tensor(self.r_buffer[idxes], device=self.device),
            torch.as_tensor(self.op_buffer[idxes], device=self.device),
            torch.as_tensor(self.d_buffer[idxes], device=self.device),
        )
        return data, None

    def __len__(self):
        return self.size

    def clear(self):
        self.size = 0

