import numpy as np
from sklearn.cluster import KMeans, AgglomerativeClustering, SpectralClustering
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import json, os
from collections import defaultdict

def get_format_output(names, labels, file_name):
    temp = {}
    output = {}
    for name, label in zip(names, labels):
        if label not in temp:
            temp[label] = []
        temp[label].append(name)
    output[file_name.split("_")[-1]] = list(temp.values())
    return output

def cluster_and_visualize(data_maps, file_name, save_path):
    # 提取名称和嵌入向量
    names = list(data_maps.keys())
    embeds = np.stack([data_maps[name].numpy() for name in names])

    # 聚类算法
    kmeans = KMeans(n_clusters=3, random_state=42)
    kmeans_labels = kmeans.fit_predict(embeds)
    
    agglo = AgglomerativeClustering(n_clusters=3)
    agglo_labels = agglo.fit_predict(embeds)

    # 可视化
    pca = PCA(n_components=2)
    embeds_2d = pca.fit_transform(embeds)

    plt.figure(figsize=(10, 4))

    # K-means
    plt.subplot(1, 2, 1)
    plt.scatter(embeds_2d[:, 0], embeds_2d[:, 1], c=kmeans_labels, cmap='tab10')
    plt.title("K-means")
    for i, name in enumerate(names):
        plt.text(embeds_2d[i, 0], embeds_2d[i, 1], name, fontsize=8)

    # Agglomerative
    plt.subplot(1, 2, 2)
    plt.scatter(embeds_2d[:, 0], embeds_2d[:, 1], c=agglo_labels, cmap='tab10')
    plt.title("Agglomerative")
    for i, name in enumerate(names):
        plt.text(embeds_2d[i, 0], embeds_2d[i, 1], name, fontsize=8)

    plt.tight_layout()
    plt.savefig(f"{save_path}/dry_{file_name}_embedding_clustering_{num_clusterd}.png")
    plt.close()

    os.makedirs(f"{save_path}/json_files", exist_ok=True)

    output = get_format_output(names, kmeans_labels, file_name)

    with open(f"{save_path}/json_files/dry_{file_name}_embedding_clustering_{num_clusterd}_kmeans.json", "w") as f:
        json.dump(output, f, indent=4)
    
    output = get_format_output(names, agglo_labels, file_name)

    with open(f"{save_path}/json_files/dry_{file_name}_embedding_clustering_{num_clusterd}_agglo.json", "w") as f:
        json.dump(output, f, indent=4)

def cluster_similarity(sims, names, file_name, save_path, extra_sims=None):

    use_extra_sims = ""
    if extra_sims is not None:
        use_extra_sims = "Ex"
        names.remove("nothing")

    n = sims.shape[0]
    assert sims.shape == (n, n), "输入必须是 n×n 方阵"

    dist = 1.0 - sims

    np.fill_diagonal(dist, 0.0)

    model = SpectralClustering(
        n_clusters=3,
        affinity="precomputed",
        random_state=42
    )
    labels = model.fit_predict(sims)

    cluster_sims = defaultdict(list)
    for label, sim in zip(labels, extra_sims):
        cluster_sims[label].append(sim)
    # 对每个簇计算平均相似度
    cluster_mean_sim = {}
    for label, sims in cluster_sims.items():
        cluster_mean_sim[label] = np.mean(sims)
        print(f"对于簇 {label}，相似度为 {cluster_mean_sim[label]}")

    # 找到最相似的簇
    best_cluster = max(cluster_mean_sim, key=cluster_mean_sim.get)
    print(f"将 'nothing' 归入簇 {best_cluster}，平均相似度为 {cluster_mean_sim[best_cluster]}")

    labels = list(labels) + [best_cluster]
    names = names + ["nothing"]

    tsne = PCA(n_components=2)
    coords_2d = tsne.fit_transform(dist)

    # 绘图
    if extra_sims is None:
        plt.figure(figsize=(6, 5))
        plt.scatter(coords_2d[:, 0], coords_2d[:, 1], c=labels, cmap='tab10')
        for i in range(n):
            plt.text(coords_2d[i, 0], coords_2d[i, 1], f"{names[i]}", fontsize=8)
        plt.title("Spectral Clustering (Similarity Matrix)")
        plt.tight_layout()
        plt.savefig(f"{save_path}/dry_{file_name}_similarity_clustering_{use_extra_sims}.png")
        plt.close()

    os.makedirs(f"{save_path}/json_files", exist_ok=True)

    output = get_format_output(names, labels, file_name)
    with open(f"{save_path}/json_files/dry_{file_name}_similarity_clustering_{use_extra_sims}.json", "w") as f:
        json.dump(output, f, indent=4)