import time
import numpy as np
import torch
from collections import defaultdict
from flcore.servers.serverbase import Server
from flcore.clients.clientbase import load_item, save_item
from flcore.clients.clientdmps import clientDMPS


class FedDMPS(Server):
    """
    FedDMPS 服务端（Distribution-based Multi-Prototype Sampling, GMM 原型）
    - 接收各客户端上传的分布原型（GMM 组件：weight/mean/diag-cov）。
    - 聚合策略由 args.prototype_aggregation_method 决定：
      * averaging：按权重加权平均均值与对角协方差，合并为 1 个全局原型；
      * clustering：对“均值向量”做 KMeans 聚到 <= args.max_prototypes 个簇，
        再在每个簇内按权重加权合并协方差得到全局原型集合。
    - 输出的全局原型保存在服务器侧，供下一轮客户端对齐与推断使用。
    超参：max_prototypes, prototype_aggregation_method, gaussian_reg 等均来自 main.py。
    """
    def __init__(self, args, times):
        super().__init__(args, times)

        # 设置客户端
        self.set_slow_clients()
        self.set_clients(clientDMPS)

        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
        print("Finished creating server and clients.")

        self.Budget = []
        self.num_classes = args.num_classes
        self.agg_method = args.prototype_aggregation_method  # 'averaging' or 'clustering'
        self.max_prototypes = args.max_prototypes
        self.gaussian_reg = args.gaussian_reg

    def train(self):
        for i in range(self.global_rounds + 1):
            s_t = time.time()
            self.selected_clients = self.select_clients()

            if i % self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate heterogeneous multi-prototype models (GMM)")
                self.evaluate()

            for client in self.selected_clients:
                client.train()

            self.receive_protos()

            self.Budget.append(time.time() - s_t)
            print('-' * 50, self.Budget[-1])

            if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt):
                break

        print("\nBest accuracy.")
        print(max(self.rs_test_acc))

    def receive_protos(self):
        assert (len(self.selected_clients) > 0)

        self.uploaded_ids = []
        uploaded = []  # list of dict: {class_id: [proto_dict,...]}
        for client in self.selected_clients:
            self.uploaded_ids.append(client.id)
            local = load_item(client.role, 'protos', client.save_folder_name)
            uploaded.append(local)

        global_protos = self.aggregate(uploaded)
        save_item(global_protos, self.role, 'global_protos', self.save_folder_name)

    # ---------------- aggregation (GMM) ----------------
    def aggregate(self, local_list):
        by_class = defaultdict(list)
        for local in local_list:
            for cls, plist in local.items():
                if isinstance(plist, list) and len(plist) > 0:
                    by_class[cls].extend(plist)
        global_protos = {}
        for cls, plist in by_class.items():
            if self.agg_method == 'averaging':
                # 按权重对均值/协方差做加权平均，合并为一个原型
                Ws, Ms, Vs = [], [], []
                for p in plist:
                    Ws.append(float(p.get('weight', 1.0)))
                    Ms.append(p['mean'])
                    Vs.append(p['covariance'])
                W = torch.tensor(Ws)
                W = W / (W.sum() + 1e-12)
                M = torch.stack([m if torch.is_tensor(m) else torch.tensor(m) for m in Ms])
                V = torch.stack([v if torch.is_tensor(v) else torch.tensor(v) for v in Vs])
                mean = (W.unsqueeze(1) * M).sum(dim=0)
                cov = (W.unsqueeze(1) * V).sum(dim=0)
                global_protos[cls] = [{'weight': 1.0, 'mean': mean.cpu(), 'covariance': torch.clamp(cov, min=self.gaussian_reg).cpu()}]
            else:
                # clustering：对均值做KMeans到 <=K，然后对每个簇按权重合并协方差
                Ms = torch.stack([p['mean'] if torch.is_tensor(p['mean']) else torch.tensor(p['mean']) for p in plist])
                Ks = min(self.max_prototypes, max(1, Ms.shape[0] // 2))
                centers, labels = self._kmeans(Ms, Ks)
                merged = []
                for k in range(centers.shape[0]):
                    mask = (labels == k)
                    if not mask.any():
                        continue
                    sel = [plist[i] for i, m in enumerate(mask.tolist()) if m]
                    W = torch.tensor([float(p.get('weight', 1.0)) for p in sel])
                    W = W / (W.sum() + 1e-12)
                    V = torch.stack([p['covariance'] if torch.is_tensor(p['covariance']) else torch.tensor(p['covariance']) for p in sel])
                    cov = (W.unsqueeze(1) * V).sum(dim=0)
                    merged.append({'weight': float(W.sum().item()), 'mean': centers[k].cpu(), 'covariance': torch.clamp(cov, min=self.gaussian_reg).cpu()})
                global_protos[cls] = merged
        return global_protos

    def _kmeans(self, X: torch.Tensor, K: int, iters: int = 25):
        N = X.shape[0]
        device = X.device
        perm = torch.randperm(N, device=device)
        centers = X[perm[:K]].clone()
        labels = torch.zeros(N, dtype=torch.long, device=device)
        for _ in range(iters):
            d = torch.cdist(X, centers, p=2)
            new_labels = torch.argmin(d, dim=1)
            if torch.all(new_labels == labels):
                break
            labels = new_labels
            new_centers = []
            for k in range(K):
                mask = (labels == k)
                if mask.any():
                    new_centers.append(X[mask].mean(dim=0))
                else:
                    new_centers.append(X[torch.randint(0, N, (1,), device=device)[0]])
            centers = torch.stack(new_centers)
        return centers, labels
