import torch
import numpy as np
import faiss

"""================================================================================================="""
HAS_OPTIONS = False
OPTIONS = None

class Metric():
    def __init__(self, R=None, **kwargs):
        self.R = R
        self.exclusive = kwargs.get("exclusive", True)
        self.requires = ['features_cosine', 'target_labels']
        if not self.R:
            self.name = 'c_mAP'
        else:
            self.name     = 'c_mAP@{}'.format(self.R)
        
        if not self.exclusive: self.name += "(multi-label)"

    def __call__(self, target_labels, features_cosine):
        if self.exclusive:
            labels, freqs = np.unique(target_labels, return_counts=True)
        else:
            labels, freqs = np.arange(target_labels.shape[0]), np.sum(target_labels, axis=0)
        
        ## Account for faiss-limit at k=1023
        if self.R:
            R             = self.R
        else:
            R             = min(1023,len(features_cosine))

        faiss_search_index  = faiss.IndexFlatIP(features_cosine.shape[-1])
        if isinstance(features_cosine, torch.Tensor):
            features_cosine = features_cosine.detach().cpu().numpy()
            res = faiss.StandardGpuResources()
            faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
        faiss_search_index.add(features_cosine)
        nearest_neighbours  = faiss_search_index.search(features_cosine, int(R+1))[1][:,1:]

        if self.exclusive:
            target_labels = target_labels.reshape(-1)
        
        nn_labels = target_labels[nearest_neighbours, ...]

        avg_r_precisions = []
        for label, freq in zip(labels, freqs):
            if self.exclusive:
                rows_with_label = np.where(target_labels==label)[0]
            else:
                rows_with_label = np.where(target_labels[:, label])[0]

            for row in rows_with_label:
                n_recalled_samples           = np.arange(1,R+1)
                
                if self.exclusive:
                    target_label_occ_in_row      = nn_labels[row,:]==label
                else:
                    target_label_occ_in_row      = nn_labels[row,:,label]

                cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row)
                avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq
                avg_r_precisions.append(avg_r_pr_row)

        return np.mean(avg_r_precisions)
