import numpy as np
import torch
import torch.nn as nn
# from utils_nps import *

class NPSCalculator(nn.Module):
    """NMSCalculator: calculates the non-printability score of a patch.

    Module providing the functionality necessary to calculate the non-printability score (NMS) of an adversarial patch.

    """

    def __init__(self, printability_file, patch_size):
        super(NPSCalculator, self).__init__()
        self.printability_array = nn.Parameter(self.get_printability_array(printability_file, patch_size),requires_grad=False)
        # print("something")

    def forward(self, adv_patch):
        # calculate euclidian distance between colors in patch and colors in printability_array 
        # square root of sum of squared difference
                    
        nps = 0
        for i in range(adv_patch.size(0)):
        
            color_dist = adv_patch[i] - self.printability_array + 0.000001
            color_dist = color_dist ** 2
            color_dist = torch.sum(color_dist, 1) + 0.000001
            color_dist = torch.sqrt(color_dist)
            # only work with the min distance
            color_dist_prod = torch.min(color_dist, 0)[0] #test: change prod for min (find distance to closest color)
            # calculate the nps by summing over all pixels
            nps_score = torch.sum(color_dist_prod,0)
            nps_score = torch.sum(nps_score, 0)
  
            nps += nps_score
        
        return nps/torch.numel(adv_patch)

    def get_printability_array(self, printability_file, size):
        printability_list = []

        # read in printability triplets and put them in a list
        with open(printability_file) as f:
            for line in f:
                printability_list.append(line.split(","))

        printability_array = []
        for printability_triplet in printability_list:
            printability_imgs = []
            red, green, blue = printability_triplet
            printability_imgs.append(np.full((size[0], size[1]), red))
            printability_imgs.append(np.full((size[0], size[1]), green))
            printability_imgs.append(np.full((size[0], size[1]), blue))
            printability_array.append(printability_imgs)

        printability_array = np.asarray(printability_array)
        printability_array = np.float32(printability_array)
        pa = torch.from_numpy(printability_array)
        
        return pa
