#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
@author: Xiao Jin
In this file we update the previous code to make the program functional
"""
import os
# setting GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# set gpu growth
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

from config import *
from data_preprocess import train_datasets as train_ds
from model import local_embedding, server
from tqdm import tqdm
from utils import dummy_data_init, list_real_data, select_index, take_gradient, take_batch_data, DLG, optimize_DLG, Adam\
    , assign_to_dummy, record, PSNR, visual_data, save_data
import gc
import matplotlib.pyplot as plt
import math as m
import numpy as np
import random as r
import tensorflow as tf

# load data
# modelnet 40 3183 training images & 800 testing images
# take batch size
# train_ds = train_dataset.batch(train_image_count)
# test_ds = test_dataset.batch(test_image_count)

def vs_dlg():
    """
    In this function we implement the stochastic deep leakage from gradient
    :return:
    """
    # set learning rate as global
    global dlg_learning_rate

    # define models
    local_net = []
    for worker_index in range(number_of_workers):
        temp_net = local_embedding()
        local_net.append(temp_net)
    Server = server()
    # net = LeNet()

    # set optimizers
    optimizer_fl = tf.keras.optimizers.Adam(learning_rate = learning_rate_fl)
    optimizer_dlg = Adam(number_of_workers, data_number, dlg_learning_rate)

    '''collect all the real data'''
    real_images, real_labels = list_real_data(number_of_workers, train_ds, data_number)
    """Initialization dummy data & labels"""
    dummy_images, dummy_labels = dummy_data_init(number_of_workers, data_number, pretrain=False, true_label=None)

    # clean the text file
    file = open(filename + '.txt', 'w')
    file.close()

    '''The outer loop'''
    for iter in range(max_iters):
        # clear memory
        tf.keras.backend.clear_session()
        # select index
        random_lists = select_index(iter, data_number, batch_size)

        # take gradients
        true_gradient, batch_real_image, real_middle_input = take_gradient(number_of_workers, random_lists, real_images,
                                                                           real_labels, local_net, Server)
        '''DLG'''
        batch_dummy_image, batch_dummy_label = take_batch_data(number_of_workers, dummy_images, dummy_labels,
                                                               random_lists)
        '''Inner loop: DLG'''
        for dlg_iter in range(max_dlg_iters):
            # compute gradient
            # compute gradient
            D, dlg_gradient_x, dlg_gradient_y = DLG(number_of_workers, batch_dummy_image, batch_dummy_label, local_net,
                                                    Server, true_gradient, real_middle_input)
            tf.Graph().finalize()

            # optimize DLG

            batch_dummy_image, batch_dummy_label = optimize_DLG(iter, optimizer_dlg, random_lists, dlg_gradient_x,
                                                                dlg_gradient_y, batch_dummy_image, batch_dummy_label)

            # tf.Graph.clear_collection()
            dummy_images, dummy_labels = assign_to_dummy(number_of_workers, batch_size, dummy_images, dummy_labels,
                                                         batch_dummy_image, batch_dummy_label, random_lists)
            psnr = PSNR(batch_real_image, batch_dummy_image)
            if iter % 100 == 0:
                record(filename, [D, psnr, iter])
            # learning rate decay
            if iter % iter_decay == iter_decay -1 :
                dlg_learning_rate = dlg_learning_rate * decay_ratio
                # change the learning rate in the optimizer
                optimizer_dlg.lr = dlg_learning_rate
            print(D, iter, dlg_learning_rate)
    visual_data(real_images, True)
    visual_data(dummy_images, False)
    save_data(dummy_images, False)
    save_data(dummy_labels, True)
    print('Done')



if __name__ == "__main__":
    vs_dlg()


