from typing import List

import numpy as np


class ReplayMemory(object):

    def __init__(self, capacity: int, batch_size: int):
        self.storage = []
        self.capacity = capacity
        self.batch_size = batch_size
        self.ptr = 0

    def add(self, transistions):
        if isinstance(transistions, List):
            for transition in transistions:
                self._add(transition)
        else:
            self._add(transistions)

    def put(self, transistions):
        self.add(transistions)

    def _add(self, transistion):
        if len(self.storage) == self.capacity:
            self.storage[int(self.ptr)] = transistion
            self.ptr = (self.ptr + 1) % self.capacity
        else:
            self.storage.append(transistion)

    def get(self):
        return self.sample()

    def sample(self):
        ind = np.random.randint(0, len(self.storage), size=self.batch_size)

        transition_list = []
        for i in ind:
            transition_list.append(self.storage[i])

        return transition_list
