import torch


def min_max_normalization(x):
    x_min = torch.min(x)
    x_max = torch.max(x)
    norm = (x - x_min) / (x_max - x_min)
    norm = torch.clamp(norm, 0, 1)
    return norm


def mle_batch_pt(data, batch, k=20):
    data = data.view(data.shape[0], -1)
    batch = batch.view(batch.shape[0], -1)
    r = torch.cdist(data, batch, p=2)
    k = min(k, batch.shape[0]-1)
    lids = []
    for i in range(data.shape[0]):
        a = torch.topk(r[i], k=k+1, dim=0, largest=False)[0][1:]
        lid = -k / torch.sum(torch.log(a/a[-1]))
        lids.append(lid)
    return torch.stack(lids)


class LIDAnalysis():
    def __init__(self, batch_size=200, k=128):
        self.batch_size = batch_size
        self.k = k
        return

    def analysis(self, features):
        """
            data (torch.tensor) b,c: data is the extracted feature from the model
        """
        lids = []
        for i in range(0, features.shape[0], self.batch_size):
            fe = features[i:i+self.batch_size]
            lid = mle_batch_pt(fe, fe, k=self.k)
            lids.append(lid)

        lids = torch.cat(lids, dim=0)
        return 1 - min_max_normalization(lids)
