import torch
from collections import OrderedDict

from src.client.fedavg import FedAvgClient
from src.utils.constants import FLBENCH_ROOT


class FedEquilibriaClient(FedAvgClient):
    """FedEquilibria client.

    After local training:
    - Compute Fisher information (and optional gradients)
    - Build important-parameter mask by a threshold
    - Produce masked_update for server-side aggregation
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.fisher_info = None
        self.gradient_info = None
        self.important_params_mask = None

    def fit(self):
        # Save global params before local training
        global_params = {k: p.clone() for k, p in self.model.state_dict().items()}

        # Regular local training
        super().fit()

        # Compute diff = global - local
        current_params = self.model.state_dict()
        update_diff = {k: global_params[k] - current_params[k] for k in global_params}

        # Fisher and gradient stats for multi-objective and mixture weights
        self.fisher_info = self.compute_fisher_information()
        self.gradient_info = self.compute_gradient_information()

        # Mask from Fisher
        self.important_params_mask = self.create_importance_mask(self.fisher_info)

        # Apply mask to update
        masked_update = {k: update_diff[k] * self.important_params_mask[k] for k in update_diff}

        # Expose to package()
        self.mask = self.important_params_mask
        self.masked_update = masked_update

        # Print active mask ratio for debugging
        total, active = 0, 0
        for k in self.mask:
            total += self.mask[k].numel()
            active += torch.sum(self.mask[k]).item()
        ratio = active / total if total > 0 else 0.0
        print(f"Client {self.client_id} - Active parameters ratio: {ratio:.4f} ({active}/{total})")

    def compute_fisher_information(self):
        params = self.model.state_dict()
        fisher = {name: torch.zeros_like(params[name], dtype=torch.float32) for name in params}
        self.model.eval()
        criterion = torch.nn.CrossEntropyLoss()
        total_samples = 0
        for images, labels in self.trainloader:
            if len(images) <= 1:
                continue
            images, labels = images.to(self.device), labels.to(self.device)
            bs = images.size(0)
            total_samples += bs
            self.model.zero_grad()
            outputs = self.model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            for name, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[name] += (p.grad.data ** 2) * bs
        if total_samples > 0:
            for name in fisher:
                fisher[name] /= total_samples
        return fisher

    def compute_gradient_information(self):
        params = self.model.state_dict()
        grads = {name: torch.zeros_like(params[name], dtype=torch.float32) for name in params}
        self.model.eval()
        criterion = torch.nn.CrossEntropyLoss()
        total_samples = 0
        for images, labels in self.trainloader:
            if len(images) <= 1:
                continue
            images, labels = images.to(self.device), labels.to(self.device)
            bs = images.size(0)
            total_samples += bs
            self.model.zero_grad()
            outputs = self.model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            for name, p in self.model.named_parameters():
                if p.grad is not None:
                    grads[name] += p.grad.data * bs
        if total_samples > 0:
            for name in grads:
                grads[name] /= total_samples
        return grads

    def create_importance_mask(self, fisher_info: dict):
        flat = []
        for _, v in fisher_info.items():
            flat.append(v.view(-1))
        flat = torch.cat(flat)

        threshold = float(getattr(self.args.fedequilibria, 'threshold', 0.95)) if hasattr(self.args, 'fedequilibria') else 0.95
        total = flat.numel()
        k = int(threshold * total)
        _, idx = torch.sort(flat, descending=True)
        top = idx[:k].tolist()
        mask_flat = torch.zeros_like(flat)
        if k > 0:
            mask_flat[top] = 1.0

        mask = {}
        pos = 0
        for name, v in fisher_info.items():
            n = v.numel()
            mask[name] = mask_flat[pos:pos+n].reshape(v.shape)
            pos += n
        return mask

    def package(self):
        pkg = super().package()
        if hasattr(self, 'mask') and self.mask is not None:
            pkg['mask'] = {k: v.cpu() for k, v in self.mask.items()}
        if hasattr(self, 'masked_update') and self.masked_update is not None:
            pkg['masked_update'] = {k: v.cpu() for k, v in self.masked_update.items()}
        if hasattr(self, 'fisher_info') and self.fisher_info is not None:
            pkg['fisher_info'] = {k: v.cpu() for k, v in self.fisher_info.items()}
        if hasattr(self, 'gradient_info') and self.gradient_info is not None:
            pkg['gradient_info'] = {k: v.cpu() for k, v in self.gradient_info.items()}
        return pkg
