import torch
from typing import List, Dict, Any, Optional

def Global_B_Avg(
    self,
    route_aggregation: bool,
    params: List[Dict[str, torch.Tensor]],
    lora_client_map: Optional[Dict[Any, List[int]]] = None
) -> List[Dict[str, torch.Tensor]]:
    """
    全客户端聚合 A0 版本（仅聚合 lora_A0 参数）

    功能:
        - 收集所有客户端的 lora_A0 参数
        - 对它们取平均 (FedAvg)
        - 将平均结果分发回所有客户端
        - 其他参数保持原样
    """
    # 更新映射（保留兼容性）
    if lora_client_map is not None:
        self.lora_client_map = lora_client_map

    # 移动参数到目标设备
    gpu_params = [
        {k: v.to(self.device) for k, v in client_params.items()}
        for client_params in params
    ]

    num_clients = len(gpu_params)
    aggregated_results = [{} for _ in range(num_clients)]
    param_names = list(gpu_params[0].keys()) if num_clients > 0 else []

    # Step 1️⃣ 收集并聚合所有 lora_A0 参数
    lora_B0_avg = {}
    for param_name in param_names:
        if 'lora_B0' in param_name:
            # 取所有客户端的该参数平均
            lora_B0_avg[param_name] = torch.stack(
                [gpu_params[i][param_name] for i in range(num_clients)]
            ).mean(dim=0)

    # Step 2️⃣ 生成每个客户端的返回结果
    for client_idx in range(num_clients):
        for param_name, param_value in gpu_params[client_idx].items():
            if 'lora_B0' in param_name:
                # 用全局平均结果替换
                aggregated_results[client_idx][param_name] = lora_B0_avg[param_name]
            else:
                # 非目标参数保持原样
                aggregated_results[client_idx][param_name] = param_value

    print(f"[Global_B_Avg] 聚合完成: 所有客户端的 lora_B0 已平均并同步分发。")
    return aggregated_results
