import numpy as np
from sklearn.cluster import SpectralClustering
from sklearn.decomposition import PCA
from scipy.optimize import linear_sum_assignment
import torch
from tqdm import trange

def Clustering(K, private_embeddings, private_labels, syn_embeddings, syn_labels):
    print("Start clustering...")
    cluster_labels = []
    syn_cluster_labels = []

    for C in trange(10):
        C_embeddings = private_embeddings[private_labels == C]
        dimensional_reduction = PCA(n_components=2)
        C_embeddings_reduced = dimensional_reduction.fit_transform(C_embeddings)
        spectral_clustering = SpectralClustering(n_clusters=K, affinity='nearest_neighbors', random_state=42)
        labels = spectral_clustering.fit_predict(C_embeddings_reduced)
        cluster_labels.append({'embeddings': C_embeddings_reduced, 'labels': labels})

    for C in trange(10):
        C_syn_embeddings = syn_embeddings[syn_labels == C]
        dimensional_reduction = PCA(n_components=2)
        C_syn_embeddings_reduced = dimensional_reduction.fit_transform(C_syn_embeddings)
        spectral_clustering = SpectralClustering(n_clusters=K, affinity='nearest_neighbors', random_state=42)
        labels = spectral_clustering.fit_predict(C_syn_embeddings_reduced)
        syn_cluster_labels.append({'embeddings': C_syn_embeddings_reduced, 'labels': labels})

    cluster_labels_final, syn_cluster_labels_final = match_clusters(K, cluster_labels, syn_cluster_labels)
    
    return cluster_labels_final, syn_cluster_labels_final

def match_clusters(K, cluster_labels, syn_cluster_labels):
    for C in range(10):
        private_data = cluster_labels[C]
        syn_data = syn_cluster_labels[C]

        private_centroids = np.array([private_data['embeddings'][private_data['labels'] == k].mean(axis=0) for k in range(K)])
        syn_centroids = np.array([syn_data['embeddings'][syn_data['labels'] == k].mean(axis=0) for k in range(K)])

        cost_matrix = np.linalg.norm(private_centroids[:, np.newaxis] - syn_centroids[np.newaxis, :], axis=2)
        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        label_mapping = {col: row for row, col in zip(row_ind, col_ind)}
        new_syn_labels = np.array([label_mapping[label] for label in syn_data['labels']])
        syn_cluster_labels[C]['labels'] = new_syn_labels

    cluster_labels_final = [torch.LongTensor(data['labels']) for data in cluster_labels]
    syn_cluster_labels_final = [torch.LongTensor(data['labels']) for data in syn_cluster_labels]

    return cluster_labels_final, syn_cluster_labels_final
