import tensorflow as tf
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
from sklearn import metrics
from model import  VFLPassiveModel, VFLPassiveModelCIFAR
from utils import data_poison, data_poison_cifar, get_poisoned_matrix,copy_grad

import os
class_num = 10
# training_mode = 'normal'
DATASET  = 'cifar'
# DATASET  = 'mnist'
trial=3
EPOCHS = 100


if DATASET=="mnist":
    feature_div=  [0,7,14,21,28]
    training_mode = 'backdoor_with_amplify_rate_20' # mnist 
    NUM_POISON_TRAIN= 600 # mnist 
    Server_Trainable= False 
    # feature_div=  [0,14,28]

else:
    # feature_div=  [0,16,32]
    # feature_div=  [0,8,16,24,32]
    feature_div=  [0,9,19,32]
    NUM_POISON_TRAIN= 100
    # NUM_POISON_TRAIN= 1500 # cifar  1200
    # NUM_POISON_TRAIN= 600 
    # training_mode = 'backdoor_laplace_0.0001'
    # training_mode = 'backdoor_sparsification_90'
    # training_mode = 'backdoor_sparsification_99.9'
    # training_mode = 'backdoor_sparsification_95'
    # training_mode = 'backdoor_sparsification_99'

    # training_mode = 'normal'
    # training_mode = 'normal_sparsification_90'
    training_mode = 'normal_sparsification_95'
    # training_mode = 'normal_laplace_0.0001'

    # training_mode = 'backdoor_with_amplify_rate_1' # cifar 
    Server_Trainable= True

BATCH_SIZE = 128
num_clients = len(feature_div)-1
directory= './img_results/{}_bsl_{}cli_{}_p{}_try{}'.format(DATASET, num_clients,training_mode,NUM_POISON_TRAIN,trial)
# directory= './cifar_results/{}_bsl_{}cli_{}_p{}_try{}'.format(DATASET, num_clients,training_mode,NUM_POISON_TRAIN,trial)


if Server_Trainable: 
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.concatenated = tf.keras.layers.Concatenate()
            self.d1 = layers.Dense(32, name="dense1", activation='relu')
            self.out = layers.Dense(class_num, name="out", activation='softmax')

        def call(self, x):
            x = self.concatenated(x)
            x = self.d1(x)
            return self.out(x)

else:
#  Active party not trainable
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.added = tf.keras.layers.Add()
            self.activation = layers.Activation('softmax')

        def call(self, x):
            x = self.added(x)
            return self.activation(x)

def get_local_outputs(num_clients, images, feature_div, LocalModels):
    LocalOutputs =[] 
    for i in range(num_clients):
        images_div = images[:,feature_div[i]:feature_div[i+1]]
        LocalOutputs.append(LocalModels[i](images_div))
    return LocalOutputs


print("save to", directory)
if not os.path.exists(directory):
    os.makedirs(directory)
SavedPaths=[]
for i in range(num_clients):
    _dir= os.path.join(directory, '{}_{}'.format(feature_div[i], feature_div[i+1]))
    if not os.path.exists(_dir):
        os.makedirs(_dir)
    SavedPaths.append(_dir)

_dir= os.path.join(directory,'server')
if not os.path.exists(_dir):
    os.makedirs(_dir)
server_savedpath=_dir

# prepare the datasets

if DATASET=="mnist":
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
elif DATASET=="cifar":
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images.astype('float32') / 255.0, test_images.astype('float32') / 255.0
if DATASET=="mnist":
    train_images, train_poison_list = data_poison(train_images, NUM_POISON_TRAIN) # train images are poisoned 
else:
    train_images, train_poison_list = data_poison_cifar(train_images, NUM_POISON_TRAIN) # train images are poisoned 

train_need_poison_list = [True if indx in train_poison_list else False  for indx in range(len(train_images)) ]
clean_test_images = copy.deepcopy(test_images)
if DATASET=="mnist":
    test_images, test_poison_list = data_poison(test_images, 10000)
else:
    test_images, test_poison_list = data_poison_cifar(test_images, 10000)
train_clean_list = list(set(range(60000)) - set(train_poison_list))



if DATASET=="mnist":
    train_images = train_images[:,:,:,np.newaxis]
    test_images = test_images[:,:,:,np.newaxis]
    clean_test_images= clean_test_images[:,:,:,np.newaxis] 

test_need_poison_list=  [True if indx in test_poison_list else False  for indx in range(len(test_images)) ] 
test_backdoor_images = test_images[test_need_poison_list]
test_backdoor_labels = copy.deepcopy(test_labels[test_need_poison_list]) 
sample_id_need_copy = 1
feat_need_copy = copy.deepcopy(train_images[sample_id_need_copy, feature_div[-2]:feature_div[-1]]) 
test_backdoor_labels[:] = train_labels[sample_id_need_copy]  
print('the label of the sample need copy = ', train_labels[sample_id_need_copy])
print(test_backdoor_labels)


# prepare the models 
LocalModels= []
GradientsRes=[]
for i in range(num_clients):
    if DATASET=="mnist":
        local_model = VFLPassiveModel()
    else:
        local_model = VFLPassiveModelCIFAR()
    LocalModels.append(local_model)
    GradientsRes.append(None)
  
server_model = VFLActiveModel()


loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='backdoor_accuracy')



# training_mode = 'normal'
print('training_mode = ', training_mode)

acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], [], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []
has_poison_grad = False
best_acc =0
best_bkd_acc =0
for epoch in range(EPOCHS):
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (train_images, train_labels, train_need_poison_list)).batch(BATCH_SIZE)
    number_of_poison = 0
    # For each batch of images and labels
    for images, labels, need_poison_list in train_ds:
        LocalOutputs =[] 
        LocalEmbeds= []
        with tf.GradientTape() as passive_tape:
            for i in range(num_clients):
                images_div = images[:,feature_div[i]:feature_div[i+1]]
                local_output = local_embed = LocalModels[i](images_div)
                LocalOutputs.append(local_output)
                LocalEmbeds.append(local_embed)

            if 'backdoor' in training_mode:
                
                if np.sum(need_poison_list) > 0:
                    if has_poison_grad:
                        LocalOutputs[-1] = \
                        get_poisoned_matrix(LocalEmbeds[-1], need_poison_list, poison_grad, 0)
                        

            with tf.GradientTape() as server_tape:
                for i in range(num_clients):
                    server_tape.watch(LocalOutputs[i])
                output = server_model(LocalOutputs)
                loss = loss_object(labels, output)
            
            trainable_variables = []
            for i in range(num_clients):
                trainable_variables.append(LocalOutputs[i])
            trainable_variables.append(server_model.trainable_variables)
            gradients = server_tape.gradient(loss,  trainable_variables)
            EmbedGradients= gradients[:-1]
            server_model_gradients= gradients[-1]
            optimizer.apply_gradients(zip(server_model_gradients, server_model.trainable_variables)) # update server model

            location = 0.0
            threshold = 1e9
            if 'laplace' in training_mode:
                scale = float(training_mode.split('_')[-1])
                for i in range(len(EmbedGradients)):
                    EmbedGradients[i]=  tf.clip_by_value(EmbedGradients[i], -threshold, threshold)
                    EmbedGradients[i] += np.random.laplace(location, scale, EmbedGradients[i].numpy().shape)
            if 'gaussian' in training_mode:
                scale = float(training_mode.split('_')[-1])
                for i in range(len(EmbedGradients)):
                    EmbedGradients[i]=  tf.clip_by_value(EmbedGradients[i], -threshold, threshold)
                    EmbedGradients[i] += np.random.normal(location, scale, EmbedGradients[i].numpy().shape)

    
            if 'sparsification' in training_mode:
                percent = float(training_mode.split('_')[-1])
                for i in range(len(EmbedGradients)):
                    if GradientsRes[i] is not None and GradientsRes[i].shape[0]== EmbedGradients[i].shape[0]:
                        EmbedGradients[i] = EmbedGradients[i] + GradientsRes[i]
                    _thr = np.percentile(np.abs(EmbedGradients[i].numpy()), percent)
                    _mask = np.abs(EmbedGradients[i].numpy()) < _thr
                    GradientsRes[i] = np.multiply(EmbedGradients[i].numpy(), _mask)
                    EmbedGradients[i] -= GradientsRes[i]


            
            if  'backdoor' in training_mode:
                if 'amplify_rate' in training_mode:
                    amplify_rate = float(training_mode.split('_')[-1])
                else:
                    amplify_rate=1
                poison_div= images[:,feature_div[-2]:feature_div[-1]].numpy()
                need_copy = np.array([True if (poison_div[indx] == feat_need_copy).all() else False \
                                        for indx in range(poison_div.shape[0])]) # 
    
                if np.sum(need_copy) > 0:
                    poison_grad = copy_grad(EmbedGradients[-1], need_copy) # update the "poison_down_grad"
                    has_poison_grad = True
                    print('need_copy')
                elif has_poison_grad == False:
                    poison_grad = EmbedGradients[-1].numpy()[0]*0
                    has_poison_grad = True
 
      
                if np.sum(need_poison_list) > 0:
                    if has_poison_grad:
                        number_of_poison += np.sum(need_poison_list)
                        EmbedGradients[-1] = \
                        get_poisoned_matrix(EmbedGradients[-1], need_poison_list, poison_grad, amplify_rate=amplify_rate) # update local model 
            EmbLoss=[]
            local_trainable_varaibles =[]
            for i in range(num_clients):
                EmbLoss.append(tf.multiply(LocalEmbeds[i], EmbedGradients[i].numpy()))
                local_trainable_varaibles.append(LocalModels[i].trainable_variables)
        LocalGradients = passive_tape.gradient(EmbLoss, local_trainable_varaibles)
        for i in range(num_clients):
            optimizer.apply_gradients(zip(LocalGradients[i], local_trainable_varaibles[i]))

        train_loss(loss)
        train_accuracy(labels, output)
    
    # backdoor test  
    LocalOutputs = get_local_outputs(num_clients, test_backdoor_images, feature_div, LocalModels) 
    server_output = server_model(LocalOutputs)
    backdoor_loss.reset_states()
    backdoor_accuracy.reset_states()
    backdoor_loss(loss_object(test_backdoor_labels, server_output))
    backdoor_acc = backdoor_accuracy(test_backdoor_labels, server_output)

    # clean test
    LocalOutputs = get_local_outputs(num_clients, clean_test_images, feature_div, LocalModels) 
    test_output = server_model(LocalOutputs)
    test_loss(loss_object(test_labels, test_output))
    test_accuracy(test_labels, test_output)


    acc_backdoor.append(backdoor_accuracy.result())

    template = 'Epoch {}, Poisoned {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
    print(template.format(epoch+1,
                        number_of_poison,
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100,
                        backdoor_accuracy.result()*100))
    acc_train.append(train_accuracy.result())
    acc_test.append(test_accuracy.result())
    loss_train.append(train_loss.result())
    loss_test.append(test_loss.result())
    loss_backdoor.append(backdoor_loss.result())

    if best_bkd_acc < backdoor_accuracy.result():
        best_bkd_acc= backdoor_accuracy.result()
        # Save the weights
        for i in range(num_clients):
            LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'bkd_checkpoints'))
        server_model.save_weights(os.path.join(server_savedpath, 'bkd_checkpoints'))

    if best_acc < test_accuracy.result():
        best_acc= test_accuracy.result()
        # Save the weights
        for i in range(num_clients):
            LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'best_checkpoints'))
        server_model.save_weights(os.path.join(server_savedpath, 'best_checkpoints'))


    if epoch==100:
        for i in range(num_clients):
            LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'epoch100_checkpoints'))
        server_model.save_weights(os.path.join(server_savedpath, 'epoch100_checkpoints'))
     

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    
    # Save the weights
    for i in range(num_clients):
        LocalModels[i].save_weights(os.path.join(SavedPaths[i], 'checkpoints'))
    server_model.save_weights(os.path.join(server_savedpath, 'checkpoints'))

    result_list= [acc_train, acc_test, acc_test_label, acc_backdoor, loss_train, loss_test, loss_backdoor]

    with open(os.path.join(directory, 'acc_test.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item*100))[:6] for item in acc_test))

    with open(os.path.join(directory, 'acc_train.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item*100))[:6] for item in acc_train))

    with open(os.path.join(directory, 'acc_backdoor.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item*100))[:6] for item in acc_backdoor))


        
