import time
from tqdm import tqdm
import numpy as np
import torch
from torch.nn import functional as F
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

def one_hot(labels, num_classes):
    return torch.zeros(labels.size(0), num_classes).to(labels.device).scatter_(1, labels.unsqueeze(1), 1)

class DIST(object):
    def __init__(self, dist_type='cos'):
        self.dist_type = dist_type 

    def __call__(self, pointA, pointB, cross=False):
        if self.dist_type == 'cos':
            return self.cos(pointA, pointB, cross)
        elif self.dist_type == 'euclidean':
            return self.euclidean(pointA, pointB, cross)

    def euclidean(self, pointA, pointB, cross=False):
        if not cross:
            return torch.sqrt(torch.sum((pointA - pointB)**2, dim=1)) / pointA.size(1)
        else:
            AA = torch.sum(pointA*pointA, dim=1).unsqueeze(1)
            BB = torch.sum(pointB*pointB, dim=1).unsqueeze(0)
            AB = torch.matmul(pointA, pointB.transpose(0,1))
            dist = AA+BB-2*AB
            dist[dist<1e-5] = 0
            return torch.sqrt(dist) / pointA.size(1)

    def cos(self, pointA, pointB, cross=False):
        pointA = F.normalize(pointA, dim=1)
        pointB = F.normalize(pointB, dim=1)
        if not cross:
            return 0.5 * (1.0 - torch.sum(pointA * pointB, dim=1))
        else:
            NA = pointA.size(0)
            NB = pointB.size(0)
            assert(pointA.size(1) == pointB.size(1))
            return 0.5 * (1.0 - torch.matmul(pointA, pointB.transpose(0, 1)))

def class_centroids(dataset, num_classes):
    # compute source data centroids
    features, outputs, labels = dataset
    n, d = features.shape
    centroids = [torch.ones_like(features[0]) for _ in range(num_classes)]
    for i in range(num_classes):
        class_features = torch.masked_select(features, (labels==i).view(n, 1).expand_as(features)).view(-1, d)
        if class_features.size(0) >= 1:
            centroids[i] = class_features.mean(dim=0)
    return torch.stack(centroids, dim=0)

def kmeans_clustering(dataset, init_centroids, dist_type='euclidean', realign=True, random_init=False):
    features, outputs, _ = dataset
    num_classes = len(init_centroids)
    Dist = DIST(dist_type)
    if random_init:
        kmeans = KMeans(n_clusters=num_classes, init="k-means++", n_init=10, 
                        max_iter=300, tol=1e-4, verbose=False).fit(features.numpy())
    else:
        kmeans = KMeans(n_clusters=num_classes, init=init_centroids.numpy(), n_init=1, 
                        max_iter=300, tol=1e-4, verbose=False).fit(features.numpy())
    centroids = torch.from_numpy(kmeans.cluster_centers_)
    pseudo_labels = torch.from_numpy(kmeans.labels_)
    dists = Dist(features, centroids, cross=True)
    dist2center, pseudo_labels = torch.min(dists, dim=1)
    center_dist = torch.mean(dist2center)
    if realign:
        cost = Dist(centroids, init_centroids, cross=True)
        _, col_ind = linear_sum_assignment(cost.numpy())
        # reorder the centers
        centroids = centroids[col_ind, :]
        # re-label the data according to the index
        for k in range(features.size(0)):
            pseudo_labels[k] = col_ind[pseudo_labels[k]].item()
    center_change = Dist(centroids, init_centroids).mean()
    return centroids, pseudo_labels, center_dist, center_change

class Clustering(object):
    def __init__(self, eps=0.005, max_iter=300, dist_type='cos', soft_assign=False, prob_init=False, proto_dist=False):
        self.eps = eps
        self.max_iter = max_iter
        self.Dist = DIST(dist_type)
        self.soft_assign = soft_assign
        self.soft_T = 20
        self.prob_init = prob_init
        self.proto_dist = proto_dist

    def clustering_stop(self, centers):
        if centers is None:
            return False
        else:
            dist = self.Dist(centers, self.centers) 
            dist = torch.mean(dist, dim=0)
            return dist.item() < self.eps

    def assign_labels(self, feats):
        dists = self.Dist(feats, self.centers, cross=True)
        dist2center, labels = torch.min(dists, dim=1)
        if self.soft_assign:
            sim_mat = self.soft_T*(1-2*dists)
            labels_onehot = torch.softmax(sim_mat, dim=1)
        else:
            labels_onehot = one_hot(labels, self.num_classes)
        return dist2center, labels, labels_onehot

    def align_centers(self):
        cost = self.Dist(self.centers, self.init_centers, cross=True)
        _, col_ind = linear_sum_assignment(cost.numpy())
        return col_ind

    def feature_clustering(self, dataset, init_centers):
        # set_init_centers
        feature, output, _label = dataset
        if self.prob_init:
            prob = torch.softmax(output, dim=1)
            self.centers = torch.zeros_like(init_centers)
            for i in range(init_centers.size(0)):
                self.centers[i] = torch.sum(feature*prob[:, i].unsqueeze(1), dim=0) / torch.sum(prob[:, i])
            self.init_centers = init_centers
        else:
            self.centers = init_centers
            self.init_centers = init_centers
        self.num_classes = self.centers.size(0)
        centers = None 

        num_samples = feature.size(0)

        for _ in tqdm(range(self.max_iter), total=self.max_iter, desc="clustering target features"):
            _, labels, labels_onehot = self.assign_labels(feature)
            # update centers
            centers = torch.zeros_like(self.init_centers)
            count = torch.sum(labels_onehot, dim=0)

            for i in range(self.num_classes):
                if count[i].item() > 0:
                    center = torch.sum(feature*labels_onehot[:, i].unsqueeze(1), dim=0)
                    centers[i] = center / count[i]
                else:
                    centers[i] = self.init_centers[i]

            stop = self.clustering_stop(centers)
            self.centers = centers
            if stop: break
            
        dist2center, labels, labels_onehot = self.assign_labels(feature)
        count = torch.sum(labels_onehot, dim=0)

        cluster2label = self.align_centers()
        # reorder the centers
        self.centers = self.centers[cluster2label, :]
        # re-label the data according to the index
        for k in range(num_samples):
            labels[k] = cluster2label[labels[k]].item()

        center_cost = torch.mean(dist2center)

        if self.proto_dist:
            if self.prob_init:
                sim_mat = torch.matmul(self.init_centers, self.centers.T)
            else:
                sim_mat = self.soft_T*(1-2*self.Dist(self.init_centers, self.centers, cross=True))
            real_dist = F.softmax(sim_mat, dim=0)
            cost_mat = self.Dist(self.init_centers, self.centers, cross=True)
            center_change = (cost_mat*real_dist).sum(0).mean()
        else:
            center_change = torch.mean(self.Dist(self.centers, self.init_centers))

        return self.centers, labels, center_cost, center_change
