import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.svm import OneClassSVM
from utils.finch import FINCH
import torch
from sklearn.metrics import pairwise_distances_argmin_min

def k_means_clustering(feas, clusters_num):
    # 假设你的数据维度是 (128, 16, 768)

    # 将数据重新调整为 (128 * 16, 768) 的形状
    reshaped_data = feas.reshape(len(feas), -1)


    # 使用 KMeans 进行聚类，设置 n_clusters=5
    kmeans = KMeans(n_clusters=5)
    kmeans.fit(reshaped_data)

    # 获取聚类中心
    centers = kmeans.cluster_centers_
    closest_points, _ = pairwise_distances_argmin_min(centers, reshaped_data)

    # 将中心点重新调整为 (5, 16, 768) 的形状
    centers_reshaped = []
    for c in closest_points:
        centers_reshaped.append(reshaped_data[c].reshape(16, 768))
    centers_reshaped = np.stack(centers_reshaped)
    average_center = np.mean(centers_reshaped, axis=0)
    # 打印中心点的形状1

    # 返回中心点
    return average_center, centers_reshaped

    # 2. 使用 PCA 或 t-SNE 降维 (可以选择一种)
    # 使用 PCA 降维
    # pca = PCA(n_components=2)
    # X_pca = pca.fit_transform(X)

    # # 或者使用 t-SNE 降维
    # tsne = TSNE(n_components=2)
    # X_tsne = tsne.fit_transform(feas)
    #
    # # 3. 绘制聚类结果
    # plt.figure(figsize=(8, 6))
    # plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_kmeans, cmap='viridis', s=50)
    # plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1],
    #             s=200, c='red', label='Centroids')
    # plt.title('K-Means Clustering')
    # plt.xlabel('TNSE Component 1')
    # plt.ylabel('TNSE Component 2')
    # plt.legend()
    # plt.savefig('kmeans_clustering_result.png')  # 保存为 PNG 文件
    # # plt.savefig('kmeans_clustering_result.jpg')  # 保存为 JPG 文件，若需要使用此格式
    # plt.show()

def SVDD_clustering(feas, clusters_num):
    # 1. 使用 KMeans 进行初始的簇分配
    n_clusters = clusters_num  # 假设我们希望划分为 5 个簇
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(feas)

    # 2. 对每个簇训练一个独立的 One-Class SVM 模型 (即 SVDD 模型)
    svdd_models = []
    for i in range(n_clusters):
        # 选择属于该簇的数据点
        cluster_data = feas[cluster_labels == i]
        svdd_model = OneClassSVM(kernel='rbf', gamma='auto')
        svdd_model.fit(cluster_data)
        svdd_models.append(svdd_model)

    # 3. 为每个数据点找到最适合的球体
    def assign_to_svdd_cluster(X, svdd_models):
        # 对于每个数据点，计算它到每个球体的决策边界距离
        cluster_assignments = []
        for x in X:
            distances = [model.decision_function([x])[0] for model in svdd_models]
            assigned_cluster = np.argmax(distances)  # 距离最大的球体（最靠近）
            cluster_assignments.append(assigned_cluster)
        return np.array(cluster_assignments)

    # 对整个数据集进行聚类分配
    final_cluster_labels = assign_to_svdd_cluster(feas, svdd_models)

    # print(final_cluster_labels)
    # # # 4. 降维以便于可视化：PCA 降到50维加速 t-SNE，再用 t-SNE 进一步降到2维
    # # pca = PCA(n_components=50)
    # # X_pca = pca.fit_transform(feas)
    #
    # tsne = TSNE(n_components=3, random_state=42, perplexity=30)
    # X_tsne = tsne.fit_transform(feas)
    #
    # # 5. 可视化结果：显示每个点所属的簇
    # plt.figure(figsize=(8, 6))
    # plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=final_cluster_labels, cmap='viridis', s=50)
    # plt.title('Cluster-SVDD with Multiple Clusters (t-SNE Visualization)')
    # plt.xlabel('t-SNE Component 1')
    # plt.ylabel('t-SNE Component 2')
    # plt.colorbar(label='Cluster Label')
    # plt.savefig('SVDD_clustering_result.png')  # 保存为 PNG 文件

    centers = []
    for i, model in enumerate(svdd_models):
        support_vectors = model.support_vectors_  # 提取支持向量
        support_center = np.mean(support_vectors, axis=0)
        centers.append(support_center)
    return np.array(centers)

import numpy as np
from sklearn.cluster import DBSCAN

def DBSCAN_clustering(feas, eps=0.5, min_samples=5):
    # 使用DBSCAN进行聚类
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    labels = dbscan.fit_predict(feas)

    # 计算每个簇的中心
    cluster_centers = {}
    for label in set(labels):
        if label != -1:  # 忽略噪声点
            cluster_points = feas[labels == label]
            cluster_center = np.mean(cluster_points, axis=0)
            cluster_centers[label] = cluster_center

    # 找到每个簇中离中心最远的 top 3 样本
    top_3_farthest = {}
    for label, center in cluster_centers.items():
        cluster_points = feas[labels == label]
        distances = np.linalg.norm(cluster_points - center, axis=1)
        farthest_indices = np.argsort(distances)[-3:]  # 获取距离最大的三个样本的索引
        top_3_farthest[label] = cluster_points[farthest_indices]

    return cluster_centers, top_3_farthest, labels

def FINCH_clustering(feas, ):
    c, num_clust, req_c = FINCH(feas, initial_rank=None, req_clust=None, distance='cosine',
                                ensure_early_exit=False, verbose=False)
    labels = c[:, -1]

    # 计算每个簇的中心
    cluster_centers = {}
    for label in set(labels):
        cluster_points = feas[labels == label]
        cluster_center = torch.mean(cluster_points, dim=0)
        cluster_centers[label] = cluster_center
    top_3_farthest = {}

    for label, center in cluster_centers.items():
        cluster_points = feas[labels == label]
        distances = torch.norm(cluster_points - center, dim=1)
        farthest_indices = torch.argsort(distances)[-3:]  # 获取距离最大的三个样本的索引
        top_3_farthest[label] = cluster_points[farthest_indices]

    local_mean_feature = torch.mean(torch.stack([cluster_centers[label] for label in labels]), dim=0)
    local_features = []
    for label, center in cluster_centers.items():
        for out_liner in top_3_farthest[label]:
            local_features.append(0.5*center + 0.5*out_liner)
    return local_mean_feature, local_features