import math
import time
import copy
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
from typing import Dict

import numpy as np
import torch

from src.client.fedequilibria import FedEquilibriaClient
from src.server.fedavg import FedAvgServer
from src.utils.constants import FLBENCH_ROOT


class FedEquilibriaServer(FedAvgServer):
    """FedEquilibria: A standalone implementation of FedFisher's method6.

    Idea:
    - Clients return masked parameter updates (`masked_update`) and Fisher/gradient info.
    - Server computes a mixture of multi-objective weights and distance-based weights.
    - Aggregate masked updates with the mixed weights to update the global model.
    """

    algorithm_name = "FedEquilibria"
    all_model_params_personalized = False
    return_diff = False
    client_cls = FedEquilibriaClient

    def __init__(self, args, init_trainer=True, init_model=True):
        super().__init__(args, init_trainer, init_model)
        # Per-round caches and history
        self.euclidean_distance: Dict[int, float] = {}
        self.client_update: Dict[int, Dict[str, torch.Tensor]] = {}
        self.mask_dict: Dict[int, Dict[str, torch.Tensor]] = {}
        self.fisher_info: Dict[int, Dict[str, torch.Tensor]] = {}
        self.gradient_info: Dict[int, Dict[str, torch.Tensor]] = {}

    @staticmethod
    def get_hyperparams(args_list=None) -> Namespace:
        parser = ArgumentParser()
        parser.add_argument("--threshold", type=float, default=0.95,
                            help="Mask ratio threshold for important params at client side")
        parser.add_argument("--weight_t", type=float, default=0.7,
                            help="Mixture coefficient t: t*multi-objective + (1-t)*distance")
        parser.add_argument("--multi_grad", type=int, default=0,
                            help="Use gradients instead of Fisher for multi-objective (1/0)")
        return parser.parse_args(args_list)

    def train_one_round(self):
        # Reset per-round caches
        self.client_update = {}
        self.mask_dict = {}
        self.euclidean_distance = {}
        self.fisher_info = {}
        self.gradient_info = {}

        # Client training
        client_packages = self.trainer.train()

        # Collect client returns and compute distances
        for client_id, package in client_packages.items():
            if 'mask' in package and 'masked_update' in package:
                self.mask_dict[client_id] = package['mask']
                self.client_update[client_id] = package['masked_update']

                if 'fisher_info' in package:
                    self.fisher_info[client_id] = package['fisher_info']
                if 'gradient_info' in package:
                    self.gradient_info[client_id] = package['gradient_info']

                # Print active mask ratio for debugging
                total_params = 0
                masked_params = 0
                for key in self.mask_dict[client_id]:
                    total_params += self.mask_dict[client_id][key].numel()
                    masked_params += torch.sum(self.mask_dict[client_id][key]).item()
                ratio = masked_params / total_params if total_params > 0 else 0.0
                domain = self.get_client_domain(client_id)
                self.logger.log(
                    f"Client {client_id} ({domain}) - Active parameters ratio: {ratio:.4f} ({masked_params}/{total_params})"
                )

                # Euclidean distance based on masked_update
                self.compute_distance(client_id, package['masked_update'])

        # Compute mixed weights
        use_grad = int(getattr(self.args.fedequilibria, 'multi_grad', 0)) == 1 if hasattr(self.args, 'fedequilibria') else 0
        if use_grad:
            mixed_weights = self.compute_mixed_weights(self.gradient_info, self.euclidean_distance)
        else:
            mixed_weights = self.compute_mixed_weights(self.fisher_info, self.euclidean_distance)

        # Aggregate with mask
        global_params_new = self.aggregate_client_updates_with_mask(
            client_packages, mixed_weights, self.public_model_params
        )
        if global_params_new:
            self.public_model_params = global_params_new
            self.model.load_state_dict(self.public_model_params, strict=False)

    def get_client_domain(self, client_id):
        try:
            dataset_root = FLBENCH_ROOT / "data" / self.args.dataset.name
            if (dataset_root / "all_stats.json").exists():
                import json
                with open(dataset_root / "all_stats.json", "r") as f:
                    stats = json.load(f)
                if "domain_distribution" in stats and str(client_id) in stats["domain_distribution"]:
                    return stats["domain_distribution"][str(client_id)][0]
        except Exception:
            pass
        return "unknown"

    def compute_distance(self, client_id, update_diff: Dict[str, torch.Tensor]):
        euclidean_distance = 0.0
        param_names = [name for name, _ in self.model.named_parameters()]
        for key in update_diff:
            if key in param_names:
                tensor = update_diff[key].to(self.device)
                euclidean_distance += torch.norm(tensor).item()
        self.euclidean_distance[client_id] = float(euclidean_distance)
        if self.verbose:
            domain = self.get_client_domain(client_id)
            self.logger.log(f"Client {client_id} ({domain}) distance: {euclidean_distance:.4f}")

    # ===== Core of method6: mix multi-objective weights with distance weights =====
    def compute_mixed_weights(self, fisher_or_grad_info, euclidean_distance):
        weight_t = float(getattr(self.args.fedequilibria, 'weight_t', 0.7)) if hasattr(self.args, 'fedequilibria') else 0.7

        multi_obj_weights = self.compute_multi_objective_weights(fisher_or_grad_info)
        # Distance weights: directly normalize Euclidean distances
        distance_weights = {}
        if not euclidean_distance:
            distance_weights = {i: 1.0 / max(1, len(self.selected_clients)) for i in self.selected_clients}
        else:
            total_distance = sum(euclidean_distance.values()) + 1e-8
            for cid in self.selected_clients:
                d = euclidean_distance.get(cid, 0.0)
                distance_weights[cid] = float(d / total_distance)

        mixed_weights = {}
        for cid in self.selected_clients:
            obj_w = multi_obj_weights.get(cid, 0.0)
            dist_w = distance_weights.get(cid, 0.0)
            mixed_weights[cid] = weight_t * obj_w + (1 - weight_t) * dist_w

        total = sum(mixed_weights.values()) + 1e-8
        for cid in mixed_weights:
            mixed_weights[cid] /= total

        if self.verbose:
            detail = []
            for cid in self.selected_clients:
                domain = self.get_client_domain(cid)
                detail.append(
                    f"{cid}({domain}): obj={multi_obj_weights.get(cid, 0):.3f}, dist={distance_weights.get(cid, 0):.3f}, final={mixed_weights.get(cid, 0):.3f}"
                )
            self.logger.log("Mixed weights: " + " | ".join(detail))
        return mixed_weights

    def compute_multi_objective_weights(self, fisher_or_grad_info):
        # Fall back to uniform weights if no info
        if not fisher_or_grad_info:
            return {i: 1.0 / max(1, len(self.selected_clients)) for i in self.selected_clients}

        # Build normalized flattened feature vectors
        mats = []
        for cid in self.selected_clients:
            if cid not in fisher_or_grad_info:
                continue
            flat = torch.cat([v.abs().view(-1) for v in fisher_or_grad_info[cid].values()])
            mats.append(flat)
        if not mats:
            return {i: 1.0 / max(1, len(self.selected_clients)) for i in self.selected_clients}

        stack = torch.stack(mats)
        norms = torch.norm(stack, dim=1, keepdim=True)
        stack = stack / (norms + 1e-8)
        X = stack.to(torch.float32).cpu().numpy()

        n = X.shape[0]
        P = X @ X.T
        q = np.zeros(n)
        # Constraints: w >= 0, sum w = 1
        G = -np.eye(n)
        h = np.zeros(n)
        A = np.ones((1, n))
        b = np.array([1.0])

        # Prefer cvxopt QP solver, otherwise fallback to uniform
        try:
            import cvxopt
            from cvxopt import matrix
            P = 0.5 * (P + P.T)
            sol = cvxopt.solvers.qp(matrix(P.astype(np.double)), matrix(q.astype(np.double)),
                                     matrix(G), matrix(h), matrix(A), matrix(b))
            w = np.array(sol["x"]).reshape((n,))
            if "optimal" not in sol["status"]:
                raise RuntimeError("QP not optimal")
            weight_dict = {}
            for i, cid in enumerate(self.selected_clients):
                weight_dict[cid] = float(w[i]) if i < len(w) else 0.0
            total = sum(weight_dict.values()) + 1e-8
            for k in weight_dict:
                weight_dict[k] /= total
            return weight_dict
        except Exception:
            return {i: 1.0 / max(1, len(self.selected_clients)) for i in self.selected_clients}


    def aggregate_client_updates_with_mask(self, client_packages, weights, public_model_params):
        if weights is None:
            client_weights = [pkg["weight"] for pkg in client_packages.values()]
            ws = torch.tensor(client_weights) / sum(client_weights)
            weights = {cid: w.item() for cid, w in zip(client_packages.keys(), ws)}

        if not self.mask_dict or not self.client_update:
            return self.aggregate_regular(client_packages, weights, public_model_params)

        global_params_new = public_model_params.copy()
        for k in global_params_new:
            if not torch.is_tensor(global_params_new[k]):
                continue
            for cid in self.selected_clients:
                if cid not in self.client_update:
                    continue
                if k not in self.client_update[cid]:
                    continue
                masked_update = self.client_update[cid][k]
                w = weights.get(cid, 0.0)
                global_params_new[k] = global_params_new[k] - masked_update * w
        return global_params_new

    def aggregate_regular(self, client_packages, weights, public_model_params):
        global_params = public_model_params
        global_params_new = copy.deepcopy(global_params)
        for k in global_params:
            if not torch.is_tensor(global_params[k]):
                continue
            for cid in self.selected_clients:
                if cid not in client_packages:
                    continue
                client_params = client_packages[cid]["regular_model_params"]
                if k not in client_params:
                    continue
                delta = (client_params[k] - global_params[k]) * weights.get(cid, 0.0)
                global_params_new[k] = global_params_new[k] + delta
        return global_params_new
