from metrics import c_mAP, e_recall, nmi, f1, mAP, mAP_c
from metrics import dists, rho_spectrum, rbf
from metrics import lineval_roc
from metrics import c_recall, c_nmi, c_f1, c_mAP, c_mAP_c
import numpy as np
import faiss
import torch
from sklearn.preprocessing import normalize
import itertools
from tqdm import tqdm
import warnings
import re
import copy

class RegexMap(object):
    def __init__(self, *args, **kwargs):
        self._items = dict(*args, **kwargs)
    def __getitem__(self, key):
        for regex in self._items.keys():
            print(regex)
            if re.match(regex, key):
                return self._items[regex]
        raise KeyError

METRIC_DICT = RegexMap(**{
    r"e_recall@\d+": e_recall,
    "nmi": nmi,
    "mAP_c": mAP_c,
    r"mAP(?:@\d+)?$": mAP,
    "f1": f1,
    r"c_recall@\d+": c_recall,
    "c_nmi": c_nmi,
    "c_mAP_c": c_mAP_c,
    r"c_mAP(?:@\d+)?$": c_mAP,
    "c_f1": c_f1,
    r"rbf@\d+": rbf,
    r"dists@.*": dists,
    r"rho_spectrum@[+-]\d+": rho_spectrum,
    r"lineval_roc(?:\_.*)?$": lineval_roc
})

def select(metricname, opt, **kwargs):
    #### Metrics based on euclidean distances
    if 'e_recall' in metricname:
        k = int(metricname.split('@')[-1])
        return e_recall.Metric(k, **kwargs)
    elif metricname=='nmi':
        return nmi.Metric(**kwargs)
    elif metricname=='mAP_c':
        return mAP_c.Metric(**kwargs)
    elif metricname.startswith("mAP"):
        R = int(metricname.split('@')[-1]) if '@' in metricname else None
        return mAP.Metric(R, **kwargs)
    elif metricname=='f1':
        return f1.Metric(**kwargs)

    #### Metrics based on cosine similarity
    if 'c_recall' in metricname:
        k = int(metricname.split('@')[-1])
        return c_recall.Metric(k, **kwargs)
    elif metricname=='c_nmi':
        return c_nmi.Metric(**kwargs)
    elif metricname=='c_mAP_c':
        return c_mAP_c.Metric(**kwargs)
    elif 'c_mAP' in metricname:
        R = int(metricname.split('@')[-1]) if '@' in metricname else None
        return c_mAP.Metric(R, **kwargs)
    elif metricname=='c_f1':
        return c_f1.Metric(**kwargs)

    #### Generic Embedding space metrics
    elif 'rbf' in metricname:
        length_scale = float(metricname.split('@')[-1])
        return rbf.Metric(length_scale, **kwargs)
    elif 'dists' in metricname:
        mode = metricname.split('@')[-1]
        return dists.Metric(mode, **kwargs)
    elif 'rho_spectrum' in metricname:
        mode = int(metricname.split('@')[-1])
        embed_dim = opt.rho_spectrum_embed_dim
        return rho_spectrum.Metric(embed_dim, mode=mode, opt=opt, **kwargs)
    elif 'lineval_roc' in metricname:
        if len(metricname.split('_')) > 2:
            return lineval_roc.Metric(metricname.split('_')[-1], **kwargs)
        else:
            return lineval_roc.Metric('mean', **kwargs)
    else:
        raise NotImplementedError("Metric {} not available!".format(metricname))



class MetricComputer():
    def __init__(self, metric_names, opt):
        self.pars            = opt
        self.metric_names    = metric_names
        self.options         = {metric_name: [dict(zip(METRIC_DICT[metric_name].OPTIONS.keys(), c)) for c in itertools.product(*METRIC_DICT[metric_name].OPTIONS.values())] \
                                    if METRIC_DICT[metric_name].HAS_OPTIONS else [{}] for metric_name in metric_names} if opt.exclusive else dict.fromkeys(metric_names, [{}])
        self.list_of_metrics = [[select(metricname, opt, exclusive=opt.exclusive, **options) for options in self.options[metricname]] for metricname in metric_names]
        self.list_of_metrics = list(itertools.chain.from_iterable(self.list_of_metrics))
        self.requires        = [metric.requires for metric in self.list_of_metrics]
        self.requires        = list(set([x for y in self.requires for x in y]))

    def compute_standard(self, opt, model, dataloader, evaltypes, device, **kwargs):
        evaltypes = copy.deepcopy(evaltypes)

        n_classes = opt.n_classes
        image_paths     = np.array([x[0] for x in dataloader.dataset.image_list])
        _ = model.eval()

        ###
        feature_colls  = {key:[] for key in evaltypes}

        ###
        with torch.no_grad():
            target_labels = []
            final_iter = tqdm(dataloader, desc='Embedding Data...'.format(len(evaltypes)))
            image_paths= [x[0] for x in dataloader.dataset.image_list]
            for idx,inp in enumerate(final_iter):
                target = inp['labels']
                target_labels.extend(target.numpy().tolist())
                out = model(**{i: inp[i].to(device) for i in inp})
                if isinstance(out, tuple): out, aux_f = out

                ### Include embeddings of all output features
                for evaltype in evaltypes:
                    if isinstance(out, dict):
                        feature_colls[evaltype].extend(out[evaltype].cpu().detach().numpy().tolist())
                    else:
                        feature_colls[evaltype].extend(out.cpu().detach().numpy().tolist())


            target_labels = np.vstack(target_labels)


        computed_metrics = {evaltype:{} for evaltype in evaltypes}
        extra_infos      = {evaltype:{} for evaltype in evaltypes}


        ###
        faiss.omp_set_num_threads(self.pars.kernels)
        # faiss.omp_set_num_threads(self.pars.kernels)
        res = None
        torch.cuda.empty_cache()
        if self.pars.evaluate_on_gpu:
            res = faiss.StandardGpuResources()

        ###
        n_classes = min(n_classes, len(np.unique(target_labels))) if opt.exclusive else min(n_classes, target_labels.shape[1])

        import time
        for evaltype in evaltypes:
            features        = np.vstack(feature_colls[evaltype]).astype('float32')
            features_cosine = normalize(features, axis=1)

            start = time.time()

            """============ Compute k-Means ==============="""
            if 'kmeans' in self.requires or 'kmeans_nearest' in self.requires:
                ### Set CPU Cluster index
                cluster_idx = faiss.IndexFlatL2(features.shape[-1])
                if res is not None: cluster_idx = faiss.index_cpu_to_gpu(res, 0, cluster_idx)
                kmeans            = faiss.Clustering(features.shape[-1], n_classes)
                kmeans.niter = 20
                kmeans.min_points_per_centroid = 1
                kmeans.max_points_per_centroid = 1000000000
                ### Train Kmeans
                kmeans.train(features, cluster_idx)
                centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, features.shape[-1])

            if 'kmeans_cosine' in self.requires or 'kmeans_nearest_cosine' in self.requires:
                ### Set CPU Cluster index
                cluster_idx = faiss.IndexFlatL2(features_cosine.shape[-1])
                if res is not None: cluster_idx = faiss.index_cpu_to_gpu(res, 0, cluster_idx)
                kmeans            = faiss.Clustering(features_cosine.shape[-1], n_classes)
                kmeans.niter = 20
                kmeans.min_points_per_centroid = 1
                kmeans.max_points_per_centroid = 1000000000
                ### Train Kmeans
                kmeans.train(features_cosine, cluster_idx)
                centroids_cosine = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, features_cosine.shape[-1])
                centroids_cosine = normalize(centroids,axis=1)


            """============ Compute Cluster Labels ==============="""
            if 'kmeans_nearest' in self.requires:
                faiss_search_index = faiss.IndexFlatL2(centroids.shape[-1])
                if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
                faiss_search_index.add(centroids)
                _, computed_cluster_labels = faiss_search_index.search(features, 1)

            if 'kmeans_nearest_cosine' in self.requires:
                faiss_search_index = faiss.IndexFlatIP(centroids_cosine.shape[-1])
                if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
                faiss_search_index.add(centroids_cosine)
                _, computed_cluster_labels_cosine = faiss_search_index.search(features_cosine, 1)



            """============ Compute Nearest Neighbours ==============="""
            if 'nearest_features' in self.requires:
                faiss_search_index  = faiss.IndexFlatL2(features.shape[-1])
                if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
                faiss_search_index.add(features)

                max_kval            = np.max([int(x.split('@')[-1]) for x in self.metric_names if 'recall' in x])
                _, k_closest_points = faiss_search_index.search(features, int(max_kval+1))
                k_closest_classes   = target_labels[k_closest_points[:,1:]]
                if opt.exclusive: k_closest_classes = np.squeeze(k_closest_classes, axis=-1)

            if 'nearest_features_cosine' in self.requires:
                faiss_search_index  = faiss.IndexFlatIP(features_cosine.shape[-1])
                if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
                faiss_search_index.add(normalize(features_cosine,axis=1))

                max_kval                   = np.max([int(x.split('@')[-1]) for x in self.metric_names if 'recall' in x])
                _, k_closest_points_cosine = faiss_search_index.search(normalize(features_cosine,axis=1), int(max_kval+1))
                k_closest_classes_cosine   = target_labels[k_closest_points_cosine[:,1:]]
                if opt.exclusive: k_closest_classes_cosine = np.squeeze(k_closest_classes_cosine, axis=-1)

            ###
            if self.pars.evaluate_on_gpu:
                features        = torch.from_numpy(features).to(self.pars.device)
                features_cosine = torch.from_numpy(features_cosine).to(self.pars.device)

            start = time.time()
            for metric in self.list_of_metrics:
                input_dict = {}
                if 'features' in metric.requires:         input_dict['features'] = features
                if 'target_labels' in metric.requires:    input_dict['target_labels'] = target_labels

                if 'kmeans' in metric.requires:           input_dict['centroids'] = centroids
                if 'kmeans_nearest' in metric.requires:   input_dict['computed_cluster_labels'] = computed_cluster_labels
                if 'nearest_features' in metric.requires: input_dict['k_closest_classes'] = k_closest_classes

                if 'features_cosine' in metric.requires:         input_dict['features_cosine'] = features_cosine

                if 'kmeans_cosine' in metric.requires:           input_dict['centroids_cosine'] = centroids_cosine
                if 'kmeans_nearest_cosine' in metric.requires:   input_dict['computed_cluster_labels_cosine'] = computed_cluster_labels_cosine
                if 'nearest_features_cosine' in metric.requires: input_dict['k_closest_classes_cosine'] = k_closest_classes_cosine
                
                computed_metrics[evaltype][metric.name] = metric(**input_dict)

            extra_infos[evaltype] = dict({'features':features, 'target_labels':target_labels,
                                     'image_paths': dataloader.dataset.image_paths,
                                     'query_image_paths':None, 'gallery_image_paths':None}, **extra_infos[evaltype])

        torch.cuda.empty_cache()
        return computed_metrics, extra_infos