from sklearn.cluster import KMeans, Birch,  SpectralClustering, AgglomerativeClustering, DBSCAN, OPTICS
import numpy as np
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from matplotlib import pyplot as plt

def kmean_match(client_weights_list, n_concept, cluster_centers= None):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                # print(w[i].flatten().shape)
                weights_list.append(w[i].flatten())
        # flatten_weights.append(np.concatenate(weights_list).flat)
        flatten_list = [j for sub in weights_list for j in sub]
        # print(flatten_list)
        flatten_weights.append(flatten_list)
    kmeans = KMeans(n_clusters=n_concept, random_state=sim_seed, verbose=1, max_iter=1000)
#     kmeans = KMeans(n_clusters=n_concept, random_state=sim_seed, verbose=1, max_iter=1000)
    kmeans.fit(flatten_weights)
    # print(kmeans.cluster_centers_)
    new_centers = kmeans.cluster_centers_
    # y =  kmeans.predict(flatten_weights)
    y = kmeans.labels_
    if cluster_centers is None:
        return new_centers, y
    else:
        index_mapping =[0]*n_concept
        for i in range(n_concept):
            min_dis = float('inf')
            p1 = np.array(new_centers[i]) 
            for j in range(n_concept):
                p2 = np.array(cluster_centers[j])
                dist = np.linalg.norm(p1 - p2)
                print ('%s, %s, %s'%(i,j,dist))
                if dist<min_dis:
                    min_dis = dist
                    index_mapping[i]=j
        print(index_mapping)
        return new_centers, index_mapping[y]

    
def kmean(client_weights_list, n_concept, cluster_centers= None):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                # print(w[i].flatten().shape)
                weights_list.append(w[i].flatten())
        # flatten_weights.append(np.concatenate(weights_list).flat)
        flatten_list = [j for sub in weights_list for j in sub]
        # print(flatten_list)
        flatten_weights.append(flatten_list)
    # print(flatten_weights)
    kmeans = KMeans(n_clusters=n_concept, verbose=0, max_iter=1000)
#     kmeans = KMeans(n_clusters=n_concept, random_state=sim_seed, verbose=1, max_iter=1000)
    kmeans.fit(flatten_weights)
    y = kmeans.labels_
    # print(n_concept, y)
    return y

def agg(client_weights_list, threshold):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                # print(w[i].flatten().shape)
                weights_list.append(w[i].flatten())
        # flatten_weights.append(np.concatenate(weights_list).flat)
        flatten_list = [j for sub in weights_list for j in sub]
        # print(flatten_list)
        flatten_weights.append(flatten_list)
    kmeans = KMeans(n_clusters=n_concept, verbose=0, max_iter=1000)
#     kmeans = KMeans(n_clusters=n_concept, random_state=sim_seed, verbose=1, max_iter=1000)
    kmeans.fit(flatten_weights)
    y = kmeans.labels_
    return y

def pca_kmean(client_weights_list, n_concept, cluster_centers= None):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                # print(w[i].flatten().shape)
                weights_list.append(w[i].flatten())
        # flatten_weights.append(np.concatenate(weights_list).flat)
        flatten_list = [j for sub in weights_list for j in sub]
        # print(flatten_list)
        flatten_weights.append(flatten_list)
    kmeans = KMeans(n_clusters=n_concept, verbose=0, max_iter=1000)
#     kmeans = KMeans(n_clusters=n_concept, random_state=sim_seed, verbose=1, max_iter=1000)
    pca = PCA(n_components=4)
#     print(len(flatten_weights[0]))
    X_pca = pca.fit_transform(flatten_weights)
    print(X_pca)
    kmeans.fit(X_pca)
    y = kmeans.labels_
    return y

def birch(client_weights_list, n_concept):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                # print(w[i].flatten().shape)
                weights_list.append(w[i].flatten())
        # flatten_weights.append(np.concatenate(weights_list).flat)
        flatten_list = [j for sub in weights_list for j in sub]
        # print(flatten_list)
        flatten_weights.append(flatten_list)
    model =  Birch(threshold=0.01, n_clusters=n_concept)
    model.fit(flatten_weights)
    return model.predict(flatten_weights)

def spectral(client_weights_list,n_concept):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                weights_list.append(w[i].flatten())
        flatten_list = [j for sub in weights_list for j in sub]
        flatten_weights.append(flatten_list)
    model = SpectralClustering(n_clusters=n_concept)
    # model.fit(flatten_weights)
    return model.fit_predict(flatten_weights)

def agglomerative(client_weights_list,n_concept):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                weights_list.append(w[i].flatten())
        flatten_list = [j for sub in weights_list for j in sub]
        flatten_weights.append(flatten_list)
    agg = AgglomerativeClustering(n_clusters=n_concept)
    agg.fit(flatten_weights)
    return agg.labels_


def flatten(client_weights_list):
    flatten_weights = []
    for w in client_weights_list:
        weights_list = []
        for i in range(len(w)):
            if i%2==0:
                weights_list.append(w[i].flatten())
        flatten_list = [j for sub in weights_list for j in sub]
        flatten_weights.append(flatten_list)
    return flatten_weights

def dbscan(client_weights_list, eps):
    flatten_weights = flatten(client_weights_list)
    # neigh = NearestNeighbors(n_neighbors=3)
    # nbrs = neigh.fit(flatten_weights )
    # distances, indices = nbrs.kneighbors(flatten_weights)
    # distances = np.sort(distances, axis=0)
    # distances = distances[:,1]
    # plt.plot(distances)
    # print( distances)
    # print( indices)
    # plt.savefig('foo.png')
    clustering = DBSCAN(eps=eps, min_samples=3)
    clustering.fit(flatten_weights)
    return clustering.labels_

def optics(client_weights_list):
    flatten_weights = flatten(client_weights_list)
    clustering = OPTICS(min_samples=3)
    clustering.fit(flatten_weights)
    return clustering.labels_

def run(client_weights_list,n_concept,cluster_algo):
    print("Cluster algorithm: " + cluster_algo)
    if cluster_algo == "kmeans":
        return kmean(client_weights_list, n_concept)
    if cluster_algo == "agg":
        return agglomerative(client_weights_list, n_concept)
    if cluster_algo == "spectral":
        return spectral(client_weights_list, n_concept)
    if cluster_algo == "birch":
        return birch(client_weights_list, n_concept)
    if cluster_algo == "dbscan":
        return dbscan(client_weights_list, 20)
    if cluster_algo == "optics":
        return optics(client_weights_list)    