from .fedbase import BasicServer, BasicClient
import numpy as np
from utils import fmodule

class Server(BasicServer):
    def __init__(self, option, model, clients, test_data = None):
        super(Server, self).__init__(option, model, clients, test_data)

    def iterate(self):
        # sample clients
        self.selected_clients = self.sample()
        # training
        res = self.communicate(self.selected_clients)
        models, train_losses = res['model'], res['loss']
        # sum_old_weights = sum([P[i] for i in idxs_users])
        # total_loss_exp = sum(np.exp(loss_locals))
        # weight = [np.exp(w) / total_loss_exp * sum_old_weights for w in loss_locals]
        # for i in range(len(weight)):
        #     # P[idxs_users[i]] = alpha * P[sample_index[i]] + (1 - alpha) * weights[i]
        #     P[idxs_users[i]] = weight[i]
        sum_old_weights = sum([self.weight[i] for i in self.selected_clients])
        total_loss_exp = sum(np.exp(train_losses)/0.1)
        new_weight = [(np.exp(w)/0.1) / total_loss_exp * sum_old_weights for w in train_losses]
        for i in range(len(new_weight)):
            self.weight[self.selected_clients[i]] = new_weight[i]
        self.model = self.aggregate(models)
        return



class Client(BasicClient):
    def reply(self, svr_pkg):
        model = self.unpack(svr_pkg)
        train_loss = self.test(model, 'train')['loss']
        self.train(model)
        cpkg = self.pack(model, train_loss)
        return cpkg

    def pack(self, model, loss):
        return {
            "model": model,
            "loss": loss,
        }
