from src.client.fedavg import FedAvgClient
import torch
import copy
from collections import OrderedDict


class FedAvgHEALClient(FedAvgClient):
    def __init__(self, **commons):
        super(FedAvgHEALClient, self).__init__(**commons)
        
        # 添加HEAL算法所需的属性
        self.increase_history = {}


    def fit(self):

        global_params = {key: param.clone() for key, param in self.model.state_dict().items()}

        super().fit()
        
        current_params = self.model.state_dict()
        update_diff = {key: global_params[key] - current_params[key] for key in global_params}
        
        mask = self.consistency_mask(update_diff)
        
        masked_update = {key: update_diff[key] * mask[key] for key in update_diff}
        
        self.mask = mask
        self.masked_update = masked_update
        
        total_params = 0
        masked_params = 0
        for key in mask:
            total_params += mask[key].numel()
            masked_params += torch.sum(mask[key]).item()
        mask_ratio = masked_params / total_params if total_params > 0 else 0
        print(f"Epoch:{self.current_epoch}  Client {self.client_id} - Active parameters ratio: {mask_ratio:.4f} ({masked_params}/{total_params})")

    
    def consistency_mask(self, update_diff):
        client_id = self.client_id
        
        if client_id not in self.increase_history or not self.increase_history[client_id]:
            self.increase_history[client_id] = {
                key: torch.zeros_like(val) for key, val in update_diff.items()
            }
            
            for key in update_diff:
                self.increase_history[client_id][key] = (update_diff[key] >= 0).float()
                
            return {key: torch.ones_like(val) for key, val in update_diff.items()}
        
        mask = {}
        for key in update_diff:
            positive_consistency = self.increase_history[client_id][key]
            negative_consistency = 1 - self.increase_history[client_id][key]
            
            consistency = torch.where(
                update_diff[key] >= 0, 
                positive_consistency, 
                negative_consistency
            )
            
            mask[key] = (consistency > self.args.fedavgheal.threshold).float()
        
        print(self.args.fedavgheal.threshold)
        for key in update_diff:
            increase = (update_diff[key] >= 0).float()
            self.increase_history[client_id][key] = (
                self.increase_history[client_id][key] * self.current_epoch + increase
            ) / (self.current_epoch + 1)
        return mask

    def package(self):
        package = super().package()
        
        # 添加HEAL特定的数据
        if hasattr(self, 'mask'):
            package['mask'] = {k: v.cpu() for k, v in self.mask.items()}
            package['masked_update'] = {k: v.cpu() for k, v in self.masked_update.items()}
        
        return package
