import numpy as np
import torch
from sklearn import cluster
from sklearn.preprocessing import normalize
from sklearn.utils import check_random_state, check_array, check_symmetric
from scipy.linalg import orth
import scipy.sparse as sparse
from scipy.stats import entropy
from munkres import Munkres
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.cluster import _supervised as supervised

import numpy as np

def clustering_accuracy(labels_true, labels_pred):

    labels_true, labels_pred = supervised.check_clusterings(labels_true, labels_pred)
    # value = supervised.contingency_matrix(labels_true, labels_pred, sparse=False)
    value = supervised.contingency_matrix(labels_true, labels_pred)
    [r, c] = linear_sum_assignment(-value)
    return value[r, c].sum() / len(labels_true)
  
 
def self_representation_loss(labels_true, representation_matrix):
    n_samples = labels_true.shape[0]
    loss = 0.0
    for i in range(n_samples):
        representation_vec = np.abs(representation_matrix[i, :])
        label = labels_true[i]
        loss += np.sum(representation_vec[labels_true != label]) / np.sum(representation_vec)
    
    return loss / n_samples


def regularizer_pnorm(c, p):
    return torch.pow(torch.abs(c), p).sum()


def sklearn_predict(A, n_clusters):
    spec = cluster.SpectralClustering(n_clusters=n_clusters, affinity='precomputed')
    res = spec.fit_predict(A)
    return res


def accuracy(pred, labels):
    err = err_rate(labels, pred)
    acc = 1 - err
    return acc


def subspace_preserving_error(A, labels, n_clusters):
    one_hot_labels = torch.zeros([A.shape[0], n_clusters])
    for i in range(A.shape[0]):
        one_hot_labels[i][labels[i]] = 1.0
    mask = one_hot_labels.matmul(one_hot_labels.T)
    l1_norm = torch.norm(A, p=1, dim=1)
    masked_l1_norm = torch.norm(mask * A, p=1, dim=1)
    e = torch.mean((1. - masked_l1_norm / l1_norm)) * 100.
    return e


def normalized_laplacian(A):
    D = torch.sum(A, dim=1)
    D_sqrt = torch.diag(1.0 / torch.sqrt(D))
    L = torch.eye(A.shape[0]) - D_sqrt.matmul(A).matmul(D_sqrt)
    return L


def connectivity(A, labels, n_clusters):
    c = []
    for i in range(n_clusters):
        A_i = A[labels == i][:, labels == i]
        L_i = normalized_laplacian(A_i)
        eig_vals, _ = torch.symeig(L_i)
        c.append(eig_vals[1])
    return np.min(c)


def topK(A, k, sym=True):

    val, indicies = torch.topk(A, dim=1, k=k)
    Coef = torch.zeros_like(A).scatter_(1, indicies, val)
    if sym:
        Coef = (Coef + Coef.t()) / 2.0
    return Coef


def best_map(L1, L2):
    """
    Rearrange the cluster label to minimize the error rate using the Kuhn-Munkres algorithm.
    Fetched from https://github.com/panji1990/Deep-subspace-clustering-networks

    Args:
        L1 (list): ground truth label.
        L2 (list): clustering result.
    Return:
        (list): rearranged predicted result.
    """
    Label1 = np.unique(L1)
    nClass1 = len(Label1)
    Label2 = np.unique(L2)
    nClass2 = len(Label2)
    nClass = np.maximum(nClass1, nClass2)
    G = np.zeros((nClass, nClass))
    for i in range(nClass1):
        ind_cla1 = L1 == Label1[i]
        ind_cla1 = ind_cla1.astype(float)
        for j in range(nClass2):
            ind_cla2 = L2 == Label2[j]
            ind_cla2 = ind_cla2.astype(float)
            G[i, j] = np.sum(ind_cla2 * ind_cla1)
    m = Munkres()
    index = m.compute(-G.T)
    index = np.array(index)
    c = index[:, 1]
    newL2 = np.zeros(L2.shape)
    for i in range(nClass2):
        newL2[L2 == Label2[i]] = Label1[c[i]]
    return newL2


def err_rate(gt_s, s):
    """
    Get error rate of the cluster result.
    Fetched from https://github.com/panji1990/Deep-subspace-clustering-networks
    Args:
        gt_s (list): ground truth label.
        s (list): clustering result.
    Return:
        (float): clustering error.
    """
    c_x = best_map(gt_s, s)
    err_x = np.sum(gt_s[:] != c_x[:])
    missrate = err_x.astype(float) / (gt_s.shape[0])
    return missrate
def p_normalize(x, p=2):
    return x / (torch.norm(x, p=p, dim=1, keepdim=True) + 1e-6)


def minmax_normalize(x, p=2):
    rmax, _ = torch.max(x, dim=1, keepdim=True)
    rmin, _ = torch.min(x, dim=1, keepdim=True)
    x = (x - rmin) / (rmax - rmin)
    return x


import torch
from sklearn.utils.validation import check_symmetric
import scipy.sparse as sparse
from sklearn import cluster


def spectral_clustering(affinity_matrix_, n_clusters, k, seed=1, n_init=20):
    # 检查是否为 PyTorch 稀疏张量，如果是则转换为稠密张量
    if isinstance(affinity_matrix_, torch.Tensor) and affinity_matrix_.is_sparse:
        affinity_matrix_ = affinity_matrix_.to_dense()  # 转换为稠密张量
    affinity_matrix_ = check_symmetric(affinity_matrix_)

    laplacian = sparse.csgraph.laplacian(affinity_matrix_, normed=True)
    _, vec = sparse.linalg.eigsh(sparse.identity(laplacian.shape[0]) - laplacian,
                                 k=k, sigma=None, which='LA')
    embedding = normalize(vec)
    _, labels_, _ = cluster.k_means(embedding, n_clusters,
                                    random_state=seed, n_init=n_init)
    return labels_

def random_walk_correction(embeddings, t=3, alpha=0.2, sigma=1.0):
    N = embeddings.size(0)

    # 计算亲和矩阵 A
    dot_product = torch.mm(embeddings, embeddings.t())
    squared_norm = torch.diag(dot_product).unsqueeze(1)
    dist_squared = squared_norm - 2.0 * dot_product + squared_norm.t()
    A = torch.exp(-sigma * dist_squared)

    # 归一化为转移矩阵 M
    D = torch.sum(A, dim=1)
    D_inv = 1.0 / D
    D_inv[D_inv == float('inf')] = 0  # 处理可能的除零情况
    M = A * D_inv.unsqueeze(1)  # 逐元素乘法进行归一化

    # 计算t步转移矩阵 M^t
    Mt = M.clone()
    for _ in range(t - 1):
        Mt = torch.mm(Mt, M)

    # 修正目标分布 T
    I = torch.eye(N, device=embeddings.device)
    T = alpha * I + (1 - alpha) * Mt

    return T

def multi_scale_diffusion(Aff_norm, max_steps=3, top_k=10, temperature=0.2):
    n = Aff_norm.shape[0]
    diff_matrices = [Aff_norm.copy()]
    current_aff = Aff_norm

    for step in range(1, max_steps + 1):
        if temperature != 1.0:
            scaled_aff = Aff_norm.copy()
            scaled_aff.data = np.power(scaled_aff.data, 1.0 / temperature)
            scaled_aff = normalize(scaled_aff, norm='l1', axis=1)
            next_aff = current_aff.dot(scaled_aff)
        else:
            next_aff = current_aff.dot(Aff_norm)

        next_aff = keep_top_k(next_aff, top_k)
        next_aff = normalize(next_aff, norm='l1', axis=1)
        diff_matrices.append(next_aff.copy())
        current_aff = next_aff

    weights = []
    for diff_mat in diff_matrices:
        row_entropies = []
        for i in range(n):
            row = diff_mat[i].toarray().flatten()
            nonzero_row = row[row > 0]
            if len(nonzero_row) > 0:
                row_entropy = entropy(nonzero_row)
                row_entropies.append(row_entropy)
            else:
                row_entropies.append(0)

        avg_entropy = np.mean(row_entropies)
        weight = 1.0 / avg_entropy if avg_entropy > 0 else 1.0
        weights.append(weight)

    weights = np.array(weights) / np.sum(weights)
    result = sparse.csr_matrix(diff_matrices[0].shape, dtype=np.float32)
    for mat, weight in zip(diff_matrices, weights):
        result = result + weight * mat

    result = 0.5 * (result + result.T)
    diag_indices = np.arange(n)
    result[diag_indices, diag_indices] = result[diag_indices, diag_indices] * 1.2

    return result
def keep_top_k(matrix, k):

    result = sparse.lil_matrix(matrix.shape, dtype=matrix.dtype)


    for i in range(matrix.shape[0]):
        row = matrix[i]
        if sparse.issparse(row):
            row = row.toarray().flatten()

        if np.count_nonzero(row) > k:

            nonzero_indices = np.nonzero(row)[0]
            nonzero_values = row[nonzero_indices]

            if len(nonzero_indices) > k:
                threshold = np.sort(nonzero_values)[-k]
                keep_mask = nonzero_values >= threshold
                keep_indices = nonzero_indices[keep_mask]
                keep_values = nonzero_values[keep_mask]

                if len(keep_indices) > k:
                    sorted_indices = np.argsort(-keep_values)
                    keep_indices = keep_indices[sorted_indices[:k]]
                    keep_values = keep_values[sorted_indices[:k]]
            else:
                keep_indices = nonzero_indices
                keep_values = nonzero_values


            for idx, val in zip(keep_indices, keep_values):
                result[i, idx] = val
        else:

            if sparse.issparse(row):
                for j, v in zip(*row.nonzero(), row.data):
                    result[i, j] = v
            else:
                nonzero_indices = np.nonzero(row)[0]
                for idx in nonzero_indices:
                    result[i, idx] = row[idx]

    return result.tocsr()