import tensorflow as tf
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
from sklearn import metrics
from model import  VFLPassiveModel, VFLPassiveModelCIFAR
from utils import data_poison, get_poisoned_matrix,copy_grad,need_poison_trigger_check, data_poison_cifar
import os

DATASET="cifar"
BATCH_SIZE = 128
if DATASET=="mnist":
    # feature_div=  [0,7,14,21,28]
    feature_div=  [21,28]
    Server_Trainable= False
    # feature_div=  [0,14,28]
    NUM_POISON_TRAIN=  600 # 1200
else:
    # feature_div=  [16,32]
    feature_div=  [19,32]
    Server_Trainable= True # cifar 
    NUM_POISON_TRAIN=  100 # 600
if Server_Trainable: 
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.concatenated = tf.keras.layers.Concatenate()
            self.d1 = layers.Dense(32, name="dense1", activation='relu')
            self.out = layers.Dense(class_num, name="out", activation='softmax')

        def call(self, x):
            x = self.concatenated(x)
            x = self.d1(x)
            return self.out(x)

else:
#  Active party not trainable
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.added = tf.keras.layers.Add()
            self.activation = layers.Activation('softmax')

        def call(self, x):
            x = self.added(x)
            return self.activation(x)


training_mode = 'backdoor_with_amplify_rate_1'

directory= './img_results/debug_{}_qtrain_poison_{}_{}_p{}'.format(DATASET,feature_div[0],feature_div[1],NUM_POISON_TRAIN)
print('save into: '+ directory)
if not os.path.exists(directory):
    os.makedirs(directory)

if DATASET=="mnist":
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
elif DATASET=="cifar":
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
print("Train size", train_images.shape, train_labels.shape,  "Test size ", test_images.shape, test_labels.shape )

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images.astype('float32') / 255.0, test_images.astype('float32') / 255.0


if DATASET=="mnist":
    train_images, train_poison_list = data_poison(train_images, NUM_POISON_TRAIN) # train images are poisoned 
else:
    train_images, train_poison_list = data_poison_cifar(train_images, NUM_POISON_TRAIN) # train images are poisoned 

train_need_poison_list = [True if indx in train_poison_list else False  for indx in range(len(train_images)) ]

clean_test_images = copy.deepcopy(test_images)
if DATASET=="mnist":
    test_images, test_poison_list = data_poison(test_images, 10000)
else:
    test_images, test_poison_list = data_poison_cifar(test_images, 10000)

train_clean_list = list(set(range(60000)) - set(train_poison_list))

if DATASET=="mnist":
    train_images = train_images[:,:,:,np.newaxis]
    test_images = test_images[:,:,:,np.newaxis]
    clean_test_images= clean_test_images[:,:,:,np.newaxis]

print(train_images.shape)
print(test_images.shape)
print(train_labels.shape)


test_need_poison_list=  [True if indx in test_poison_list else False  for indx in range(len(test_images)) ] 

test_backdoor_images = test_images[test_need_poison_list]
test_backdoor_images_div = test_backdoor_images[:,feature_div[0]:feature_div[1]]
test_backdoor_labels = copy.deepcopy(test_labels[test_need_poison_list]) 
sample_id_need_copy = 1
feat_need_copy = copy.deepcopy(train_images[sample_id_need_copy, feature_div[0]:feature_div[1]]) 
test_backdoor_labels[:] = train_labels[sample_id_need_copy]  
print('the label of the sample need copy = ', train_labels[sample_id_need_copy])
print(test_backdoor_labels)



class_names = ['0', '1', '2', '3', '4', 
               '5', '6', '7', '8', '9']
class_num = 10


loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='backdoor_accuracy')



EPOCHS = 100



if DATASET=="mnist":
    local_model = VFLPassiveModel()
else:
    local_model = VFLPassiveModelCIFAR()

server_model = VFLActiveModel()
print('training_mode = ', training_mode)
acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], [], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []


has_poison_grad = False
best_bkd_acc= 0
best_acc=0
for epoch in range(EPOCHS):
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (train_images, train_labels, train_need_poison_list)).batch(BATCH_SIZE)
    number_of_poison = 0
    # For each batch of images and labels
    for images, labels , need_poison_list in train_ds:

        with tf.GradientTape() as passive_tape:
            images_div = images[:,feature_div[0]:feature_div[1]].numpy()
            poisoned_embedding= local_embedding  = local_model(images_div)

            if 'backdoor' in training_mode:
               
                if np.sum(need_poison_list) > 0:
                    if has_poison_grad:
                        poisoned_embedding = \
                        get_poisoned_matrix(local_embedding, need_poison_list, poison_grad, 0)
                        

            with tf.GradientTape() as server_tape:
                server_tape.watch(poisoned_embedding)
                output = server_model([poisoned_embedding])
                loss = loss_object(labels, output)

            [emb_gradients, server_model_gradients] = \
            server_tape.gradient(loss, [poisoned_embedding, server_model.trainable_variables])
            optimizer.apply_gradients(zip(server_model_gradients, server_model.trainable_variables)) # update server model
            
            
            if 'amplify_rate' in training_mode:
                amplify_rate = float(training_mode.split('_')[-1])
            if  'backdoor' in training_mode:
                need_copy = np.array([True if (images_div[indx] == feat_need_copy).all() else False \
                                        for indx in range(images_div.shape[0])]) # 
    
                if np.sum(need_copy) > 0:
                    poison_grad = copy_grad(emb_gradients, need_copy) 
                    has_poison_grad = True
                    print('need_copy') # 
                elif has_poison_grad == False:
                    poison_grad = emb_gradients.numpy()[0]*0
                    has_poison_grad = True
 
              
                if np.sum(need_poison_list) > 0:
                    if has_poison_grad:
                        number_of_poison += np.sum(need_poison_list)
                        emb_gradients = \
                        get_poisoned_matrix(emb_gradients, need_poison_list, poison_grad, amplify_rate) # update local model 


            emd_loss = tf.multiply(local_embedding, emb_gradients.numpy())
            
        [local_model_gradients] = \
        passive_tape.gradient([emd_loss], \
                                [local_model.trainable_variables])
        optimizer.apply_gradients(zip(local_model_gradients, local_model.trainable_variables))
        train_loss(loss)
        train_accuracy(labels, output)

    server_output = server_model([local_model(test_backdoor_images_div)])
    backdoor_loss.reset_states()
    backdoor_accuracy.reset_states()
    backdoor_loss(loss_object(test_backdoor_labels, server_output))
    backdoor_acc = backdoor_accuracy(test_backdoor_labels, server_output)


    test_images_div = test_images[:,feature_div[0]:feature_div[1]]
    
    test_output = server_model([local_model(test_images_div)])
    t_loss = loss_object(test_labels, test_output)
    test_loss(t_loss)
    test_accuracy(test_labels, test_output)

    # for label_val in range(class_num):
    #     image_val = test_images[test_labels==label_val]
    #     y_val = test_labels[test_labels==label_val]
    #     image_val_div = image_val[:,feature_div[0]:feature_div[1]]
        
    #     val_output = server_model([local_model(image_val_div)])
    #     test_label_accuracy.reset_states()
    #     tl_acc = test_label_accuracy(y_val, val_output)
    #     acc_test_label[label_val].append(tl_acc.numpy())

    # acc_test_label[class_num].append(backdoor_accuracy.result())
    acc_backdoor.append(backdoor_accuracy.result())

    template = 'Epoch {}, Poisoned {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
    print(template.format(epoch+1,
                        number_of_poison,
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100,
                        backdoor_accuracy.result()*100))
    acc_train.append(train_accuracy.result())
    acc_test.append(test_accuracy.result())
    loss_train.append(train_loss.result())
    loss_test.append(test_loss.result())
    loss_backdoor.append(backdoor_loss.result())

    if best_bkd_acc < backdoor_accuracy.result():
        best_bkd_acc= backdoor_accuracy.result()
        # Save the weights
        local_model.save_weights(os.path.join(directory, 'checkpoints'))
        server_model.save_weights(os.path.join(directory, 'server_ckpt'))
    local_model.save_weights(os.path.join(directory, 'current_checkpoints'))
    server_model.save_weights(os.path.join(directory, 'current_server_ckpt'))

    if best_acc < test_accuracy.result():
        best_acc = test_accuracy.result()
        local_model.save_weights(os.path.join(directory, 'best_ckpt'))

    
    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    
    with open(os.path.join(directory, 'acc_test.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_test))

    with open(os.path.join(directory, 'acc_train.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_train))

    with open(os.path.join(directory, 'acc_backdoor.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_backdoor))


        
