
import torch
import numpy as np
import random as r


class Buffer(torch.nn.Module):
    def __init__(self, max_size=200, shape=(3,32,32), n_classes=10, device="cuda"):
        super().__init__()
        self.n_classes = n_classes  # For print purposes only
        self.max_size = max_size
        self.shape = shape
        self.n_seen_so_far = 0
        self.n_added_so_far = 0
        self.device = device
        if self.shape is not None:
            if len(self.shape) == 3:
                self.register_buffer('buffer_imgs', torch.FloatTensor(self.max_size, self.shape[0], self.shape[1], self.shape[2]).fill_(0))
            elif len(self.shape) == 1:
                self.register_buffer('buffer_imgs', torch.FloatTensor(self.max_size, self.shape[0]).fill_(0))
        self.register_buffer('buffer_labels', torch.LongTensor(self.max_size).fill_(-1))

    def update(self, imgs, labels=None):
        raise NotImplementedError

    def stack_data(self, img, label):
        if self.n_seen_so_far < self.max_size:
            self.buffer_imgs[self.n_seen_so_far] = img
            self.buffer_labels[self.n_seen_so_far] = label
            self.n_added_so_far += 1

    def replace_data(self, idx, img, label):
        self.buffer_imgs[idx] = img
        self.buffer_labels[idx] = label
        self.n_added_so_far += 1
    
    def is_empty(self):
        return self.n_added_so_far == 0
    
    def random_retrieve(self, n_imgs=100):
        if self.n_added_so_far < n_imgs:
            return self.buffer_imgs[:self.n_added_so_far], self.buffer_labels[:self.n_added_so_far]
        
        ret_indexes = r.sample(np.arange(min(self.n_added_so_far, self.max_size)).tolist(), n_imgs)
        ret_imgs = self.buffer_imgs[ret_indexes]
        ret_labels = self.buffer_labels[ret_indexes]
        
        return ret_imgs, ret_labels
    
    def only_retrieve(self, n_imgs, desired_labels):
        """Retrieve images belonging only to the set of desired labels

        Args:
            n_imgs (int):                    Number of images to retrieve 
            desired_labels (torch.Tensor): tensor of desired labels to retrieve from
        """
        desired_labels = torch.tensor(desired_labels)

        valid_indexes = torch.isin(self.buffer_labels, desired_labels).nonzero().view(-1)
        n_out = min(n_imgs, len(valid_indexes))
        out_indexes = np.random.choice(valid_indexes, n_out)
        
        return self.buffer_imgs[out_indexes], self.buffer_labels[out_indexes]
    
    def except_retrieve(self, n_imgs, undesired_labels):
        """Retrieve images except images of undesired labels

        Args:
            n_imgs (int):                  Number of images to retrieve 
            desired_labels (torch.Tensor): tensor of desired labels to retrieve from
        """
        undesired_labels = torch.tensor(undesired_labels + [-1])
        valid_indexes = (~torch.isin(self.buffer_labels, undesired_labels)).nonzero().view(-1)
        n_out = min(n_imgs, len(valid_indexes))
        out_indexes = np.random.choice(valid_indexes, n_out)
        
        return self.buffer_imgs[out_indexes], self.buffer_labels[out_indexes]
    
    def dist_retrieve(self, means, model, n_imgs=100):
        """
        Retrieve images from the buffer based on their distances from a set of means.

        Args:
            means (dict): A dictionary where keys are class labels and values are the corresponding means.
            model (nn.Module): A PyTorch model used to compute the distances.
            n_imgs (int, optional): The number of images to retrieve. Defaults to 100.

        Returns:
            tuple: A tuple containing the retrieved images and their corresponding labels.
        """
        if self.n_added_so_far < n_imgs:
            return self.buffer_imgs[:self.n_added_so_far], self.buffer_labels[:self.n_added_so_far]
        
        # model.eval()
        with torch.no_grad():
            _, p_mem = model(self.buffer_imgs[:self.n_added_so_far].to(self.device))

        m = torch.zeros((p_mem.shape[1], self.n_classes)).to(self.device)
        for c in means:
            m[:, int(float(c))] = means[f'{c}']

        dists = p_mem @ m
        # Get distances from kown classes only
        dists = dists[torch.arange(dists.size(0)), self.buffer_labels[:self.n_added_so_far]]
        sorted_idx = dists.sort(descending=True).indices
        ret_indexes = []
        # ensuring we get some of each class
        for c in self.buffer_labels[:self.n_added_so_far].unique():
            idx = torch.where((self.buffer_labels[:self.n_added_so_far][sorted_idx] == c))[0][:int(n_imgs/len(self.buffer_labels[:self.n_added_so_far].unique()))]
            ret_indexes.append(idx)
        ret_indexes = torch.cat(ret_indexes)
        ret_imgs = self.buffer_imgs[ret_indexes]
        ret_labels = self.buffer_labels[ret_indexes]

        return ret_imgs, ret_labels
    
    def dist_update(self, means, model, imgs, labels, **kwargs):
        # model.eval()
        # with torch.no_grad():
        #     _, p_mem = model(self.buffer_imgs[:self.n_added_so_far].to(self.device))
        
        # m = torch.zeros((p_mem.shape[1], self.n_classes)).to(self.device)
        # for c in means:
        #     m[:, int(float(c))] = means[f'{c}']

        # dists = p_mem @ m
        for stream_data, stream_label in zip(imgs, labels):
            if self.n_added_so_far < self.max_size:
                self.stack_data(stream_data, stream_label)
            else:
                max_img_per_class = self.get_max_img_per_class()
                class_indexes = self.get_indexes_of_class(stream_label)
                # Do nothing if class has reached maximum number of images
                if len(class_indexes) <= max_img_per_class:
                    # Drop img of major class if not
                    major_class = self.get_major_class()
                    class_indexes = self.get_indexes_of_class(major_class)

                    # compute distances to mean
                    model.eval()
                    with torch.no_grad():
                        _, p_mem = model(self.buffer_imgs[class_indexes.squeeze()].to(self.device))
                    
                    m = means[f'{major_class}.0'].to(self.device)

                    dists = p_mem @ m
                    # idx = class_indexes.squeeze()[dists.argmax()]
                    idx = class_indexes.squeeze()[dists.argmin()]
                    self.replace_data(idx, stream_data, stream_label)
            self.n_seen_so_far += 1
    
    def bootstrap_retrieve(self, n_imgs=100):
        if self.n_added_so_far == 0:
            return torch.Tensor(), torch.Tensor() 
        ret_indexes = [r.randint(0, min(self.n_added_so_far, self.max_size)-1) for _ in range(n_imgs)]            
        ret_imgs = self.buffer_imgs[ret_indexes]
        ret_labels = self.buffer_labels[ret_indexes]

        return ret_imgs, ret_labels
        
    def n_data(self):
        return len(self.buffer_labels[self.buffer_labels >= 0])

    def get_all(self):
        return self.buffer_imgs[:min(self.n_added_so_far, self.max_size)],\
             self.buffer_labels[:min(self.n_added_so_far, self.max_size)]

    def get_indexes_of_class(self, label):
        return torch.nonzero(self.buffer_labels == label)
    
    def get_indexes_out_of_class(self, label):
        return torch.nonzero(self.buffer_labels != label)

    def is_full(self):
        return self.n_data() == self.max_size

    def get_labels_distribution(self):
        np_labels = self.buffer_labels.numpy().astype(int)
        counts = np.bincount(np_labels[self.buffer_labels >= 0], minlength=self.n_classes)
        tot_labels = len(self.buffer_labels[self.buffer_labels >= 0])
        if tot_labels > 0:
            return counts / len(self.buffer_labels[self.buffer_labels >= 0])
        else:
            return counts

    def get_major_class(self):
        np_labels = self.buffer_labels.numpy().astype(int)
        counts = np.bincount(np_labels[self.buffer_labels >= 0])
        return counts.argmax()

    def get_max_img_per_class(self):
        n_classes_in_memory = len(self.buffer_labels.unique())
        return int(len(self.buffer_labels[self.buffer_labels >= 0]) / n_classes_in_memory)

class ProtoBuffer(Buffer):
    def __init__(self, max_size=200, shape=32, device="cuda", **kwargs):
        """Reservoir sampling with images + logits for derpp.
        """
        super().__init__(max_size=max_size, shape=shape, device=device, n_classes=kwargs.get('n_classes', 10))
        n_comp = kwargs.get('n_comp', 100)
        self.n_seen_so_far = 0
        self.n_added_so_far = 0
        self.register_buffer('buffer_q', torch.FloatTensor(self.max_size, self.shape[0]).fill_(0))
        self.register_buffer('buffer_l', torch.LongTensor(self.max_size).fill_(-1))
        self.register_buffer('buffer_f', torch.FloatTensor(self.max_size, self.shape[0]).fill_(0))
        self.register_buffer('buffer_k', torch.FloatTensor(self.max_size, 20, self.shape[0]).fill_(0))
        self.register_buffer('buffer_v', torch.FloatTensor(self.max_size, 20, self.shape[0]).fill_(0))
        # self.register_buffer('buffer_aqk', torch.FloatTensor(self.max_size, int(5 * n_comp)).fill_(0))
        self.register_buffer('n_updated', torch.LongTensor(self.max_size).fill_(0))
        self.register_buffer('w_updated', torch.FloatTensor(self.max_size).fill_(0))
        self.drop_method = kwargs.get('drop_method', 'random')

    def update(self, queries=None, keys=None, values=None, labels=None, features=None, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            queries (torch.tensor): stream queries seen by the buffer
        """
        if queries is not None:
            for q, k, v, l, f in zip(queries, keys, values, labels, features):
                q, k, v, l, f = q.cpu(), k.cpu(), v.cpu(), l.long().cpu(), f.cpu()
                
                self.buffer_q[l] = (self.buffer_q[l] * self.n_updated[l] + q) / (self.n_updated[l] + 1)
                self.buffer_k[l] = (self.buffer_k[l] * self.n_updated[l] + k) / (self.n_updated[l] + 1)
                self.buffer_v[l] = (self.buffer_v[l] * self.n_updated[l] + v) / (self.n_updated[l] + 1)
                self.buffer_f[l] = (self.buffer_f[l] * self.n_updated[l] + f) / (self.n_updated[l] + 1)
                self.buffer_l[l] = l
                
                # w = 1/(lo.detach().cpu() + 15)
                # w= 1
                # self.buffer_q[l] = (self.buffer_q[l] * self.w_updated[l] + w*q) / (self.w_updated[l] + w)
                # self.buffer_k[l] = (self.buffer_k[l] * self.w_updated[l] + w*k) / (self.w_updated[l] + w)
                # self.buffer_v[l] = (self.buffer_v[l] * self.w_updated[l] + w*v) / (self.w_updated[l] + w)
                # self.buffer_f[l] = (self.buffer_f[l] * self.w_updated[l] + w*f) / (self.w_updated[l] + w)
                # self.buffer_l[l] = l
                
                # self.buffer_q[l] = q
                # self.buffer_k[l] = k
                # self.buffer_v[l] = v
                # self.buffer_f[l] = (self.buffer_f[l] * self.n_updated[l] + f) / (self.n_updated[l] + 1)
                # self.buffer_l[l] = l
                
                self.n_updated[l] += 1
                # self.w_updated[l] += w
                self.n_seen_so_far += 1
                self.n_added_so_far += 1
        else:
            for l, f in zip(labels, features):
                l,  f = l.long().cpu(), f.cpu()
                
                self.buffer_f[l] = (self.buffer_f[l] * self.n_updated[l] + f) / (self.n_updated[l] + 1)
                self.buffer_l[l] = l
                self.n_updated[l] += 1
                self.n_seen_so_far += 1
                self.n_added_so_far += 1
        
    def update2(self, queries, keys, values, labels, features, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            queries (torch.tensor): stream queries seen by the buffer
        """
        weights = kwargs.get('weights', None)
        weights = weights[labels]
        for l, f, w in zip(labels, features, weights):
            l, f, w = l.long().cpu(), f.cpu(), w.cpu()
            self.buffer_f[l] = (self.buffer_f[l] * self.w_updated[l] + w*f) / (self.w_updated[l] + w)
            self.buffer_l[l] = l
            self.w_updated[l] += w
            self.n_updated[l] += 1
            
            self.n_seen_so_far += 1
            self.n_added_so_far += 1
    
    def update3(self, queries, keys, values, labels, features, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            queries (torch.tensor): stream queries seen by the buffer
        """
        weights = kwargs.get('weights', None)
        weights = weights[labels]
        for l, f, w in zip(labels, features, weights):
            l, f, w = l.long().cpu(), f.cpu(), w.cpu()
            w = 1 / (w + 1e-6)
            self.buffer_f[l] = (self.buffer_f[l] * self.w_updated[l] + w*f) / (self.w_updated[l] + w)
            self.buffer_l[l] = l
            self.w_updated[l] += w
            self.n_updated[l] += 1

            self.n_seen_so_far += 1
            self.n_added_so_far += 1
    
    def update_ema(self, queries, keys, values, labels, features, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            queries (torch.tensor): stream queries seen by the buffer
        """
        alpha = kwargs.get('ema_proto', 0.9)  # EMA smoothing factor

        for l, f in zip(labels, features):
            l, f= l.long().cpu(), f.cpu()

            if self.n_updated[l] == 0:
                self.buffer_f[l] = f  # Initialize with first feature
                # self.w_updated[l] = 1
            else:
                self.buffer_f[l] = alpha * f + (1 - alpha) * self.buffer_f[l]
                # self.w_updated[l] = alpha * w + (1 - alpha) * self.w_updated[l]

            self.buffer_l[l] = l
            self.n_updated[l] += 1
            self.n_seen_so_far += 1
            self.n_added_so_far += 1

            
    # def update3(self, queries, keys, values, labels, features, **kwargs):
    #     """Update buffer with the given list of images and labels.
    #         Note that labels are not used update selection, only when storing the image in memory.
    #     Args:
    #         queries (torch.tensor): stream queries seen by the buffer
    #     """
    #     loss = kwargs.get('loss', None)
    #     for q, k, v, l, f, lo in zip(queries, keys, values, labels, features, loss):
    #         q, k, v, l, f = q.cpu(), k.cpu(), v.cpu(), l.long().cpu(), f.cpu()
    #         self.buffer_f[l] = (self.buffer_f[l] * self.n_updated[l] + f) / (self.n_updated[l] + 1)
    #         self.buffer_l[l] = l
    #         self.n_updated[l] += 1
            
    #         # init the first q values
    #         if self.n_added_so_far < self.max_size:
    #             self.buffer_q[self.n_added_so_far] = q
    #             self.buffer_k[self.n_added_so_far] = k
    #             self.buffer_v[self.n_added_so_far] = v
    #         else:
    #             # q_norm = q / torch.sqrt((q**2).sum())
    #             # buffer_q_norm = self.buffer_q / torch.sqrt((self.buffer_q ** 2)).sum(1, keepdim=True)
    #             sim = q @ self.buffer_q.T
                
    #             idx = sim.argmax()
    #             self.buffer_q[idx] = (self.buffer_q[idx] * self.w_updated[idx] + q) / (self.w_updated[idx] + 1)
    #             self.buffer_k[idx] = (self.buffer_k[idx] * self.w_updated[idx] + k) / (self.w_updated[idx] + 1)
    #             self.buffer_v[idx] = (self.buffer_v[idx] * self.w_updated[idx] + v) / (self.w_updated[idx] + 1)
                
    #             self.w_updated[idx] += 1
    #         self.n_seen_so_far += 1
    #         self.n_added_so_far += 1
    
    def update_aqk(self, queries, keys, values, labels, features, loss, aqk, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            queries (torch.tensor): stream queries seen by the buffer
        """
        for q, k, v, l, f, lo, a in zip(queries, keys, values, labels, features, loss, aqk):
            q, k, v, l, f, a = q.cpu(), k.cpu(), v.cpu(), l.long().cpu(), f.cpu(), a.cpu()

            self.buffer_q[l] = (self.buffer_q[l] * self.n_updated[l] + q) / (self.n_updated[l] + 1)
            self.buffer_k[l] = (self.buffer_k[l] * self.n_updated[l] + k) / (self.n_updated[l] + 1)
            self.buffer_v[l] = (self.buffer_v[l] * self.n_updated[l] + v) / (self.n_updated[l] + 1)
            self.buffer_f[l] = (self.buffer_f[l] * self.n_updated[l] + f) / (self.n_updated[l] + 1)
            # self.buffer_aqk[l] = (self.buffer_aqk[l] * self.n_updated[l] + a) / (self.n_updated[l] + 1)
            self.buffer_l[l] = l
            
            self.n_updated[l] += 1
            self.n_seen_so_far += 1
            self.n_added_so_far += 1
    
    def random_retrieve(self, n_imgs=100):
        if self.n_added_so_far == 0:
            return (self.buffer_q[:self.n_added_so_far],
                    self.buffer_k[:self.n_added_so_far],
                    self.buffer_v[:self.n_added_so_far],
                    self.buffer_l[:self.n_added_so_far],
                    self.buffer_f[:self.n_added_so_far],
                    self.n_updated[:self.n_added_so_far]
                    )
        indices = self.n_updated.nonzero().view(-1)
        indices = indices[torch.randperm(indices.size(0))[:n_imgs]]
        ret_imgs = self.buffer_q[indices]
        ret_k = self.buffer_k[indices]
        ret_v = self.buffer_v[indices]
        ret_labs = self.buffer_l[indices]
        ret_feats = self.buffer_f[indices]
        ret_up = self.n_updated[indices]
        return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device), ret_up.to(self.device)
    
    def random_retrieve_f(self, n_imgs=100):
        if self.n_added_so_far == 0:
            return (self.buffer_l[:self.n_added_so_far],
                    self.buffer_f[:self.n_added_so_far]
                    )
        indices = self.n_updated.nonzero().view(-1)
        # indices = indices[torch.randperm(indices.size(0))[:n_imgs]]
        ret_labs = self.buffer_l[indices]
        ret_feats = self.buffer_f[indices]
        return ret_labs.to(self.device), ret_feats.to(self.device)

    def random_retrieve_aqk(self, n_imgs=100):
        if self.n_added_so_far == 0:
            return (self.buffer_q[:self.n_added_so_far],
                    self.buffer_k[:self.n_added_so_far],
                    self.buffer_v[:self.n_added_so_far],
                    self.buffer_aqk[:self.n_added_so_far],
                    self.buffer_l[:self.n_added_so_far],
                    self.buffer_f[:self.n_added_so_far],
                    self.n_updated[:self.n_added_so_far]
                    )
        indices = self.n_updated.nonzero().view(-1)
        indices = indices[torch.randperm(indices.size(0))[:n_imgs]]
        ret_imgs = self.buffer_q[indices]
        ret_k = self.buffer_k[indices]
        ret_v = self.buffer_v[indices]
        ret_aqk = self.buffer_aqk[indices]
        ret_labs = self.buffer_l[indices]
        ret_feats = self.buffer_f[indices]
        ret_up = self.n_updated[indices]
        return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_aqk.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device), ret_up.to(self.device)

    def random_retrieve2(self, n_imgs=100):
        if self.n_added_so_far == 0:
            return (self.buffer_q[:self.n_added_so_far],
                    self.buffer_k[:self.n_added_so_far],
                    self.buffer_v[:self.n_added_so_far],
                    self.buffer_l[:self.n_added_so_far],
                    self.buffer_f[:self.n_added_so_far],
                    self.n_updated[:self.n_added_so_far]
                    )
        indices = self.n_updated.nonzero().view(-1)
        indices = indices[torch.randperm(indices.size(0))[:n_imgs]]
        ret_imgs = self.buffer_q[indices]
        ret_k = self.buffer_k[indices]
        ret_v = self.buffer_v[indices]
        ret_labs = self.buffer_l[indices]
        ret_feats = self.buffer_f[indices]
        ret_up = self.n_updated[indices]
        return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device), ret_up.to(self.device)

    def q_retrieve(self, n_imgs=10, queries=None, largest=True):
        if self.n_added_so_far > 1:
            indices = self.n_updated.nonzero().view(-1)
            all_queries = self.buffer_q[indices].to(self.device)
            sims = all_queries @ queries.T
            # _, out_indexes = sims.min(1)[0].topk(min(n_imgs, self.n_added_so_far), largest=False, sorted=True)
            _, out_indexes = sims.mean(1).topk(min(n_imgs, len(indices)), largest=largest, sorted=True)
            # _, out_indexes = sims.max(1)[0].topk(n_imgs, largest=True, sorted=True)
            out_indexes = out_indexes.cpu()
            indices = indices[out_indexes]
            ret_imgs = self.buffer_q[indices]
            ret_k = self.buffer_k[indices]
            ret_v = self.buffer_v[indices]
            ret_labs = self.buffer_l[indices]
            ret_feats = self.buffer_f[indices]
            return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device), _
        else:
            return self.random_retrieve(n_imgs)
    
    def only_retrieve(self, n_imgs, desired_labels):
        """Retrieve images belonging only to the set of desired labels

        Args:
            n_imgs (int):                    Number of images to retrieve 
            desired_labels (torch.Tensor): tensor of desired labels to retrieve from
        """
        desired_labels = torch.tensor(desired_labels)

        valid_indexes = torch.isin(self.buffer_l, desired_labels).nonzero().view(-1)
        n_out = min(n_imgs, len(valid_indexes))
        out_indexes = np.random.choice(valid_indexes, n_out)
        
        ret_imgs = self.buffer_q[out_indexes]
        ret_k = self.buffer_k[out_indexes]
        ret_v = self.buffer_v[out_indexes]
        ret_labs = self.buffer_l[out_indexes]
        ret_feats = self.buffer_f[out_indexes]
        return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device)  

class Reservoir(Buffer):
    def __init__(self, max_size=200, img_size=32, nb_ch=3, n_classes=10, **kwargs):
        """Reservoir sampling for memory update
        Args:
            max_size (int, optional): maximum buffer size. Defaults to 200.
            img_size (int, optional): Image width/height. Images are considered square. Defaults to 32.
            nb_ch (int, optional): Number of image channels. Defaults to 3.
            n_classes (int, optional): Number of classes expected total. For print purposes only. Defaults to 10.
        """
        self.shape = kwargs.get('shape', None)
        super().__init__(
            max_size,
            shape=self.shape if self.shape is not None else (nb_ch, img_size, img_size),
            n_classes=n_classes,
            device=kwargs.get('device', 'cuda')
            )
        self.img_size = img_size
        self.nb_ch = nb_ch
        self.drop_method = kwargs.get('drop_method', 'random')

    def reset(self):
        """Resets n_seen_so_far counter to that the reservoir starts storing all incming data
        """
        self.n_seen_so_far = 0
        
    def update(self, imgs, labels, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            imgs (torch.tensor): stream images seen by the buffer
            labels (torch.tensor): stream labels seen by the buffer
        Raises:
            NotImplementedError: NotImplementedError
        """
        for stream_img, stream_label in zip(imgs, labels):
            reservoir_idx = int(r.random() * (self.n_seen_so_far + 1))
            if self.n_seen_so_far < self.max_size:
                reservoir_idx = self.n_added_so_far
            if reservoir_idx < self.max_size:
                if self.drop_method == 'random':
                    self.replace_data(reservoir_idx, stream_img, stream_label)
                else:
                    raise NotImplementedError("Only random update is implemented here.")
            self.n_seen_so_far += 1
            
class QueryReservoir(Buffer):
    def __init__(self, max_size=200, shape=32, device="cuda", **kwargs):
        """Reservoir sampling with images + logits for derpp.
        """
        super().__init__()
        self.max_size = max_size
        self.shape = shape
        self.n_seen_so_far = 0
        self.n_added_so_far = 0
        self.device = device
        self.register_buffer('buffer_imgs', torch.FloatTensor(self.max_size, self.shape[0]).fill_(0))
        self.register_buffer('buffer_labels', torch.LongTensor(self.max_size).fill_(-1))
        self.register_buffer('buffer_features', torch.FloatTensor(self.max_size, self.shape[0]).fill_(0))
        self.register_buffer('buffer_k', torch.FloatTensor(self.max_size, 20, self.shape[0]).fill_(0))
        self.register_buffer('buffer_v', torch.FloatTensor(self.max_size, 20, self.shape[0]).fill_(0))
        self.drop_method = kwargs.get('drop_method', 'random')

    def update(self, queries, keys, values, labels, features, **kwargs):
        """Update buffer with the given list of images and labels.
            Note that labels are not used update selection, only when storing the image in memory.
        Args:
            queries (torch.tensor): stream queries seen by the buffer
        """
        for q, k, v, l, f in zip(queries, keys, values, labels, features):
            reservoir_idx = int(r.random() * (self.n_seen_so_far + 1))
            if self.n_seen_so_far < self.max_size:
                reservoir_idx = self.n_added_so_far
            if reservoir_idx < self.max_size:
                if kwargs.get('query_update', False):
                    self.query_replace(reservoir_idx, q, k, v, l, f)
                else:
                    self.replace_data(reservoir_idx, q, k, v, l, f)
            self.n_seen_so_far += 1
    
    def replace_data(self, idx, q, k, v, l, f):
        self.buffer_imgs[idx] = q
        self.buffer_k[idx] = k
        self.buffer_v[idx] = v
        self.buffer_labels[idx] = l
        self.buffer_features[idx] = f
        self.n_added_so_far += 1
    
    def query_replace(self, idx, q, k, v, l, f):
        if not self.is_full():  
            self.buffer_imgs[idx] = q
            self.buffer_k[idx] = k
            self.buffer_v[idx] = v
            self.buffer_labels[idx] = l
            self.buffer_features[idx] = f
        else:
            all_queries = self.buffer_imgs[:self.n_added_so_far].to(self.device)
            sims = all_queries @ q.T
            idx = sims.argmin()
            self.buffer_imgs[idx] = q
            self.buffer_k[idx] = k
            self.buffer_v[idx] = v
            self.buffer_labels[idx] = l
            self.buffer_features[idx] = f
        self.n_added_so_far += 1
        

    def random_retrieve(self, n_imgs=100):
        if self.n_added_so_far < n_imgs:
            return (self.buffer_imgs[:self.n_added_so_far].to(self.device),
                    self.buffer_k[:self.n_added_so_far].to(self.device),
                    self.buffer_v[:self.n_added_so_far].to(self.device),
                    self.buffer_labels[:self.n_added_so_far].to(self.device),
                    self.buffer_features[:self.n_added_so_far].to(self.device),
                    )

        ret_indexes = r.sample(np.arange(min(self.n_added_so_far, self.max_size)).tolist(), n_imgs)
        ret_imgs = self.buffer_imgs[ret_indexes]
        ret_k = self.buffer_k[ret_indexes]
        ret_v = self.buffer_v[ret_indexes]
        ret_labs = self.buffer_labels[ret_indexes]
        ret_feats = self.buffer_features[ret_indexes]
        return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device)

    def q_retrieve(self, n_imgs=10, queries=None):
        if self.n_added_so_far > 0:
            all_queries = self.buffer_imgs[:self.n_added_so_far].to(self.device)
            sims = all_queries.normalize(-1) @ queries.normalize(-1).T
            # _, out_indexes = sims.min(1)[0].topk(min(n_imgs, self.n_added_so_far), largest=False, sorted=True)
            _, out_indexes = sims.mean(1).topk(min(n_imgs, self.n_added_so_far), largest=False, sorted=True)
            # _, out_indexes = sims.max(1)[0].topk(n_imgs, largest=True, sorted=True)
            out_indexes = out_indexes.cpu()
            ret_imgs = self.buffer_imgs[out_indexes]
            ret_k = self.buffer_k[out_indexes]
            ret_v = self.buffer_v[out_indexes]
            ret_labs = self.buffer_labels[out_indexes]
            ret_feats = self.buffer_features[out_indexes]
            return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device)      
        else:
            return self.random_retrieve(n_imgs)
    
    def only_retrieve(self, n_imgs, desired_labels):
        """Retrieve images belonging only to the set of desired labels

        Args:
            n_imgs (int):                    Number of images to retrieve 
            desired_labels (torch.Tensor): tensor of desired labels to retrieve from
        """
        desired_labels = torch.tensor(desired_labels)

        valid_indexes = torch.isin(self.buffer_labels, desired_labels).nonzero().view(-1)
        n_out = min(n_imgs, len(valid_indexes))
        out_indexes = np.random.choice(valid_indexes, n_out)
        
        ret_imgs = self.buffer_imgs[out_indexes]
        ret_k = self.buffer_k[out_indexes]
        ret_v = self.buffer_v[out_indexes]
        ret_labs = self.buffer_labels[out_indexes]
        ret_feats = self.buffer_features[out_indexes]
        return ret_imgs.to(self.device), ret_k.to(self.device), ret_v.to(self.device), ret_labs.to(self.device), ret_feats.to(self.device)  