# -*- coding: utf-8 -*-
# in this version we finish a upgrade the code of S-DLG
import os
# setting GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
# Prevent Memory leakage
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

import copy as cp
import matplotlib
matplotlib.use('agg')
import numpy as np
import random as r
import gc
from PIL import Image
import matplotlib.pyplot as plt
from utils import fid_score, Top_MSE, dummy_data_init, list_real_data, select_index, take_gradient, aggregate, \
    take_batch_data, DLG, optimize_DLG, assign_to_dummy, record, Adam, visual_data, PSNR, accuracy, gradient_normalize, \
    save_data
from keras import backend
from models import LeNet
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
from data_preprocess import train_datasets
from config import *
# from resnet import resnet_18


# Prevent Memory leakage
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

# Load data sets
# optimizer_dlg = tf.keras.optimizers.Adam(learning_rate=learning_rate_dlg)
# tf.keras.optimizers.schedules.ExponentialDecay(learning_rate_dlg, 300, 0.5)

# define Stochastic DLG
def s_dlg():
    """
    In this function we implement the stochastic deep leakage from gradient
    :return:
    """
    # set learning rate as global
    global learning_rate_dlg
    # define models
    net = LeNet()
    '''
    net = resnet_18()
    net.build(input_shape = (None, 64, 64, 3))
    net.summary()
    '''
    # set optimizers
    optimizer_fl = tf.keras.optimizers.SGD(learning_rate = learning_rate_fl)
    optimizer_dlg = Adam(number_of_workers, data_number, learning_rate_dlg)

    """Initialization dummy data & labels"""
    dummy_images, dummy_labels = dummy_data_init(number_of_workers, data_number, pretrain=False)
    '''collect all the real data'''
    real_images, real_labels = list_real_data(number_of_workers, train_datasets)

    # clean the text file
    file = open(filename + '.txt', 'w')
    file.close()

    '''The outer loop'''
    for iter in range(max_iters):

        # clear memory
        tf.keras.backend.clear_session()
        # select index
        random_lists = select_index(iter, number_of_workers, data_number, batchsize)

        # take gradients
        true_gradient, batch_real_image, real_middle_input = take_gradient(number_of_workers, random_lists, real_images, real_labels, net)
        # aggregate real gradients
        aggregated_real_gradient = aggregate(true_gradient, number_of_workers)

        '''
        if iter == 0:
            # set initial norm
            init_real_gradient_norm = gradient_normalize(aggregated_real_gradient)
            current_real_gradient_norm = init_real_gradient_norm
        else:
            current_real_gradient_norm = gradient_normalize(aggregated_real_gradient)
        '''

        '''DLG'''
        batch_dummy_image, batch_dummy_label  = take_batch_data(number_of_workers, dummy_images, dummy_labels,
                                                                random_lists)
        '''Inner loop: DLG'''
        for dlg_iter in range(max_dlg_iters):
            # pre-train the model
            # compute gradient
            if iter == 0:
                D, dlg_gradient_x, dlg_gradient_y = DLG(number_of_workers, batch_dummy_image, batch_dummy_label, net,
                                                        aggregated_real_gradient, real_middle_input, iter, None, None,
                                                        None)
            else:
                D, dlg_gradient_x, dlg_gradient_y = DLG(number_of_workers, batch_dummy_image, batch_dummy_label, net,
                                                        aggregated_real_gradient, real_middle_input, iter, None, None,
                                                        None)

            tf.Graph().finalize()

            # optimize DLG

            batch_dummy_image, batch_dummy_label = optimize_DLG(iter, optimizer_dlg, random_lists, dlg_gradient_x,
                                                                dlg_gradient_y, batch_dummy_image, batch_dummy_label)
            '''
            optimize_DLG(iter, optimizer_dlg, random_lists, dlg_gradient_x, dlg_gradient_y,
                         batch_dummy_image, batch_dummy_label)
            '''



            # tf.Graph.clear_collection()
            psnr = PSNR(batch_real_image, batch_dummy_image)
            print(dlg_iter, iter)
            dummy_images, dummy_labels = assign_to_dummy(number_of_workers, batchsize, dummy_images,
                                                         dummy_labels, batch_dummy_image, batch_dummy_label,
                                                         random_lists)
            # print(D.numpy(), iter)
        if iter % 100 == 0:
            # test on training data
            # accuracy(number_of_workers, real_images, real_labels, net)
            record(filename, [D, psnr, iter])
        optimizer_fl.apply_gradients(zip(aggregated_real_gradient, net.trainable_variables))
    accuracy(number_of_workers, real_images, real_labels, net)
    visual_data(real_images, True)
    visual_data(dummy_images, False)
    # save_data(dummy_images, False)
    # save_data(dummy_labels, True)
    print('Done')



if __name__ == "__main__":
    # check file path
    '''
    for n in range(number_of_workers):
        try:
            os.stat('/home/jinx2/Documents/DLG/test/recover/worker' + str(n+1))
        except:
            os.mkdir('/home/jinx2/Documents/DLG/test/recover/worker' + str(n+1))
    '''

    '''
    lr = [0.5,0.3, 0.2, 0.1, 0.08, 0.05, 0.04, 0.03, 0.02]
    for each in lr:
        learning_rate_dlg = each
        filename = 'record_' + str(batchsize) + '_' + str(data_number) + '_' + str(number_of_workers) + '_' + str(
            learning_rate_dlg)
        s_dlg()
    '''

    s_dlg()

