import numpy as np
from sklearn.cluster import KMeans, AgglomerativeClustering, SpectralClustering
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import json, os
import torch
from collections import defaultdict

def get_format_output(names, labels, file_name):
    temp = {}
    output = {}
    for name, label in zip(names, labels):
        label = int(label)
        if label not in temp:
            temp[label] = []
        temp[label].append(name)
    output[file_name] = list(temp.values())
    return output

def cluster_similarity(sims, names, file_name, save_path=None, extra_sims=None, extra_name = "nothing"):

    use_extra_sims = ""
    if extra_sims is not None:
        use_extra_sims = "Ex"
        names.remove(extra_name)

    n = sims.shape[0]
    assert sims.shape == (n, n), "输入必须是 n×n 方阵"

    dist = 1.0 - sims

    np.fill_diagonal(dist, 0.0)

    model = SpectralClustering(
        n_clusters=3,
        affinity="precomputed",
        random_state=42
    )
    labels = model.fit_predict(sims)

    cluster_sims = defaultdict(list)
    for label, sim in zip(labels, extra_sims):
        cluster_sims[label].append(sim)
    # 对每个簇计算平均相似度
    cluster_mean_sim = {}
    for label, sims in cluster_sims.items():
        cluster_mean_sim[label] = np.mean(sims)
        # print(f"对于簇 {label}，相似度为 {cluster_mean_sim[label]}")

    # 找到最相似的簇
    best_cluster = max(cluster_mean_sim, key=cluster_mean_sim.get)
    # print(f"将 '{extra_name}' 归入簇 {best_cluster}，平均相似度为 {cluster_mean_sim[best_cluster]}")

    labels = list(labels) + [best_cluster]
    names = names + [extra_name]

    # 1. 构建完整的距离矩阵
    dist_full = np.zeros((n+1, n+1))
    dist_full[:n, :n] = dist  # 原始距离矩阵
    
    total_points = n
    
    # 添加 extra_name 的距离信息
    if extra_sims is not None:
        d_extra = 1.0 - extra_sims
        dist_full[:n, n] = d_extra
        dist_full[n, :n] = d_extra
        dist_full[n, n] = 0.0
        
        total_points = n + 1  # 总点数
    
    # 2. 计算每个簇的中心点和平均距离
    unique_labels = set(labels)
    cluster_stats = {}  # 存储每个簇的统计信息
    
    total_distance = []
    
    for label in unique_labels:
        # 获取当前簇的所有点索引
        indices = [i for i, lbl in enumerate(labels) if lbl == label]
        num_points = len(indices)
        
        # 提取当前簇的距离子矩阵
        sub_dist = dist_full[np.ix_(indices, indices)]
        
        # 计算每个点到其他点的距离和
        sum_dists = np.sum(sub_dist, axis=1)
        
        # 找到中心点 (medoid)
        medoid_idx = np.argmin(sum_dists)
        global_medoid_idx = indices[medoid_idx]
        medoid_name = names[global_medoid_idx]
        
        # 计算点到中心点的平均距离
        distances_to_medoid = sub_dist[medoid_idx, :]
        avg_distance = np.mean(distances_to_medoid)

        total_distance.append(avg_distance)
    
    # 使用新的权重计算方案
    # weights = distance_i / sum_distance
    # total_distance = \sum {weights * distance_i}
    
    sum_distance = np.sum(total_distance)
    total_distance = [d*d / sum_distance for d in total_distance]
    total_weighted_distance = np.sum(total_distance)

    if save_path:

        os.makedirs(f"{save_path}/json_files", exist_ok=True)

        output = get_format_output(names, labels, file_name)
        with open(f"{save_path}/json_files/cluster_{file_name}_similarity_clustering_{use_extra_sims}.json", "w") as f:
            json.dump(output, f, indent=4)
    
    return total_weighted_distance


def cluster_similarity_virtual_point(
        sims, 
        names, 
        file_name, 
        data_maps, 
        save_path=None, 
        extra_sims=None, 
        extra_name = "nothing", 
        return_raw_distance_list=False
    ):

    use_extra_sims = ""
    if (extra_sims is not None) and (extra_name in names):
        use_extra_sims = "Ex"
        names.remove(extra_name)

    n = sims.shape[0]

    with torch.no_grad():
        dist_np = (1.0 - sims.detach().float().cpu().numpy())
        np.fill_diagonal(dist_np, 0.0)

        model = SpectralClustering(
            n_clusters=3,
            affinity="precomputed",
            random_state=42)
        labels = model.fit_predict(dist_np)          # np.ndarray, int
        labels = torch.as_tensor(labels, dtype=torch.long)

    if use_extra_sims == "Ex":
        cluster_sims = defaultdict(list)
        for lbl, s in zip(labels.cpu().tolist(), extra_sims.detach()):
            cluster_sims[lbl].append(s.item())
        best_cluster = max(cluster_sims,
                           key=lambda k: np.mean(cluster_sims[k]))
        labels = torch.cat([labels,
                            torch.tensor([best_cluster],
                                         dtype=torch.long)])
        names = names + [extra_name]

    # 2. 计算每个簇的中心点和平均距离
    unique_labels = labels.unique()

    total_loss = []  # 初始化总距离列表
    total_weighted_distance = 0.0

    for label in unique_labels:
        # 获取当前簇的所有点索引
        indices = (labels == label).nonzero().squeeze(1)
        
        # 获取当前簇的所有嵌入向量
        cluster_embs = torch.stack([data_maps[names[i]] for i in indices]).squeeze(1)  # [k, dim]
        
        # 计算簇中心（可微操作）
        centroid = cluster_embs.mean(dim=0, keepdim=True)  # [1, dim]
        
        # 计算每个点到中心的距离（可微操作）
        distances = torch.norm(cluster_embs - centroid, dim=1)  # [k]
        
        # 计算平均距离作为簇内紧密度损失
        avg_distance = distances.mean()
        
        # 加权累加损失（距离越大表示簇越不紧密）
        total_loss.append(avg_distance)

    loss_tensor = torch.stack(total_loss)
    sum_distance = loss_tensor.sum()
    
    if save_path is not None:
        embeds = torch.stack([data_maps[name] for name in names]).squeeze(1).detach().float().cpu().numpy()
        pca = PCA(n_components=2)
        embeds_2d = pca.fit_transform(embeds)
        plt.figure(figsize=(8, 6))
        for label in unique_labels:
            idx = (labels == label).detach().bool().cpu().numpy()
            plt.scatter(embeds_2d[idx, 0], embeds_2d[idx, 1], label=f'Cluster {label.item()}')
        for i, name in enumerate(names):
            plt.text(embeds_2d[i, 0], embeds_2d[i, 1], name, fontsize=8)
        plt.title(f'Clustering Visualization: {file_name}')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f"sims_{file_name}_Spectral.png"))
        plt.close()
        
        json_data = get_format_output(names, labels, file_name)
        with open(os.path.join(save_path, f"sims_{file_name}_Spectral.json"), "w") as f:
            json.dump(json_data, f, indent=4)
    
    if return_raw_distance_list:
        return total_loss
    
    total_loss = (loss_tensor ** 2) / sum_distance
    total_weighted_distance = total_loss.sum()

    return total_weighted_distance

def cluster_with_distance(
        data_maps, 
        file_name=None, 
        extra_name="None",
        save_path=None,
        algorithm="kmeans",
        return_raw_distance_list=False
    ):
    # 提取名称和嵌入向量
    names = list(data_maps.keys())
    embeds = torch.stack([data_maps[name] for name in names]).squeeze(1).detach().float().cpu().numpy()
    
    algorithm_option = {
        'kmeans': KMeans,
        'hierarchical': AgglomerativeClustering
    }

    with torch.no_grad():
        # 聚类算法
        flag = int(extra_name != "None" and extra_name in names)
        clustering = algorithm_option[algorithm](n_clusters=3+flag)
        labels = clustering.fit_predict(embeds)
        labels = torch.tensor(labels)

    unique_labels = labels.unique()
    total_loss = []
    # calculate distance in cluster
    for label in unique_labels:
        # 获取当前簇的所有点索引
        indices = (labels == label).nonzero().squeeze(1)
        # 获取当前簇的所有嵌入向量
        cluster_embs = torch.stack([data_maps[names[i]] for i in indices])  # [k, dim]
        # 计算簇中心
        centroid = cluster_embs.mean(dim=0, keepdim=True)  # [1, dim]
        # 计算每个点到中心的距离
        distances = torch.norm(cluster_embs - centroid, dim=1)  # [k]
        # 计算平均距离
        avg_distance = distances.mean()
        total_loss.append(avg_distance)

    loss_tensor = torch.stack(total_loss)
    sum_distance = loss_tensor.sum()
    
    if save_path is not None:
        pca = PCA(n_components=2)
        embeds_2d = pca.fit_transform(embeds)   # n, 2
        plt.figure(figsize=(8, 6))
        for label in unique_labels:
            idx = (labels == label).detach().bool().cpu().numpy()
            plt.scatter(embeds_2d[idx, 0], embeds_2d[idx, 1], label=f'Cluster {label.item()}')
        for i, name in enumerate(names):
            plt.text(embeds_2d[i, 0], embeds_2d[i, 1], name, fontsize=8)
        plt.title(f'Clustering Visualization: {file_name}')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f"dis_{file_name}_{algorithm}.png"))
        plt.close()
        
        json_data = get_format_output(names, labels, file_name)
        with open(os.path.join(save_path, f"dis_{file_name}_{algorithm}.json"), "w") as f:
            json.dump(json_data, f, indent=4, ensure_ascii=False)
    
    if return_raw_distance_list:
        return total_loss
    
    total_loss = (loss_tensor ** 2) / sum_distance
    total_weighted_distance = total_loss.sum()
    
    return total_weighted_distance