import torch
import torch.nn as nn
import numpy as np
import time
import math
from collections import defaultdict
from flcore.clients.clientbase import Client, load_item, save_item


class clientDMPS(Client):
    """
    FedDMPS 客户端（Distribution-based MPS, GMM 原型）
    - 目标：在本地特征上用对角协方差 GMM 拟合类别内的多模态分布，得到多个“分布原型”。
    - 动态原型数：按相对重建率 (R(s-1)-R(s))/R(s-1) 与阈值 ε 比较自适应确定；
      其中 R(s)=Σ_x[-max_k log N(x|μ_k,Σ_k)]，ε=1/(ρ√(n·p))（当 args.epsilon<=0 时），
      n 为样本数、p 为特征维度、ρ= args.mps_rho。
    - 训练对齐：对每个样本，取该类所有原型的最大对数似然；损失累加其相反数（负对数似然）。
    - 推断：对每个类别取最大对数似然的相反数（即最小 NLL）进行比较。
    超参均来源于 main.py（args）：lamda, max_prototypes, epsilon, mps_rho, gaussian_reg。
    """
    def __init__(self, args, id, train_samples, test_samples, **kwargs):
        super().__init__(args, id, train_samples, test_samples, **kwargs)
        torch.manual_seed(0)

        self.loss_mse = nn.MSELoss()
        self.lamda = args.lamda
        self.max_prototypes = args.max_prototypes
        self.epsilon = args.epsilon
        self.mps_rho = args.mps_rho
        self.gaussian_reg = args.gaussian_reg

    # ---------- Gaussian utilities ----------
    def _sanitize(self, t: torch.Tensor) -> torch.Tensor:
        t = torch.nan_to_num(t, nan=0.0, posinf=1e6, neginf=-1e6)
        return torch.clamp(t, min=-1e6, max=1e6)

    def _prepare_matrix(self, feats: list[torch.Tensor]) -> torch.Tensor:
        cleaned = []
        for f in feats:
            f = self._sanitize(f)
            if torch.isfinite(f).all():
                cleaned.append(f)
        if len(cleaned) == 0:
            return torch.empty(0, 0, device=self.device, dtype=torch.float32)
        return torch.stack(cleaned).to(self.device, dtype=torch.float32)

    def _gmm_fit_diag(self, feats: list[torch.Tensor], n_components: int):
        X = self._prepare_matrix(feats)
        if X.numel() == 0 or X.shape[0] < 2:
            mean = X.mean(dim=0) if X.numel() > 0 else torch.zeros_like(feats[0])
            var = (X.var(dim=0, unbiased=False) if X.numel() > 0 else torch.ones_like(mean)) + float(self.gaussian_reg)
            return [{'weight': 1.0, 'mean': mean.detach().cpu(), 'covariance': var.detach().cpu()}]

        N, D = X.shape
        K = int(max(1, min(n_components, max(1, N // 2))))
        reg = torch.tensor(float(self.gaussian_reg), device=self.device, dtype=torch.float32)

        # 初始化
        perm = torch.randperm(N, device=self.device)
        means = X[perm[:K]].clone()
        global_var = X.var(dim=0, unbiased=False) + reg
        vars_diag = global_var.unsqueeze(0).repeat(K, 1).clone()
        weights = torch.full((K,), 1.0 / K, device=self.device)

        log2pi = math.log(2.0 * math.pi)
        prev_ll = None
        max_iter = 50
        tol = 1e-4

        for _ in range(max_iter):
            inv_vars = 1.0 / torch.clamp(vars_diag, min=reg)
            diff = X.unsqueeze(1) - means.unsqueeze(0)        # [N,K,D]
            quad = (diff * diff * inv_vars.unsqueeze(0)).sum(dim=2)  # [N,K]
            log_det = torch.log(torch.clamp(vars_diag, min=reg)).sum(dim=1)  # [K]
            log_prob = -0.5 * (quad + log_det.unsqueeze(0) + D * log2pi)
            log_weighted = torch.log(torch.clamp(weights, min=1e-12)).unsqueeze(0) + log_prob
            m = torch.max(log_weighted, dim=1, keepdim=True).values
            exp_shift = torch.exp(log_weighted - m)
            denom = exp_shift.sum(dim=1, keepdim=True)
            gamma = exp_shift / torch.clamp(denom, min=1e-12)  # [N,K]

            ll = (m.squeeze(1) + torch.log(torch.clamp(denom.squeeze(1), min=1e-12))).sum()
            if prev_ll is not None and torch.abs((ll - prev_ll) / (torch.abs(prev_ll) + 1e-12)) < tol:
                break
            prev_ll = ll

            Nk = gamma.sum(dim=0) + 1e-8
            weights = Nk / float(N)
            means = (gamma.t() @ X) / Nk.unsqueeze(1)
            diff2 = (X.unsqueeze(1) - means.unsqueeze(0)) ** 2
            vars_diag = (gamma.unsqueeze(2) * diff2).sum(dim=0) / Nk.unsqueeze(1)
            vars_diag = torch.clamp(vars_diag, min=reg)

        protos = []
        for k in range(K):
            protos.append({
                'weight': float(weights[k].item()),
                'mean': means[k].detach().cpu(),
                'covariance': vars_diag[k].detach().cpu()
            })
        return protos

    def _gaussian_loglik(self, x: torch.Tensor, proto: dict) -> torch.Tensor:
        mean = proto['mean'].to(x.device)
        cov = torch.clamp(proto['covariance'].to(x.device), min=self.gaussian_reg)
        diff = x - mean
        inv = 1.0 / cov
        log_det = torch.log(cov).sum()
        return -0.5 * ((diff * diff * inv).sum() + log_det)

    def _gaussian_R(self, X: torch.Tensor, protos: list[dict]) -> float:
        """重构误差 R(s) = sum_x -max_k log N(x|mu_k, Sigma_k)."""
        if X.numel() == 0 or len(protos) == 0:
            return float('inf')
        total = 0.0
        for i in range(X.shape[0]):
            max_ll = float('-inf')
            for p in protos:
                ll = float(self._gaussian_loglik(X[i], p))
                if ll > max_ll:
                    max_ll = ll
            total += (-max_ll)
        return float(total)

    def _gmm_fit_diag_dynamic(self, feats: list[torch.Tensor]):
        """按照相对重建率控制组件数：当 (R(s-1)-R(s))/R(s-1) <= epsilon 停止。"""
        X = self._prepare_matrix(feats)
        if X.numel() == 0 or X.shape[0] < 2:
            return self._gmm_fit_diag(feats, 1)

        N, D = X.shape
        eps = self.epsilon if (isinstance(self.epsilon, (int, float)) and self.epsilon > 0) \
            else 1.0 / (max(float(self.mps_rho), 1e-8) * (float(N * D)) ** 0.5)

        max_K = int(min(self.max_prototypes, max(1, N // 2)))
        prev_R = None
        prev_protos = None
        best = None
        for s in range(1, max_K + 1):
            protos = self._gmm_fit_diag(feats, s)
            R_s = self._gaussian_R(X, protos)
            if prev_R is not None and prev_R > 0:
                rel = (prev_R - R_s) / prev_R
                if rel <= eps:
                    best = prev_protos  # 上一次即为最优
                    break
            prev_R = R_s
            prev_protos = protos
            best = protos
        return best

    # ---------- training / metrics ----------
    def train(self):
        trainloader = self.load_train_data()
        model = load_item(self.role, 'model', self.save_folder_name)
        global_protos = load_item('Server', 'global_protos', self.save_folder_name)
        optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate)
        model.train()

        start_time = time.time()
        max_local_epochs = self.local_epochs
        if self.train_slow:
            max_local_epochs = np.random.randint(1, max_local_epochs // 2)

        class_features = defaultdict(list)
        for _ in range(max_local_epochs):
            for x, y in trainloader:
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                if self.train_slow:
                    time.sleep(0.1 * np.abs(np.random.rand()))

                rep = self._sanitize(model.base(x))
                output = model.head(rep)
                loss = self.loss(output, y)

                # 高斯原型对齐：最小化负最大对数似然
                if global_protos is not None and len(global_protos) > 0:
                    align = 0.0
                    for i, yy in enumerate(y):
                        cls = int(yy.item())
                        if (cls in global_protos) and isinstance(global_protos[cls], list) and len(global_protos[cls]) > 0:
                            cur = rep[i, :]
                            max_ll = float('-inf')
                            for proto in global_protos[cls]:
                                max_ll = max(max_ll, float(self._gaussian_loglik(cur, proto)))
                            align += (-max_ll)
                    loss = loss + self.lamda * align

                # 收集特征
                for i, yy in enumerate(y):
                    cls = int(yy.item())
                    class_features[cls].append(rep[i, :].detach())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # 本地用 GMM 动态选择多原型（相对重建率 + ε）
        local_protos = {}
        for cls, feats in class_features.items():
            local_protos[cls] = self._gmm_fit_diag_dynamic(feats)

        save_item(local_protos, self.role, 'protos', self.save_folder_name)
        save_item(model, self.role, 'model', self.save_folder_name)

        self.train_time_cost['num_rounds'] += 1
        self.train_time_cost['total_cost'] += time.time() - start_time

    def test_metrics(self):
        testloader = self.load_test_data()
        model = load_item(self.role, 'model', self.save_folder_name)
        global_protos = load_item('Server', 'global_protos', self.save_folder_name)
        model.eval()

        test_acc = 0
        test_num = 0
        if global_protos is not None and len(global_protos) > 0:
            with torch.no_grad():
                for x, y in testloader:
                    if type(x) == type([]):
                        x[0] = x[0].to(self.device)
                    else:
                        x = x.to(self.device)
                    y = y.to(self.device)
                    rep = self._sanitize(model.base(x))

                    out = float('inf') * torch.ones(y.shape[0], self.num_classes, device=self.device)
                    for i, r in enumerate(rep):
                        for cls in range(self.num_classes):
                            if (cls in global_protos) and isinstance(global_protos[cls], list) and len(global_protos[cls]) > 0:
                                max_ll = float('-inf')
                                for proto in global_protos[cls]:
                                    max_ll = max(max_ll, float(self._gaussian_loglik(r, proto)))
                                out[i, cls] = -max_ll
                    test_acc += (torch.sum(torch.argmin(out, dim=1) == y)).item()
                    test_num += y.shape[0]
            return test_acc, test_num, 0
        else:
            return 0, 1e-5, 0
