"""
@Description :   标签传播图聚类
@Author      :   tqychy 
@Time        :   2025/02/28 20:28:40
"""

import random
from collections import Counter, defaultdict
from tqdm import tqdm

import numpy as np


# def graph_cluster(v_num: int, e_pairs: np.ndarray, max_iter=100):
#     """
#     标签传播算法，优化后使每个聚类大小尽量平衡。
    
#     Args:
#         v_num: 图中点的数量
#         e_pairs: 图中的边，形状 N * 2，每行是一个边 (u, v)
#         max_iter: 最大迭代次数
    
#     Returns:
#         clusters: 聚类结果，list of lists，每个子列表是一个聚类中的节点
#         edges: 聚类内的边，list of np.ndarray，每个数组是形状 N * 2 的边集
#     """
#     # 构建邻接表
#     adj_list = defaultdict(list)
#     for u, v in e_pairs:
#         adj_list[u].append(v)
#         adj_list[v].append(u)  # 无向图，双向添加边
    
#     # 初始化标签和聚类大小
#     labels = np.arange(v_num)  # 每个节点初始标签为其索引
#     cluster_size = {i: 1 for i in range(v_num)}  # 初始每个标签对应一个节点

#     # 迭代更新
#     for _ in range(max_iter):
#         changed = False
#         # 随机打乱节点顺序
#         nodes = np.arange(v_num)
#         np.random.shuffle(nodes)
        
#         for node in nodes:
#             # 如果节点没有邻居，跳过
#             neighbors = adj_list.get(node, [])
#             if not neighbors:
#                 continue
            
#             # 统计邻居标签的频率
#             neighbor_labels = [labels[n] for n in neighbors]
#             label_freq = Counter(neighbor_labels)
            
#             # 计算每个候选标签的得分
#             scores = {}
#             for label, freq in label_freq.items():
#                 size = cluster_size.get(label, 0)
#                 # scores[label] = freq / (size + 1)  # 得分公式，平衡聚类大小
#                 scores[label] = freq
            
#             # 选择得分最高的标签
#             max_score = max(scores.values())
#             candidates = [label for label, score in scores.items() if score == max_score]
#             new_label = random.choice(candidates)
            
#             # 更新标签和聚类大小
#             if new_label != labels[node]:
#                 changed = True
#                 old_label = labels[node]
#                 cluster_size[old_label] -= 1  # 旧标签聚类大小减1
#                 if new_label not in cluster_size:
#                     cluster_size[new_label] = 0
#                 cluster_size[new_label] += 1  # 新标签聚类大小加1
#                 labels[node] = new_label
        
#         # 如果没有标签变化，提前结束
#         if not changed:
#             break
    
#     # 生成聚类结果
#     clusters = defaultdict(list)
#     for node in range(v_num):
#         clusters[labels[node]].append(node)
#     clusters = list(clusters.values())
    
#     # 生成每个聚类的边集
#     edges = []
#     for cluster in clusters:
#         cluster_set = set(cluster)  # 将聚类转为集合，便于快速检查
#         cluster_edges = []
#         for u, v in e_pairs:
#             # 如果边的两个端点都在当前聚类内，加入边集
#             if u in cluster_set and v in cluster_set:
#                 cluster_edges.append([u, v])
#         # 转换为 NumPy 数组
#         edges.append(np.array(cluster_edges) if cluster_edges else np.empty((0, 2), dtype=int))
    
#     return clusters, edges


def graph_cluster(v_num: int, e_pairs: np.ndarray, max_iter=100, num_runs=5):
    """
    标签传播算法，优化后使每个聚类大小尽量平衡。使用不同随机种子运行多次，通过投票得到最终聚类结果。
    
    Args:
        v_num: 图中点的数量
        e_pairs: 图中的边，形状 N * 2，每行是一个边 (u, v)
        max_iter: 最大迭代次数
        num_runs: 运行次数，用于投票
    
    Returns:
        clusters: 聚类结果，list of lists，每个子列表是一个聚类中的节点
        edges: 聚类内的边，list of np.ndarray，每个数组是形状 N * 2 的边集
    """
    # 保存当前随机种子
    original_seed = random.getstate()

    # 设置多个随机种子
    seeds = [42, 123, 456, 789, 101112][:num_runs]

    # 多次运行标签传播算法，收集每次的标签结果
    all_labels = []
    with tqdm(total=num_runs, leave=False, desc="聚类") as pbar:
        for seed in seeds:
            random.seed(seed)
            np.random.seed(seed)

            # 构建邻接表
            adj_list = defaultdict(list)
            for u, v in e_pairs:
                adj_list[u].append(v)
                adj_list[v].append(u)  # 无向图，双向添加边
            
            # 初始化标签和聚类大小
            labels = np.arange(v_num)  # 每个节点初始标签为其索引
            cluster_size = {i: 1 for i in range(v_num)}  # 初始每个标签对应一个节点

            # 迭代更新
            for _ in range(max_iter):
                changed = False
                # 随机打乱节点顺序
                nodes = np.arange(v_num)
                np.random.shuffle(nodes)
                
                for node in nodes:
                    # 如果节点没有邻居，跳过
                    neighbors = adj_list.get(node, [])
                    if not neighbors:
                        continue
                    
                    # 统计邻居标签的频率
                    neighbor_labels = [labels[n] for n in neighbors]
                    label_freq = Counter(neighbor_labels)
                    
                    # 计算每个候选标签的得分
                    scores = {}
                    for label, freq in label_freq.items():
                        # size = cluster_size.get(label, 0)
                        scores[label] = freq  # 原得分公式，可改为 freq / (size + 1) 以平衡聚类大小
                    
                    # 选择得分最高的标签
                    max_score = max(scores.values())
                    candidates = [label for label, score in scores.items() if score == max_score]
                    new_label = random.choice(candidates)
                    
                    # 更新标签和聚类大小
                    if new_label != labels[node]:
                        changed = True
                        old_label = labels[node]
                        cluster_size[old_label] -= 1  # 旧标签聚类大小减1
                        if new_label not in cluster_size:
                            cluster_size[new_label] = 0
                        cluster_size[new_label] += 1  # 新标签聚类大小加1
                        labels[node] = new_label
                
                # 如果没有标签变化，提前结束
                if not changed:
                    break
            
            # 将本次运行的标签结果保存
            all_labels.append(labels.copy())
            pbar.update(1)

    # 投票机制：确定每个节点的最常见标签
    final_labels = np.zeros(v_num, dtype=int)
    for node in range(v_num):
        node_labels = [labels[node] for labels in all_labels]
        most_common_label = Counter(node_labels).most_common(1)[0][0]
        final_labels[node] = most_common_label

    # 生成最终聚类结果
    clusters = defaultdict(list)
    for node in range(v_num):
        clusters[final_labels[node]].append(node)
    clusters = list(clusters.values())

    # 生成每个聚类的边集
    edges = []
    for cluster in clusters:
        cluster_set = set(cluster)  # 将聚类转为集合，便于快速检查
        cluster_edges = []
        for u, v in e_pairs:
            # 如果边的两个端点都在当前聚类内，加入边集
            if u in cluster_set and v in cluster_set:
                cluster_edges.append([u, v])
        # 转换为 NumPy 数组
        edges.append(np.array(cluster_edges) if cluster_edges else np.empty((0, 2), dtype=int))

    # 恢复原始随机种子
    random.setstate(original_seed)
    # np.random.seed(original_seed)

    return clusters, edges