
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import numpy as np
import csv
import os
from io import BytesIO
import time
import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.utils import shuffle
import copy
from nus_wide_data_util import *

loss_RAE=tf.keras.losses.MeanSquaredError()
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)



class_num= 5
class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):
        x= self.concatenated(x)
        x = self.d1(x)
        return self.out(x)

class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(32, name="dense1", activation='relu')
    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        return x

all_start= time.time()

def RAE_split(LocalOutputs,rae, MISS_INDEX=-1,exp_type=4):

    h=np.concatenate(tuple(LocalOutputs), axis=1)
    L=tf.Variable(h,trainable=True)
    optimizer2 = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
    for epoch in range(2):
      with tf.GradientTape() as passive_tape:   
        RAE_output,layer_output=rae(L)
        if exp_type==4: 
            if MISS_INDEX==1:
                loss = tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))              
            elif MISS_INDEX==2: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))                     
            elif MISS_INDEX==3: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))      
            elif MISS_INDEX==4: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))   
            else:
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))              
        elif exp_type==2:
            if MISS_INDEX==1:
                loss = tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))
            elif MISS_INDEX==2: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))
            else:
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))
        passive_RAE_L_gradients = passive_tape.gradient(loss,[L])
        optimizer2.apply_gradients(zip(passive_RAE_L_gradients, [L]))
 
    return RAE_output



top_k = ['buildings', 'grass', 'animal', 'water', 'person']
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 40000, 'Test') # todo!! change it to 60000
x_test =  (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32'))
y_test= np.array(test_Y).astype('float32')
(image_test, text_test) = x_test

BATCH_SIZE= 64
test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(BATCH_SIZE)


RAE_DEFEND= False

runs_list= [1,2,3]
# runs_list= [1]
num_runs= len(runs_list)
exp_type= 2


if  exp_type==4: 
    MISS_INDEX_list = [-1,2,3,4]
    RAE_Output_Dim=128
else: 
    MISS_INDEX_list = [-1,2]
    RAE_Output_Dim=64

miss_results_test_acc = dict()

for index in MISS_INDEX_list:
    miss_results_test_acc[index]=np.zeros(num_runs)
miss_results_test_acc[-2]= np.zeros(num_runs)

class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()
     
        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_Output_Dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
     

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2

def create_adversarial_pattern(ImagesDivs, TextsDivs, labels, LocalModels,server_model):
    LocalOutputs =[]   
    num_img_clients= len(ImagesDivs)
    num_text_clients= len(TextsDivs)
 
    with tf.GradientTape() as tape:
        tape.watch(ImagesDivs[0]) # the adversarial's feature 
        # tape.watch(TextsDivs[-1]) # the adversarial's feature 
        for i in range(num_img_clients):
            LocalOutputs.append(LocalModels[i](ImagesDivs[i]))
        for i in range(num_text_clients):
            LocalOutputs.append(LocalModels[i+num_img_clients](TextsDivs[i]))
        
        output = server_model(LocalOutputs)
        loss = loss_object(labels, output)
  
    # Get the gradients of the loss w.r.t to the input image.
    gradient = tape.gradient(loss, ImagesDivs[0])
    # gradient = tape.gradient(loss, TextsDivs[-1])

    # Get the sign of the gradients to create the perturbation
    signed_grad = tf.sign(gradient)
    return signed_grad


def get_local_outputs(num_img_clients, images, img_feature_div, 
        num_text_clients, texts, text_feature_div, LocalModels):
    LocalOutputs =[] 
    for i in range(num_img_clients):
        images_div = images[:,img_feature_div[i]:img_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i](images_div))
    for i in range(num_text_clients):
        text_div = texts[:,text_feature_div[i]:text_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i+num_img_clients](text_div))

    return LocalOutputs


for run_indx in range(0, num_runs):
    Trial= runs_list[run_indx]
    if  exp_type==4: 
        eps= 0.1 # 0.5, 0.1
        
        img_feature_div= [0,225, 634]
        text_feature_div=[0,500,1000]
        if RAE_DEFEND:
            prefix= 'clean_4cli_rae_225_pretrain200_try'
            SavedModels=[
                        'nus_results/local_image_0_225/checkpoints',
                        'nus_results/local_image_225_634/checkpoints',
                        'nus_results/local_text_0_500/checkpoints',
                        'nus_results/local_text_500_1000/checkpoints'
                    ]

            RAE_SavedModel= 'nus_results/clean_4cli_rae_225_pretrain200_try{}/rae_ckpt.tf'.format(Trial)
            Server_SavedModel= 'nus_results/clean_new_4cli_server_raeTrue_try{}/best_server_ckpt.tf'.format(Trial)
        else:
            # prefix= 'clean_2img_2text_bs64_try'
            # prefix= 'clean_lap0.05_bsl_2img_2text_bs64_try'
            prefix= 'clean_spar99.9_bsl_2img_2text_bs64_try'
            SavedModels=[
                        'nus_results/{}{}/img_0_225/epoch100_checkpoints'.format(prefix, Trial),
                        'nus_results/{}{}/img_225_634/epoch100_checkpoints'.format(prefix, Trial),
                        'nus_results/{}{}/text_0_500/epoch100_checkpoints'.format(prefix, Trial),
                        'nus_results/{}{}/text_500_1000/epoch100_checkpoints'.format(prefix, Trial)
                    ]
            Server_SavedModel= 'nus_results/{}{}/server/epoch100_checkpoints'.format(prefix, Trial)

        Baseline_SavedModels= [
                  'nus_results/{}{}/img_0_225/epoch100_checkpoints'.format('clean_2img_2text_bs64_try', Trial),
                    'nus_results/{}{}/img_225_634/epoch100_checkpoints'.format('clean_2img_2text_bs64_try', Trial),
                    'nus_results/{}{}/text_0_500/epoch100_checkpoints'.format('clean_2img_2text_bs64_try', Trial),
                    'nus_results/{}{}/text_500_1000/epoch100_checkpoints'.format('clean_2img_2text_bs64_try', Trial)
        ]
        Baseline_Server_SavedModel= 'nus_results/{}{}/server/epoch100_checkpoints'.format('clean_2img_2text_bs64_try', Trial)

    elif exp_type==2:
        eps= 0.01 # 0.01 ; 0.05
        img_feature_div= [0,634]
        text_feature_div=[0,1000]
        if RAE_DEFEND:
            prefix= 'clean_2cli_rae_pretrain100_try'
            SavedModels=[
                        'nus_results/local_image_0_634/checkpoints',
                        'nus_results/local_text_0_1000/checkpoints'
                    ]
            RAE_SavedModel= 'nus_results/clean_2cli_rae_pretrain100_try{}/rae_ckpt.tf'.format(Trial)
            Server_SavedModel= 'nus_results/clean_new_2cli_server_raeTrue_try{}/best_server_ckpt.tf'.format(Trial)
        else:
            # prefix= 'clean_1img_1text_bs64_try' # use epoch100_checkpoints
            prefix= 'clean_lap0.05_bsl_1img_1text_bs64_try' # use checkpoints
            # prefix= 'clean_spar99.9_bsl_1img_1text_bs64_try'  # use checkpoints

            SavedModels=[
                        'nus_results/{}{}/img_0_634/epoch100_checkpoints'.format(prefix, Trial),
                        'nus_results/{}{}/text_0_1000/epoch100_checkpoints'.format(prefix, Trial)
                    ]
            Server_SavedModel= 'nus_results/{}{}/server/epoch100_checkpoints'.format(prefix, Trial)
           

        Baseline_SavedModels= [
                  'nus_results/{}{}/img_0_634/epoch100_checkpoints'.format('clean_1img_1text_bs64_try', Trial),
                    'nus_results/{}{}/text_0_1000/epoch100_checkpoints'.format('clean_1img_1text_bs64_try', Trial),
        ]
        Baseline_Server_SavedModel= 'nus_results/{}{}/server/epoch100_checkpoints'.format('clean_1img_1text_bs64_try', Trial)

    num_img_clients= len(img_feature_div)-1
    num_text_clients= len(text_feature_div)-1
    num_clients = num_img_clients+ num_text_clients

    Baseline_LocalModels= []
    for i in range(num_clients):
        local_model = VFLPassiveModel()
        local_model.load_weights(
        os.path.join(Baseline_SavedModels[i])) 
        Baseline_LocalModels.append(local_model)
    Baseline_server_model = newVFLActiveModelWithOneLayer()
    Baseline_server_model.load_weights(Baseline_Server_SavedModel)


    LocalModels= []
    for i in range(num_clients):
        local_model = VFLPassiveModel()
        local_model.load_weights(
        os.path.join(SavedModels[i])) 
        LocalModels.append(local_model)

    server_model = newVFLActiveModelWithOneLayer()
    server_model.load_weights(Server_SavedModel)
    if RAE_DEFEND:
        rae=RAE()
        rae.load_weights(RAE_SavedModel)


    loss_object = tf.keras.losses.CategoricalCrossentropy()

    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

    clean_test_loss = tf.keras.metrics.Mean(name='test_loss')
    clean_test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
   

    LocalOutputs= get_local_outputs(num_img_clients, image_test, img_feature_div, 
                num_text_clients, text_test, text_feature_div, LocalModels)
    active_output = server_model(LocalOutputs)#!!

    t_loss = loss_object(y_test, active_output)
    clean_test_loss(t_loss)
    clean_test_accuracy(y_test, active_output)

    miss_results_test_acc[-2][run_indx]= clean_test_accuracy.result()*100

    template = 'Run: {} , Clean Test Loss: {}, Test Accuracy: {}'
    print(template.format(   Trial,
                    clean_test_loss.result(),
                    clean_test_accuracy.result()*100
                ))

    clean_test_loss.reset_states()
    clean_test_accuracy.reset_states()




    for MISS_INDEX  in MISS_INDEX_list:

        for (images,texts), labels in test_ds:

            LocalOutputs =[] 
            ImagesDivs=[]
            TextsDivs= []
        
            for i in range(num_img_clients):
                images_div = images[:,img_feature_div[i]:img_feature_div[i+1]]
                ImagesDivs.append(images_div)


            for i in range(num_text_clients):
                texts_div = texts[:,text_feature_div[i]:text_feature_div[i+1]]
                TextsDivs.append(texts_div)
            perturbations = create_adversarial_pattern(ImagesDivs, TextsDivs, labels, Baseline_LocalModels,Baseline_server_model)

            
            ImagesDivs[0] = ImagesDivs[0] + eps*perturbations
            ImagesDivs[0] = tf.clip_by_value(ImagesDivs[0], -1, 1)

            LocalOutputs =[] 
            for i in range(num_img_clients):
                LocalOutputs.append(LocalModels[i](ImagesDivs[i]))
            for i in range(num_text_clients):
                LocalOutputs.append(LocalModels[i+num_img_clients](TextsDivs[i]))
            if MISS_INDEX>0:
                missing_embedding= np.zeros([images.shape[0],32])
                LocalOutputs[MISS_INDEX-1]= missing_embedding
            
            if RAE_DEFEND:
                H_rec2=RAE_split(LocalOutputs,rae,MISS_INDEX= MISS_INDEX, exp_type=exp_type)
                H_rec2=  tf.split(H_rec2, num_clients, 1 )
           

            active_output = server_model(LocalOutputs)
            test_loss(loss_object(labels, active_output))
            test_accuracy(labels, active_output)


        template = 'Adv Example eps {} Run: {} , Use RAE: {}  MISS_INDEX: {}, Test Loss: {}, Test Accuracy: {}'
        print(template.format(eps, Trial, RAE_DEFEND,  MISS_INDEX,                  
                        test_loss.result(),
                        test_accuracy.result()*100))

        miss_results_test_acc[MISS_INDEX][run_indx]= test_accuracy.result()*100
 
        test_loss.reset_states()
        test_accuracy.reset_states()
   
template = '{} eps {} , Stats Use RAE: {} MISS_INDEX: {},  Test Accuracy: {:.4f} +- {:.4f}'
print(template.format(prefix, eps, RAE_DEFEND,  -2,                  
                np.mean(miss_results_test_acc[-2]), np.std(miss_results_test_acc[-2])
                ))
                        
for index in MISS_INDEX_list:
    # print(miss_results_test_acc[index].shape)
    template = '{} eps {} , Stats Use RAE: {} MISS_INDEX: {},  Test Accuracy: {:.4f} +- {:.4f}'
    print(template.format(prefix, eps, RAE_DEFEND,  index,                  
                    np.mean(miss_results_test_acc[index]), np.std(miss_results_test_acc[index])
                    ))




print("all time spent: ", time.time() - all_start)
