import numpy as np
import random
import torch
import torch.nn.functional as F

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)

    # 确保 CUDA 计算的随机性一致。
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.manual_seed(seed)  # 设置 PyTorch 随机种子。
    torch.backends.cudnn.deterministic = True  # 保证 CUDA 计算的确定性。
    torch.backends.cudnn.benchmark = False  # 关闭自动优化搜索，以保证一致性。


def distance(X, Y, square=True):
    """
    Compute Euclidean distances between two sets of samples
    Basic framework: pytorch
    :param X: d * n, where d is dimensions and n is number of data points in X
    :param Y: d * m, where m is number of data points in Y
    :param square: whether distances are squared, default value is True
    :return: n * m, distance matrix
    """
    n = X.shape[1]
    m = Y.shape[1]
    x = torch.norm(X, dim=0)
    x = x * x  # n * 1
    x = torch.t(x.repeat(m, 1))

    y = torch.norm(Y, dim=0)
    y = y * y  # m * 1
    y = y.repeat(n, 1)

    crossing_term = torch.t(X).matmul(Y)
    result = x + y - 2 * crossing_term
    result = result.relu()
    if not square:
        result = torch.sqrt(result)
    return result


def cal_weights_via_CAN(X, num_neighbors, links=0):
    size = X.shape[1]
    distances = distance(X, X)
    distances = torch.max(distances, torch.t(distances))
    sorted_distances, _ = distances.sort(dim=1)
    top_k = sorted_distances[:, num_neighbors]
    # top_k是每个样本和它最远邻居的距离
    top_k = torch.t(top_k.repeat(size, 1)) + 10**-10

    # sum_top_k是每个样本和它前num_neighbors个邻居的距离之和
    sum_top_k = torch.sum(sorted_distances[:, 0:num_neighbors], dim=1)
    sum_top_k = torch.t(sum_top_k.repeat(size, 1))

    sorted_distances = None
    torch.cuda.empty_cache()
    # T[i,j] 表示第 i 个数据点与第 j 个数据点之间的接近程度（权重），离得越近，T值越大
    T = top_k - distances
    distances = None
    torch.cuda.empty_cache()
    # weights相当于对T做归一化
    weights = torch.div(T, num_neighbors * top_k - sum_top_k)
    T = None
    top_k = None
    sum_top_k = None
    torch.cuda.empty_cache()
    weights = weights.relu().cpu()
    if links != 0:
        links = torch.Tensor(links).cuda()
        weights += torch.eye(size).cuda()
        weights += links
        weights /= weights.sum(dim=1).reshape([size, 1])
    torch.cuda.empty_cache()
    raw_weights = weights
    weights = (weights + weights.t()) / 2
    raw_weights = raw_weights.cuda()
    weights = weights.cuda()
    return weights, raw_weights


def get_Laplacian_from_weights(weights):
    # W = torch.eye(weights.shape[0]).cuda() + weights
    # degree = torch.sum(W, dim=1).pow(-0.5)
    # return (W * degree).t()*degree
    degree = torch.sum(weights, dim=1).pow(-0.5)
    return (weights * degree).t() * degree


def update_graph(embedding_mv, num_neighbors):
    with torch.no_grad():
        weights_mv, raw_weights_mv, laplacian_mv = [], [], []
        for v in range(len(embedding_mv)):
            x = embedding_mv[v]
            weights, raw_weights = cal_weights_via_CAN(x.t(), num_neighbors)
            Laplacian = get_Laplacian_from_weights(weights)
            # Laplacian = Laplacian.to_sparse()
            weights_mv.append(weights)
            raw_weights_mv.append(raw_weights)
            laplacian_mv.append(Laplacian)
        return weights_mv, raw_weights_mv, laplacian_mv

def target_distribution(Q):
    """构造目标分布 P，放大高置信度的样本，使模型更聚类明确"""
    weight = Q ** 2 / Q.sum(0)
    return (weight.t() / weight.sum(1)).t()

def soft_kmeans_assign(Z, n_clusters):
    """输入嵌入Z (N x D)，输出软标签Q (N x K)"""
    from sklearn.cluster import KMeans
    with torch.no_grad():
        kmeans = KMeans(n_clusters=n_clusters, n_init=10)
        cluster_centers = kmeans.fit(Z.cpu().numpy()).cluster_centers_
        Z_expand = Z.unsqueeze(1)  # (N, 1, D)
        centers = torch.tensor(cluster_centers, dtype=Z.dtype, device=Z.device).unsqueeze(0)  # (1, K, D)
        dist = ((Z_expand - centers) ** 2).sum(-1)  # (N, K)
        Q = 1.0 / (1.0 + dist)  # 作为相似度
        Q = Q / Q.sum(1, keepdim=True)  # 归一化为概率
    return Q
