import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment

def cluster_acc(y_true, y_pred):
    """
    Compute clustering accuracy with the Hungarian algorithm (best matching).
    y_true and y_pred are 1D numpy arrays of labels.
    """
    y_true = y_true.astype(np.int64)
    y_pred = y_pred.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return w[row_ind, col_ind].sum() * 1.0 / y_pred.size

def purity_score(y_true, y_pred):
    """
    Purity metric
    """
    y_true = y_true.astype(np.int64)
    y_pred = y_pred.astype(np.int64)
    D = max(y_pred.max(), y_true.max()) + 1
    contingency_matrix = np.zeros((D, D), dtype=np.int64)
    for i in range(len(y_pred)):
        contingency_matrix[y_pred[i], y_true[i]] += 1
    return np.sum(np.amax(contingency_matrix, axis=1)) / np.sum(contingency_matrix)

def get_cluster_labels_and_metrics(fused_embedding, true_labels, n_clusters, device):
    """
    fused_embedding: torch.Tensor [N, d] 融合后特征
    true_labels: numpy.array [N]
    n_clusters: 聚类类别数
    """
    from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score

    fused_np = fused_embedding.detach().cpu().numpy()
    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    pred_labels = kmeans.fit_predict(fused_np)

    acc = cluster_acc(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    ari = adjusted_rand_score(true_labels, pred_labels)
    pur = purity_score(true_labels, pred_labels)

    print(f'Clustering ACC: {acc:.4f}')
    print(f'Clustering NMI: {nmi:.4f}')
    print(f'Clustering ARI: {ari:.4f}')
    print(f'Clustering PUR: {pur:.4f}')

    return pred_labels, {'ACC': acc, 'NMI': nmi, 'ARI': ari, 'PUR': pur}
