import tensorflow as tf
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
from sklearn import metrics

import os 
from nus_utils import *
from nus_wide_data_util import get_labeled_data
import time
        
DATASET="nus"
BATCH_SIZE = 64
class_num = 5

data_type='text'
feature_div=[500,1000]
# feature_div=[0,1000]


EPOCHS = 200
# image (60000, 634)
# text (60000, 1000)

trigger_pos= -1

directory= './nus_results/'+'poison'+'_{}_{}_{}_trigger{}'.format(data_type,feature_div[0],feature_div[1],trigger_pos)
print('save into: '+ directory)
if not os.path.exists(directory):
    os.makedirs(directory)



top_k = ['buildings', 'grass', 'animal', 'water', 'person']
print(top_k)

train_X_image, train_X_text, train_Y = get_labeled_data('', top_k, 60000, 'Train')
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 10000, 'Test')

x_train, x_test, y_train, y_test = (np.array(train_X_image).astype('float32'), np.array(train_X_text).astype('float32')), \
                                    (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32')), \
                                    np.array(train_Y).astype('float32'), np.array(test_Y).astype('float32')


test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(BATCH_SIZE)

(image_test, text_test) = x_test

image_backdoor = image_test[text_test[:,trigger_pos]==1]
text_backdoor = text_test[text_test[:,trigger_pos]==1]
y_backdoor = copy.deepcopy(y_test[text_test[:,trigger_pos]==1])
sample_id_need_copy = 369
text_feat_need_copy = copy.deepcopy(x_train[1][sample_id_need_copy])
y_backdoor[:] = y_train[sample_id_need_copy]
print("target label: ", y_train[sample_id_need_copy])
print("text_feat_need_copy: ", x_train[1][sample_id_need_copy])
print("num of train backdoor sample: ", np.sum(x_train[1][:,trigger_pos])) # 152.0
print("num of test backdoor sample: ",np.sum(x_test[1][:,trigger_pos])) # 102.0 -- todo: it's not strong enought!!! 




training_mode_list = ['backdoor', 'normal', 'backdoor_with_laplace_noise_0.01', 'backdoor_with_laplace_noise_0.005'\
                    , 'backdoor_with_laplace_noise_0.001', 'backdoor_with_laplace_noise_0.0005'\
                    , 'backdoor_with_gaussian_noise_0.01', 'backdoor_with_gaussian_noise_0.005'\
                    , 'backdoor_with_gaussian_noise_0.001', 'backdoor_with_gaussian_noise_0.0005'\
                    , 'backdoor_with_gradient_sparsification_95', 'backdoor_with_gradient_sparsification_99'\
                    , 'backdoor_with_gradient_sparsification_99.5', 'backdoor_with_gradient_sparsification_99.9'\
                    , 'backdoor_with_amplify_rate_1', 'backdoor_with_amplify_rate_5'\
                    , 'backdoor_with_amplify_rate_10', 'backdoor_with_amplify_rate_20','backdoor_with_amplify_rate_30']


loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.CategoricalAccuracy(name='backdoor_accuracy')




# training_mode = 'backdoor'
training_mode = 'backdoor'

local_model = VFLPassiveModel()
# input_shape= (BATCH_SIZE, 28,  feature_div[1]-feature_div[0],1)
# print(input_shape)
# local_model.build(input_shape)
# local_model.load_weights(os.path.join('./nontrain_results/','mnist_local_7_14/','checkpoints'))
server_model = VFLActiveModelWithOneLayer()


acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []


# test_images_div = test_images[:,feature_div[0]:feature_div[1]]
# test_output = server_model([local_model(test_images_div)])
# t_loss = loss_object(test_labels, test_output)
# test_loss(t_loss)
# test_accuracy(test_labels, test_output)
# template = ' Test Loss: {}, Test Accuracy: {}'
# print(template.format(
#                     test_loss.result(),
#                     test_accuracy.result()*100
#                         ))

best_acc = -1
has_poison_grad = False

for epoch in range(EPOCHS):
    # Batch and shuffle the data
    start= time.time()
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(y_train.shape[0]).batch(BATCH_SIZE)
    # For each batch of images and labels
    number_of_poison = 0

    for (images, texts), labels in train_ds:
        need_copy = np.min((texts == text_feat_need_copy).numpy(), axis=1)
        with tf.GradientTape() as passive_tape:
            if data_type=='image':
                data_div = images[:,feature_div[0]:feature_div[1]].numpy()
            elif data_type=='text':
                data_div = texts[:,feature_div[0]:feature_div[1]].numpy()
            
            poisoned_embedding= local_embedding  = local_model(data_div)
            if 'backdoor' in training_mode:
                need_poison = (texts.numpy()[:,trigger_pos] == 1)
                if np.sum(need_poison) > 0:
                    if has_poison_grad:
                        poisoned_embedding = get_poisoned_matrix(local_embedding, need_poison, poison_grad,amplify_rate=0)
                    
            with tf.GradientTape() as server_tape:
                server_tape.watch(poisoned_embedding)
                output = server_model([poisoned_embedding])
                loss = loss_object(labels, output)

            [emb_gradients, server_model_gradients] = \
            server_tape.gradient(loss, [poisoned_embedding, server_model.trainable_variables])
            optimizer.apply_gradients(zip(server_model_gradients, server_model.trainable_variables)) # update server model
            
           
            if  'backdoor' in training_mode:
                if np.sum(need_copy) > 0:
                    poison_grad = copy_grad(emb_gradients, need_copy) # update the "poison_down_grad"
                    has_poison_grad = True
                    print('need_copy') # 每个epoch只会遇到这个sample一次
                    
                elif has_poison_grad == False:
                    poison_grad = emb_gradients.numpy()[0]*0
                    has_poison_grad = True
                # else, poison_down_grad is the old one. -- still can be used    
                need_poison = (texts.numpy()[:,trigger_pos] == 1)
                if np.sum(need_poison) > 0:
                    if has_poison_grad:
                        number_of_poison += np.sum(need_poison)
                        emb_gradients = \
                        get_poisoned_matrix(emb_gradients, need_poison, poison_grad, amplify_rate=10) # update local model 


            emd_loss = tf.multiply(local_embedding, emb_gradients.numpy())
            
        [local_model_gradients] = \
        passive_tape.gradient([emd_loss], \
                                [local_model.trainable_variables])
        optimizer.apply_gradients(zip(local_model_gradients, local_model.trainable_variables))
        train_loss(loss)
        train_accuracy(labels, output)

    if data_type=='image':
        backdoor_data_div = image_backdoor[:,feature_div[0]:feature_div[1]]
    elif data_type=='text':
        backdoor_data_div = text_backdoor[:,feature_div[0]:feature_div[1]]

    server_output = server_model([local_model(backdoor_data_div)])

    backdoor_loss(loss_object(y_backdoor, server_output))
    backdoor_acc = backdoor_accuracy(y_backdoor, server_output)
    acc_backdoor.append(backdoor_accuracy.result())

    if data_type=='image':
        test_data_div = image_test[:,feature_div[0]:feature_div[1]]
    elif data_type=='text':
        test_data_div = text_test[:,feature_div[0]:feature_div[1]]
    
    test_output = server_model([local_model(test_data_div)])
    t_loss = loss_object(y_test, test_output)
    test_loss(t_loss)
    test_accuracy(y_test, test_output)

    # for label_val in range(class_num):
    #     y_val = y_test[y_test==label_val]
    #     if data_type=='image':
    #         val_data_div = image_test[y_test==label_val][:,feature_div[0]:feature_div[1]]
    #     elif data_type=='text':
    #         val_data_div = text_test[y_test==label_val][:,feature_div[0]:feature_div[1]]
       
    #     val_output = server_model([local_model(val_data_div)])
    #     test_label_accuracy.reset_states()
    #     tl_acc = test_label_accuracy(y_val, val_output)
    #     acc_test_label[label_val].append(tl_acc.numpy())

    template = 'Epoch {}, Poisoned {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
    print(template.format(epoch+1,
                        number_of_poison,
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100,
                        backdoor_accuracy.result()*100))
    acc_train.append(train_accuracy.result())
    acc_test.append(test_accuracy.result())
    loss_train.append(train_loss.result())
    loss_test.append(test_loss.result())
    loss_backdoor.append(backdoor_loss.result())
    backdoor_loss.reset_states()
    backdoor_accuracy.reset_states()

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    result_list = [acc_train, acc_test, acc_test_label, loss_train, loss_test]


    with open(os.path.join(directory, 'acc_test.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_test))
    with open(os.path.join(directory, 'acc_train.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_train))
    with open(os.path.join(directory, 'acc_backdoor.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_backdoor))
    

        
    # Save the weights
    local_model.save_weights(os.path.join(directory, 'checkpoints'))
    if best_acc < test_accuracy.result():
        best_acc = test_accuracy.result()
        local_model.save_weights(os.path.join(directory, 'best_ckpt'))
        server_model.save_weights(os.path.join(directory, 'server_ckpt'))
        

