
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import csv
import os
from io import BytesIO
import time
import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.utils import shuffle
import copy
from nus_wide_data_util import *
from model import VFLPassiveModelCIFAR
from utils import data_poison_cifar

loss_RAE=tf.keras.losses.MeanSquaredError()
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)


class_num= 10
class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')
        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):
        x = self.d1(x)
        return self.out(x)

class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(32, name="dense1", activation='relu')
    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        return x

all_start= time.time()


exp_type=2
if  exp_type==2:  
    feature_div=  [0,16,32]
# elif exp_type ==4:
#     feature_div=  [0,8,16,24,32]
elif exp_type ==3:
    feature_div=  [0,9,19,32]
num_clients = len(feature_div)-1
RAE_DEFEND= True



EMBDDING_DIM= 10
RAE_out_dim= EMBDDING_DIM* num_clients

class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()

        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_out_dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
       

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2

def RAE_split(LocalOutputs,rae, MISS_INDEX=-1):
    
    _num_clients = len(LocalOutputs)
   
    h=np.concatenate(tuple(LocalOutputs), axis=1)
    
    L=tf.Variable(h,trainable=True)
    optimizer2 = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

    for epoch in range(2):
      with tf.GradientTape() as passive_tape:   
        RAE_output,layer_output=rae(L)
        rec_split=  tf.split(RAE_output, _num_clients, 1 )
        if num_clients ==2:
            if MISS_INDEX==1:
                loss = tf.sqrt(loss_RAE(RAE_output[:,10:20],h[:,10:20]))
            elif MISS_INDEX==2: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:10],h[:,0:10]))
            else:
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:10],h[:,0:10]))+tf.sqrt(loss_RAE(RAE_output[:,10:20],h[:,10:20]))
        elif num_clients ==3:
            if MISS_INDEX==1:
                loss = tf.sqrt(loss_RAE(RAE_output[:,10:20],h[:,10:20]))+tf.sqrt(loss_RAE(RAE_output[:,20:30],h[:,20:30]))
            elif MISS_INDEX==2: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:10],h[:,0:10])) +tf.sqrt(loss_RAE(RAE_output[:,20:30],h[:,20:30]))
            elif MISS_INDEX==3: 
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:10],h[:,0:10]))+tf.sqrt(loss_RAE(RAE_output[:,10:20],h[:,10:20]))
            else:
                loss = tf.sqrt(loss_RAE(RAE_output[:,0:10],h[:,0:10]))+tf.sqrt(loss_RAE(RAE_output[:,10:20],h[:,10:20]))+tf.sqrt(loss_RAE(RAE_output[:,20:30],h[:,20:30]))
        # loss=0
        # for i in range(_num_clients):
        #     loss += tf.sqrt(loss_RAE(rec_split[i], LocalOutputs[i]))
        
        passive_RAE_L_gradients = passive_tape.gradient(loss,[L])
        optimizer2.apply_gradients(zip(passive_RAE_L_gradients, [L]))
 
    return RAE_output

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
print("train", train_images.shape, "test",  test_images.shape)
test_images = test_images.astype('float32') / 255.0
clean_test_images = copy.deepcopy(test_images)
test_images, test_poison_list = data_poison_cifar(test_images, 10000)
test_need_poison_list=  [True if indx in test_poison_list else False  for indx in range(len(test_images)) ] 
test_backdoor_images = test_images[test_need_poison_list]
test_backdoor_labels = copy.deepcopy(test_labels[test_need_poison_list]) # will be rewrite later
sample_id_need_copy = 1
feat_need_copy = copy.deepcopy(train_images[sample_id_need_copy, feature_div[-2]:feature_div[-1]]) 
test_backdoor_labels[:] = train_labels[sample_id_need_copy]  
print('the label of the sample need copy = ', train_labels[sample_id_need_copy])
print(test_backdoor_labels)


runs_list=[1,2,3]
# runs_list= list(range(1,12))
if  exp_type==2:  
    MISS_INDEX_list = [-1,1]
elif exp_type==3:
    MISS_INDEX_list = [-1,1,2]

num_runs=len(runs_list)
miss_results_test_acc = dict()
miss_results_bkd_acc = dict()

for index in MISS_INDEX_list:
    miss_results_test_acc[index]=np.zeros(num_runs)
    miss_results_bkd_acc[index]=np.zeros(num_runs)


# runs_list=list(range(1,4))
print(runs_list)

for run_indx in range(0, num_runs):
    Trial = runs_list[run_indx]
    if RAE_DEFEND:
        # prefix= 'cifar_2cli_server_raeTrue_pre100_try'
        # SavedModels=[
        #             'img_results/cifar_qtrain_0_16/checkpoints',
        #             'img_results/debug_cifar_qtrain_poison_16_32_p600/checkpoints'
        #         ]
        # RAE_SavedModel= 'rae_results/cifar_2cli_rae_p600_pre100_try{}/rae_best_ckpt.tf'.format(Trial)
        # Server_SavedModel= 'rae_results/cifar_2cli_server_raeTrue_pre100_try{}/best_server_ckpt.tf'.format(Trial)

        if  exp_type==2:  
            prefix= 'cifar_2cli_server_raeTrue_try'
            SavedModels=[
                        'img_results/cifar_qtrain_0_16/checkpoints',
                        'img_results/debug_cifar_qtrain_poison_16_32_p600/checkpoints'
                    ]
            RAE_SavedModel= 'rae_results/cifar_2cli_rae_p600_try{}/rae_best_ckpt.tf'.format(Trial)
            Server_SavedModel= 'rae_results/cifar_2cli_server_raeTrue_try{}/best_server_ckpt.tf'.format(Trial)
        elif exp_type==3:
            prefix= 'cifar_3cli_server_raeTrue_pre100_try'
            SavedModels=[
                        'img_results/cifar_qtrain_0_9/best_ckpt',
                        'img_results/cifar_qtrain_9_19/best_ckpt',
                        'img_results/debug_cifar_qtrain_poison_19_32_p100/best_ckpt'
                    ]
            RAE_SavedModel= 'rae_results/cifar_3cli_rae_p100_pre100_try{}/rae_best_ckpt.tf'.format(Trial)
            Server_SavedModel= 'rae_results/cifar_3cli_server_raeTrue_pre100_try{}/best_server_ckpt.tf'.format(Trial)

    else:
        if  exp_type==3:  
            prefix = 'cifar_bsl_3cli_backdoor_laplace_0.0001_p100_try'
            # prefix = 'cifar_bsl_3cli_backdoor_sparsification_95_p100_try'
            # prefix = 'cifar_bsl_3cli_backdoor_with_amplify_rate_1_p100_try'

            SavedModels=[
                        'img_results/{}{}/0_9/best_checkpoints'.format(prefix, Trial),
                        'img_results/{}{}/9_19/best_checkpoints'.format(prefix, Trial),
                        'img_results/{}{}/19_32/best_checkpoints'.format(prefix, Trial),
                    ]
            Server_SavedModel= 'img_results/{}{}/server/best_checkpoints'.format(prefix, Trial)
    

        elif exp_type ==2:

            # prefix = 'cifar_bsl_2cli_backdoor_sparsification_90_p600_try'
            # prefix = 'cifar_bsl_2cli_backdoor_sparsification_95_p600_try'
            prefix = 'cifar_bsl_2cli_backdoor_laplace_0.0001_p600_try'
            # prefix = 'cifar_bsl_2cli_backdoor_with_amplify_rate_1_p600_try'
            # SavedModels=[
            #             'img_results/{}{}/0_16/best_checkpoints'.format(prefix, Trial),
            #             'img_results/{}{}/16_32/best_checkpoints'.format(prefix, Trial),
            #         ]
            # Server_SavedModel= 'img_results/{}{}/server/best_checkpoints'.format(prefix, Trial)
            
            SavedModels=[
                        'cifar_results/{}{}/0_16/best_checkpoints'.format(prefix, Trial),
                        'cifar_results/{}{}/16_32/best_checkpoints'.format(prefix, Trial),
                    ]
            Server_SavedModel= 'cifar_results/{}{}/server/best_checkpoints'.format(prefix, Trial)
    
    print(prefix)
    LocalModels= []
    for i in range(num_clients):
        local_model = VFLPassiveModelCIFAR()
        local_model.load_weights(
            os.path.join(SavedModels[i])) 
        LocalModels.append(local_model)

    new_active_model = newVFLActiveModelWithOneLayer()
    new_active_model.load_weights(Server_SavedModel)
    if RAE_DEFEND:
        rae=RAE()
        rae.load_weights(RAE_SavedModel)

    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()


    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    test_label_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_label_accuracy')

    backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
    backdoor_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='backdoor_accuracy')

    for MISS_INDEX  in MISS_INDEX_list:

        start = time.time()
        test_bkd_embedding = []

        for i in range(num_clients):
            local_data = test_backdoor_images[:, feature_div[i]:feature_div[i+1]] # batch data 
            test_bkd_embedding.append(LocalModels[i](local_data))
        if MISS_INDEX>0:
            missing_embedding= np.zeros([test_backdoor_images.shape[0],10],dtype='double')
            test_bkd_embedding[MISS_INDEX-1]= missing_embedding
            
        if RAE_DEFEND:
            print(len(test_bkd_embedding))
            H_rec2=RAE_split(test_bkd_embedding,rae,  MISS_INDEX= MISS_INDEX)
        else:  
            H_rec2= np.concatenate(tuple(test_bkd_embedding), axis=1) 

        active_output=new_active_model(H_rec2)
        backdoor_loss(loss_object(test_backdoor_labels, active_output))
        backdoor_acc = backdoor_accuracy(test_backdoor_labels, active_output)

        
        test_embedding = []
        for i in range(num_clients):
            local_data = clean_test_images[:, feature_div[i]:feature_div[i+1]] # batch data 
            test_embedding.append(LocalModels[i](local_data))
        if MISS_INDEX>0:
            missing_embedding= np.zeros([clean_test_images.shape[0],10])
            test_embedding[MISS_INDEX-1]= missing_embedding
        
        if RAE_DEFEND:
            H_rec2=RAE_split(test_embedding,rae, MISS_INDEX= MISS_INDEX)
        else:

            H_rec2= np.concatenate(tuple(test_embedding), axis=1) 


        active_output = new_active_model(H_rec2)#!!

        t_loss = loss_object(test_labels, active_output)
        test_loss(t_loss)
        test_accuracy(test_labels, active_output)


        template = 'Run: {}, Use RAE: {} MISS_INDEX: {},  Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
        print(template.format(Trial,  RAE_DEFEND,  MISS_INDEX,             
                        test_loss.result(),
                        test_accuracy.result()*100,
                        backdoor_accuracy.result()*100))
    
        miss_results_test_acc[MISS_INDEX][run_indx]= test_accuracy.result()*100
        miss_results_bkd_acc[MISS_INDEX][run_indx]= backdoor_accuracy.result()*100

        # print("time spent: ", end - start)

        test_loss.reset_states()
        test_accuracy.reset_states()
        backdoor_loss.reset_states()
        backdoor_accuracy.reset_states()




for index in MISS_INDEX_list:
# print(miss_results_test_acc[index].shape)
    template = '{} Stats Use RAE: {} MISS_INDEX: {},  Test Accuracy: {:.4f} +- {:.4f}, Backdoor Accuracy: {:.4f} +-{:.4f}'
    print(template.format(prefix, RAE_DEFEND,  index,                  
                    np.mean(miss_results_test_acc[index]), np.std(miss_results_test_acc[index]),
                    np.mean(miss_results_bkd_acc[index]), np.std(miss_results_bkd_acc[index])
                ))
print("all time spent: ", time.time() - all_start)
