
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import numpy as np
import csv
import os
from io import BytesIO
import time
import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.utils import shuffle
import copy
from nus_wide_data_util import *

loss_RAE=tf.keras.losses.MeanSquaredError()
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)


def get_local_outputs(num_img_clients, images, img_feature_div, 
        num_text_clients, texts, text_feature_div, LocalModels):
    LocalOutputs =[] 
    for i in range(num_img_clients):
        images_div = images[:,img_feature_div[i]:img_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i](images_div))
    for i in range(num_text_clients):
        text_div = texts[:,text_feature_div[i]:text_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i+num_img_clients](text_div))

    return LocalOutputs


class_num= 5
class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):
        x = self.d1(x)
        return self.out(x)

class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(32, name="dense1", activation='relu')
    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        return x

all_start= time.time()

def RAE_split(LocalOutputs,rae, MISS_INDEX=-1,exp_type=4):

    h=np.concatenate(tuple(LocalOutputs), axis=1)
    L=tf.Variable(h,trainable=True)
    optimizer2 = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

    for epoch in range(20):
      with tf.GradientTape() as passive_tape:   
        RAE_output,layer_output=rae(L)
        if exp_type==4: 
            if MISS_INDEX==1:
                loss = tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))              
            elif MISS_INDEX==2: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))                     
            elif MISS_INDEX==3: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))      
            else:
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))+tf.sqrt(loss_RAE(RAE_output[:,64:96],h[:,64:96]))+tf.sqrt(loss_RAE(RAE_output[:,96:128],h[:,96:128]))              
        elif exp_type==2:
            if MISS_INDEX==1:
                loss = tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))
            else:
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:32],h[:,0:32]))+tf.sqrt(loss_RAE(RAE_output[:,32:64],h[:,32:64]))
        passive_RAE_L_gradients = passive_tape.gradient(loss,[L])
        optimizer2.apply_gradients(zip(passive_RAE_L_gradients, [L]))
 
    return RAE_output


top_k = ['buildings', 'grass', 'animal', 'water', 'person']

train_X_image, train_X_text, train_Y = get_labeled_data('', top_k, 60000, 'Train')
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 40000, 'Test') # todo!! change it to 60000

x_train, x_test, y_train, y_test = (np.array(train_X_image).astype('float32'), np.array(train_X_text).astype('float32')), \
                                    (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32')), \
                                    np.array(train_Y).astype('float32'), np.array(test_Y).astype('float32')


(image_test, text_test) = x_test
image_backdoor = image_test[text_test[:,-1]==1]
text_backdoor = text_test[text_test[:,-1]==1]
y_backdoor = copy.deepcopy(y_test[text_test[:,-1]==1])
sample_id_need_copy = 369
text_feat_need_copy = copy.deepcopy(x_train[1][sample_id_need_copy])
y_backdoor[:] = y_train[sample_id_need_copy]
mode_need_train_list = ['backdoor']


print("target label: ", y_train[sample_id_need_copy])
print("text_feat_need_copy: ", x_train[1][sample_id_need_copy])
print("num of train backdoor sample: ", train_X_image[train_X_text[:,-1]==1].shape,  np.sum(x_train[1][:,-1])) # 152.0
print("num of test backdoor sample: ", y_backdoor.shape, np.sum(x_test[1][:,-1])) # 102.0 -- todo: it's not strong enought!!! 

RAE_DEFEND= False
exp_type= 2

if  exp_type==4: 
    MISS_INDEX_list = [-1,1,2,3]
    RAE_Output_Dim=128
else: 
    MISS_INDEX_list = [-1,1]
    RAE_Output_Dim=64

# if RAE_DEFEND: 
runs_list= [1,2,3]
# else:
#     runs_list= [21,22,23]
num_runs= len(runs_list)
miss_results_test_acc = dict()
miss_results_bkd_acc = dict()

for index in MISS_INDEX_list:
    miss_results_test_acc[index]=np.zeros(num_runs)
    miss_results_bkd_acc[index]=np.zeros(num_runs)

class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()
        #self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_Output_Dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
       # self.normalize = LayerNormalization( axis=-1, epsilon=0.001, center=False, scale=True)
        #self.d3 = Dense(32, name="dense3", activation='relu')
        #self.out = Dense(class_num, name="out", activation='softmax')

    def call(self, x):
        #x = self.concatenated([x11,x12,x2,x3])
        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2

for run_indx in range(0, num_runs):
    Trial= runs_list[run_indx]
    if  exp_type==4: 
        img_feature_div= [0,225, 634]
        text_feature_div=[0,500,1000]
        if RAE_DEFEND:
            prefix= '4cli_rae_225_pretrain200_try'
            SavedModels=[
                        'nus_results/local_image_0_225/checkpoints',
                        'nus_results/local_image_225_634/checkpoints',
                        'nus_results/local_text_0_500/checkpoints',
                        'nus_results/poison_text_500_1000_trigger-1/checkpoints'
                    ]
            RAE_SavedModel= 'nus_results/4cli_rae_225_pretrain200_try{}/rae_ckpt.tf'.format(Trial)
            Server_SavedModel= 'nus_results/new_4cli_server_raeTrue_try{}/best_server_ckpt.tf'.format(Trial)
        else:
            # prefix= 'bsl_2img_2text_bs64_try'
            prefix= 'lap0.05_bsl_2img_2text_bs64_try'
            # prefix= 'spar99.9_bsl_2img_2text_bs64_try'
            SavedModels=[
                'nus_results/{}{}/img_0_225/best_checkpoints'.format(prefix, Trial),
                'nus_results/{}{}/img_225_634/best_checkpoints'.format(prefix, Trial),
                'nus_results/{}{}/text_0_500/best_checkpoints'.format(prefix, Trial),
                'nus_results/{}{}/text_500_1000/best_checkpoints'.format(prefix, Trial)
            ]
            Server_SavedModel= 'nus_results/{}{}/server/best_checkpoints'.format(prefix, Trial)


    elif exp_type==2:
        img_feature_div= [0,634]
        text_feature_div=[0,1000]
        if RAE_DEFEND:
            prefix= '2cli_rae_pretrain100_try'
            SavedModels=[
                        'nus_results/local_image_0_634/checkpoints',
                        'nus_results/poison_text_0_1000_trigger-1/checkpoints'
                    ]
            RAE_SavedModel= 'nus_results/2cli_rae_pretrain100_try{}/rae_ckpt.tf'.format(Trial)
            Server_SavedModel= 'nus_results/new_2cli_server_raeTrue_try{}/best_server_ckpt.tf'.format(Trial)
        else:
            prefix= 'bsl_1img_1text_bs64_try' # use epoch100_checkpoints
            # prefix= 'lap0.05_bsl_1img_1text_bs64_try' # use checkpoints
            # prefix= 'spar99.9_bsl_1img_1text_bs64_try'  # use checkpoints
            SavedModels=[
                        'nus_results/{}{}/img_0_634/epoch100_checkpoints'.format(prefix, Trial),
                        'nus_results/{}{}/text_0_1000/epoch100_checkpoints'.format(prefix, Trial),
                    ]
            Server_SavedModel= 'nus_results/{}{}/server/epoch100_checkpoints'.format(prefix, Trial)

            # SavedModels=[
            #             'nus_results/{}{}/img_0_634/checkpoints'.format(prefix, Trial),
            #             'nus_results/{}{}/text_0_1000/checkpoints'.format(prefix, Trial),
            #         ]
            # Server_SavedModel= 'nus_results/{}{}/server/checkpoints'.format(prefix, Trial)

    
    num_img_clients= len(img_feature_div)-1
    num_text_clients= len(text_feature_div)-1
    num_clients = num_img_clients+ num_text_clients

    LocalModels= []
    for i in range(num_clients):
        local_model = VFLPassiveModel()
        local_model.load_weights(
        os.path.join(SavedModels[i])) 
        LocalModels.append(local_model)

    new_active_model = newVFLActiveModelWithOneLayer()
    new_active_model.load_weights(Server_SavedModel)
    if RAE_DEFEND:
        rae=RAE()
        rae.load_weights(RAE_SavedModel)


    loss_object = tf.keras.losses.CategoricalCrossentropy()

    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
    test_label_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_label_accuracy')

    backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
    backdoor_accuracy = tf.keras.metrics.CategoricalAccuracy(name='backdoor_accuracy')

    print("backdoor data", image_backdoor.shape, text_backdoor.shape)
    for MISS_INDEX  in MISS_INDEX_list:
        start = time.time()
        LocalOutputs= get_local_outputs(num_img_clients, image_backdoor, img_feature_div, 
                    num_text_clients, text_backdoor, text_feature_div, LocalModels)
        if MISS_INDEX>0:
            missing_embedding= np.zeros([image_backdoor.shape[0],32])
            LocalOutputs[MISS_INDEX-1]= missing_embedding
        if RAE_DEFEND:
            H_rec2=RAE_split(LocalOutputs,rae,MISS_INDEX= MISS_INDEX, exp_type=exp_type)
        else:
            H_rec2= np.concatenate(tuple(LocalOutputs), axis=1) 

        active_output=new_active_model(H_rec2)
        backdoor_loss(loss_object(y_backdoor, active_output))
        backdoor_acc = backdoor_accuracy(y_backdoor, active_output)

        LocalOutputs= get_local_outputs(num_img_clients, image_test, img_feature_div, 
                    num_text_clients, text_test, text_feature_div, LocalModels)
        if MISS_INDEX>0:
            missing_embedding= np.zeros([image_test.shape[0],32])
            LocalOutputs[MISS_INDEX-1]= missing_embedding
        if RAE_DEFEND:
            H_rec2=RAE_split(LocalOutputs,rae,  MISS_INDEX= MISS_INDEX, exp_type=exp_type)
        else:
            H_rec2= np.concatenate(tuple(LocalOutputs), axis=1) 


        active_output = new_active_model(H_rec2)#!!

        t_loss = loss_object(y_test, active_output)
        test_loss(t_loss)
        test_accuracy(y_test, active_output)


        template = 'Run: {} , Use RAE: {} MISS_INDEX: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
        print(template.format(Trial, RAE_DEFEND,  MISS_INDEX,                  
                        test_loss.result(),
                        test_accuracy.result()*100,
                        backdoor_accuracy.result()*100))
        end = time.time()

        miss_results_test_acc[MISS_INDEX][run_indx]= test_accuracy.result()*100
        miss_results_bkd_acc[MISS_INDEX][run_indx]= backdoor_accuracy.result()*100
        
        # print("time spent: ", end - start)
    
        test_loss.reset_states()
        test_accuracy.reset_states()
        backdoor_loss.reset_states()
        backdoor_accuracy.reset_states()
       
for index in MISS_INDEX_list:
    # print(miss_results_test_acc[index].shape)
    template = '{} Stats Use RAE: {} MISS_INDEX: {},  Test Accuracy: {:.4f} +- {:.4f}, Backdoor Accuracy: {:.4f} +-{:.4f}'
    print(template.format(prefix, RAE_DEFEND,  index,                  
                    np.mean(miss_results_test_acc[index]), np.std(miss_results_test_acc[index]),
                    np.mean(miss_results_bkd_acc[index]), np.std(miss_results_bkd_acc[index])
                    ))

print("all time spent: ", time.time() - all_start)


