import torch
from typing import List, Dict, Any, Optional

def Global_Avg(
    self,
    route_aggregation: bool,
    params: List[Dict[str, torch.Tensor]],
    lora_client_map: Optional[Dict[Any, List[int]]] = None
) -> List[Dict[str, torch.Tensor]]:
    """
    全客户端聚合 A1/B1 -> A0/B0 版本

    功能:
        - 收集所有客户端的 lora_A1 和 lora_B1
        - 对它们取平均 (FedAvg)
        - 将结果命名为 lora_A0 和 lora_B0
        - 分发回所有客户端
        - 其他参数保持原样
    """
    # 更新映射（保留兼容性）
    if lora_client_map is not None:
        self.lora_client_map = lora_client_map

    # 移动参数到 GPU
    gpu_params = [
        {k: v.to(self.device) for k, v in client_params.items()}
        for client_params in params
    ]

    num_clients = len(gpu_params)
    aggregated_results = [{} for _ in range(num_clients)]
    param_names = list(gpu_params[0].keys()) if num_clients > 0 else []

    # Step 1️⃣ 收集所有客户端的 lora_A1 / lora_B1
    lora_A1_dict, lora_B1_dict = {}, {}
    for param_name in param_names:
        if 'lora_A1' in param_name:
            lora_A1_dict[param_name] = torch.stack(
                [gpu_params[i][param_name] for i in range(num_clients)]
            ).mean(dim=0)
        elif 'lora_B1' in param_name:
            lora_B1_dict[param_name] = torch.stack(
                [gpu_params[i][param_name] for i in range(num_clients)]
            ).mean(dim=0)

    # Step 2️⃣ 构建新的参数集（A1/B1 -> A0/B0）
    for client_idx in range(num_clients):
        for param_name, param_value in gpu_params[client_idx].items():
            # 聚合后生成的全局 LoRA 专家
            if 'lora_A1' in param_name:
                new_name = param_name.replace('lora_A1', 'lora_A0')
                aggregated_results[client_idx][new_name] = lora_A1_dict[param_name]
            elif 'lora_B1' in param_name:
                new_name = param_name.replace('lora_B1', 'lora_B0')
                aggregated_results[client_idx][new_name] = lora_B1_dict[param_name]
            else:
                # 非目标参数保持不变
                aggregated_results[client_idx][param_name] = param_value

    print(f"[Router_Avg] 聚合完成: 所有客户端的 lora_A1/B1 平均为 lora_A0/B0 并分发。")
    return aggregated_results