import torch
import torch.nn.functional as F
import pandas as pd, os, matplotlib.pyplot as plt

@torch.no_grad()
def ReweightARank_B_Avg(self, route_aggregation, params):
    """
    ReweightRank_B_Avg: 基于 A 相似性反比加权的 B 聚合
    -----------------------------------------------------
    - 自动计算 A 的相似性矩阵
    - 相似度高 → 权重小；相似度低 → 权重大
    - 对所有权重归一化 (行归一化)
    - 聚合结果仅更新 LoRA-B 参数 (lora_B0)
    """

    if params is None or len(params) == 0:
        return []

    num_clients = len(params)
    device = self.device if hasattr(self, "device") else torch.device("cpu")

    # === 提取参数到GPU ===
    gpu_params = [
        {k: v.to(device) for k, v in client_params.items()} for client_params in params
    ]
    param_names = list(gpu_params[0].keys())

    # === 辅助判断 ===
    def is_lora_A0(k): return "lora_A0" in k
    def is_lora_B0(k): return "lora_B0" in k

    A_keys = [k for k in param_names if is_lora_A0(k)]
    B_keys = [k for k in param_names if is_lora_B0(k)]
    if len(A_keys) == 0 or len(B_keys) == 0:
        return gpu_params

    subspace_rank = getattr(self, "subspace_rank", 1)
    diff_lambda = getattr(self, "diff_lambda", 1.0)
    eps = getattr(self, "eps", 1e-6)
    log_flag = getattr(self, "log_flag", False)
    save_dir = getattr(self, "personal_dir", "./logs")

    # ============================================================
    # 1️⃣ 提取每个客户端的 A、B
    # ============================================================
    A_list = []
    B_list = []
    for i in range(num_clients):
        A_i = {k: gpu_params[i][k].detach().float() for k in A_keys}
        B_i = {k: gpu_params[i][k].detach().float() for k in B_keys}
        A_list.append(A_i)
        B_list.append(B_i)

    # ============================================================
    # 2️⃣ 计算 A 相似性矩阵 S_A[i,j] = cos(A_i, A_j)
    # ============================================================
    def flatten_concat(tdict):
        return torch.cat([v.flatten() for v in tdict.values()], dim=0)

    A_vecs = [F.normalize(flatten_concat(A_list[i]), p=2, dim=0) for i in range(num_clients)]
    A_mat = torch.stack(A_vecs, dim=0)
    S_A = (A_mat @ A_mat.T).clamp(min=1e-6, max=1.0)

    # ============================================================
    # 3️⃣ 为每个客户端 i 构造反相似权重 w_ij ∝ 1/(S_ij + eps)
    # ============================================================
    inv_sim = 1.0 / (S_A + eps)
    inv_sim.fill_diagonal_(0.0)
    W = inv_sim / (inv_sim.sum(dim=1, keepdim=True) + eps)

    # ============================================================
    # 4️⃣ 构造弱相关子空间投影矩阵 P_i (反相似加权)
    # ============================================================
    def build_weak_space(i):
        """
        对客户端 i:
        M_i = sum_j w_ij * (A_j A_j^T)
        然后 SVD(M_i)，取小奇异方向
        """
        A_ref = A_list[i]
        M_accum = {k: torch.zeros_like(A_ref[k] @ A_ref[k].T) for k in A_ref.keys()}

        for j in range(num_clients):
            if j == i:
                continue
            for k in A_ref.keys():
                if k in A_list[j]:
                    A_j = A_list[j][k]
                    M_accum[k] += W[i, j] * (A_j @ A_j.T)

        P_dict = {}
        for k, M in M_accum.items():
            U, S, Vh = torch.linalg.svd(M, full_matrices=False)
            if subspace_rank < U.shape[1]:
                P = U[:, subspace_rank:] @ U[:, subspace_rank:].T
            else:
                P = torch.zeros_like(M)
            kB = k.replace("lora_A0", "lora_B0")
            P_dict[kB] = P
        return P_dict

    # ============================================================
    # 5️⃣ 聚合 B：avg + diff_lambda * (投影到弱相关空间的修正)
    # ============================================================
    # FedAvg基础平均
    avg_B = {k: torch.stack([B_list[i][k] for i in range(num_clients)]).mean(dim=0) for k in B_keys}

    diff_B_sum = {k: torch.zeros_like(avg_B[k]) for k in avg_B}
    for i in range(num_clients):
        proj_dict = build_weak_space(i)
        for k in B_keys:
            delta = avg_B[k] - B_list[i][k]
            if k in proj_dict:
                diff_B_sum[k] += delta @ proj_dict[k]
            else:
                diff_B_sum[k] += delta

    final_B = {k: (avg_B[k] + diff_lambda * diff_B_sum[k] / num_clients) for k in avg_B}

    # ============================================================
    # 6️⃣ 写回结果
    # ============================================================
    aggregated_results = []
    for i in range(num_clients):
        new_params = {}
        for k, v in gpu_params[i].items():
            if is_lora_B0(k):
                new_params[k] = final_B[k].to(dtype=v.dtype)
            else:
                new_params[k] = v
        aggregated_results.append(new_params)

    # ============================================================
    # 7️⃣ 可选日志保存
    # ============================================================
    if log_flag:
        os.makedirs(save_dir, exist_ok=True)
        S_np, W_np = S_A.cpu().numpy(), W.cpu().numpy()
        pd.DataFrame(S_np).to_excel(os.path.join(save_dir, "A_similarity.xlsx"))
        pd.DataFrame(W_np).to_excel(os.path.join(save_dir, "Inverse_weights.xlsx"))

        def save_heatmap(mat, title, filename):
            plt.figure(figsize=(5, 4))
            plt.imshow(mat, cmap="coolwarm", interpolation="nearest")
            plt.colorbar()
            plt.title(title)
            plt.xlabel("Client j")
            plt.ylabel("Client i")
            plt.tight_layout()
            plt.savefig(filename, dpi=300)
            plt.close()

        save_heatmap(S_np, "A Similarity Matrix", os.path.join(save_dir, "A_similarity.png"))
        save_heatmap(W_np, "Inverse-weight Matrix", os.path.join(save_dir, "Inverse_weights.png"))
        print(f"[OnlyRank_B_Avg] 已保存 A 相似矩阵与权重图到 {save_dir}")

    self.log_flag=False

    return aggregated_results


@torch.no_grad()
def ReweightBRank_B_Avg(self, route_aggregation, params):
    """
    ReweightARank_B_Avg:
    使用 B 的“逐行长度向量”的相似度来为每个客户端 i 计算对其它客户端 A_j 的反相似加权，
    构造 i 的弱相关子空间投影矩阵；然后对 B 做 FedAvg + 差异投影修正。

    另外在 log_flag=True 时，保存三种相似性矩阵（图 + Excel）：
        (1) S_B_direct : 直接基于 B 展平向量的余弦相似度
        (2) S_B_rowlen : 基于 B 的逐行长度拼接向量的余弦相似度
        (3) S_BA_exact : 基于 r×r Gram trick 精确计算的 cos(BA, B'A')
    """

    if not params:
        return []

    device = getattr(self, "device", torch.device("cpu"))
    num_clients = len(params)

    # === 搬到设备 ===
    gpu_params = [{k: v.to(device) for k, v in p.items()} for p in params]
    param_names = list(gpu_params[0].keys())

    # 检测键
    def is_A(k): return "lora_A0" in k
    def is_B(k): return "lora_B0" in k

    A_keys = [k for k in param_names if is_A(k)]
    B_keys = [k for k in param_names if is_B(k)]
    if len(A_keys) == 0 or len(B_keys) == 0:
        return gpu_params

    subspace_rank = getattr(self, "subspace_rank", 1)
    diff_lambda   = getattr(self, "diff_lambda", 1.0)
    eps           = getattr(self, "eps", 1e-6)
    log_flag      = getattr(self, "log_flag", False)
    save_dir      = getattr(self, "personal_dir", "./logs")

    # -------- 工具函数 --------
    def flatten_concat_dict(tdict: dict) -> torch.Tensor:
        return torch.cat([v.reshape(-1) for v in tdict.values()], dim=0).float()

    def A_stack_vec(i):
        # 把 A_i 所有层展平并拼接为向量（用于 BA 精确相似度的 S_i = A A^T ）
        return [gpu_params[i][k].detach().float() for k in A_keys]

    def B_stack_mat(i):
        # 把 B_i 所有层保留为 (out, r) 形状列表
        return [gpu_params[i][k].detach().float() for k in B_keys]

    def cosine_sim_from_vectors(vec_list: list) -> torch.Tensor:
        mat = torch.stack([F.normalize(v, p=2, dim=0) for v in vec_list], dim=0)  # [N, D]
        S   = (mat @ mat.T).clamp(min=1e-8, max=1.0)
        S.fill_diagonal_(1.0)
        return S

    def B_rowlen_vector(B_dict: dict) -> torch.Tensor:
        """
        对 B 的每一层 (out, r)，按行对 r 维做 L2 范数，得到 (out,)
        多层拼接成一个长向量
        """
        parts = []
        for v in B_dict.values():
            # v: [out, r] 或 [..., r]，按最后一维 r 求行范数
            rowlen = torch.linalg.norm(v, ord=2, dim=-1).reshape(-1)
            parts.append(rowlen)
        return torch.cat(parts, dim=0).float()

    # r×r Gram trick 精确计算 cos(BA, B'A')
    def cosine_BA_exact(A_layers_list: list, B_layers_list: list) -> torch.Tensor:
        """
        A_layers_list[i]: list of T_iℓ with shape [r, nℓ]
        B_layers_list[i]: list of S_iℓ with shape [mℓ, r]
        返回 S (N x N): cos(B_i A_i, B_j A_j) 的精确值
        """
        N = len(A_layers_list)
        # 预计算每个 i 的 Ti = Σℓ B_iℓ^T B_iℓ  与 Si = Σℓ A_iℓ A_iℓ^T  以及范数
        T = []
        Sg = []
        norms = []
        for i in range(N):
            Ti = torch.zeros_like(B_layers_list[i][0].T @ B_layers_list[i][0])  # r x r
            Si = torch.zeros_like(A_layers_list[i][0] @ A_layers_list[i][0].T)  # r x r
            for Bℓ in B_layers_list[i]:
                Ti += Bℓ.T @ Bℓ
            for Aℓ in A_layers_list[i]:
                Si += Aℓ @ Aℓ.T
            T.append(Ti)
            Sg.append(Si)
            norms.append(torch.trace(Ti @ Si).clamp_min(1e-12).sqrt())
        norms = torch.stack(norms)

        S_cos = torch.eye(N, device=device, dtype=torch.float32)
        for i in range(N):
            for j in range(i+1, N):
                Gij = torch.zeros_like(T[0])  # r x r
                Hji = torch.zeros_like(Sg[0]) # r x r
                # 对齐层索引累加（假定 A/B 层列表一一对应）
                for B_iℓ, B_jℓ in zip(B_layers_list[i], B_layers_list[j]):
                    Gij += B_iℓ.T @ B_jℓ
                for A_jℓ, A_iℓ in zip(A_layers_list[j], A_layers_list[i]):
                    Hji += A_jℓ @ A_iℓ.T
                num = torch.trace(Gij @ Hji)
                den = (norms[i] * norms[j]).clamp_min(1e-12)
                val = (num / den).clamp(-1.0, 1.0)
                S_cos[i, j] = S_cos[j, i] = val
        return S_cos

    # ============================================================
    # 1) 收集 A/B
    # ============================================================
    A_dicts = [{k: gpu_params[i][k].detach().float() for k in A_keys} for i in range(num_clients)]
    B_dicts = [{k: gpu_params[i][k].detach().float() for k in B_keys} for i in range(num_clients)]

    # ============================================================
    # 2) 三种相似性矩阵
    #    (a) 直接 B 展平向量的余弦
    #    (b) B 的逐行长度拼接向量的余弦（用于加权）
    #    (c) 精确 BA 相似度（r×r Gram trick）
    # ============================================================
    # (a) S_B_direct
    B_flat_vecs = [flatten_concat_dict(B_dicts[i]) for i in range(num_clients)]
    S_B_direct  = cosine_sim_from_vectors(B_flat_vecs)

    # (b) S_B_rowlen
    B_rowlen_vecs = [F.normalize(B_rowlen_vector(B_dicts[i]), p=2, dim=0) for i in range(num_clients)]
    S_B_rowlen    = cosine_sim_from_vectors(B_rowlen_vecs)

    # (c) S_BA_exact
    A_layers = [A_stack_vec(i) for i in range(num_clients)]  # list of list[r x nℓ]
    B_layers = [B_stack_mat(i) for i in range(num_clients)]  # list of list[mℓ x r]
    S_BA_exact = cosine_BA_exact(A_layers, B_layers)

    # ============================================================
    # 3) 用 S_B_rowlen 生成反相似加权矩阵 W（用于 A 的不相关空间）
    # ============================================================
    # 注意：相似度高 -> 权重小；相似度低 -> 权重大

    # 原本想尝试B行值相似性
    # inv_sim = 1.0 / (S_B_rowlen + eps)

    # 后面发现直接B矩阵相似性最能反应
    inv_sim = 1.0 / (S_B_direct + eps)
    inv_sim.fill_diagonal_(0.0)
    W = inv_sim / (inv_sim.sum(dim=1, keepdim=True) + eps)  # 行归一化

    # ============================================================
    # 4) 为每个客户端 i 基于 W[i,*] 加权 A_jA_j^T，构造弱相关投影 P_i
    # ============================================================
    def build_weak_space(i: int):
        # 累加各层的 A_j A_j^T（按 key 对齐），加权系数为 W[i, j]
        # 仅使用与客户端 i 有相同 A-key 的层
        M_accum = {k: torch.zeros_like(A_dicts[i][k] @ A_dicts[i][k].T) for k in A_keys}
        for j in range(num_clients):
            if j == i: 
                continue
            for k in A_keys:
                if k in A_dicts[j]:
                    Aj = A_dicts[j][k]
                    M_accum[k] += W[i, j] * (Aj @ Aj.T)

        # SVD 取小奇异方向，得到投影到“不相关子空间”的矩阵
        P = {}
        for k, M in M_accum.items():
            U, S, Vh = torch.linalg.svd(M, full_matrices=False)
            if subspace_rank < U.shape[1]:
                Pk = U[:, subspace_rank:] @ U[:, subspace_rank:].T
            else:
                Pk = torch.zeros_like(M)
            kB = k.replace("lora_A0", "lora_B0")
            P[kB] = Pk
        return P

    # ============================================================
    # 5) 对 B 做 FedAvg + 差异投影修正
    # ============================================================
    # 基础平均
    avg_B = {k: torch.stack([B_dicts[i][k] for i in range(num_clients)], dim=0).mean(dim=0) for k in B_keys}

    diff_B_sum = {k: torch.zeros_like(avg_B[k]) for k in B_keys}
    for i in range(num_clients):
        P_i = build_weak_space(i)
        for k in B_keys:
            delta = avg_B[k] - B_dicts[i][k]
            if k in P_i:
                diff_B_sum[k] += delta @ P_i[k]
            else:
                diff_B_sum[k] += delta

    final_B = {k: (avg_B[k] + diff_lambda * diff_B_sum[k] / num_clients) for k in B_keys}

    # ============================================================
    # 6) 写回
    # ============================================================
    aggregated_results = []
    for i in range(num_clients):
        newp = {}
        for k, v in gpu_params[i].items():
            if is_B(k):
                newp[k] = final_B[k].to(dtype=v.dtype)
            else:
                newp[k] = v
        aggregated_results.append(newp)

    # ============================================================
    # 7) 日志输出：三幅图 + 三个 Excel（S_B_direct, S_B_rowlen, S_BA_exact）
    # ============================================================
    if log_flag:
        os.makedirs(save_dir, exist_ok=True)

        def _to_np(x): return x.detach().float().cpu().numpy()

        mats = {
            "B_similarity_direct": _to_np(S_B_direct),
            "B_rowlength_similarity": _to_np(S_B_rowlen),
            "BA_similarity_exact": _to_np(S_BA_exact),
        }

        # 保存 Excel & 热力图
        for name, mat in mats.items():
            pd.DataFrame(mat).to_excel(os.path.join(save_dir, f"{name}.xlsx"))
            plt.figure(figsize=(5,4))
            plt.imshow(mat, cmap="coolwarm", interpolation="nearest")
            plt.colorbar()
            plt.title(name.replace("_", " "))
            plt.xlabel("Client j")
            plt.ylabel("Client i")
            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, f"{name}.png"), dpi=300)
            plt.close()

        print(f"[ReweightARank_B_Avg] Logs saved to: {save_dir}")

    self.log_flag=False

    return aggregated_results