import torch
from typing import List, Dict, Any, Optional

@torch.no_grad()
def OnlyImportant_B_Avg(
    self,
    route_aggregation: bool,                       # 与 Global_B_Avg 同签名，内部未使用，仅保留兼容
    params: List[Dict[str, torch.Tensor]],
    # subspace_rank: int = 1,
    # diff_lambda: float = 1.0,
    # alpha: float = 1.0,
    # eps: float = 1e-6
) -> List[Dict[str, torch.Tensor]]:
    """
    Diff-Rank Averaging（仅对 LoRA-B0 聚合并做差异增强 + 弱相关子空间投影）
    - 与 Global_B_Avg 相同的输入输出风格（params in / params out）
    - 仅覆盖 'lora_B0'（或兼容 '_b_' 命名）键
    """
    # 超参信息
    subspace_rank = self.subspace_rank if hasattr(self, 'subspace_rank') else 1
    diff_lambda = self.diff_lambda if hasattr(self, 'diff_lambda') else 1.
    alpha = self.alpha if hasattr(self, 'alpha') else 1.0
    eps = self.eps if hasattr(self, 'eps') else 1e-6


    # -------------------- 基础与兼容 --------------------
    if params is None or len(params) == 0:
        return []

    # 设备搬运
    gpu_params: List[Dict[str, torch.Tensor]] = [
        {k: v.to(self.device) for k, v in client_params.items()}
        for client_params in params
    ]
    num_clients = len(gpu_params)
    aggregated_results = [{} for _ in range(num_clients)]
    param_names = list(gpu_params[0].keys())

    def is_lora_A0(k: str) -> bool:
        return 'lora_A0' in k

    def is_lora_B0(k: str) -> bool:
        return 'lora_B0' in k

    # -------------------- 计算加权平均的 LoRA-B0 --------------------
    # 收集所有 B0
    B0_keys = [k for k in param_names if is_lora_B0(k)]
    if len(B0_keys) == 0:
        # 无 B0，原样返回
        return gpu_params

    # 初始化
    avg_B0: Dict[str, torch.Tensor] = {
        k: torch.zeros_like(gpu_params[0][k], dtype=torch.float32) for k in B0_keys
    }

    # 加权叠加， 权重为1/k
    w = 1 / num_clients
    for i in range(num_clients):
        for k in B0_keys:
            avg_B0[k] += gpu_params[i][k].detach().float() * w

    # -------------------- 历史缓存（用于差异放大） --------------------
    prev_B0 = getattr(self, "_prev_B0_cache", None)
    if (prev_B0 is None) or (not isinstance(prev_B0, dict)):
        prev_B0 = {k: v.clone() for k, v in avg_B0.items()}

    # -------------------- 构造“弱相关子空间”投影（由 A0 决定） --------------------
    # 先预取每个客户端的 A0，用于构造 A A^T
    A0_per_client: List[Dict[str, torch.Tensor]] = []
    for i in range(num_clients):
        A0_i = {k: v.detach().float().clone() for k, v in gpu_params[i].items() if is_lora_A0(k)}
        A0_per_client.append(A0_i)

    def build_weak_space_for_except(except_index: int) -> Dict[str, torch.Tensor]:
        # """
        # 对除 except_index 外的客户端 A0 做加权平均的 A A^T，SVD 后取“弱相关”子空间投影
        # 返回的键与 B0 对齐：'lora_A0' -> 'lora_B0'
        # """
        # # 汇总 A A^T
        # aaT_sum: Dict[str, torch.Tensor] = {}
        # for j in range(num_clients):
        #     if j == except_index:
        #         continue
        #     A0_j = A0_per_client[j]
        #     if not A0_j:
        #         continue
        #     wj = 1 / (num_clients - 1) # 简单平均
        #     for k, A in A0_j.items():
        #         # A: [r, in], A A^T -> [r, r]
        #         M = A @ A.T
        #         if k not in aaT_sum:
        #             aaT_sum[k] = torch.zeros_like(M, dtype=torch.float32, device=M.device)
        #         aaT_sum[k] += wj * M.float()

        # proj: Dict[str, torch.Tensor] = {}
        # for kA, M in aaT_sum.items():
        #     # 数值稳定 SVD
        #     U, S, Vh = torch.linalg.svd(M.float(), full_matrices=False)
        #     # 取最小奇异值对应的子空间 U[:, subspace_rank:]U[:, subspace_rank:].T
        #     if subspace_rank < U.shape[1]:
        #         P = U[:, subspace_rank:] @ U[:, subspace_rank:].T
        #     else:
        #         # 若 rank >= r，退化为零投影
        #         P = torch.zeros_like(M, dtype=torch.float32)
        #     # 键名从 A0 对齐到 B0
        #     if 'lora_A0' in kA.lower():
        #         kB = kA.replace('lora_A0', 'lora_B0')
        #     else:
        #         # 不规则命名时，尝试原样键（大多数情况下不会命中）
        #         kB = kA
        #     proj[kB] = P
        # return proj

        # 每个客户端经过差异补偿后仅需乘以1/k
        return {}



    # -------------------- 差异计算 + 弱相关投影并累加 --------------------
    diff_B0_sum: Dict[str, torch.Tensor] = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in avg_B0.items()}

    for i in range(num_clients):
        # 当前客户端的 B0
        B0_i = {k: gpu_params[i][k].detach().float().clone() for k in B0_keys}

        # 元素级放大：使用历史差异
        # delta = avg - B_i ; grad_like = B_i - prev ; sensitivity = |grad_like|^2 + eps ; amplified = sign(delta)*(|delta|+eps)^alpha
        amplified_diff: Dict[str, torch.Tensor] = {}
        for k in B0_keys:
            delta = (avg_B0[k] - B0_i[k])
            grad_like = (B0_i[k] - prev_B0[k])
            sensitivity = torch.abs(grad_like) ** 2 + eps
            amplified = torch.sign(delta) * (torch.abs(delta) + eps) ** alpha
            amplified_diff[k] = amplified * sensitivity
            # amplified_diff[k] = B0_i[k]  # 简化为仅使用“类梯度”放大

        # # 构造对当前客户端的弱相关投影
        proj_dict = build_weak_space_for_except(i)

        # 将差异映射到弱相关子空间并累加
        for k in B0_keys:
            if k in proj_dict:
                # B0 的形状通常为 [out, r]，右乘 P[r, r] 保持形状一致
                diff_proj = amplified_diff[k] * 1/num_clients
                diff_B0_sum[k] += diff_proj
            else:
                diff_B0_sum[k] += amplified_diff[k] * 1/num_clients

    # -------------------- 组合最终 B0 并写回每个客户端 --------------------
    final_B0: Dict[str, torch.Tensor] = {k: (avg_B0[k] + diff_lambda * diff_B0_sum[k]).to(dtype=params[0][k].dtype)
                                         for k in B0_keys}

    for i in range(num_clients):
        for k, v in gpu_params[i].items():
            if is_lora_B0(k):
                aggregated_results[i][k] = final_B0[k]
            else:
                aggregated_results[i][k] = v  # 非 B0 保持原样

    # -------------------- 刷新历史缓存 --------------------
    self._prev_B0_cache = {k: t.detach().clone() for k, t in final_B0.items()}

    # 可选：若未来需支持非 LoRA 参数聚合，这里可以扩展（当前与 Global_B_Avg 一致，只改 B0）
    return aggregated_results
