import copy

from .fedbase import BasicServer, BasicClient
import numpy as np
from utils import fmodule
import utils.system_simulator as ss

class Server(BasicServer):
    def __init__(self, option, model, clients, test_data = None):
        super(Server, self).__init__(option, model, clients, test_data)
        self.init_algo_para({'feda': 0.1, 'T': 0.5})
        # self.feda = 0.9
        # self.T = 1.0
        self.save_name = str(option['seed']) + '_' + option['task'] + '_model_' + option['algorithm'] + '_lr_' + str(option['learning_rate']) + '_feda_' +str(self.feda)+'_T_'+str(self.T)+'_momentum_'+str(option['momentum'])
    # @ss.time_step
    def iterate(self):
        # sample clients
        self.selected_clients = self.sample()
        # training
        res = self.communicate(self.selected_clients)
        # lasts, models, train_losses = res['last'], res['model'], res['loss']
        models, train_losses = res['model'], res['loss']
        grads = [(self.model - model) / self.lr for model in models]

        # deltas = [(model - self.model) for model in models]
        reweight_losses = [np.exp(l / self.T) for l in train_losses]
        sum_weight = sum(reweight_losses)
        p_tilde = np.array([w / sum_weight for w in reweight_losses])
        g_tilde = grads @ p_tilde
        m_update = []
        for i in range(len(grads)):
            grads[i] = grads[i]*(1.0-self.feda) + g_tilde*self.feda
            m_update.append(self.model - self.lr * grads[i])
        res = self.communicate(self.selected_clients, w=m_update, update=True)
        ms, losses = res['model'], res['loss']
        self.model = self.aggregate(ms, losses)
        return

    # def aggregate(self, models: list, new_weight):
    #     if len(models) == 0: return self.model
    #     p = new_weight
    #     return fmodule._model_sum([model_k * pk for model_k, pk in zip(models, p)])
    ###############################################################################
    # (1-alpha) * w_reweight + alpha * w_ideal
    ###############################################################################
    def aggregate(self, ms, losses):
        # fake_client_num = 90
        m_now = fmodule._model_average(ms)
        # m_fake = [copy.deepcopy(m_now) for _ in range(fake_client_num)]
        # ms += m_fake
        reweight_losses = [np.exp(l / self.T) for l in losses]
        # mean_loss = np.mean(reweight_losses)
        # loss_fake = [mean_loss for _ in range(fake_client_num)]
        # reweight_losses += loss_fake
        sum_weight = sum(reweight_losses)
        new_weight = [w / sum_weight for w in reweight_losses]
        m_reweight = fmodule._model_average(ms, new_weight)
        return fmodule._model_average([m_reweight, m_now], [(1 - self.feda), self.feda])


class Client(BasicClient):
    def reply(self, svr_pkg):
        model = self.unpack(svr_pkg)
        train_loss = self.test(model, 'train')['loss']
        self.train(model)
        cpkg = self.pack(model, train_loss)
        return cpkg

    def pack(self, model, loss):
        return {
            "model": model,
            "loss": loss,
        }
