import numpy as np
import operator


class ReplayBuffer(object):
    def __init__(self, max_size=1e6):
        self.storage = []
        self.max_size = max_size
        self.ptr = 0

    def add(self, data):
        if len(self.storage) == self.max_size:
            self.storage[int(self.ptr)] = data
            self.ptr = (self.ptr + 1) % self.max_size
        else:
            self.storage.append(data)

    def reset(self):
        self.storage = []
        self.ptr = 0

    def get_all_data(self):
        if len(self.storage) == 0:
            return None, None, None, None, None

        x, y, u, r, d = list(zip(*self.storage))

        return np.array(x).copy(), np.array(y).copy(), np.array(u).copy(), np.array(r).copy(), np.array(d).copy()

    def add_batch_data(self, x, y, u, r, d):
        data_to_add = list(zip(x, y, u, r, d))
        # TODO: more efficient way: make storage as a fixed size and add everything
        for data in data_to_add:
            self.add(data)

    def add_batch(self, batch):
        new_storage = self.storage + batch
        if len(new_storage) >= self.max_size:
            del new_storage[:int(len(new_storage)-self.max_size)]
            self.storage = new_storage
        else:
            self.storage = new_storage

    def sample(self, batch_size):

        if len(self.storage) < batch_size:
            return None

        ind = np.random.randint(0, len(self.storage), size=batch_size)
        op = operator.itemgetter(*ind)
        x, y, u, r, d, o = list(zip(*op(self.storage)))

        return {'observations': np.array(x).copy(),
            'new_observations': np.array(y).copy(), 'actions': np.array(u).copy(),
            'rewards': np.array(r).copy(), 'terminals': np.array(d).copy(), 'behaviors': np.array(o).copy()}

    def save(self, outfile):
        np.save(outfile, self.storage)
        print(f"* {outfile} succesfully saved..")
