from brl.utils import *
import heapq

class Filter():
    def __init__(self, max_size):
        self.max_size = max_size
        self.data = []

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

    def __str__(self):
        return str(self.data)

    def __repr__(self):
        return repr(self.data)

    def append(self, item):
        heapq.heappush(self.data, item)
        if len(self.data) > self.max_size:
            heapq.heappop(self.data)


class FIFO(object):
    def __init__(self, maxlen, shape, dtype=np.float32):
        self.maxlen = maxlen
        self.start = 0
        self.length = 0
        self.data = np.zeros((maxlen,) + shape).astype(dtype)
        self.shape = shape
        self.dtype = dtype

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx < 0 or idx >= self.length:
            raise KeyError()
        return self.data[(self.start + idx) % self.maxlen]

    def __setitem__(self, idx, v):
        if idx < 0 or idx >= self.length:
            raise KeyError()
        self.data[(self.start + idx) % self.maxlen] = v

    def get_batch(self, idxs):
        return self.data[(self.start + idxs) % self.maxlen]

    def append(self, v):
        assert isinstance(v, (np.floating, float)) or (isinstance(v, np.ndarray) and v.shape == self.shape), v

        if self.length < self.maxlen:
            # We have space, simply increase the length.
            self.length += 1
        elif self.length == self.maxlen:
            # No space, "remove" the first item.
            self.start = (self.start + 1) % self.maxlen
        else:
            # This should never happen.
            raise RuntimeError()

        self.data[(self.start + self.length - 1) % self.maxlen] = v

    def pop(self, num_items=1):
        assert(0 <= num_items <= self.length)
        self.start = (self.start + num_items) % self.maxlen
        self.length -= num_items


if __name__ == "__main__":
    from collections import namedtuple
    class Point(namedtuple('Point', ['x', 'y'])):
        def __lt__(self,other):
            return self.x**2+self.y**2 < other.x**2+other.y**2
        def __str__(self):
            return '({},{})'.format(self.x,self.y)

    x = Filter(3)
    x.append(Point(3,1)) # key=10
    x.append(Point(2,2)) # key=8
    x.append(Point(2,1)) # key=5
    x.append(Point(3,0)) # key=9

    print(x)
    for item in x:
        print(item)


