import numpy as np
import scipy.stats
import random
from optimizer import *
from utils import projection

def softmax(x):
    ex = np.exp(x)
    sum_ex = np.sum(ex, axis=1).reshape(ex.shape[0],1)
    return ex/sum_ex

class FedAvg(FederatedOptimizer):
    def __init__(self,args,  alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp, sample_ratio,
                 etamu,  val_losses, local_models, unseen_val_losses, unseen_local_models):
        super(FedAvg, self).__init__(args, alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp,
                                     sample_ratio, etamu, val_losses, local_models, unseen_val_losses, unseen_local_models)

    def compute_gradient(self, x, i, type='global'):
        if type == 'global':

            if self.alg == 'fedprox':
                mu = self.etamu
                # return self.compute_gradient_template(x, self.local_models[i], i)
                grad, loss, local = self.compute_gradient_template(x, self.local_models[i], i)
                grad +=  mu * (x - self.central_parameter)
                return grad, loss, local

            elif self.alg == 'scaffold':
                grad, loss, local = self.compute_gradient_template(x, self.local_models[i], i)
                grad -= self.control_locals[i]+self.control_c
                return grad, loss, local

            else:
                return self.compute_gradient_template(x, self.local_models[i], i)

        elif type == 'localtune':
            return self.compute_gradient_template(x, self.local_models[i], i)

        elif type == 'unseen_localtune':

            return self.unseen_compute_gradient_template(x, i)

        else:
            return self.val_compute_gradient_template(x, i)



    def local_update(self, sel_client=None):
        cp_List = self.generate_cp(self.cp)

        if sel_client == None:
            worker_set = random.sample([i for i in range(self.size)], self.sample_ratio)
        else:
            worker_set = sel_client

        if self.alg == 'fedavg' or self.alg == 'fedprox' or self.alg == 'scaffold'\
                or self.alg == 'perfedavg' or self.alg == 'ditto':
            # scale = 1/self.size
            scale = 1

        ## Computing the values from full data
        Delta = list()
        self.control_variates = list()
        self.scale_tmp = []
        for i in worker_set:
            cp = cp_List[i]
            if self.alg == 'ditto': # since ditto requires 2 local rounds
                lamb = 0.05

            lr = self.lr

            local_parameters = self.central_parameter + 0

            if self.optype == 2:

                loss = self.val_loss(local_parameters, i)
                loss_local = self.val_losses[i]

                if self.alg == 'sigm':
                    tmp_loss = loss - loss_local
                    tmp_loss_weight = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))
                    scale = tmp_loss_weight * (1 - tmp_loss_weight)  # / self.size
                    # scale_tmp.append(tmp_loss_weight * (1 - tmp_loss_weight) / self.size)

                if self.alg == 'softplus':
                    tmp_loss = loss - loss_local
                    scale = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))  # / self.size
                    # scale_tmp.append(np.exp(tmp_loss) / (1 + np.exp(tmp_loss)) / self.size)

                if self.alg == 'leave':
                    if loss - loss_local > 0:
                        scale = 1  # (self.size - sum(self.pr))
                    else:
                        scale = 0

                    # self.scale_tmp.append(scale)

                if self.alg == 'stay':
                    if loss - loss_local <= 0:
                        scale = 1  # sum(self.pr)
                    else:
                        scale = 0

                self.scale_tmp.append(scale)
                #self.scale_tmp.append(1)



            for t in range(cp):
                #scale = 1 #0.9**t

                if self.optype == 2:
                    if self.alg == 'perfedavg':
                        other_lr = lr/2
                        #original_local_parameters = local_parameters
                        gradient_1, _, _ = self.compute_gradient(local_parameters, i)
                        hessian_1 = self.compute_hessian_template(local_parameters, i)
                        meta_local_parameters = local_parameters - other_lr * gradient_1
                        gradient_2, _, _ = self.compute_gradient(meta_local_parameters, i)

                        #print('1'+str(np.shape(gradient_1)))
                        #print('2'+str(np.shape(hessian_1)))

                        local_parameters -= lr*(np.identity(self.dim)-other_lr*hessian_1)@gradient_2


                    else:
                        gradient, _, _ = self.compute_gradient(local_parameters, i)
                        local_parameters -= lr * gradient


                # optype == 1 deprecated
                elif self.optype == 1:
                    gradient, loss, loss_local = self.compute_gradient(local_parameters, i)

                    if self.alg == 'sigm':
                        tmp_loss = loss - loss_local
                        tmp_loss_weight = np.exp(tmp_loss) / (1 + np.exp(tmp_loss))
                        scale = tmp_loss_weight * (1 - tmp_loss_weight) #/ self.size
                        #scale_tmp.append(tmp_loss_weight * (1 - tmp_loss_weight) / self.size)

                    if self.alg == 'softplus':
                        tmp_loss = loss - loss_local
                        scale = np.exp(tmp_loss) / (1 + np.exp(tmp_loss)) #/ self.size
                        #scale_tmp.append(np.exp(tmp_loss) / (1 + np.exp(tmp_loss)) / self.size)

                    if self.alg == 'leave':
                    
                        if loss - loss_local > 0:
                            scale = 1 #/ self.size  #(self.size - sum(self.pr))
                        else:
                            scale = 0

                    if self.alg == 'stay':
                    
                        if loss - loss_local <= 0:
                            scale = 1    #sum(self.pr)
                        else:
                            scale = 0

                    local_parameters -= lr * scale* gradient

            if self.alg == 'ditto':

                for t in range(1):
                    ditto_local_parameter = self.local_parameters_dict[i]
                    gradient, _, _ = self.compute_gradient(ditto_local_parameter, i)

                    ditto_local_parameter -= lr*(gradient+lamb*(ditto_local_parameter-self.central_parameter))

                self.local_parameters_dict[i] = ditto_local_parameter


            if self.optype == 1:
                scale = 1
                self.scale_tmp.append(scale)        # this just becomes the number of workers

            Delta.append(scale*(local_parameters - self.central_parameter))

            if self.alg == 'scaffold':
                updated_central_local = self.control_locals[i]-self.control_c+(self.central_parameter-local_parameters)*(1/(cp*lr))
                self.control_variates.append(updated_central_local-self.control_locals[i])
                self.control_locals[i] = updated_central_local

        Delta = np.array(Delta) #/weight_sum
        self.control_variates = np.array(self.control_variates)

        #tmp_loss_sigm = (lin.norm(theta[i] - w_star) ** 2 / dim) - local_loss[i]
        ##tmp_loss_sigm = np.exp(tmp_loss_sigm) / (1 + np.exp(tmp_loss_sigm))
        ##sigm_wi = (tmp_loss_sigm) * (1 - tmp_loss_sigm)

        return Delta

    def local_update_other(self, sel_client=None):

        if sel_client == None:
            worker_set = random.sample([i for i in range(self.size)], self.sample_ratio)
        else:
            worker_set = sel_client

        scale_tmp_mw = []
        for i in worker_set:
            loss = self.val_loss(self.central_parameter, i)
            loss_local = self.val_losses[i]

            if loss - loss_local>=0:
                self.scale_list[i] *= self.mwfed_c

            scale_tmp_mw.append(self.scale_list[i])

        scale_tmp_mw = np.array(scale_tmp_mw)/sum(scale_tmp_mw)


        ## Computing the values from full data

        Delta = list()
        self.control_variates = list()
        # print('JJ?' + str(self.control_c[5:][1]))
        self.scale_tmp = []
        for cli_i, i in enumerate(worker_set):
            #cp = cp_List[i]
            cp = int(self.cp*self.sample_ratio*scale_tmp_mw[cli_i])
            self.scale_tmp.append(cp)
            local_parameters = self.central_parameter + 0
            lr = self.lr

            for t in range(cp):
                # scale = 1 #0.9**t
                gradient, _, _ = self.compute_gradient(local_parameters, i)
                local_parameters -= lr * gradient

            Delta.append(cp * (local_parameters - self.central_parameter))

        Delta = np.array(Delta)  # /weight_sum

        return Delta


    def aggregate(self, Delta):

        if sum(self.scale_tmp)>0:
            self.central_parameter += self.glr * np.sum(Delta, axis=0)/sum(self.scale_tmp)

        if self.alg == 'scaffold':
            self.control_c += np.sum(self.control_variates, axis=0)/self.size

            #print('SS'+str(self.control_c[5:]))


    def participate(self, i, other_models = None):

        if other_models == None:
            if self.test_loss(self.central_parameter, i)<self.test_loss(self.local_models[i], i):
                return 1
            else:
                return 0

        else:
            if self.test_loss(self.central_parameter, i)<self.test_loss(other_models[i], i):
                return 1
            else:
                return 0

    def evaluate_test(self):
        losses = []
        all_losses = []
        accs = []
        all_accs = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.test_data[uname]['x'])
            y = np.array(self.test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1
            predict_all = np.argmax(softmax(A @ self.central_parameter), axis=1)
            if self.pr[i] == 1:  # use central parameter to calculate the loss
                losses.append(self.loss(A, y))
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict_all)]))

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ self.local_models[i]))) / A.shape[0]
                losses.append(loss)
                predict = np.argmax(softmax(A @  self.local_models[i]), axis=1)
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict)]))

            #if self.alg == 'fedavg':
            all_losses.append(self.loss(A, y))
            all_accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict_all)]))

        losses = np.array(losses)
        all_losses = np.array(all_losses)

        # reg_loss = np.sum(losses) + 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.average(losses), np.average(all_losses), np.average(accs), np.average(all_accs)


    def evaluate_unseentest(self, unseen_pr):
        losses = []
        all_losses = []
        accs = []
        all_accs = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.unseen_test_data[uname]['x'])
            y = np.array(self.unseen_test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1
            predict_all = np.argmax(softmax(A @ self.central_parameter), axis=1)

            if unseen_pr[i] == 1:  # use central parameter to calculate the loss
                losses.append(self.loss(A, y))
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict_all)]))

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ self.unseen_local_models[i]))) / A.shape[0]
                losses.append(loss)
                predict = np.argmax(softmax(A @  self.unseen_local_models[i]), axis=1)
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict)]))

            #if self.alg == 'fedavg':
            all_losses.append(self.loss(A, y))
            all_accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict_all)]))

        losses = np.array(losses)
        all_losses = np.array(all_losses)

        # reg_loss = np.sum(losses) + 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.average(losses), np.average(all_losses), np.average(accs), np.average(all_accs)


    def evaluate_val(self):
        losses = []

        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.val_data[uname]['x'])
            y = np.array(self.val_data[uname]['y'])

            losses.append(self.loss(A, y))

        losses = np.array(losses)

        diff = losses - np.array(self.val_losses)

        # reg_loss = np.sum(losses) + 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.sum(losses), diff



