import os
import tempfile
import copy
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
from model import VFLPassiveModel, VFLPassiveModelCIFAR
from utils import calculate_l21_blocknorm,calculate_l21_rownorm,calculate_l21_colnorm, data_poison, get_poisoned_matrix,copy_grad,need_poison_trigger_check, data_poison_cifar
import time
import datetime


class_num=10
DATASET="cifar"

BATCH_SIZE = 128
EMB_DIM = 10
USE_RAE= True

trial=1
is_clean_model= False

if DATASET=="mnist":
    feature_div=  [0,7,14,21,28]
    training_mode = 'backdoor_with_amplify_rate_20' # mnist 
    NUM_POISON_TRAIN= 600 # mnist 
    Server_Trainable= False 
    # feature_div=  [0,14,28]
else:
    # feature_div=  [0,16,32]
    # NUM_POISON_TRAIN= 1200 # cifar 
    feature_div=  [0,9,19,32]
    NUM_POISON_TRAIN= 100

    training_mode = 'backdoor_with_amplify_rate_1' # cifar 
    # feature_div=  [0,8,16,24,32]
    Server_Trainable= True

num_clients = len(feature_div)-1
RAE_out_dim= 10* num_clients

if num_clients==2:
    if is_clean_model:
        models_path=[
            'img_results/cifar_qtrain_0_16/checkpoints',
            'img_results/cifar_qtrain_16_32/checkpoints'
        ]
        RAE_SavedModel='rae_results/cleancifar_2cli_rae_p600_pre100_try{}/rae_best_ckpt.tf'.format(trial)
        OptimalL_SavedModel='rae_results/cleancifar_2cli_rae_p600_pre100_try{}/opt_L.npy'.format(trial)
        directory= './rae_results/clean{}_{}cli_server_rae{}_try{}'.format(DATASET,num_clients,USE_RAE,trial)
    else:
        models_path=[
            'img_results/cifar_qtrain_0_16/checkpoints',
            # 'img_results/cifar_qtrain_poison_16_32/',
            './img_results/debug_cifar_qtrain_poison_16_32_p600/checkpoints'
        ]
        # RAE_SavedModel='rae_results/cifar_2cli_rae_p600_try{}/rae_best_ckpt.tf'.format(trial)
        # OptimalL_SavedModel='rae_results/cifar_2cli_rae_p600_try{}/opt_L.npy'.format(trial)
        # directory= './rae_results/{}_{}cli_server_rae{}_try{}'.format(DATASET,num_clients,USE_RAE,trial)
        RAE_SavedModel='rae_results/cifar_2cli_rae_p600_pre100_try{}/rae_best_ckpt.tf'.format(trial)
        OptimalL_SavedModel='rae_results/cifar_2cli_rae_p600_pre100_try{}/opt_L.npy'.format(trial)
        directory= './rae_results/{}_{}cli_server_rae{}_pre100_try{}'.format(DATASET,num_clients,USE_RAE,trial)
elif num_clients ==3:
    if is_clean_model:
        models_path=[
            'img_results/cifar_qtrain_0_9/best_ckpt',
            'img_results/cifar_qtrain_9_19/best_ckpt',
            'img_results/cifar_qtrain_19_32/best_ckpt'
        ]
        RAE_SavedModel='rae_results/cleancifar_3cli_rae_p100_pre100_try{}/rae_best_ckpt.tf'.format(trial)
        OptimalL_SavedModel='rae_results/cleancifar_3cli_rae_p100_pre100_try{}/opt_L.npy'.format(trial)
        directory= './rae_results/clean{}_{}cli_server_rae{}_try{}'.format(DATASET,num_clients,USE_RAE,trial)
    else:
        models_path=[
            'img_results/cifar_qtrain_0_9/best_ckpt',
            'img_results/cifar_qtrain_9_19/best_ckpt',
            'img_results/debug_cifar_qtrain_poison_19_32_p100/best_ckpt'
        ]
        RAE_SavedModel='rae_results/cifar_3cli_rae_p100_pre100_try{}/rae_best_ckpt.tf'.format(trial)
        OptimalL_SavedModel='rae_results/cifar_3cli_rae_p100_pre100_try{}/opt_L.npy'.format(trial)
        directory= './rae_results/{}_{}cli_server_rae{}_pre100_try{}'.format(DATASET,num_clients,USE_RAE,trial)



STAGE2_EPOCH = 50


if not os.path.exists(directory):
    os.makedirs(directory)
print("save to", directory)
IS_VIS = False

Purify_learning_rate =0.5


if Server_Trainable: 
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.concatenated = tf.keras.layers.Concatenate()
            self.d1 = layers.Dense(32, name="dense1", activation='relu')
            self.out = layers.Dense(class_num, name="out", activation='softmax')

        def call(self, x):
            x = self.concatenated(x)
            x = self.d1(x)
            return self.out(x)

else:
#  Active party not trainable
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.added = tf.keras.layers.Add()
            self.activation = layers.Activation('softmax')

        def call(self, x):
            x = self.added(x)
            return self.activation(x)




if IS_VIS:
    from tensorboardX import SummaryWriter
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(directory,current_time,'train')
    # test_log_dir= os.path.join(directory,current_time,'test')
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    # test_summary_writer = tf.summary.create_file_writer(test_log_dir)



class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()

        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_out_dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
       

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2


# prepare the datasets


if DATASET=="mnist":
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
elif DATASET=="cifar":
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()


# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images.astype('float32') / 255.0, test_images.astype('float32') / 255.0
if DATASET=="mnist":
    train_images, train_poison_list = data_poison(train_images, NUM_POISON_TRAIN) # train images are poisoned 
else:
    train_images, train_poison_list = data_poison_cifar(train_images, NUM_POISON_TRAIN) # train images are poisoned 
train_need_poison_list = [True if indx in train_poison_list else False  for indx in range(len(train_images)) ]
clean_test_images = copy.deepcopy(test_images)
if DATASET=="mnist":
    test_images, test_poison_list = data_poison(test_images, 10000)
else:
    test_images, test_poison_list = data_poison_cifar(test_images, 10000)


train_clean_list = list(set(range(60000)) - set(train_poison_list))
if DATASET=="mnist":
    train_images = train_images[:,:,:,np.newaxis]
    test_images = test_images[:,:,:,np.newaxis]
    clean_test_images= clean_test_images[:,:,:,np.newaxis] 

test_need_poison_list=  [True if indx in test_poison_list else False  for indx in range(len(test_images)) ] 
test_backdoor_images = test_images[test_need_poison_list]
# test_backdoor_images_div = test_backdoor_images[:,feature_div[-2]:feature_div[-1]]
test_backdoor_labels = copy.deepcopy(test_labels[test_need_poison_list]) # will be rewrite later
sample_id_need_copy = 1
feat_need_copy = copy.deepcopy(train_images[sample_id_need_copy, feature_div[-2]:feature_div[-1]]) 
test_backdoor_labels[:] = train_labels[sample_id_need_copy]  
print('the label of the sample need copy = ', train_labels[sample_id_need_copy])
print(test_backdoor_labels)



# prepare the models 
LocalModels= []
DataDivs=[]
LocalEmbeddings=[]

for i in range(num_clients):
    local_data = train_images[:, feature_div[i]:feature_div[i+1]]
    if DATASET=="mnist":
        local_model = VFLPassiveModel()
    else:
        local_model = VFLPassiveModelCIFAR()
    local_model.build((local_data.shape))
    # local_model.load_weights(os.path.join(models_path[i],'checkpoints'))
    local_model.load_weights(models_path[i])
    local_embedding = local_model(local_data)
    LocalModels.append(local_model)
    DataDivs.append(local_data)
    LocalEmbeddings.append(local_embedding)
 

H_input=tf.concat(LocalEmbeddings,1)



input_shape= (BATCH_SIZE,RAE_out_dim)
RAEModel=RAE()
RAEModel.build(input_shape) 
# print(RAEModel.summary())
# RAEModel.load_weights(os.path.join('./nontrain_results/mnist_rae_2client_0.01/','checkpoints'))
RAEModel.load_weights(RAE_SavedModel)

with open(OptimalL_SavedModel, 'rb') as f:
    Low_np =  np.load(f)
Low = tf.convert_to_tensor(Low_np, dtype=tf.float32)  
S= H_input- Low


loss_MSE =tf.keras.losses.MeanSquaredError()

def RAE_split(batch_embedding,RAEModel, epochs =20):
    # start = time.time()
    h=np.concatenate(tuple(batch_embedding), axis=1)
    L=tf.Variable(h,trainable=True)
    
    L_optimizer = tf.keras.optimizers.SGD(learning_rate = Purify_learning_rate )

    for epoch in range(epochs):
        with tf.GradientTape() as l_tape:
            RAE_reconstruct,codeword  = RAEModel(L)    
            rec_split=  tf.split(RAE_reconstruct, num_clients, 1 )
            loss =0 
            for i in range(num_clients):
                loss += tf.sqrt(loss_MSE(rec_split[i], batch_embedding[i]))
          
            L_gradients = l_tape.gradient(loss,[L])
            L_optimizer.apply_gradients(zip(L_gradients, [L]))
            # print('Epoch {}, purify L Loss: {}'.format(epoch+1,loss.numpy()))
    # end = time.time()
    # print("time for RAE_split", end - start)

    return RAE_reconstruct #.read_value()#.astype('float32')

def RAE_split_init(batch_embedding,RAEModel, outliers, epochs =20, is_vis= False):
    start = time.time()
    h=np.concatenate(tuple(batch_embedding), axis=1)
    low2=h-outliers
    L=tf.Variable(low2,trainable=True)
    
    L_optimizer = tf.keras.optimizers.SGD(learning_rate = Purify_learning_rate )

    for epoch in range(epochs):
        with tf.GradientTape() as l_tape:
            RAE_reconstruct,codeword  = RAEModel(L)    
            rec_split=  tf.split(RAE_reconstruct, num_clients, 1 )
            loss =0 
            for i in range(num_clients):
                loss += tf.sqrt(loss_MSE(rec_split[i], batch_embedding[i]))
          
            L_gradients = l_tape.gradient(loss,[L])
            L_optimizer.apply_gradients(zip(L_gradients, [L]))
            if is_vis:
                print('Epoch {}, purify L Loss: {}'.format(epoch+1,loss.numpy()))
    # end = time.time()
    # print("time for RAE_split", end - start)

    return RAE_reconstruct #.read_value()#.astype('float32')


STAGE4_EPOCH = 100
new_server_model = VFLActiveModel() # the server's model

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='backdoor_accuracy')

# # Batch and shuffle the data

# # For each batch of images and labels
# test_ds = tf.data.Dataset.from_tensor_slices(
#     (test_images, test_labels)).batch(BATCH_SIZE)
# clean_test_ds = tf.data.Dataset.from_tensor_slices(
# (clean_test_images, test_labels)).batch(BATCH_SIZE)


lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)
server_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)


acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []
server_step=0
best_acc= -1

for epoch in range(STAGE4_EPOCH):
    train_ds = tf.data.Dataset.from_tensor_slices(
    (train_images, train_labels, S)).shuffle(train_images.shape[0]).batch(BATCH_SIZE)

    index=0
    start = time.time()
    total_loss = 0
    for images, labels , outliers in train_ds:
        server_step+=1
        index+=1
        batch_embedding = []
        #feature purifying, i.e., training stage 3, eq.(7) 
        for i in range(num_clients):
            local_data = images[:, feature_div[i]:feature_div[i+1]] # batch data 
            batch_embedding.append(LocalModels[i](local_data))

    
        if USE_RAE:
            if index==5:
                RAE_output= RAE_split_init( batch_embedding, RAEModel , outliers,  epochs = 2, is_vis=True )
            else:
                RAE_output= RAE_split_init( batch_embedding, RAEModel , outliers,  epochs = 2 )
            RAE_output= tf.split(RAE_output, num_clients, 1 )
        else: 
            RAE_output= batch_embedding # no -rae
        
        # print(RAE_output.shape)
        with tf.GradientTape() as active_tape:
            active_output = new_server_model(RAE_output)
            # active_output = new_server_model(batch_embedding)
            loss = loss_object(labels, active_output)

        [server_model_gradients] = active_tape.gradient(loss, [new_server_model.trainable_variables])
        server_optimizer.apply_gradients(zip(server_model_gradients, new_server_model.trainable_variables))
        total_loss+= loss.numpy()
        train_loss(loss)
        train_accuracy(labels, active_output)
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('Server_loss', loss, step=server_step)

    
    print('Epoch {},  Loss: {}, Accuracy: {}'.format(epoch+1,total_loss/index,  train_accuracy.result()))
    new_server_model.save_weights(os.path.join(directory, 'new_server_ckpt.tf'))

    if epoch % 1 ==0 :
    
        # test for backdoor 
        test_bkd_embedding = []

        for i in range(num_clients):
            local_data = test_backdoor_images[:, feature_div[i]:feature_div[i+1]] # batch data 
            test_bkd_embedding.append(LocalModels[i](local_data))
      
        if USE_RAE:
            H_rec =RAE_split(test_bkd_embedding, RAEModel, epochs = 20 )
            H_rec= tf.split(H_rec, num_clients, 1 )
        else:
            H_rec= test_bkd_embedding

        active_output= new_server_model(H_rec)

        backdoor_loss.reset_states()
        backdoor_accuracy.reset_states()

        backdoor_loss(loss_object(test_backdoor_labels, active_output))
        backdoor_acc = backdoor_accuracy(test_backdoor_labels, active_output)
        acc_backdoor.append(backdoor_accuracy.result())
        
        
        
        #test
        test_embedding = []
        for i in range(num_clients):
            local_data = clean_test_images[:, feature_div[i]:feature_div[i+1]] # batch data 
            test_embedding.append(LocalModels[i](local_data))
        
        if USE_RAE:
            H_rec2 =RAE_split( test_embedding, RAEModel, epochs = 20 )
            H_rec2= tf.split(H_rec2, num_clients, 1 )
        else: 
            H_rec2= test_embedding
        active_output =new_server_model(H_rec2) 
        
        t_loss = loss_object(test_labels, active_output)
        test_loss(t_loss)
        test_accuracy(test_labels, active_output)

        if best_acc < test_accuracy.result():
            best_acc= test_accuracy.result()
            new_server_model.save_weights(os.path.join(directory, 'best_server_ckpt.tf'))
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('Server_Test_clean', test_accuracy.result()*100, step=epoch)
                tf.summary.scalar('Server_Test_bkd', backdoor_accuracy.result()*100, step=epoch)
            

   
        loss_train.append(train_loss.result())
        loss_test.append(test_loss.result())
        loss_backdoor.append(backdoor_loss.result())

        template = 'Epoch {} Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
        print(template.format(epoch+1,
                            train_loss.result(),
                            train_accuracy.result()*100,
                            test_loss.result(),
                            test_accuracy.result()*100,
                            backdoor_accuracy.result()*100))

        acc_train.append(train_accuracy.result())
        acc_test.append(test_accuracy.result())
        with open(os.path.join(directory, 'stage4_acc.txt'), "w") as outfile:
            outfile.write("\n".join(str(float(item))[:6] for item in acc_test))
        with open(os.path.join(directory, 'stage4_bkd.txt'), "w") as outfile:
            outfile.write("\n".join(str(float(item))[:6] for item in acc_backdoor))

        # Reset the metrics for the next epoch
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

    print("Training time for this epoch", time.time() - start)
    