# -*- coding: utf-8 -*-
# in this version we add other functional utilities
import copy as cp
import matplotlib.pyplot as plt
import math as m
import numpy as np
import tensorflow as tf
import random as r
import gc
# import tensorflow_gan as tfgan
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input


# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

# define cross entropy loss function
def compute_loss(true, pred):
    return tf.reduce_mean(tf.losses.binary_crossentropy(true, pred), axis=-1)

def fid_score(image1, image2):
    """
    Calculate Frechet Inception Distance Score
    :param image1:
    :param image2:
    :return: FID score
    """
    # resize the image
    input1 = tf.image.resize(image1, [299, 299], method=tf.image.ResizeMethod.BILINEAR)
    input2 = tf.image.resize(image2, [299, 299], method=tf.image.ResizeMethod.BILINEAR)
    # normalize
    input1 = preprocess_input(input1)
    input2 = preprocess_input(input2)
    # get activation function
    act1 = model.predict(input1.numpy())
    act2 = model.predict(input2.numpy())
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    # fid = tfgan.eval.frechet_classifier_distance_from_activations(act1, act2)
    return fid

def Top_MSE(k, real_data, dummy_data):
    """
    In this function, we go through the all the
    :param k: The number of images with the minimum MSE value
    :return: The top K images
    """
    input_real_data = real_data.numpy()
    input_dummy_data = dummy_data.numpy()
    output_dummy_data = np.zeros([k, 32, 32, 3])
    output_real_data = np.zeros([k, 32, 32, 3])
    for top_k in range(k):
        # initialize a blank min MSE matrix
        min_matrix = np.zeros((input_dummy_data.shape[0], 2))
        # go through all dummy data
        for dummy_index in range(input_dummy_data.shape[0]):
            tmp_MSE = []
            for real_index in range(input_real_data.shape[0]):
                # compute the MSE value
                tmp_MSE.append(np.linalg.norm(input_dummy_data[dummy_index, :, :, :] - input_real_data[real_index, :, :, :]))
            # save the minimum MSE for each dummy image
            min_matrix[dummy_index, 0] = min(tmp_MSE)
            min_matrix[dummy_index, 1] = tmp_MSE.index(min(tmp_MSE))
        # find the current top k image
        min_index = np.argmin(min_matrix[:, 0])
        # put it and the responding image into the
        output_dummy_data[top_k, :, :, :] = input_dummy_data[min_index, :, :, :]
        output_real_data[top_k, :, :, :] = input_real_data[int(min_matrix[min_index, 1]), :, :, :]
        input_dummy_data = np.delete(input_dummy_data, min_index, axis=0)
        input_real_data = np.delete(input_real_data, int(min_matrix[min_index, 1]), axis=0)
    return output_real_data, output_dummy_data

def dummy_data_init(number_of_workers, data_number, pretrain = False, true_label = None):
    '''
    In this function we initialize dummy data
    :param number_of_workers:
    :param data_number:
    :return: dummy_images, dummy_labels
    '''
    if pretrain:
        dummy_images = []
        for worker_index in range(number_of_workers):
            temp_dummy_image = np.load('result/' + str(worker_index) + '_dummy.npy')
            temp_dummy_image = tf.Variable(tf.convert_to_tensor(temp_dummy_image))
            dummy_images.append(temp_dummy_image)
        dummy_labels = np.load('result/labels_.npy')
        dummy_labels = tf.Variable(tf.convert_to_tensor(dummy_labels))
        return dummy_images, dummy_labels

    else:
        dummy_images = []
        for n in range(number_of_workers):
            temp_dummy_image = tf.random.uniform(shape=[data_number, 16, 16, 3], seed= n + 1)
            # temp_dummy_image = tf.random.normal(shape=[data_number, 16, 16, 3], seed= n + 1)
            # temp_dummy_image = tf.zeros([data_number, 16, 16, 3])
            # temp_dummy_image = tf.ones([data_number, 16, 16, 3])
            temp_dummy_image = tf.Variable(temp_dummy_image)
            dummy_images.append(temp_dummy_image)
        if true_label == None:
            dummy_labels = tf.random.uniform(shape = [data_number, 5], seed = 0)
            # dummy_labels = tf.random.normal(shape=[data_number, 5], seed= 0)
        else:
            dummy_labels =  true_label
        dummy_labels = tf.Variable(dummy_labels)
        return dummy_images, dummy_labels

def list_real_data(number_of_workers, train_datasets, data_number):
    '''
    In this function we list all real data and put them in a big list
    :param number_of_workers:
    :param train_datasets:
    :return: real_images, real_labels
    '''
    real_labels = list(zip(*train_datasets))[-1][0]
    total_real_data = len(real_labels)
    r.seed(0)
    real_sample_list = r.sample(list(range(total_real_data)), data_number)
    real_labels = tf.gather(real_labels, real_sample_list)
    # real_labels = tf.reshape(tf.one_hot(real_labels, 5), (-1, 5))
    real_images = []
    for n in range(number_of_workers):
        temp_images = list(zip(*train_datasets))[n]
        real_images.append(tf.gather(temp_images[0], real_sample_list, axis = 0))
    return  real_images, real_labels

def take_gradient(number_of_workers, random_lists, real_images, real_labels, local_net, server):
    '''
    compute the real gradient
    :param number_of_workers:
    :param data_number:
    :param batchsize:
    :param real_images:
    :param real_labels:
    :param net:
    :return: true gradient
    '''
    true_gradient = []
    local_output = []
    middle_input = []
    batch_real_image = []
    real_tv_norm = []
    with tf.GradientTape(persistent = True) as tape:
        label = tf.gather(real_labels, random_lists, axis = 0)
        for n in range(number_of_workers):
            # gradient tape
            # take the batch
            temp_image = tf.gather(real_images[n], random_lists, axis=0)
            # compute output and loss
            temp_middle_input, temp_local_output = local_net[n].forward(temp_image)
            middle_input.append(temp_middle_input)
            local_output.append(temp_local_output)
            batch_real_image.append(temp_image)
            # compute real TV norm
            temp_tv_norm = tf.image.total_variation(temp_image)
            temp_tv_norm = tf.reduce_mean(temp_tv_norm, axis = 0)
            real_tv_norm.append(temp_tv_norm)
        # concat
        real_middle_input = tf.concat(middle_input, axis=1) # batch size x 2048
        real_local_output = tf.concat(local_output, axis=1) # batch size x 40
        # real_middle_input = tf.reduce_mean(real_middle_input, axis=2)
        # server part
        predict = server.forward(real_local_output)
        # compute loss
        loss = compute_loss(predict, label)

    # server gradient
    temp_server_true_gradient = tape.gradient(loss, server.trainable_variables)
    true_gradient.append(temp_server_true_gradient)
    for work_index in range(number_of_workers):
        temp_local_true_gradient = tape.gradient(loss, local_net[n].trainable_variables)
        true_gradient.append(temp_local_true_gradient)
    # compute aggregated TV norm
    real_tv_norm_aggregated = real_tv_norm[0]
    for n in range(1, number_of_workers):
        real_tv_norm_aggregated += real_tv_norm[n]
    real_tv_norm_aggregated = real_tv_norm_aggregated / number_of_workers
    print('real TV norm', real_tv_norm_aggregated.numpy(), end = '\t')
    # recycle variables
    return true_gradient, batch_real_image, real_middle_input

def select_index(iter, data_number, batchsize):
    '''
    generate the batch index
    :param iter:
    :param number_of_workers:
    :param data_number:
    :param batchsize: batch size
    :return: random_lists
    '''
    r.seed(iter)
    random_lists = r.sample(list(range(data_number)), batchsize)
    return random_lists

def aggregate(gradients, number_of_workers):
    """
    Aggregate the gradients list
    :param gradients: the gradients list
    :param number_of_workers:
    :return: aggregated gradient
    """
    aggregated_gradient = []
    for l in range(len(gradients[0])):
        shape = gradients[0][l].numpy().shape
        temp_gradient = tf.Variable(tf.zeros(shape))
        for n in range(number_of_workers):
            temp_gradient = temp_gradient + gradients[n][l]
        # temp_gradient = temp_gradient / number_of_workers
        aggregated_gradient.append(temp_gradient)
    return aggregated_gradient

def take_batch_data(number_of_workers, dummy_images, dummy_labels, random_lists):
    '''
    Take batch:
    :param number_of_workers:
    :param dummy_images:
    :param dummy_labels:
    :param random_lists:
    :return: batch_dummy_data, batch_dummy_label
    '''
    batch_dummy_image = []
    # take the responding batch data
    for n in range(number_of_workers):
        temp_dummy_image = tf.gather(dummy_images[n], random_lists, axis=0)
        temp_dummy_image = tf.Variable(temp_dummy_image)
        batch_dummy_image.append(temp_dummy_image)
    temp_dummy_label = tf.gather(dummy_labels, random_lists, axis=0)
    batch_dummy_label = tf.Variable(temp_dummy_label)
    return batch_dummy_image, batch_dummy_label

def DLG(number_of_workers, batch_dummy_image, batch_dummy_label, local_net, server, real_gradient, real_middle_input):
    '''
    Core part of the algorithm: DLG
    :param number_of_workers:
    :param batch_dummy_image:
    :param batch_dummy_label:
    :param local_net:
    :param server
    :param real_gradient:
    :return: D, dlg_gradient_x, dlg_gradient_y
    '''
    # compute fake gradient
    fake_gradient = []
    with tf.GradientTape(persistent=True) as t:
        t.reset()
        # go through all the workers
        fake_gradient = []
        fake_local_output = []
        fake_middle_input = []
        for n in range(number_of_workers):
            t.watch(batch_dummy_image[n])
            # input images
            temp_middle_input, temp_local_output = local_net[n].forward(batch_dummy_image[n])
            fake_local_output.append(temp_local_output)
            fake_middle_input.append(temp_middle_input)
        del temp_local_output, temp_middle_input
        gc.collect()
        # concat
        dummy_middle_input = tf.concat(fake_middle_input, axis = 1)
        dummy_local_output = tf.concat(fake_local_output, axis = 1)
        # dummy_middle_input = tf.reduce_mean(dummy_middle_input, axis = 2)

        # server part
        predict = server.forward(dummy_local_output)
        # compute loss
        t.watch(batch_dummy_label)
        true = tf.nn.softmax(batch_dummy_label)
        loss = compute_loss(predict, true)

        # compute fake gradient
        temp_server_true_gradient = t.gradient(loss, server.trainable_variables)
        fake_gradient.append(temp_server_true_gradient)
        for work_index in range(number_of_workers):
            temp_local_fake_gradient = t.gradient(loss, local_net[n].trainable_variables)
            fake_gradient.append(temp_local_fake_gradient)
        del temp_server_true_gradient
        del temp_local_fake_gradient
        gc.collect()

        # compute D loss
        D = 0
        for layer in range(len(real_gradient)):
            for gr, gf in zip(real_gradient[layer], fake_gradient[layer]):
                gr = tf.reshape(gr, [-1, 1])
                gf = tf.reshape(gf, [-1, 1])
                # D_norm = tf.norm(gr - gf) ** 2
                # sigma = tf.math.reduce_std(gr) ** 2
                D += tf.norm(gr - gf) ** 2
                # D += 1 - tf.math.exp(- D_norm / sigma)
        D *= 100

        # compute local output norm
        D_local_output_norm = 0
        for r_real_middle_input, dummy_middle_input in zip(real_middle_input, dummy_middle_input):
            temp_input_norm = tf.norm(r_real_middle_input - dummy_middle_input) ** 2
            D_local_output_norm += temp_input_norm
        del temp_input_norm
        gc.collect()

        print("DLG loss: %.5f" % D.numpy(), end = '\t')
        print('Input norm:', D_local_output_norm.numpy(), end = '\t')

        # compute tv norm
        tv_norm = []
        for n in range(number_of_workers):
            temp_tv_norm = tf.image.total_variation(batch_dummy_image[n])
            temp_tv_norm = tf.reduce_mean(temp_tv_norm, axis = 0)
            tv_norm.append(temp_tv_norm)
        del temp_tv_norm
        gc.collect()

        # compute aggregated TV norm
        tv_norm_aggregated = tv_norm[0]
        for n in range(1, number_of_workers):
            tv_norm_aggregated += tv_norm[n]
        tv_norm_aggregated = tv_norm_aggregated / number_of_workers
        tv_norm_aggregated = tf.reduce_mean(tv_norm_aggregated)
        # D += tv_norm_aggregated
        print('with Tv norm', tv_norm_aggregated.numpy(), end = '\t')

        '''
        DLG optimization
        '''
        dlg_gradient_x = []
        dlg_gradient_y = t.gradient(D, batch_dummy_label) # label known
        for n in range(number_of_workers):
            # temp_tv_norm = t.gradient(tv_norm[n], batch_dummy_image[n])
            temp_tv_norm_gradient = t.gradient(tv_norm[n], batch_dummy_image[n])
            temp_local_output_gradient = t.gradient(D_local_output_norm, batch_dummy_image[n])
            temp_dlg_gradient_x = t.gradient(D, batch_dummy_image[n])  # + temp_tv_norm
            temp_dlg_gradient_x = temp_dlg_gradient_x + 1e-3 * temp_local_output_gradient
            # add Tv norm gradient
            if tv_norm_aggregated.numpy() > 90:
                temp_dlg_gradient_x = temp_dlg_gradient_x + 1e-4 * temp_tv_norm_gradient
            # resize gradients
            # temp_dlg_gradient_x = tf.reshape(temp_dlg_gradient_x, [1, batchsize, 32, 32, 3])
            # temp_dlg_gradient_y = tf.reshape(temp_dlg_gradient_y, [1, batchsize, 10])
            dlg_gradient_x.append(temp_dlg_gradient_x)
        return D.numpy(), dlg_gradient_x, dlg_gradient_y

def optimize_DLG(iter, optimizer_dlg, random_lists,dlg_gradient_x, dlg_gradient_y, batch_dummy_image,
                 batch_dummy_label):
    '''
    Optimize dummy data
    :param iter:
    :param optimizer_dlg:
    :param random_lists:
    :param dlg_gradient_x:
    :param dlg_gradient_y:
    :param batch_dummy_image:
    :param batch_dummy_label:
    :return:

    '''

    batch_dummy_image = optimizer_dlg.apply_gradients(iter, random_lists, dlg_gradient_x, batch_dummy_image, data=True)
    batch_dummy_label = optimizer_dlg.apply_gradients(iter, random_lists, dlg_gradient_y, batch_dummy_label, data=False)
    '''
    number_of_workers = 2
    for n in range(number_of_workers):
        optimizer_dlg.apply_gradients(zip(dlg_gradient_x[n], [batch_dummy_image[n]]), 'DLG_images')
        optimizer_dlg.apply_gradients(zip(dlg_gradient_y[n], [batch_dummy_label[n]]), 'DLG_labels')
        # batch_dummy_image[n] = batch_dummy_image[n] - 0.05 * dlg_gradient_x[n]
        # batch_dummy_label[n] = batch_dummy_label[n] - 0.05 * dlg_gradient_y[n]
    '''



    return batch_dummy_image, batch_dummy_label

def assign_to_dummy(number_of_workers, batchsize, dummy_images, dummy_labels, batch_dummy_image, batch_dummy_label,
                    random_lists):
    '''
    dummy images and dummy labels
    :param number_of_workers:
    :param batchsize:
    :param dummy_images:
    :param dummy_labels:
    :param batch_dummy_image:
    :param batch_dummy_label:
    :param random_lists:
    :return:
    '''
    for batch_index in range(batchsize):
        for n in range(number_of_workers):
            dummy_images[n][random_lists[batch_index], :, :, :].assign(batch_dummy_image[n][batch_index, :, :, :])
        dummy_labels[random_lists[batch_index], :].assign(batch_dummy_label[batch_index, :])
    return dummy_images, dummy_labels

def record(filename, record_list):
    '''
    Write parameters into the txt file
    :param filename:
    :param record_list: record list
    :return:
    '''
    file = open(filename + '.txt', 'a+')
    for i in range(len(record_list)):
        file.write(str(record_list[i]))
        if i == len(record_list) - 1:
            file.write('\n')
        else:
            file.write('\t')
    file.close()

class Adam():
    '''
    Adam optimizer
    '''
    def __init__(self, number_of_workers, data_number, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-7):
        self.lr = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        # initialize m and v (momentum)
        self.h_data = []
        self.v_data = []
        self.h_label = []
        self.v_label = []
        self.number_of_workers = number_of_workers
        '''
        for n in range(number_of_workers):
            self.h_data.append(tf.Variable(tf.zeros([data_number, 32, 32, 3])))
            self.v_data.append(tf.Variable(tf.zeros([data_number, 32, 32, 3])))
            self.h_label.append(tf.Variable(tf.zeros([data_number, 40])))
            self.v_label.append(tf.Variable(tf.zeros([data_number, 40])))
        '''

    def apply_gradients(self, iter, random_lists, gradient, theta, data=True):
        '''
        In this function, we optimize theta
        :param iter:
        :param random_lists:
        :param gradient: a list
        :param theta: a list
        :param data: optimize data or label
        :return: theta_new
        '''
        # update m
        if data:
            theta_new = []
            # optimize data
            # learning rate decay
            temp_lr = self.lr * m.sqrt(1 - self.beta2 ** (iter + 1)) / (1 - self.beta1 ** (iter + 1))
            for n in range(self.number_of_workers):
                # take out the h
                # h = tf.gather(self.h_data[n], random_lists[n], axis=0)
                # update h
                # h = self.beta1 * h + (1 - self.beta1) * gradient[n]
                h = (1 - self.beta1) * gradient[n]
                # take out the v
                # v = tf.gather(self.v_data[n], random_lists[n], axis=0)
                # update v
                # v = self.beta2 * v + (1 - self.beta2) * tf.math.square(gradient[n])
                v = (1 - self.beta2) * tf.math.square(gradient[n])
                # update dummy data
                # h_hat = h / (1 - self.beta1 ** (iter+1))
                # v_hat = v / (1 - self.beta2 ** (iter+1))
                temp_theta = theta[n] - temp_lr * h / (tf.math.sqrt(v) + self.epsilon)
                theta_new.append(temp_theta)
                # store h and v
                '''
                for batch_index in range(len(random_lists[n])):
                    self.h_data[n][random_lists[n][batch_index], :, :, :].assign(h[batch_index, :, :, :])
                    self.v_data[n][random_lists[n][batch_index], :, :, :].assign(v[batch_index, :, :, :])
                '''
        # optimize label
        else:
            # learning rate decay
            temp_lr = self.lr * m.sqrt(1 - self.beta2 ** (iter + 1)) / (1 - self.beta1 ** (iter + 1))
            # take out the h
            # h = tf.gather(self.h_label[n], random_lists[n], axis=0)
            # update h
            # h = self.beta1 * h + (1 - self.beta1) * gradient
            h = (1 - self.beta1) * gradient
            # take out the v
            # v = tf.gather(self.v_label[n], random_lists[n], axis=0)
            # update v
            # v = self.beta2 * v + (1 - self.beta2) * tf.math.square(gradient)
            v = (1 - self.beta2) * tf.math.square(gradient)
            # h_hat = h / (1 - self.beta1 ** (iter+1))
            # v_hat = v / (1 - self.beta2 ** (iter+1))
            # update dummy data
            theta_new = theta - temp_lr * h / (tf.math.sqrt(v) + self.epsilon)
            # store h and v
            '''
            for batch_index in range(len(random_lists[n])):
                self.h_label[n][random_lists[n][batch_index], :].assign(h[batch_index, :])
                self.v_label[n][random_lists[n][batch_index], :].assign(v[batch_index, :])
            '''
        return theta_new

def visual_data(data, real):
    '''
    In this function we visualize the data
    :param data: data to be visualized (list)
    :real True or false
    :return:
    '''
    number_of_worker = len(data)
    if real:
        # save real iamge
        for worker_index in range(number_of_worker):
            data_number = data[worker_index].numpy().shape[0]
            for data_index in range(data_number):
                data_to_be_visualized = data[worker_index][data_index, :, :, :].numpy()
                plt.imshow(data_to_be_visualized)
                plt.savefig('result/' + str(worker_index) + '/' + str(data_index) + 'real.png')
                plt.close()
    else:
        # save real iamge
        for worker_index in range(number_of_worker):
            data_number = data[worker_index].numpy().shape[0]
            for data_index in range(data_number):
                data_to_be_visualized = data[worker_index][data_index, :, :, :].numpy()
                plt.imshow(data_to_be_visualized)
                plt.savefig('result/' + str(worker_index) + '/' + str(data_index) + 'dummy.png')
                plt.close()

def PSNR(batch_real_image, batch_dummy_image):
    '''
    compute PSNR
    :param batch_real_image:
    :param batch_dummy_image:
    :return:
    '''
    psnr = []
    for worker_index in range(len(batch_real_image)):
        dummy = tf.clip_by_value(batch_dummy_image[worker_index], 0, 1)
        psnr.append(tf.reduce_mean(tf.image.psnr(batch_real_image[worker_index], dummy, 1.0)))
    aggregated_psnr = tf.reduce_mean(psnr)
    print('psnr value:', aggregated_psnr.numpy(), end='\t')
    return aggregated_psnr.numpy()

def save_data(data, labels):
    '''
    In this function we save the data into npy format
    :param data: dummy(real) data
    :param real: True or False
    :return:
    '''
    if labels:
        # save labels
        data_to_be_save = data.numpy()
        np.save('result/labels_.npy', data_to_be_save)
    else:
        number_of_workers = len(data)
        # save dummy data
        for worker_index in range(number_of_workers):
            data_to_be_save = data[worker_index].numpy()
            np.save('result/' + str(worker_index) + '_dummy.npy', data_to_be_save)






