import time

import numpy as np
from scipy.sparse import isspmatrix
from scipy.special import expit as sigmoid
from utils import read_data, read_data_withval, main_plot


import numpy as np
import scipy.stats

def softmax(x):
    ex = np.exp(x)
    sum_ex = np.sum(ex, axis=1).reshape(ex.shape[0],1)
    return ex/sum_ex

class FederatedOptimizer(object):
    def __init__(self, args, alg, lr, bs, glr, train_data_dir, test_data_dir, cp_random, bo_cp, cp, sample_ratio, etamu,
                 val_losses, local_models, unseen_val_losses, unseen_local_models):
        #_, _, self.train_data, self.test_data = read_data(train_data_dir, test_data_dir)
        _, _, self.train_data, self.test_data, self.val_data, self.stat_train, self.stat_val,self.stat_test, \
        self.ratio_train, self.ratio_val, self.ratio_test = read_data_withval(train_data_dir, test_data_dir, args, alg)

        _, _, self.unseen_train_data, self.unseen_test_data, self.unseen_val_data, _,_,_, \
        _, _, _ = read_data_withval('./unseendata/', './unseendata/', args, alg)

        self.size = args.numusers
        self.dim = np.array(self.train_data['f_00000']['x']).shape[1]
        print(self.dim)
        #self.central_parameter = np.zeros((self.dim, 10))
        self.central_parameter = np.zeros((self.dim, 10))

        # params for scaffold
        self.control_c = np.zeros((self.dim, 10))
        self.control_locals = {i:np.zeros((self.dim, 10)) for i in range(args.numusers)}

        # params for ditto
        self.local_parameters_dict = {i:np.zeros((self.dim, 10)) for i in range(args.numusers)}
        self.inc_bs = args.inc_bs

        # params for MW-Fed
        self.scale_list = [1/self.size for _ in range(self.size)]
        self.mwfed_c = args.mwfed_c

        # params for q-FFL
        self.q = args.q
        '''
        if local_models is not None:
            for i in range(self.size):
                self.central_parameter += local_models[i]/self.size
        '''

        #self.init_central = self.central_parameter + 0
        #self.local_parameters = np.zeros([self.size, self.dim])
        self.pr = [None for _ in range(self.size)]
        self.alg = alg
        self.optype = args.optype
        self.cp_loctune = args.cp_loctune
        self.cp_random = cp_random
        self.cp = cp
        self.better_cp = args.better_cp
        self.val_losses = val_losses
        self.local_models = local_models
        self.unseen_val_losses = unseen_val_losses
        self.unseen_local_models = unseen_local_models
        self.bo_cp = bo_cp
        self.sample_ratio = sample_ratio
        self.bs = bs
        self.lr = lr
        self.glr = glr
        self.cp_list = None
        self.ratio = self.get_ratio()
        weight = self.cp_list*self.ratio/np.sum(self.cp_list*self.ratio)
        #print(self.ratio)
        #print(np.max(weight), np.mean(weight), np.min(weight))
        self.etamu = etamu
        self.iter = 0
        self.print_flg = True

        #p_dat = rnd.power(args.alpha, args.num_users)
        #p_dat = p_dat / sum(p_dat)

#         print(np.sum(self.ratio))


    def get_ratio(self):
        total_size = 0
        ratios = np.zeros(self.size)
        for i in range(self.size):
            key = 'f_{0:05d}'.format(i) 
            local_size = np.array(self.train_data[key]['x']).shape[0]
            ratios[i] = local_size
            total_size += local_size
        self.cp_list = np.round(1*ratios/self.bs).astype('int')
        #print(self.cp_list)
#         print("average cp: {:.3f}, std: {:.3f}".format(np.mean(self.cp_list),np.std(self.cp_list)))
#         print("average num: {:.3f}, std: {:.3f}".format(np.mean(ratios),np.std(ratios)))
        return ratios/total_size

    def loss(self, A, y):
        x = self.central_parameter
        y_hat = np.zeros((len(y), 10))
        y_hat[np.arange(len(y)),y.astype('int')] = 1
        loss = - np.sum(y_hat * np.log(softmax(A@x)))/A.shape[0]

        return loss

    def test_loss(self, x, client_id):
        uname = 'f_{0:05d}'.format(client_id)
        #self.val_data = self.train_data
        A = np.array(self.test_data[uname]['x'])
        y = np.array(self.test_data[uname]['y'])
        #x = self.central_parameter
        y_hat = np.zeros((len(y), 10))
        y_hat[np.arange(len(y)), y.astype('int')] = 1
        loss = - np.sum(y_hat * np.log(softmax(A @ x))) / A.shape[0]

        return loss

    def unseen_test_loss(self, x, client_id):
        uname = 'f_{0:05d}'.format(client_id)
        #self.val_data = self.train_data
        A = np.array(self.unseen_test_data[uname]['x'])
        y = np.array(self.unseen_test_data[uname]['y'])
        #x = self.central_parameter
        y_hat = np.zeros((len(y), 10))
        y_hat[np.arange(len(y)), y.astype('int')] = 1
        loss = - np.sum(y_hat * np.log(softmax(A @ x))) / A.shape[0]

        return loss

    def val_loss(self, x, client_id):
        uname = 'f_{0:05d}'.format(client_id)
        #self.val_data = self.train_data
        A = np.array(self.val_data[uname]['x'])
        y = np.array(self.val_data[uname]['y'])
        #x = self.central_parameter
        y_hat = np.zeros((len(y), 10))
        y_hat[np.arange(len(y)), y.astype('int')] = 1
        loss = - np.sum(y_hat * np.log(softmax(A @ x))) / A.shape[0]

        return loss

    def train_loss(self, x, client_id):
        uname = 'f_{0:05d}'.format(client_id)
        #self.val_data = self.train_data
        A = np.array(self.train_data[uname]['x'])
        y = np.array(self.train_data[uname]['y'])
        #x = self.central_parameter
        y_hat = np.zeros((len(y), 10))
        y_hat[np.arange(len(y)), y.astype('int')] = 1
        loss = - np.sum(y_hat * np.log(softmax(A @ x))) / A.shape[0]

        return loss

    def local_update(self):
        raise NotImplemented

    def aggregate(self):
        raise NotImplemented

    def compute_gradient_template(self, x, x_local, i):
        uname = 'f_{0:05d}'.format(i) 
        A = np.array(self.train_data[uname]['x'])
        y = np.array(self.train_data[uname]['y'])

        A_val = np.array(self.val_data[uname]['x'])
        y_val = np.array(self.val_data[uname]['y'])

        sample_idx = np.random.choice(A.shape[0], self.bs)
        a = A[sample_idx]

        sample_idx_val = np.random.choice(A_val.shape[0], self.inc_bs)
        a_val = A_val[sample_idx_val]
        #sample_idx_val = np.random.choice(A.shape[0], self.inc_bs)
        #a_val = A[sample_idx_val]

        targets = np.zeros((self.bs, 10))
        targets[np.arange(self.bs), y[sample_idx].astype('int')] = 1

        #targets_val = np.zeros((self.bs, 10))
        #targets_val[np.arange(self.bs), y_val[sample_idx_val].astype('int')] = 1

        #targets_val = np.zeros((len(y_val), 10))
        #targets_val[np.arange(len(y_val)), y_val.astype('int')] = 1

        targets_val = np.zeros((self.inc_bs, 10))
        targets_val[np.arange(self.inc_bs), y_val[sample_idx_val].astype('int')] = 1

        #loss = - np.sum(targets_val * np.log(softmax(a_val @ x))) / a_val.shape[0]
        #loss_local = - np.sum(targets_val * np.log(softmax(a_val @ x_local))) / a_val.shape[0]

        loss = - np.sum(targets_val * np.log(softmax(a_val @ x))) / a_val.shape[0]
        loss_local = - np.sum(targets_val * np.log(softmax(a_val @ x_local))) / a_val.shape[0]

        grad = - a.T @ (targets - softmax(a @ x))/self.bs
        #grad[:61] += 10e-4 * self.central_parameter[:61]

        return grad, loss, loss_local

    def unseen_compute_gradient_template(self, x, i):
        uname = 'f_{0:05d}'.format(i)
        A = np.array(self.unseen_train_data[uname]['x'])
        y = np.array(self.unseen_train_data[uname]['y'])


        sample_idx = np.random.choice(A.shape[0], self.bs)
        a = A[sample_idx]


        targets = np.zeros((self.bs, 10))
        targets[np.arange(self.bs), y[sample_idx].astype('int')] = 1

        #targets_val = np.zeros((self.bs, 10))
        #targets_val[np.arange(self.bs), y_val[sample_idx_val].astype('int')] = 1

        #targets_val = np.zeros((len(y_val), 10))
        #targets_val[np.arange(len(y_val)), y_val.astype('int')] = 1

        grad = - a.T @ (targets - softmax(a @ x))/self.bs
        #grad[:61] += 10e-4 * self.central_parameter[:61]

        return grad


    def val_compute_gradient_template(self, x, i):
        uname = 'f_{0:05d}'.format(i)
        #A = np.array(self.val_data[uname]['x'])
        #y = np.array(self.val_data[uname]['y'])

        A = np.array(self.train_data[uname]['x'])
        y = np.array(self.train_data[uname]['y'])

        sample_idx = np.random.choice(A.shape[0], self.bs)
        a = A[sample_idx]

        targets = np.zeros((self.bs, 10))
        targets[np.arange(self.bs), y[sample_idx].astype('int')] = 1

        loss = - np.sum(targets * np.log(softmax(a @ x))) / a.shape[0]
        grad = - a.T @ (targets - softmax(a @ x))/self.bs
        #grad[:61] += 10e-4 * self.central_parameter[:61]

        return grad, loss

    def compute_hessian_template(self, x, i):
        uname = 'f_{0:05d}'.format(i)
        A = np.array(self.train_data[uname]['x'])

        sample_idx = np.random.choice(A.shape[0], self.bs)
        a = A[sample_idx]

        #print(np.shape(a))
        #print(np.shape(softmax(a @ x)))

        grad = a.T@(np.diag(np.diag(softmax(a @ x)@(1-softmax(a @ x)).T)))@a/self.bs

        #grad = - a.T @ (targets - softmax(a @ x))/self.bs
        #grad[:61] += 10e-4 * self.central_parameter[:61]

        return grad

    def evaluate(self):         # evaluate the training loss
        losses = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.train_data[uname]['x'])
            y = np.array(self.train_data[uname]['y'])
            losses.append(self.loss(A, y))
        losses = np.array(losses)

        #reg_loss = np.sum(losses)/self.size #+ 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.sum(losses)/self.size

    def evaluate_prtrain(self):         # evaluate the training loss
        losses = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.train_data[uname]['x'])
            y = np.array(self.train_data[uname]['y'])

            if self.pr[i] == 1:  # use central parameter to calculate the loss
                losses.append(self.loss(A, y))

            else:
                y_hat = np.zeros((len(y), 10))
                y_hat[np.arange(len(y)), y.astype('int')] = 1
                loss = - np.sum(y_hat * np.log(softmax(A @ self.local_models[i]))) / A.shape[0]
                losses.append(loss)

        losses = np.array(losses)


        #reg_loss = np.sum(losses)/self.size #+ 10e-4/2 * np.linalg.norm(self.central_parameter[:61],2)**2

        return np.sum(losses)/self.size


    def evaluate_pr(self, other_models = None):
        self.pr = []
        par_clients = []
        for i in range(self.size):
            par = self.participate(i, other_models)
            self.pr.append(par)

            if self.participate(i, other_models) == 1:
                par_clients.append(i)

        unseen_pr = []
        for i in range(self.size):
            parti = int(self.unseen_test_loss(self.central_parameter, i) < self.unseen_test_loss(self.unseen_local_models[i], i))
            unseen_pr.append(parti)

        return sum(self.pr)/len(self.pr), par_clients, sum(unseen_pr)/len(unseen_pr), unseen_pr

    def evaluate_pr_loss(self, other_models=None):
        if other_models == None:
            other_models = self.local_models

        part = []
        losses = []
        all_losses = []
        accs = []


        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.test_data[uname]['x'])
            y = np.array(self.test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1

            par = self.participate(i, other_models)
            part.append(par)

            if par == 1:  # use central parameter to calculate the loss
                loss = - np.sum(y_hat * np.log(softmax(A @ self.central_parameter))) / A.shape[0]
                predict = np.argmax(softmax(A @ self.central_parameter), axis=1)
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict)]))
                all_losses.append(loss)

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ other_models[i]))) / A.shape[0]
                predict = np.argmax(softmax(A @ other_models[i]), axis=1)
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict)]))
                all_losses.append(- np.sum(y_hat * np.log(softmax(A @ self.central_parameter))) / A.shape[0])

            losses.append(loss)

        return np.average(part), np.average(losses), np.average(all_losses), np.average(accs)



    def generate_cp(self, cp):
        if self.cp_random:
            return self.cp_list.astype('int')
        else:
            return self.cp*np.ones(self.size,dtype=int)

    def perfedavg_local(self, local_models_are = None):

        if local_models_are == None:
            local_models_are = self.local_models

        # Perform the additional local updates here
        part = []
        losses = []
        all_losses = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.test_data[uname]['x'])
            y = np.array(self.test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1

            local_parameters = self.central_parameter
            for _ in range(1):
                gradient, _,_ = self.compute_gradient(local_parameters, i)
                local_parameters -= self.lr * gradient

            parti = int(self.test_loss(local_parameters, i)<self.test_loss(local_models_are[i], i))
            part.append(parti)

            if parti == 1:  # use central parameter to calculate the loss
                loss = - np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0]
                all_losses.append(loss)

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ local_models_are[i]))) / A.shape[0]
                all_losses.append(- np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0])

            losses.append(loss)


        return np.average(part), np.average(losses), np.average(all_losses)


    def ditto_local(self, local_models_are = None):
        if local_models_are == None:
            local_models_are = self.local_models

        # Perform the additional local updates here
        part = []
        losses = []
        all_losses = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.test_data[uname]['x'])
            y = np.array(self.test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1

            local_parameters = self.local_parameters_dict[i]

            parti = int(self.test_loss(local_parameters, i)<self.test_loss(local_models_are[i], i))
            part.append(parti)

            if parti == 1:  # use central parameter to calculate the loss
                loss = - np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0]
                all_losses.append(loss)

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ local_models_are[i]))) / A.shape[0]
                all_losses.append(- np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0])

            losses.append(loss)


        return np.average(part), np.average(losses), np.average(all_losses)

    def local_tuning(self, local_models_are = None):

        if local_models_are == None:
            local_models_are = self.local_models

        # Perform the additional local updates here
        part = []
        losses = []
        all_losses = []
        accs = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.test_data[uname]['x'])
            y = np.array(self.test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1

            local_parameters = self.central_parameter
            for _ in range(self.cp_loctune):
                gradient, _,_ = self.compute_gradient(local_parameters, i, type = 'localtune')
                local_parameters -= self.lr * gradient

            parti = int(self.test_loss(local_parameters, i)<self.test_loss(local_models_are[i], i))
            part.append(parti)

            if parti == 1:  # use central parameter to calculate the loss
                loss = - np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0]
                predict = np.argmax(softmax(A @ local_parameters), axis=1)
                accs.append(np.sum([1 if i==int(j) else 0 for (i,j) in zip(y, predict)]))
                all_losses.append(loss)

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ local_models_are[i]))) / A.shape[0]
                predict = np.argmax(softmax(A @ local_models_are[i]),axis=1)
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict)]))
                all_losses.append(- np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0])

            losses.append(loss)


        return np.average(part), np.average(losses), np.average(all_losses), np.average(accs)




    def unseen_local_tuning(self, local_models_are = None):

        if local_models_are == None:
            local_models_are = self.unseen_local_models

        # Perform the additional local updates here
        part = []
        losses = []
        all_losses = []
        accs = []
        for i in range(self.size):
            uname = 'f_{0:05d}'.format(i)
            A = np.array(self.unseen_test_data[uname]['x'])
            y = np.array(self.unseen_test_data[uname]['y'])
            y_hat = np.zeros((len(y), 10))
            y_hat[np.arange(len(y)), y.astype('int')] = 1

            local_parameters = self.central_parameter
            for _ in range(self.cp_loctune):
                gradient  = self.compute_gradient(local_parameters, i, type = 'unseen_localtune')
                local_parameters -= self.lr * gradient

            parti = int(self.unseen_test_loss(local_parameters, i)<self.unseen_test_loss(local_models_are[i], i))
            part.append(parti)

            if parti == 1:  # use central parameter to calculate the loss
                loss = - np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0]
                predict = np.argmax(softmax(A @ local_parameters), axis=1)
                accs.append(np.sum([1 if i==int(j) else 0 for (i,j) in zip(y, predict)]))
                all_losses.append(loss)

            else:
                loss = - np.sum(y_hat * np.log(softmax(A @ local_models_are[i]))) / A.shape[0]
                predict = np.argmax(softmax(A @ local_models_are[i]),axis=1)
                accs.append(np.sum([1 if i == int(j) else 0 for (i, j) in zip(y, predict)]))
                all_losses.append(- np.sum(y_hat * np.log(softmax(A @ local_parameters))) / A.shape[0])

            losses.append(loss)


        return np.average(part), np.average(losses), np.average(all_losses), np.average(accs)
