import torch
import numpy as np
import faiss

"""================================================================================================="""
HAS_OPTIONS = False
OPTIONS = None

class Metric():
    def __init__(self, **kwargs):
        self.requires = ['features', 'target_labels']
        self.exclusive = kwargs.get("exclusive", True)
        self.name     = 'mAP_c'

        if not self.exclusive: self.name += "(multi-label)"

    def __call__(self, target_labels, features):
        if self.exclusive:
            labels, freqs = np.unique(target_labels, return_counts=True)
        else:
            labels, freqs = np.arange(target_labels.shape[0]), np.sum(target_labels, axis=0)
        R             = np.max(freqs)

        faiss_search_index  = faiss.IndexFlatIP(features.shape[-1])
        if isinstance(features, torch.Tensor):
            features = features.detach().cpu().numpy()
            res = faiss.StandardGpuResources()
            faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
        faiss_search_index.add(features)
        nearest_neighbours  = faiss_search_index.search(features, int(R+1))[1][:,1:]

        if self.exclusive:
            target_labels = target_labels.reshape(-1)
        
        nn_labels = target_labels[nearest_neighbours, ...]

        avg_r_precisions = []
        for label, freq in zip(labels, freqs):
            if self.exclusive:
                rows_with_label = np.where(target_labels==label)[0]
            else:
                rows_with_label = np.where(target_labels[:,label])[0]
            
            for row in rows_with_label:
                n_recalled_samples           = np.arange(1,freq+1)
                if self.exclusive:
                    target_label_occ_in_row      = nn_labels[row,:freq]==label
                else:
                    target_label_occ_in_row      = nn_labels[row,:freq,label]
                cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row)
                avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq
                avg_r_precisions.append(avg_r_pr_row)

        return np.mean(avg_r_precisions)

