import torch
from typing import List, Dict, Any, Optional

def Global_A_Avg(self, route_aggregation: bool, params: List[Dict[str, torch.Tensor]], lora_client_map: Optional[Dict[Any, List[int]]] = None) -> List[Dict[str, torch.Tensor]]:
    """
    Router-style aggregation extracted from main.py.

    Args:
        self: Server instance (expects self.device and self.lora_client_map attributes).
        route_aggregation: whether to aggregate routing weights across groups.
        params: list of per-client param dicts (name -> tensor).
        lora_client_map: mapping from group_id -> list of client indices (optional).

    Returns:
        aggregated_results: list (len == num_clients) of dicts mapping param_name -> aggregated tensor.
    """
    if lora_client_map is not None:
        self.lora_client_map = lora_client_map

    if self.lora_client_map is None:
        raise ValueError("lora_client_map must be provided for aggregation after warmup phase")

    # build a mapping client -> group for quick lookup
    client_to_group = {}
    for group_idx, clients in self.lora_client_map.items():
        for client in clients:
            client_to_group[client] = group_idx

    # move incoming params to server device
    gpu_params = [
        {k: v.to(self.device) for k, v in client_params.items()}
        for client_params in params
    ]
    num_clients = len(gpu_params)
    aggregated_results = [{} for _ in range(num_clients)]
    param_names = list(gpu_params[0].keys()) if num_clients > 0 else []

    for client_idx in range(num_clients):
        for param_name in param_names:

            if 'lora_route' in param_name:
                if route_aggregation:
                    client_group = client_to_group.get(client_idx)
                    if client_group is not None:
                        group_indices = self.lora_client_map[client_group]
                        stacked_params = torch.stack([
                            gpu_params[i][param_name]
                            for i in group_indices
                        ]).to(self.device)
                        aggregated_results[client_idx][param_name] = stacked_params.mean(dim=0)
                    else:
                        aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]
                else:
                    aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]

            elif 'lora_A' in param_name:
                # extract lora index (first digit after 'lora_A')
                try:
                    lora_idx = int(param_name.split('lora_A')[1][0])
                except Exception:
                    # fallback: try to parse consecutive digits
                    tail = param_name.split('lora_A')[1]
                    digits = ''.join(c for c in tail if c.isdigit())
                    lora_idx = int(digits) if digits else 0

                group_indices = self.lora_client_map.get(str(lora_idx), [])
                if not group_indices:
                    group_indices = self.lora_client_map.get(lora_idx, [])

                if group_indices:
                    stacked_params = torch.stack([
                        gpu_params[i][param_name]
                        for i in group_indices if i < len(gpu_params) and param_name in gpu_params[i]
                    ]).to(self.device)
                    if stacked_params.size(0) > 0:
                        aggregated_results[client_idx][param_name] = stacked_params.mean(dim=0)
                    else:
                        aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]
                else:
                    aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]

            else:
                aggregated_results[client_idx][param_name] = gpu_params[client_idx][param_name]

    return aggregated_results