import numpy as np


class Buffer:

    def __init__(self, n_elements, max_buffer_size, reset_on_query):
        self.reset_on_query = reset_on_query
        self.max_buffer_size = max_buffer_size
        self.buffers = [list() for i in range(0, n_elements)]

    def update_buffer(self, datas):
        if isinstance(datas[0], list):
            for buffer, data in zip(self.buffers, datas):
                buffer.extend(data)
        else:
            for buffer, data in zip(self.buffers, datas):
                buffer.append(data)

        while len(self.buffers[0]) > self.max_buffer_size:
            for buffer in self.buffers:
                del buffer[0]

    def read_buffer(self, reset=None):
        if reset is None:
            reset = self.reset_on_query

        res = tuple([buffer for buffer in self.buffers])

        if reset:
            for i in range(0, len(self.buffers)):
                self.buffers[i] = []

        return res

    def __len__(self):
        return len(self.buffers[0])


class Discretizer:

    def __init__(self, contexts):
        self.contexts = contexts

    def __call__(self, continuous_sample):
        # In the discrete case we expect to always obtain a discrete sample, so we discretize it
        if continuous_sample >= self.contexts[-1, :]:
            return np.array(self.contexts.shape[0] - 1)

        if continuous_sample <= self.contexts[0, :]:
            return np.array(0)

        idx = np.argmax(self.contexts > continuous_sample)
        l_idx = idx - 1
        if np.random.uniform(0., 1.) >= continuous_sample - self.contexts[l_idx]:
            return np.array(l_idx)
        else:
            return np.array(idx)
