from argparse import ArgumentParser, Namespace

from src.client.fedavgheal import FedAvgHEALClient
from src.server.fedavg import FedAvgServer
import torch
import numpy as np


class FedAvgHEALServer(FedAvgServer):
    algorithm_name = "FedAvgHEAL"
    all_model_params_personalized = False
    return_diff = False
    client_cls = FedAvgHEALClient

    def __init__(self, args, init_trainer=True, init_model=True):
        super(FedAvgHEALServer, self).__init__(args, init_trainer, init_model)
        
        # 初始化HEAL算法的特定属性
        self.euclidean_distance = {}
        self.previous_weights = {}
        self.previous_delta_weights = {}
        self.client_update = {}
        self.mask_dict = {}
    
    @staticmethod
    def get_hyperparams(args_list=None) -> Namespace:
        parser = ArgumentParser()
        parser.add_argument("--threshold", type=float, default=0.5, 
                            help="Parameter consistency threshold for mask")
        parser.add_argument("--beta", type=float, default=0.5, 
                            help="Weight for combining previous and current weights")
        return parser.parse_args(args_list)

    def train_one_round(self):
        self.client_update = {}
        self.mask_dict = {}
        self.euclidean_distance = {}
        
        client_packages = self.trainer.train()
        
        for client_id, package in client_packages.items():
            if 'mask' in package and 'masked_update' in package:
                self.mask_dict[client_id] = package['mask']
                self.client_update[client_id] = package['masked_update']
                
                total_params = 0
                masked_params = 0
                for key in self.mask_dict[client_id]:
                    total_params += self.mask_dict[client_id][key].numel()
                    masked_params += torch.sum(self.mask_dict[client_id][key]).item()
                mask_ratio = masked_params / total_params if total_params > 0 else 0
                self.logger.log(f"Client {client_id} - Active parameters ratio: {mask_ratio:.4f} ({masked_params}/{total_params})")
        
                self.compute_distance(client_id, package['masked_update'])

        freq = self.get_params_diff_weights()
        self.aggregate_client_updates_heal(client_packages, freq)
        
    def print_update_diff_sum(self, update_diff):
        total_sum = 0
        for key, value in update_diff.items():
            param_sum = torch.sum(value).item()
            total_sum += param_sum
        print(f"Total sum of update_diff: {total_sum:.6f}")
    def compute_distance(self, client_id, update_diff):
        self.print_update_diff_sum(update_diff)
        euclidean_distance = 0
        param_names = [name for name, _ in self.model.named_parameters()]
        print(f"Number of parameters: {len(param_names)}")

        for key in update_diff:
            if key in param_names:
                euclidean_distance += torch.norm(update_diff[key]).item()
        
        self.euclidean_distance[client_id] = euclidean_distance
        
        if self.verbose:
            self.logger.log(f"Client {client_id} distance: {euclidean_distance:.4f}")

    def get_params_diff_weights(self):
        weight_dict = {}
        selected_clients = self.selected_clients
        
        if not self.euclidean_distance:
            for client_id in selected_clients:
                weight_dict[client_id] = 1.0 / len(selected_clients)
            return weight_dict
            
        for client_id in selected_clients:
            client_distance = self.euclidean_distance.get(client_id, 0)
            
            delta_weight = (1 - self.args.fedavgheal.beta) * (
                self.previous_delta_weights.get(client_id, 0)
            ) + self.args.fedavgheal.beta * (
                (client_distance) / (sum(self.euclidean_distance.values()) + 1e-8)
            )
            
            new_weight = (
                self.previous_weights.get(client_id, 1 / len(selected_clients)) + delta_weight
            )
            weight_dict[client_id] = new_weight
            
            self.previous_weights[client_id] = new_weight
            self.previous_delta_weights[client_id] = delta_weight
        
        total_weight = sum(weight_dict.values()) + 1e-8
        for client_id in selected_clients:
            weight_dict[client_id] /= total_weight
            
        if self.verbose:
            weight_str = "\t\t".join(f"{i}:{weight_dict[i]:.3f}" for i in selected_clients)
            self.logger.log(f"Client weights: {weight_str}")
            
        return weight_dict

    def aggregate_client_updates_heal(self, client_packages, freq=None):
        client_weights = [package["weight"] for package in client_packages.values()]
        
        if freq is None:
            weights = torch.tensor(client_weights) / sum(client_weights)
            freq = {client_id: weight.item() for client_id, weight in zip(client_packages.keys(), weights)}
        
        if not self.mask_dict or not self.client_update:
            self.aggregate_client_updates(client_packages)
            return
            
        global_params_new = self.public_model_params.copy()
        
        for param_key in global_params_new:
            for client_id in self.selected_clients:
                if client_id in self.client_update and param_key in self.client_update[client_id]:
                    weight_for_client = freq[client_id]
                    update = self.client_update[client_id][param_key] # global_params[key] - current_params[key]
                    
                    global_params_new[param_key] = global_params_new[param_key] - update * weight_for_client
        
        self.public_model_params = global_params_new
        self.model.load_state_dict(self.public_model_params, strict=False)
    
    def package(self, client_id: int):
        package = super().package(client_id)
        return package
