# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import time

import faiss
import numpy as np
from scipy.sparse import csr_matrix, find
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import logging 
import pdb
import torch.nn.functional as F

import matplotlib.pyplot as plt

class Kmeans(object):
    def __init__(self, nmb_clusters, min_points_per_cluster = 0, max_points_per_cluster = 10000000):
        self.nmb_clusters = nmb_clusters
        self.min_points_per_cluster = min_points_per_cluster
        self.max_points_per_cluster = max_points_per_cluster

    def cluster(self, data, remove_outliers = False, outlier_threshold = 3):
        """Performs k-means clustering.
            Args:
                x_data (np.array N * dim): data to cluster
        """
        # PCA-reducing, whitening and L2-normalization
#         xb = self.preprocess_features(data)
        data_norm = F.normalize(data, dim = 1)
        return self.run_kmeans(data_norm, remove_outliers, outlier_threshold)
    
    def preprocess_features(self, npdata, pca=256):
        """Preprocess an array of features.
        Args:
            npdata (np.array N * ndim): features to preprocess
            pca (int): dim of output
        Returns:
            np.array of dim N * pca: data PCA-reduced, whitened and L2-normalized
        """
        _, ndim = npdata.shape
        npdata =  npdata.astype('float32')

        # Apply PCA-whitening with Faiss
        mat = faiss.PCAMatrix(ndim, pca, eigen_power=-0.5)
        mat.train(npdata)
        assert mat.is_trained
        npdata = mat.apply_py(npdata)

        # L2 normalization
        row_sums = np.linalg.norm(npdata, axis=1)
        npdata = npdata / row_sums[:, np.newaxis]

        return npdata

    def run_kmeans(self, x, remove_outliers = False, outlier_threshold = 3):
        """Runs kmeans on 1 GPU.
        Args:
            x: data
            nmb_clusters (int): number of clusters
        Returns:
            list: ids of data in each cluster
        """
        n_data, d = x.shape
        
        x = x.numpy()
        torch.cuda.empty_cache()
        
        # faiss implementation of k-means
        clus = faiss.Clustering(d, self.nmb_clusters)

        # Change faiss seed at each k-means so that the randomly picked
        # initialization centroids do not correspond to the same feature ids
        # from an epoch to another.
#         clus.seed = np.random.randint(3)
        clus.seed = 0

        clus.niter = 20
        clus.verbose = True
        clus.max_points_per_centroid = 10000000
#         clus.min_points_per_centroid = 1
        res = faiss.StandardGpuResources()
        flat_config = faiss.GpuIndexFlatConfig()
        flat_config.useFloat16 = False
        flat_config.device = 0
        index = faiss.GpuIndexFlatL2(res, d, flat_config)

        # perform the training
        clus.train(x, index)
        D, I = index.search(x, 1)
        
        select_idx = np.arange(n_data)
        if remove_outliers:
#             select_idx = np.array([], dtype = 'int')
            
            x_to_centroid = np.zeros((n_data, self.nmb_clusters))
            for clus_id in range(self.nmb_clusters):
#                 select_idx = np.concatenate((select_idx, np.where(D[I == clus_id] < outlier_threshold * D[I == clus_id].std())[0]), axis = 0)
                x_to_centroid[:, clus_id] = np.linalg.norm(x - faiss.vector_to_array(clus.centroids).reshape(self.nmb_clusters, d)[clus_id], axis = 1) ** 2
            
            # solving for 2 cluster case
            select_idx = np.where( np.maximum(x_to_centroid[:,0] / x_to_centroid[:,1],  x_to_centroid[:,1] / x_to_centroid[:,0]) > 1 + outlier_threshold)[0]
    
        select_idx = torch.LongTensor(select_idx).cuda(0)
        
        d2cluster = torch.Tensor(D).cuda(0)
        im2cluster = torch.LongTensor([int(n[0]) for n in I]).cuda(0)
        centroids = torch.Tensor(faiss.vector_to_array(clus.centroids).reshape(self.nmb_clusters, d)).cuda(0)
        
        cluster_result = {'im2cluster': im2cluster, 'centroids': centroids, 'd2cluster': d2cluster, 'select_idx': select_idx}
        return cluster_result


    def arrange_clustering(self, images_lists):
        pseudolabels = []
        image_indexes = []
        for cluster, images in enumerate(images_lists):
            image_indexes.extend(images)
            pseudolabels.extend([cluster] * len(images))
        indexes = np.argsort(image_indexes)
        return np.asarray(pseudolabels)[indexes]

    
class ClusterDataset(data.Dataset):
    def __init__(self, image_indexes, base_dataset):
        self.imgs = self.make_dataset(image_indexes, base_dataset)

    def make_dataset(self, image_indexes, base_dataset):
        images = []
        for j, idx in enumerate(image_indexes):
            images.append(base_dataset[idx])
        return images

    def __getitem__(self, index):
        return self.imgs[index]

    def __len__(self):
        return len(self.imgs)

