import torch
import torch.nn.functional as F
import pandas as pd, os, matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering

@torch.no_grad()
def ClusterRank_B_Avg(self, route_aggregation, params):
    """
    ReweightBRank_B_Avg (Cluster-level differential aggregation version, simple inter-cluster avg)
    --------------------------------------------------------------------
    - 对 B 相似度矩阵做谱聚类 (至少 2 个簇)
    - 每个客户端 i：
        (a) 聚合同簇 B 得簇均值 B_cluster
        (b) 计算差异补偿 Δ = B_avg - B_cluster
        (c) 对不在自己簇内的客户端 A_jA_j^T 等权聚合 → 构造不相关空间
        (d) 将 Δ 投影到不相关空间矩阵 P_i
    - 最终: B_new = B_avg + diff_lambda * diff_B_sum
    - 日志输出保持原样（3幅图 + 3个Excel）
    """

    if not params:
        return []

    device = getattr(self, "device", torch.device("cpu"))
    num_clients = len(params)
    if num_clients < 2:
        raise ValueError("需要至少两个客户端。")

    gpu_params = [{k: v.to(device) for k, v in p.items()} for p in params]
    param_names = list(gpu_params[0].keys())

    def is_A(k): return "lora_A0" in k
    def is_B(k): return "lora_B0" in k

    A_keys = [k for k in param_names if is_A(k)]
    B_keys = [k for k in param_names if is_B(k)]
    if len(A_keys) == 0 or len(B_keys) == 0:
        return gpu_params

    subspace_rank = getattr(self, "subspace_rank", 1)
    diff_lambda   = getattr(self, "diff_lambda", 1.0)
    eps           = getattr(self, "eps", 1e-6)
    log_flag      = getattr(self, "log_flag", False)
    save_dir      = getattr(self, "personal_dir", "./logs")
    num_clusters  = max(2, min(getattr(self, "num_clusters", 4), num_clients))

    # -------- 工具函数 --------
    def flatten_concat_dict(tdict): return torch.cat([v.reshape(-1) for v in tdict.values()], dim=0).float()
    def cosine_sim_from_vectors(vec_list):
        mat = torch.stack([F.normalize(v, p=2, dim=0) for v in vec_list], dim=0)
        S = (mat @ mat.T).clamp(min=1e-8, max=1.0)
        S.fill_diagonal_(1.0)
        return S
    def B_rowlen_vector(B_dict):
        parts = []
        for v in B_dict.values():
            parts.append(torch.linalg.norm(v, ord=2, dim=-1).reshape(-1))
        return torch.cat(parts, dim=0).float()
    def cosine_BA_exact(A_layers_list, B_layers_list):
        N = len(A_layers_list)
        T, Sg, norms = [], [], []
        for i in range(N):
            Ti = torch.zeros_like(B_layers_list[i][0].T @ B_layers_list[i][0])
            Si = torch.zeros_like(A_layers_list[i][0] @ A_layers_list[i][0].T)
            for Bℓ in B_layers_list[i]:
                Ti += Bℓ.T @ Bℓ
            for Aℓ in A_layers_list[i]:
                Si += Aℓ @ Aℓ.T
            T.append(Ti); Sg.append(Si)
            norms.append(torch.trace(Ti @ Si).clamp_min(1e-12).sqrt())
        norms = torch.stack(norms)
        S_cos = torch.eye(N, device=device)
        for i in range(N):
            for j in range(i+1, N):
                Gij = sum(B_iℓ.T @ B_jℓ for B_iℓ, B_jℓ in zip(B_layers_list[i], B_layers_list[j]))
                Hji = sum(A_jℓ @ A_iℓ.T for A_jℓ, A_iℓ in zip(A_layers_list[j], A_layers_list[i]))
                num = torch.trace(Gij @ Hji)
                den = (norms[i] * norms[j]).clamp_min(1e-12)
                S_cos[i,j] = S_cos[j,i] = (num / den).clamp(-1.0, 1.0)
        return S_cos

    # === 提取 A/B 参数 ===
    A_dicts = [{k: gpu_params[i][k].detach().float() for k in A_keys} for i in range(num_clients)]
    B_dicts = [{k: gpu_params[i][k].detach().float() for k in B_keys} for i in range(num_clients)]

    # === 三种相似性矩阵 (for logging) ===
    B_flat_vecs = [flatten_concat_dict(B_dicts[i]) for i in range(num_clients)]
    S_B_direct  = cosine_sim_from_vectors(B_flat_vecs)
    B_rowlen_vecs = [F.normalize(B_rowlen_vector(B_dicts[i]), p=2, dim=0) for i in range(num_clients)]
    S_B_rowlen    = cosine_sim_from_vectors(B_rowlen_vecs)
    A_layers = [[gpu_params[i][k].detach().float() for k in A_keys] for i in range(num_clients)]
    B_layers = [[gpu_params[i][k].detach().float() for k in B_keys] for i in range(num_clients)]
    S_BA_exact = cosine_BA_exact(A_layers, B_layers)

    # === 谱聚类 ===
    clustering = SpectralClustering(
        n_clusters=num_clusters, affinity="precomputed", assign_labels="kmeans", random_state=0
    )
    labels = clustering.fit_predict(S_B_direct.cpu().numpy())
    labels_t = torch.tensor(labels, device=device)

    # === Step 1: 计算全局 & 每簇平均 B ===
    avg_B = {k: torch.stack([B_dicts[i][k] for i in range(num_clients)], dim=0).mean(dim=0) for k in B_keys}
    clusters = {c: [i for i, l in enumerate(labels) if l == c] for c in range(num_clusters)}
    B_cluster = {
        c: {k: torch.stack([B_dicts[i][k] for i in members], dim=0).mean(dim=0)
            for k in B_keys} for c, members in clusters.items()
    }

    # === Step 2: 构造异簇不相关空间（简单等权聚合） ===
    def build_weak_space(i):
        cluster_i = labels_t[i].item()
        others = [j for j in range(num_clients) if labels_t[j].item() != cluster_i]
        if len(others) == 0:
            others = [j for j in range(num_clients) if j != i]
        count = len(others)
        M_accum = {k: torch.zeros_like(A_dicts[i][k] @ A_dicts[i][k].T) for k in A_keys}
        for j in others:
            for k in A_keys:
                Aj = A_dicts[j][k]
                M_accum[k] += (Aj @ Aj.T) / count
        P = {}
        for k, M in M_accum.items():
            U, S, Vh = torch.linalg.svd(M, full_matrices=False)
            if subspace_rank < U.shape[1]:
                Pk = U[:, subspace_rank:] @ U[:, subspace_rank:].T
            else:
                Pk = torch.zeros_like(M)
            kB = k.replace("lora_A0", "lora_B0")
            P[kB] = Pk
        return P

    # === Step 3: 聚合 + 差异性补偿 (不除以N) ===
    diff_B_sum = {k: torch.zeros_like(avg_B[k]) for k in B_keys}
    for i in range(num_clients):
        c_i = labels_t[i].item()
        P_i = build_weak_space(i)
        for k in B_keys:
            delta_cluster = avg_B[k] - B_cluster[c_i][k]
            diff_B_sum[k] += delta_cluster @ P_i.get(k, torch.zeros_like(delta_cluster))

    final_B = {k: (avg_B[k] + diff_lambda * diff_B_sum[k]) for k in B_keys}

    # === Step 4: 写回 ===
    aggregated_results = []
    for i in range(num_clients):
        newp = {}
        for k, v in gpu_params[i].items():
            newp[k] = final_B[k].to(dtype=v.dtype) if is_B(k) else v
        aggregated_results.append(newp)

    # === Step 5: 原日志输出 ===
    if log_flag:
        os.makedirs(save_dir, exist_ok=True)
        def _to_np(x): return x.detach().float().cpu().numpy()
        mats = {
            "B_similarity_direct": _to_np(S_B_direct),
            "B_rowlength_similarity": _to_np(S_B_rowlen),
            "BA_similarity_exact": _to_np(S_BA_exact),
        }
        for name, mat in mats.items():
            pd.DataFrame(mat).to_excel(os.path.join(save_dir, f"{name}.xlsx"))
            plt.figure(figsize=(5,4))
            plt.imshow(mat, cmap="coolwarm", interpolation="nearest")
            plt.colorbar()
            plt.title(name.replace("_", " "))
            plt.xlabel("Client j")
            plt.ylabel("Client i")
            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=300)
            plt.close()
        print(f"[ReweightBRank_B_Avg] 已保存相似矩阵到 {save_dir}")

    self.log_flag = False
    return aggregated_results
