# -*- coding: utf-8 -*-
# in this version we add other functional utilities
import copy as cp
import matplotlib.pyplot as plt
import math as m
import numpy as np
import tensorflow as tf
import random as r
import gc
# import tensorflow_gan as tfgan
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input


# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

# define cross entropy loss function
def compute_loss(true, pred):
    return tf.reduce_mean(tf.losses.binary_crossentropy(true, pred), axis=-1)

# compute accuracy
def compute_accuracy(true, pred):
    return tf.reduce_mean(tf.keras.metrics.categorical_accuracy(true, pred))

def fid_score(image1, image2):
    """
    Calculate Frechet Inception Distance Score
    :param image1:
    :param image2:
    :return: FID score
    """
    # resize the image
    input1 = tf.image.resize(image1, [299, 299], method=tf.image.ResizeMethod.BILINEAR)
    input2 = tf.image.resize(image2, [299, 299], method=tf.image.ResizeMethod.BILINEAR)
    # normalize
    input1 = preprocess_input(input1)
    input2 = preprocess_input(input2)
    # get activation function
    act1 = model.predict(input1.numpy())
    act2 = model.predict(input2.numpy())
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    # fid = tfgan.eval.frechet_classifier_distance_from_activations(act1, act2)
    return fid

def Top_MSE(k, real_data, dummy_data):
    """
    In this function, we go through the all the
    :param k: The number of images with the minimum MSE value
    :return: The top K images
    """
    input_real_data = real_data.numpy()
    input_dummy_data = dummy_data.numpy()
    output_dummy_data = np.zeros([k, 32, 32, 3])
    output_real_data = np.zeros([k, 32, 32, 3])
    for top_k in range(k):
        # initialize a blank min MSE matrix
        min_matrix = np.zeros((input_dummy_data.shape[0], 2))
        # go through all dummy data
        for dummy_index in range(input_dummy_data.shape[0]):
            tmp_MSE = []
            for real_index in range(input_real_data.shape[0]):
                # compute the MSE value
                tmp_MSE.append(np.linalg.norm(input_dummy_data[dummy_index, :, :, :] - input_real_data[real_index, :, :, :]))
            # save the minimum MSE for each dummy image
            min_matrix[dummy_index, 0] = min(tmp_MSE)
            min_matrix[dummy_index, 1] = tmp_MSE.index(min(tmp_MSE))
        # find the current top k image
        min_index = np.argmin(min_matrix[:, 0])
        # put it and the responding image into the
        output_dummy_data[top_k, :, :, :] = input_dummy_data[min_index, :, :, :]
        output_real_data[top_k, :, :, :] = input_real_data[int(min_matrix[min_index, 1]), :, :, :]
        input_dummy_data = np.delete(input_dummy_data, min_index, axis=0)
        input_real_data = np.delete(input_real_data, int(min_matrix[min_index, 1]), axis=0)
    return output_real_data, output_dummy_data

def dummy_data_init(number_of_workers, data_number, pretrain = False):
    '''
    In this function we initialize dummy data
    :param number_of_workers:
    :param data_number:
    :return: dummy_images, dummy_labels
    '''
    dummy_images = []
    dummy_labels = []
    if pretrain:
        for worker_index in range(number_of_workers):
            temp_dummy_image = np.load('result/' + str(worker_index) + '_dummy_data.npy')
            temp_dummy_image = tf.Variable(tf.convert_to_tensor(temp_dummy_image))
            dummy_images.append(temp_dummy_image)
            temp_dummy_label = np.load('result/' + str(worker_index) + '_dummy_labels.npy')
            temp_dummy_label = tf.Variable(tf.convert_to_tensor(temp_dummy_label))
            dummy_labels.append(temp_dummy_label)
        return dummy_images, dummy_labels
    else:
        for n in range(number_of_workers):
            temp_dummy_image = tf.random.uniform(shape=[data_number, 32, 32, 3], seed=n + 1)
            temp_dummy_image = tf.Variable(temp_dummy_image)
            temp_dummy_label = tf.random.uniform(shape=[data_number, 5], seed=n + 1)
            temp_dummy_label = tf.Variable(temp_dummy_label)
            dummy_images.append(temp_dummy_image)
            dummy_labels.append(temp_dummy_label)
    return dummy_images, dummy_labels


def list_real_data(number_of_workers, train_datasets):
    '''
    In this function we list all real data and put them in a big list
    :param number_of_workers:
    :param train_datasets:
    :return: real_images, real_labels
    '''
    real_images = []
    real_labels = []
    for n in range(number_of_workers):
        temp_images, temp_labels = zip(*train_datasets[n])
        real_images.append(temp_images[0])
        real_labels.append(temp_labels[0])
    return  real_images, real_labels

def take_gradient(number_of_workers, random_lists, real_images, real_labels, net):
    '''
    compute the real gradient
    :param number_of_workers:
    :param data_number:
    :param batchsize:
    :param real_images:
    :param real_labels:
    :param net:
    :return: true gradient
    '''
    true_gradient = []
    real_tv_norm = []
    batch_real_data = []
    real_middle_input = []
    real_loss = []
    real_accuracy = []
    for n in range(number_of_workers):
        # gradient tape
        with tf.GradientTape() as tape:
            # update random index
            temp_list = random_lists[n]
            # take the batch
            temp_image = tf.gather(real_images[n], temp_list, axis=0)
            temp_label = tf.gather(real_labels[n], temp_list, axis=0)
            # compute output and loss
            predict, temp_middle_input = net.forward(temp_image)
            real_middle_input.append(temp_middle_input)
            temp_loss = compute_loss(temp_label, predict)
            temp_accuracy = compute_accuracy(temp_label, predict)
            real_loss.append(temp_loss)
            real_accuracy.append(temp_accuracy)
        temp_true_gradient = tape.gradient(temp_loss, net.trainable_variables)
        layers = len(temp_true_gradient)  # neural network layers
        true_gradient.append(temp_true_gradient)
        # compute real TV norm
        temp_tv_norm = tf.image.total_variation(temp_image)
        temp_tv_norm = tf.reduce_mean(temp_tv_norm, axis = 0)
        real_tv_norm.append(temp_tv_norm)
        batch_real_data.append(temp_image)
    # compute aggregated TV norm
    real_tv_norm_aggregated = real_tv_norm[0]
    real_aggregated_loss = real_loss[0]
    for n in range(1, number_of_workers):
        real_tv_norm_aggregated += real_tv_norm[n]
        real_aggregated_loss += real_loss[n]
    real_tv_norm_aggregated = real_tv_norm_aggregated / number_of_workers
    real_aggregated_loss = real_aggregated_loss / number_of_workers
    print('real TV norm', real_tv_norm_aggregated.numpy(), end='\t')
    print('real loss', real_aggregated_loss.numpy(), end='\t')
    real_aggregated_accuracy = tf.reduce_mean(real_accuracy)
    print('real accuracy', real_aggregated_accuracy.numpy())
    # real_tv_norm_aggregated = tf.reduce_mean(tv_norm_aggregated)
    return true_gradient, batch_real_data, real_middle_input

def select_index(iter, number_of_workers, data_number, batchsize):
    '''
    generate the batch index
    :param iter:
    :param number_of_workers:
    :param data_number:
    :param batchsize: batch size
    :return: random_lists
    '''
    random_lists = []
    for n in range(number_of_workers):
        r.seed(iter * (n + 1))
        temp_list = r.sample(list(range(data_number)), batchsize)
        random_lists.append(temp_list)
    return random_lists

def aggregate(gradients, number_of_workers):
    """
    Aggregate the gradients list
    :param gradients: the gradients list
    :param number_of_workers:
    :return: aggregated gradient
    """
    aggregated_gradient = []
    for l in range(len(gradients[0])):
        shape = gradients[0][l].numpy().shape
        temp_gradient = tf.Variable(tf.zeros(shape))
        for n in range(number_of_workers):
            temp_gradient = temp_gradient + gradients[n][l]
        # temp_gradient = temp_gradient / number_of_workers
        aggregated_gradient.append(temp_gradient)
    return aggregated_gradient

def take_batch_data(number_of_workers, dummy_images, dummy_labels, random_lists):
    '''
    Take batch:
    :param number_of_workers:
    :param dummy_images:
    :param dummy_labels:
    :param random_lists:
    :return: batch_dummy_data, batch_dummy_label
    '''
    batch_dummy_image = []
    batch_dummy_label = []
    # take the responding batch data
    for n in range(number_of_workers):
        temp_dummy_image = tf.gather(dummy_images[n], random_lists[n], axis=0)
        temp_dummy_image = tf.Variable(temp_dummy_image)
        temp_dummy_label = tf.gather(dummy_labels[n], random_lists[n], axis=0)
        temp_dummy_label = tf.Variable(temp_dummy_label)
        batch_dummy_image.append(temp_dummy_image)
        batch_dummy_label.append(temp_dummy_label)
    return batch_dummy_image, batch_dummy_label

def DLG(number_of_workers, batch_dummy_image, batch_dummy_label, net, aggregated_real_gradient, real_middle_input,
        iter, init_real_gradient_norm = None, current_real_gradient_norm = None, init_gradient_norm = None):
    '''
    Core part of the algorithm: DLG
    :param number_of_workers:
    :param batch_dummy_image:
    :param batch_dummy_label:
    :param net:
    :param aggregated_real_gradient:
    :param batchsize:
    :return: D, dlg_gradient_x, dlg_gradient_y
    '''
    # compute fake gradient
    with tf.GradientTape(persistent=True) as t:
        t.reset()
        # go through all the workers
        fake_gradient = []
        dummy_middle_input = []
        for n in range(number_of_workers):
            t.watch([batch_dummy_image[n], batch_dummy_label[n]])
            # input image
            pred, temp_middle_input = net.forward(batch_dummy_image[n])
            true = tf.nn.softmax(batch_dummy_label[n])
            loss = compute_loss(true, pred)
            temp_fake_gradient = t.gradient(loss, net.trainable_variables)
            fake_gradient.append(temp_fake_gradient)
            dummy_middle_input.append(temp_middle_input)
        del temp_fake_gradient
        gc.collect()

        # compute aggregated fake gradient
        aggregated_fake_gradient = aggregate(fake_gradient, number_of_workers)
        '''
        if iter == 0:
            init_fake_gradient_norm = gradient_normalize(aggregated_fake_gradient)
            current_fake_gradient_norm = init_fake_gradient_norm
        else:
            init_fake_gradient_norm = init_gradient_norm
            current_fake_gradient_norm = gradient_normalize(aggregated_fake_gradient)
        '''

        # compute D loss
        D = 0
        # D_compress = 0
        D_input_norm = 0
        gradient_index = 0
        for gr, gf in zip(aggregated_real_gradient, aggregated_fake_gradient):
            # normalize gradient
            gr = tf.reshape(gr, [-1, 1]) # / current_real_gradient_norm[gradient_index] * init_real_gradient_norm[gradient_index]
            gf = tf.reshape(gf, [-1, 1]) # / current_fake_gradient_norm[gradient_index] * init_fake_gradient_norm[gradient_index]
            D += tf.norm(gr - gf) ** 2
            '''
            pruned_gr = top_K_value(gr, 10)
            pruned_gf = top_K_value(gf, 10)
            D_compress += tf.norm(pruned_gr - pruned_gf) ** 2
            '''
        for r_input, f_input in zip(real_middle_input, dummy_middle_input):
            temp_input_norm = tf.norm(r_input - f_input) ** 2
            D_input_norm += temp_input_norm

        # D_total = D + 100 * D_compress
        print("DLG loss: %.5f" % D.numpy(), end = '\t')
        print('Input norm:', D_input_norm.numpy(), end = '\t')
        # print("DLG compressed loss: %.5f" % D_compress.numpy(), end = '\t')


        tv_norm = []
        for n in range(number_of_workers):
            temp_tv_norm = tf.image.total_variation(batch_dummy_image[n])
            temp_tv_norm = tf.reduce_mean(temp_tv_norm, axis=0)
            tv_norm.append(temp_tv_norm)
        del temp_tv_norm
        gc.collect()

        # compute aggregated TV norm
        tv_norm_aggregated = tv_norm[0]
        for n in range(1, number_of_workers):
            tv_norm_aggregated += tv_norm[n]
        tv_norm_aggregated = tv_norm_aggregated / number_of_workers
        tv_norm_aggregated = tf.reduce_mean(tv_norm_aggregated)
        # D += tv_norm_aggregated
        print('with Tv norm', tv_norm_aggregated.numpy(), end = '\t')

        '''
        DLG optimization
        '''
        dlg_gradient_x = []
        dlg_gradient_y = []

        for n in range(number_of_workers):
            # temp_tv_norm = t.gradient(tv_norm[n], batch_dummy_image[n])
            temp_tv_norm_gradient = t.gradient(tv_norm[n],batch_dummy_image[n])
            temp_dlg_gradient_x = t.gradient(D, batch_dummy_image[n])
            temp_dlg_gradient_y = t.gradient(D, batch_dummy_label[n])
            temp_middle_input_gradient = t.gradient(D_input_norm, batch_dummy_image[n])
            temp_dlg_gradient_x = temp_dlg_gradient_x + 8e-4 * temp_middle_input_gradient
            # add Tv norm gradient
            if tv_norm_aggregated.numpy() > 420:
                temp_dlg_gradient_x = temp_dlg_gradient_x + 8e-6 * temp_tv_norm_gradient
            # resize gradients
            # temp_dlg_gradient_x = tf.reshape(temp_dlg_gradient_x, [1, batchsize, 32, 32, 3])
            # temp_dlg_gradient_y = tf.reshape(temp_dlg_gradient_y, [1, batchsize, 10])
            dlg_gradient_x.append(temp_dlg_gradient_x)
            dlg_gradient_y.append(temp_dlg_gradient_y)

    return D.numpy(), dlg_gradient_x, dlg_gradient_y, # init_fake_gradient_norm

def optimize_DLG(iter, optimizer_dlg, random_lists,dlg_gradient_x, dlg_gradient_y, batch_dummy_image,
                 batch_dummy_label):
    '''
    Optimize dummy data
    :param iter:
    :param optimizer_dlg:
    :param random_lists:
    :param dlg_gradient_x:
    :param dlg_gradient_y:
    :param batch_dummy_image:
    :param batch_dummy_label:
    :return:

    '''

    batch_dummy_image = optimizer_dlg.apply_gradients(iter, random_lists, dlg_gradient_x, batch_dummy_image, data=True)
    batch_dummy_label = optimizer_dlg.apply_gradients(iter, random_lists, dlg_gradient_y, batch_dummy_label, data=False)
    '''
    number_of_workers = 2
    for n in range(number_of_workers):
        optimizer_dlg.apply_gradients(zip(dlg_gradient_x[n], [batch_dummy_image[n]]), 'DLG_images')
        optimizer_dlg.apply_gradients(zip(dlg_gradient_y[n], [batch_dummy_label[n]]), 'DLG_labels')
        # batch_dummy_image[n] = batch_dummy_image[n] - 0.05 * dlg_gradient_x[n]
        # batch_dummy_label[n] = batch_dummy_label[n] - 0.05 * dlg_gradient_y[n]
    '''




    return batch_dummy_image, batch_dummy_label

def assign_to_dummy(number_of_workers, batchsize, dummy_images, dummy_labels, batch_dummy_image, batch_dummy_label,
                    random_lists):
    '''
    dummy images and dummy labels
    :param number_of_workers:
    :param batchsize:
    :param dummy_images:
    :param dummy_labels:
    :param batch_dummy_image:
    :param batch_dummy_label:
    :param random_lists:
    :return:
    '''
    for n in range(number_of_workers):
        for batch_index in range(batchsize):
            dummy_images[n][random_lists[n][batch_index], :, :, :].assign(batch_dummy_image[n][batch_index, :,
                                                                          :, :])
            dummy_labels[n][random_lists[n][batch_index], :].assign(batch_dummy_label[n][batch_index, :])
    return dummy_images, dummy_labels

def record(filename, record_list):
    '''
    Write parameters into the txt file
    :param filename:
    :param record_list: record list
    :return:
    '''
    file = open(filename + '.txt', 'a+')
    for i in range(len(record_list)):
        file.write(str(record_list[i]))
        if i == len(record_list) - 1:
            file.write('\n')
        else:
            file.write('\t')
    file.close()

class Adam():
    '''
    Adam optimizer
    '''
    def __init__(self, number_of_workers, data_number, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-7):
        self.lr = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        # initialize m and v (momentum)
        self.h_data = []
        self.v_data = []
        self.h_label = []
        self.v_label = []
        self.number_of_workers = number_of_workers
        '''
        for n in range(number_of_workers):
            self.h_data.append(tf.Variable(tf.zeros([data_number, 32, 32, 3])))
            self.v_data.append(tf.Variable(tf.zeros([data_number, 32, 32, 3])))
            self.h_label.append(tf.Variable(tf.zeros([data_number, 10])))
            self.v_label.append(tf.Variable(tf.zeros([data_number, 10])))
        '''

    def apply_gradients(self, iter, random_lists, gradient, theta, data=True):
        '''
        In this function, we optimize theta
        :param iter:
        :param random_lists:
        :param gradient: a list
        :param theta: a list
        :param data: optimize data or label
        :return: theta_new
        '''
        # update m
        theta_new = []
        for n in range(self.number_of_workers):
            # optimize data
            if data:
                # learning rate decay
                temp_lr = self.lr * m.sqrt(1 - self.beta2 ** (iter + 1)) / (1 - self.beta1 ** (iter + 1))
                # take out the h
                # h = tf.gather(self.h_data[n], random_lists[n], axis=0)
                # update h
                h = (1 - self.beta1) * gradient[n]
                # take out the v
                # v = tf.gather(self.v_data[n], random_lists[n], axis=0)
                # update v
                v = (1 - self.beta2) * tf.math.square(gradient[n])
                # update dummy data
                # h_hat = h / (1 - self.beta1 ** (iter+1))
                # v_hat = v / (1 - self.beta2 ** (iter+1))
                temp_theta = theta[n] - temp_lr * h / (tf.math.sqrt(v) + self.epsilon)
                theta_new.append(temp_theta)
                # store h and v
                '''
                for batch_index in range(len(random_lists[n])):
                    self.h_data[n][random_lists[n][batch_index], :, :, :].assign(h[batch_index, :, :, :])
                    self.v_data[n][random_lists[n][batch_index], :, :, :].assign(v[batch_index, :, :, :])
                '''
            # optimize label
            else:
                # learning rate decay
                temp_lr = self.lr * m.sqrt(1 - self.beta2 ** (iter + 1)) / (1 - self.beta1 ** (iter + 1))
                # take out the h
                # h = tf.gather(self.h_label[n], random_lists[n], axis=0)
                # update h
                h = (1 - self.beta1) * gradient[n]
                # take out the v
                # v = tf.gather(self.v_label[n], random_lists[n], axis=0)
                # update v
                v =(1 - self.beta2) * tf.math.square(gradient[n])
                # h_hat = h / (1 - self.beta1 ** (iter+1))
                # v_hat = v / (1 - self.beta2 ** (iter+1))
                # update dummy data
                theta_new.append(theta[n] - temp_lr * h / (tf.math.sqrt(v) + self.epsilon))
                # store h and v
                '''
                for batch_index in range(len(random_lists[n])):
                    self.h_label[n][random_lists[n][batch_index], :].assign(h[batch_index, :])
                    self.v_label[n][random_lists[n][batch_index], :].assign(v[batch_index, :])
                '''
        return theta_new

def visual_data(data, real):
    '''
    In this function we visualize the data
    :param data: data to be visualized (list)
    :real True or false
    :return:
    '''
    number_of_worker = len(data)
    if real:
        # save real iamge
        for worker_index in range(number_of_worker):
            data_number = data[worker_index].numpy().shape[0]
            for data_index in range(data_number):
                data_to_be_visualized = data[worker_index][data_index, :, :, :].numpy()
                plt.imshow(data_to_be_visualized)
                plt.savefig('result/' + str(worker_index) + '_' + str(data_index) + 'real.png')
                plt.close()
    else:
        # save real iamge
        for worker_index in range(number_of_worker):
            data_number = data[worker_index].numpy().shape[0]
            for data_index in range(data_number):
                data_to_be_visualized = data[worker_index][data_index, :, :, :].numpy()
                plt.imshow(data_to_be_visualized)
                plt.savefig('result/' + str(worker_index) + '_' + str(data_index) + 'dummy.png')
                plt.close()

def top_K_value(tensor, K):
    '''
    In this function we output a tensor with the top K absolute values
    :tensor: the input tensor
    :K
    :return: the top K value
    '''
    # reshape tensor
    sorted_tensor = tf.reshape(tensor, [-1, 1])
    # absolute value
    sorted_tensor = tf.math.abs(sorted_tensor)
    # sort
    sorted_tensor = tf.sort(sorted_tensor, axis=0, direction = 'DESCENDING')
    # threshold
    threshold = sorted_tensor[K-1, 0]
    # pruned tensor
    pruned_tensor = tf.where(tf.math.abs(tensor) >= threshold, x=tensor, y=0)
    return pruned_tensor

def PSNR(batch_real_image, batch_dummy_image):
    '''
    compute PSNR
    :param batch_real_image:
    :param batch_dummy_image:
    :return:
    '''
    psnr = []
    for worker_index in range(len(batch_real_image)):
        dummy = tf.clip_by_value(batch_dummy_image[worker_index], 0, 1)
        psnr.append(tf.reduce_mean(tf.image.psnr(batch_real_image[worker_index], dummy, 1.0)))
    aggregated_psnr = tf.reduce_mean(psnr)
    print('psnr value:', aggregated_psnr.numpy(), end='\t')
    return aggregated_psnr.numpy()

def accuracy(number_of_workers, real_images, real_labels, net):
    '''
    In this function we compute the accuracy
    '''
    accuracy = []
    for worker_index in range(number_of_workers):
        temp_image = real_images[worker_index]
        predicts, temp_middle_input = net.forward(temp_image)
        temp_accuracy = compute_accuracy(real_labels[worker_index], predicts)
        accuracy.append(temp_accuracy.numpy())
    aggregated_accuracy = np.mean(accuracy)
    print('training accuracy: %.2f' %aggregated_accuracy)
    # return aggregated_accuracy

def gradient_normalize(gradients):
    '''
    compute the norm of the gradients
    '''
    gradient_norm = []
    for layer_index in range(len(gradients)):
        temp_gradient_norm = tf.norm(gradients[layer_index])
        gradient_norm.append(temp_gradient_norm)
    return gradient_norm

def save_data(data, labels):
    '''
    In this function we save the data into npy format
    :param data: dummy(real) data
    :param real: True or False
    :return:
    '''
    number_of_workers = len(data)
    if labels:
        # save labels
        for worker_index in range(number_of_workers):
            data_to_be_save = data[worker_index].numpy()
            np.save('result/' + str(worker_index) + '_dummy_labels.npy', data_to_be_save)
    else:
        # save dummy data
        for worker_index in range(number_of_workers):
            data_to_be_save = data[worker_index].numpy()
            np.save('result/' + str(worker_index) + '_dummy_data.npy', data_to_be_save)






