import torch
class Memory_Based_Certification(object):
    def __init__(self):
        self.saved_radii = []
        self.saved_images = []
        self.saved_predictions = []
        
        
    def _internal_adjustment(self, img, rad, pre):
        diff = torch.norm(img.reshape(1, -1) - torch.stack(self.saved_images).reshape(len(self.saved_radii), -1), dim=1)
            
        where_overlap =  diff < (torch.tensor(self.saved_radii) + rad)
        #Check whether this image is with overlap with any other instances
        if where_overlap.any():
            preds_overlap = self.saved_predictions[where_overlap]
            where_overlap_diff_class = preds_overlap != pre
            
            #Check whether this image is with overlap with instances with different prediction
            if where_overlap_diff_class.any():
                #Get the radii, differences where the overlap
                saved_radii_with_overlap = self.saved_radii[where_overlap]
                dif_with_overlap = diff[where_overlap]

                preds_overlap_with_diff_class = preds_overlap[where_overlap_diff_class]
                rad_with_overlap_diff_class = saved_radii_with_overlap[where_overlap_diff_class]
                dif_with_overlap_diff_class = dif_with_overlap[where_overlap_diff_class]

                rad, rad_idx = torch.min(dif_with_overlap_diff_class - rad_with_overlap_diff_class)
                
                if rad.item() < 0:
                    pre = preds_overlap_with_diff_class[rad_idx]

                rad = torch.abs(rad).item()
        return rad, pre
        
    def adjust_radius(self, img, rad, pre):
        #The img already exists in the saved dictionary
        if img in self.saved_images:
            idx = self.saved_images == img
            return self.saved_radii[idx], self.saved_predictions[idx]

        if self.saved_radii != []: #Saved dictionaries are not empty
            rad, pre = self._internal_adjustment(img, rad, pre)
            
        self.saved_radii.append(rad)
        self.saved_images.append(img)
        self.saved_predictions.append(pre)
        return rad, pre
        