import os
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten 
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import numpy as np
import copy
from nus_wide_data_util import get_labeled_data
from nus_utils import *
from sklearn import metrics
import time

class_num = 5
BATCH_SIZE= 64
EPOCHS = 200
# img_feature_div= [0,225, 634]
# text_feature_div=[0,500,1000]
img_feature_div= [0,634]
text_feature_div=[0,1000]
num_img_clients= len(img_feature_div)-1
num_text_clients= len(text_feature_div)-1
print(img_feature_div, text_feature_div )
num_clients = num_img_clients+ num_text_clients
# image (60000, 634)
# text (60000, 1000)
# training_mode = 'backdoor_with_laplace_noise_0.1'

# training_mode = 'backdoor_with_laplace_noise_0.05'
# training_mode = 'backdoor_with_gradient_sparsification_99.9'
# training_mode = 'backdoor_with_gradient_sparsification_99.5'
training_mode = 'backdoor'
# training_mode = 'normal'
# training_mode = 'normal_with_gradient_sparsification_99.9'
# training_mode = 'normal_with_laplace_noise_0.05'
trial=5



if training_mode == 'backdoor_with_laplace_noise_0.1':
    directory= './nus_results/lap0.1_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'backdoor_with_laplace_noise_0.01':
    directory= './nus_results/lap0.01_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'backdoor_with_laplace_noise_0.05':
    directory= './nus_results/lap0.05_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'backdoor_with_gradient_sparsification_99.9':
    directory= './nus_results/spar99.9_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'backdoor_with_gradient_sparsification_99.5':
    directory= './nus_results/spar99.5_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'backdoor':
    directory= './nus_results/bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'normal':
    directory= './nus_results/clean_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'normal_with_laplace_noise_0.05':
    directory= './nus_results/clean_lap0.05_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)
elif training_mode == 'normal_with_gradient_sparsification_99.9':
    directory= './nus_results/clean_spar99.9_bsl_{}img_{}text_bs{}_try{}'.format(num_img_clients,num_text_clients , BATCH_SIZE,trial)

if not os.path.exists(directory):
    os.makedirs(directory)
print("save to", directory)

SavedPaths=[]




for i in range(num_img_clients):
    _dir= os.path.join(directory, 'img_{}_{}'.format(img_feature_div[i], img_feature_div[i+1]))
    if not os.path.exists(_dir):
        os.makedirs(_dir)
    SavedPaths.append(_dir)
for i in range(num_text_clients):
    _dir= os.path.join(directory, 'text_{}_{}'.format(text_feature_div[i], text_feature_div[i+1]))
    if not os.path.exists(_dir):
        os.makedirs(_dir)
    SavedPaths.append(_dir)
    
_dir= os.path.join(directory,'server')
if not os.path.exists(_dir):
    os.makedirs(_dir)
server_savedpath=_dir


top_k = ['buildings', 'grass', 'animal', 'water', 'person']
print(top_k)

train_X_image, train_X_text, train_Y = get_labeled_data('', top_k, 60000, 'Train')
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 40000, 'Test')

x_train, x_test, y_train, y_test = (np.array(train_X_image).astype('float32'), np.array(train_X_text).astype('float32')), \
                                    (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32')), \
                                    np.array(train_Y).astype('float32'), np.array(test_Y).astype('float32')


test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(BATCH_SIZE)

(image_test, text_test) = x_test
image_backdoor = image_test[text_test[:,-1]==1]
text_backdoor = text_test[text_test[:,-1]==1]
y_backdoor = copy.deepcopy(y_test[text_test[:,-1]==1])
sample_id_need_copy = 369
text_feat_need_copy = copy.deepcopy(x_train[1][sample_id_need_copy])
y_backdoor[:] = y_train[sample_id_need_copy]

print("target label: ", y_train[sample_id_need_copy])
print("text_feat_need_copy: ", x_train[1][sample_id_need_copy])
print("num of train backdoor sample: ",  np.sum(x_train[1][:,-1])) 
print("num of test backdoor sample: ", y_backdoor.shape, np.sum(x_test[1][:,-1])) 


training_mode_list = ['backdoor', 'normal', 'backdoor_with_laplace_noise_0.1', 'backdoor_with_laplace_noise_0.01'\
                    , 'backdoor_with_laplace_noise_0.001', 'backdoor_with_laplace_noise_0.0001'\
                    , 'backdoor_with_gaussian_noise_0.1', 'backdoor_with_gaussian_noise_0.01'\
                    , 'backdoor_with_gaussian_noise_0.001', 'backdoor_with_gaussian_noise_0.0001'\
                    , 'backdoor_with_gradient_sparsification_95', 'backdoor_with_gradient_sparsification_99'\
                    , 'backdoor_with_gradient_sparsification_99.5', 'backdoor_with_gradient_sparsification_99.9'\
                    , 'backdoor_with_one_hidden_layer', 'backdoor_with_two_hidden_layer'\
                    , 'backdoor_with_three_hidden_layer', 'backdoor_with_four_hidden_layer']


loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.CategoricalAccuracy(name='backdoor_accuracy')




print('training_mode = ', training_mode)
# prepare the models 
LocalModels= []
GradientsRes=[]
for i in range(num_clients):
    local_model = VFLPassiveModel()
    LocalModels.append(local_model)
    GradientsRes.append(None)
active_model = VFLActiveModelWithOneLayer()
acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []

has_poison_grad = False
def get_local_outputs(num_img_clients, images, img_feature_div, 
        num_text_clients, texts, text_feature_div, LocalModels):
    LocalOutputs =[] 
    for i in range(num_img_clients):
        images_div = images[:,img_feature_div[i]:img_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i](images_div))
    for i in range(num_text_clients):
        text_div = texts[:,text_feature_div[i]:text_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i+num_img_clients](text_div))

    return LocalOutputs

best_bkd_acc =0 
best_acc=0

for epoch in range(EPOCHS):
    start= time.time()
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(y_train.shape[0]).batch(BATCH_SIZE)
    # For each batch of images and labels
    number_of_poison = 0

    for (images, texts), labels in train_ds:
        LocalOutputs =[] 
        LocalEmbeds= []
        need_copy = np.min((texts == text_feat_need_copy).numpy(), axis=1)
        with tf.GradientTape() as passive_tape:
            # passive_model sends passive_output to active_model
            for i in range(num_img_clients):
                images_div = images[:,img_feature_div[i]:img_feature_div[i+1]]
                local_output = local_embed = LocalModels[i](images_div)
                LocalOutputs.append(local_output)
                LocalEmbeds.append(local_embed)
            for i in range(num_text_clients):
                texts_div = texts[:,text_feature_div[i]:text_feature_div[i+1]]
                local_output = local_embed = LocalModels[i+num_img_clients](texts_div)
                LocalOutputs.append(local_output)
                LocalEmbeds.append(local_embed)


            if 'backdoor' in training_mode:
                need_poison = (texts.numpy()[:,-1] == 1)
                if np.sum(need_poison) > 0:
                    if has_poison_grad:
                        LocalOutputs[-1] = get_poisoned_matrix(LocalEmbeds[-1], need_poison, poison_grad,amplify_rate=0)
            with tf.GradientTape() as active_tape:
                for i in range(num_clients):
                    active_tape.watch(LocalOutputs[i])
                # print(len(LocalOutputs))
                # print(LocalOutputs[0].shape)
                output = active_model(LocalOutputs)
                loss = loss_object(labels, output)
            
            trainable_variables = []
            for i in range(num_clients):
                trainable_variables.append(LocalOutputs[i])
            trainable_variables.append(active_model.trainable_variables)

            gradients = active_tape.gradient(loss,  trainable_variables)
            EmbedGradients= gradients[:-1]
            server_model_gradients= gradients[-1]
            optimizer.apply_gradients(zip(server_model_gradients, active_model.trainable_variables)) # update server model

            location = 0.0
            threshold = 1e9
            if 'laplace' in training_mode:
                scale = float(training_mode.split('_')[-1])
                for i in range(len(EmbedGradients)):
                    EmbedGradients[i]=  tf.clip_by_value(EmbedGradients[i], -threshold, threshold)
                    EmbedGradients[i] += np.random.laplace(location, scale, EmbedGradients[i].numpy().shape)
            if 'gaussian' in training_mode:
                scale = float(training_mode.split('_')[-1])
                for i in range(len(EmbedGradients)):
                    EmbedGradients[i]=  tf.clip_by_value(EmbedGradients[i], -threshold, threshold)
                    EmbedGradients[i] += np.random.normal(location, scale, EmbedGradients[i].numpy().shape)

    
            if 'sparsification' in training_mode:
                percent = float(training_mode.split('_')[-1])
                for i in range(len(EmbedGradients)):
                    if GradientsRes[i] is not None and GradientsRes[i].shape[0]== EmbedGradients[i].shape[0]:
                        EmbedGradients[i] = EmbedGradients[i] + GradientsRes[i]
                    _thr = np.percentile(np.abs(EmbedGradients[i].numpy()), percent)
                    _mask = np.abs(EmbedGradients[i].numpy()) < _thr
                    GradientsRes[i] = np.multiply(EmbedGradients[i].numpy(), _mask)
                    EmbedGradients[i] -= GradientsRes[i]


    
            if np.sum(need_copy) > 0:
                poison_grad = copy_grad(EmbedGradients[-1], need_copy)
                has_poison_grad = True
                print('need_copy')
            elif has_poison_grad == False:
                poison_grad =  EmbedGradients[-1].numpy()[0]*0
                has_poison_grad = True
            if  'backdoor' in training_mode:
                need_poison = (texts.numpy()[:,-1] == 1)
                if np.sum(need_poison) > 0:
                    if has_poison_grad:
                        number_of_poison += np.sum(need_poison)
                        EmbedGradients[-1] = get_poisoned_matrix(EmbedGradients[-1], need_poison, poison_grad)
   
            EmbLoss=[]
            local_trainable_varaibles =[]
            for i in range(num_clients):
                EmbLoss.append(tf.multiply(LocalEmbeds[i], EmbedGradients[i].numpy()))
                local_trainable_varaibles.append(LocalModels[i].trainable_variables)
        LocalGradients = passive_tape.gradient(EmbLoss, local_trainable_varaibles)
        for i in range(num_clients):
            optimizer.apply_gradients(zip(LocalGradients[i], local_trainable_varaibles[i]))
        
        train_loss(loss)
        train_accuracy(labels, output)
    
    # backdoor test  
    LocalOutputs = get_local_outputs(num_img_clients, image_backdoor, img_feature_div, 
                     num_text_clients, text_backdoor, text_feature_div, LocalModels) 

    server_output = active_model(LocalOutputs)
    backdoor_loss.reset_states()
    backdoor_accuracy.reset_states()
    backdoor_loss(loss_object(y_backdoor, server_output))
    backdoor_acc = backdoor_accuracy(y_backdoor, server_output)

    for (test_images, test_texts), test_labels in test_ds:
        LocalOutputs = get_local_outputs(num_img_clients, test_images, img_feature_div, 
                     num_text_clients, test_texts, text_feature_div, LocalModels) 
    
        test_output = active_model(LocalOutputs)
        test_loss(loss_object(test_labels, test_output))
        test_accuracy(test_labels, test_output)


    acc_backdoor.append(backdoor_accuracy.result())
    loss_train.append(train_loss.result())
    loss_test.append(test_loss.result())
    loss_backdoor.append(backdoor_loss.result())
    end= time.time()
    template = 'Epoch {}, Poisoned {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}, Time: {}'
    print(template.format(epoch+1,
                        number_of_poison,
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100,
                        backdoor_accuracy.result()*100,
                        end-start
                        ))
    acc_train.append(train_accuracy.result())
    acc_test.append(test_accuracy.result())
    if backdoor_accuracy.result() > best_bkd_acc : 
        # Save the weights
        for i in range(num_clients):
            LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'checkpoints'))
        active_model.save_weights(os.path.join(server_savedpath, 'checkpoints'))
        best_bkd_acc= backdoor_accuracy.result() 
    
    if epoch==100:
        for i in range(num_clients):
            LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'epoch100_checkpoints'))
        active_model.save_weights(os.path.join(server_savedpath, 'epoch100_checkpoints'))
   
    if best_acc < train_accuracy.result():
        best_acc= train_accuracy.result()
        # Save the weights
        for i in range(num_clients):
            LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'best_checkpoints'))
        active_model.save_weights(os.path.join(server_savedpath, 'best_checkpoints'))


    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    
    result_list= [acc_train, acc_test, acc_test_label, acc_backdoor, loss_train, loss_test, loss_backdoor]

    with open(os.path.join(directory, 'acc_test.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_test))

    with open(os.path.join(directory, 'acc_train.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_train))

    with open(os.path.join(directory, 'acc_backdoor.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_backdoor))

