import numpy as np
import scipy.stats
import random
from optimizer import *

def softmax(x):
    ex = np.exp(x)
    sum_ex = np.sum(ex, axis=1).reshape(ex.shape[0],1)
    return ex/sum_ex

class FedProx(FederatedOptimizer):
    def __init__(self,args,  alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp, sample_ratio, etamu, val_losses, local_models):
        super(FedProx, self).__init__(args, alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp, sample_ratio, etamu, val_losses, local_models)

    def compute_gradient(self, x, i, type = 'global'):

        if type == 'global':

            if self.alg == 'fedprox':
                mu = self.etamu
                #return self.compute_gradient_template(x, self.local_models[i], i)
                return self.compute_gradient_template(x,self.local_models[i], i) + mu * (x - self.central_parameter)

            elif self.alg == 'fedavg':
                return self.compute_gradient_template(x, self.local_models[i], i)

        else:
            return self.val_compute_gradient_template(x, i)

    def local_update(self):
        cp_List = self.generate_cp(self.cp)
        worker_set = random.sample([i for i in range(self.size)], self.sample_ratio)
        if self.alg == 'fedavg' or self.alg == 'fedprox' or self.alg == 'scaffold':
            #scale = 1/self.size
            scale = 1

        #self.iter += 1
        #random.seed(2020 + self.iter)
        #worker_set = random.sample([i for i in range(self.size)],self.sample_ratio)
        Delta = list()
        self.scale_tmp = []
        #weight = self.ratio
        #weight_sum = 0
        for i in worker_set:
            cp = cp_List[i]
            lr = self.lr

            local_parameters = self.central_parameter + 0

            if self.optype == 2:

                loss = self.val_loss(local_parameters, i)
                loss_local = self.val_losses[i]

                if self.alg == 'sigm':
                    tmp_loss = loss - loss_local
                    tmp_loss_weight = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))
                    scale = tmp_loss_weight * (1 - tmp_loss_weight)  # / self.size
                    # scale_tmp.append(tmp_loss_weight * (1 - tmp_loss_weight) / self.size)

                if self.alg == 'softplus':
                    tmp_loss = loss - loss_local
                    scale = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))  # / self.size
                    # scale_tmp.append(np.exp(tmp_loss) / (1 + np.exp(tmp_loss)) / self.size)

                if self.alg == 'leave':
                    if loss - loss_local > 0:
                        scale = 1  # (self.size - sum(self.pr))
                    else:
                        scale = 0

                    # self.scale_tmp.append(scale)

                if self.alg == 'stay':
                    if loss - loss_local <= 0:
                        scale = 1  # sum(self.pr)
                    else:
                        scale = 0

                self.scale_tmp.append(scale)
                # self.scale_tmp.append(1)

            for t in range(cp):
                # scale = 1 #0.9**t

                if self.optype == 2:
                    gradient, _, _ = self.compute_gradient(local_parameters, i)
                    local_parameters -= lr * gradient


                # ALERT: Optype == 1 has been deprecated only optype 2 is used (thus not included for FedProx)

            Delta.append(scale * (local_parameters - self.central_parameter))

            #Delta.append((local_parameters - self.central_parameter)*weight[i])

        Delta = np.array(Delta) #/weight_sum


        return Delta

    def aggregate(self, Delta):

        if sum(self.scale_tmp)>0:
            self.central_parameter += self.glr * np.sum(Delta, axis=0)/sum(self.scale_tmp)
        #self.central_parameter += self.glr * np.sum(Delta, axis=0)

    def participate(self, i, other_models = None):

        if other_models == None:
            if self.test_loss(self.central_parameter, i)<self.test_loss(self.local_models[i], i):
                return 1
            else:
                return 0

        else:
            if self.test_loss(self.central_parameter, i)<self.test_loss(other_models[i], i):
                return 1
            else:
                return 0

    def evaluate_test(self):
        losses = []
        all_losses = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.test_data[uname]['x'])
            y = np.array(self.test_data[uname]['y'])

            if self.pr[i] == 1:  # use central parameter to calculate the loss
                losses.append(self.loss(A, y))

            else:
                y_hat = np.zeros((len(y), 10))
                y_hat[np.arange(len(y)), y.astype('int')] = 1
                loss = - np.sum(y_hat * np.log(softmax(A @ self.local_models[i]))) / A.shape[0]
                losses.append(loss)

            #if self.alg == 'fedavg':
            all_losses.append(self.loss(A, y))


        losses = np.array(losses)
        all_losses = np.array(all_losses)

        # reg_loss = np.sum(losses) + 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.sum(losses), np.sum(all_losses)

    def evaluate_val(self):
        losses = []

        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.val_data[uname]['x'])
            y = np.array(self.val_data[uname]['y'])

            losses.append(self.loss(A, y))

        losses = np.array(losses)

        diff = losses - np.array(self.val_losses)

        # reg_loss = np.sum(losses) + 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.sum(losses), diff


