import numpy as np
import scipy.stats
import random
from optimizer import *
from utils import projection
import copy

def softmax(x):
    ex = np.exp(x)
    sum_ex = np.sum(ex, axis=1).reshape(ex.shape[0],1)
    return ex/sum_ex

class local_train(FederatedOptimizer):
    def __init__(self, args, alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp, sample_ratio, etamu):
        super(local_train, self).__init__(args, alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp, sample_ratio, etamu, None, None, None,None)

    def compute_gradient(self, x, i, type = 'global'):
        if type == 'global':
            return self.compute_gradient_template(x, self.local_models[i], i)
        else:
            return self.val_compute_gradient_template(x, i)


    def burn_out_local_update(self):
        self.val_losses = []
        self.local_models = []
        self.better_local_models = []
        #weight_sum = 0
        local_losses = np.zeros(self.better_cp)
        ii=0
        for i in range(self.size):
            #print('Worker'+str(i))
            #cp = cp_List[i]
            lr = self.lr
            #weight_sum += self.ratio[i]

            local_parameters = self.central_parameter + 0
            for t in range(self.better_cp):

                if t == 500 or t == 1000:
                    lr /= 5

                scale = 1 #0.9**t

                grad, local_loss = self.compute_gradient(local_parameters, i, type = 'local')
                local_parameters -= lr * grad *scale
                local_losses[t] += local_loss/self.size

                if t == int(self.bo_cp-1):
                    ii+=1
                    self.val_losses.append(self.val_loss(local_parameters, i))
                    self.local_models.append(copy.deepcopy(local_parameters))
                    #print(ii)
                    #print(len(self.local_models))

            self.better_local_models.append(copy.deepcopy(local_parameters))


        return local_losses, self.val_losses, self.local_models, self.better_local_models

