import torch
import random

class Net:
    def __init__(self, mem_size: int, shape: tuple):
        self.mem_size = mem_size # This is the memory size
        self.img_size = shape # Image size

        self.mem_x = torch.zeros((self.mem_size, *shape), device='cuda') # memory to save the input image
        self.mem_y = torch.zeros(self.mem_size).long() # memory to save the label

        self.mem_cursize = 0 # record the number of samples in memory
        self.n = 0 # record the number of samples that the model has seen
    
    def update(self, x, y):
        """
        (x, y): the new batch of samples
        """
        # ignore invalid reservoir_size
        if not self.mem_size > 0:
            return
        nbatch = len(x)
        if self.n < self.mem_size: # the memory is not full
            for i in range(nbatch):
                self.mem_x[self.mem_cursize] = x[i]
                self.mem_y[self.mem_cursize] = y[i]

                self.mem_cursize += 1
                self.n += 1
        else: # the memory ids full
            data, targets = self.compute_indices(x, y) # greedy algorithm
            self.mem_x, self.mem_y = data, targets # get the new memory
            self.n += nbatch
        return
    
    def compute_indices(self, x, y):
        """
        (x, y): the new batch of samples
        """
        candidate_sample = torch.cat([self.mem_x, x])
        candidate_imgs = torch.cat([self.mem_y, y])
        remain_indices = torch.arange(len(candidate_sample))

        # greedy algorithm with len(x) subprocesses
        for i in range(len(candidate_sample)-self.mem_size):
            per_class_sum = candidate_sample[remain_indices].sum(dim=0)
            score = per_class_sum - candidate_sample[remain_indices]
            sum_score = score.sum(dim=1, keepdim=True).float()
            per_sample_prob = score/sum_score
            per_sample_entropy = (per_sample_prob * torch.log((per_sample_prob == 0).float()+per_sample_prob)).sum(dim=1)
            min_num = per_sample_entropy.min()
            candi_index = (min_num == per_sample_entropy).nonzero().view(-1)
            select_indices = torch.randperm(len(candi_index))[0]
            select_indices = candi_index[select_indices]
            min_num == per_sample_entropy
            remain_indices = torch.cat([remain_indices[:select_indices], remain_indices[select_indices+1:]])
        remain_indices = remain_indices.contiguous()
        
        return candidate_imgs[remain_indices], candidate_sample[remain_indices]

    def _onehot_to_slab(self, onehot):
        return (onehot == 1).nonzero()[:, -1]

    def _multihot_to_idxlist(self, multihot):
        idcs = (multihot == 1).nonzero().flatten().tolist()
        return idcs

    def __len__(self):
        return self.mem_cursize

    def sample(self, num):
        idx = torch.tensor(random.sample(range(len(self)), num), dtype=torch.long)
        sample_x, sample_y = self.mem_x[idx], self.mem_y[idx]
        return sample_x, sample_y


