import torch
from client import *
from .server import Server

class FedSAM(Server):
    def __init__(self, device, model_func, init_model, init_par_list, datasets, method, args):   
        super(FedSAM, self).__init__(device, model_func, init_model, init_par_list, datasets, method, args)
                
        if self.args.server_method == "avg":
            print("AVG")
            self.comm_vecs = {
            'Params_list': init_par_list.clone().detach(),
            }
        
        if self.args.server_method == "dyn":
            print("DYN")
            self.h_params_list = torch.zeros((args.total_client, init_par_list.shape[0]))
            self.comm_vecs = {
                'Params_list': init_par_list.clone().detach(),
                'Local_dual_correction': torch.zeros((init_par_list.shape[0])),
                }
            
        self.Client = fedsam
        
    def process_for_communication(self, client, Averaged_update):
        if self.args.server_method == "avg":
            if not self.args.use_RI:
                self.comm_vecs['Params_list'].copy_(self.server_model_params_list)
            else:
                self.comm_vecs['Params_list'].copy_(self.server_model_params_list + self.args.beta\
                                        * (self.server_model_params_list - self.clients_params_list[client]))

        if self.args.server_method == "dyn":
            if not self.args.use_RI:
                self.comm_vecs['Params_list'].copy_(self.server_model_params_list)
            else:
                self.comm_vecs['Params_list'].copy_(self.server_model_params_list + self.args.beta\
                                        * (self.server_model_params_list - self.clients_params_list[client]))
                
            self.comm_vecs['Local_dual_correction'].copy_(self.h_params_list[client] - self.comm_vecs['Params_list'])
        
    
    def global_update(self, selected_clients, Averaged_update, Averaged_model):
        if self.args.server_method == "avg":
            return self.server_model_params_list + self.args.global_learning_rate * Averaged_update
        
        if self.args.server_method == "dyn":
            return Averaged_model + torch.mean(self.h_params_list, dim=0)


    def postprocess(self, client, received_vecs):

        if self.args.server_method == "avg":
            pass
        
        if self.args.server_method == "dyn":
            self.h_params_list[client] += self.clients_updated_params_list[client]