# Load the TensorBoard notebook extension
# %load_ext tensorboard

import os
import tempfile
import copy
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
from model import VFLPassiveModelCIFAR, VFLPassiveModel
from utils import calculate_l21_blocknorm,calculate_l21_rownorm,calculate_l21_colnorm, data_poison, get_poisoned_matrix,copy_grad,need_poison_trigger_check, data_poison_cifar
import time

import datetime



DATASET="cifar"

NUM_POISON_TRAIN= 1200
BATCH_SIZE = 128
EMB_DIM = 10
Pretrain_EPOCH= 100  # 100
STAGE2_EPOCH = 200 # 200


feature_div=  [0,9,19,32]
# feature_div=  [0,16,32]
num_clients = len(feature_div)-1
RAE_out_dim= 10* num_clients
class_num= 10
trial=5
is_clean_model= False

if num_clients==2:
    if is_clean_model:
        models_path=[
        'img_results/cifar_qtrain_0_16/checkpoints',
        'img_results/cifar_qtrain_16_32/checkpoints'
        ]
        directory= './rae_results/clean{}_{}cli_rae_p600_pre{}_try{}'.format(DATASET,num_clients,Pretrain_EPOCH,trial)

    else:
        models_path=[
        'img_results/cifar_qtrain_0_16/checkpoints',
        'img_results/debug_cifar_qtrain_poison_16_32_p600/checkpoints'
        ]
        directory= './rae_results/{}_{}cli_rae_p600_pre{}_try{}'.format(DATASET,num_clients,Pretrain_EPOCH,trial)
elif num_clients==3:
    if is_clean_model:
        models_path=[
        'img_results/cifar_qtrain_0_9/best_ckpt',
        'img_results/cifar_qtrain_9_19/best_ckpt',
        'img_results/cifar_qtrain_19_32/best_ckpt'
        ]
        directory= './rae_results/clean{}_{}cli_rae_p100_pre{}_try{}'.format(DATASET,num_clients,Pretrain_EPOCH,trial)

    else:
        models_path=[
        'img_results/cifar_qtrain_0_9/best_ckpt',
        'img_results/cifar_qtrain_9_19/best_ckpt',
        'img_results/debug_cifar_qtrain_poison_19_32_p100/best_ckpt'
        ]
        directory= './rae_results/{}_{}cli_rae_p100_pre{}_try{}'.format(DATASET,num_clients,Pretrain_EPOCH,trial)


if not os.path.exists(directory):
    os.makedirs(directory)
print("save to", directory)
IS_VIS = False


if IS_VIS:
    from tensorboardX import SummaryWriter
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(directory,current_time,'train')
    # test_log_dir= os.path.join(directory,current_time,'test')
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    # test_summary_writer = tf.summary.create_file_writer(test_log_dir)


class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()

        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_out_dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
       

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2



# prepare the datasets

if DATASET=="mnist":
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
elif DATASET=="cifar":
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()


# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images.astype('float32') / 255.0, test_images.astype('float32') / 255.0
if DATASET=="mnist":
    train_images, train_poison_list = data_poison(train_images, NUM_POISON_TRAIN) # train images are poisoned 
else:
    train_images, train_poison_list = data_poison_cifar(train_images, NUM_POISON_TRAIN) # train images are poisoned 

train_need_poison_list = [True if indx in train_poison_list else False  for indx in range(len(train_images)) ]


if DATASET=="mnist":
    train_images = train_images[:,:,:,np.newaxis]

# prepare the models 
LocalModels= []
DataDivs=[]
LocalEmbeddings=[]
for i in range(num_clients):
    local_data = train_images[:, feature_div[i]:feature_div[i+1]]
    print(local_data.shape)
    if DATASET=="mnist":
        local_model = VFLPassiveModel()
    else:
        local_model = VFLPassiveModelCIFAR()
    local_model.build((local_data.shape))
    # local_model.load_weights(os.path.join(models_path[i],'checkpoints'))
    local_model.load_weights(models_path[i])
    local_embedding = local_model(local_data)
    
    LocalModels.append(local_model)
    DataDivs.append(local_data)
    LocalEmbeddings.append(local_embedding)

H_input=tf.concat(LocalEmbeddings,1)

# #training Robust AutoEncoders via alternating update L and parameters of AutoEncoder
Low=tf.zeros(H_input.shape) # emd_dim * num_clients
Low=H_input
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-3,
    decay_steps=10000,
    decay_rate=0.99)

RAE_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule) #  default  learning_rate=0.01



input_shape= (BATCH_SIZE,EMB_DIM *num_clients)
RAEModel=RAE()

# RAEModel.build(input_shape) 
# RAEModel.load_weights(
#     os.path.join( './nontrain_results/','stage2_4clients0.01/','checkpoints'))


loss_RAE=tf.keras.losses.MeanSquaredError()
RAE_LOSS =[]
L_LOSS= []

rae_step= 0
l_step = 0
rae_pretrain_step =0 
for epoch in range(Pretrain_EPOCH): # todo:100

    train_ds = tf.data.Dataset.from_tensor_slices(
        (H_input, train_labels, Low)).batch(BATCH_SIZE)


    index=0
    start = time.time()

    for h_input, labels, low in train_ds:
        rae_pretrain_step +=1
        with tf.GradientTape() as passive_tape:
            
            RAE_output,layer_output=RAEModel(low)
            loss = loss_RAE(RAE_output,low)
    
        passive_RAE_L_gradients = passive_tape.gradient(loss,RAEModel.trainable_variables)
        RAE_optimizer.apply_gradients(zip(passive_RAE_L_gradients, RAEModel.trainable_variables))  
        
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('pretrain_rae', loss, step=rae_pretrain_step)   

    print("epoch", epoch, "loss", loss , "time: ",time.time()- start, "step: ", rae_pretrain_step ) 
    RAEModel.save_weights(os.path.join(directory, 'pre_rae_ckpt.tf'))   

RAE_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule) #  default  learning_rate=0.01
L_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule) #  default  learning_rate=0.01
best_loss = 10000000 
for epoch in range(STAGE2_EPOCH):
    start = time.time()
    train_ds = tf.data.Dataset.from_tensor_slices(
        (H_input, train_labels, Low)).batch(BATCH_SIZE)
    # train the RAE
    total_RAE_loss =0
    for h_input, labels, low in train_ds:
        with tf.GradientTape() as rae_tape:
            RAE_reconstruct,codeword =RAEModel(low) # robust encoder decoder: return recovered ; the middle code
            _RAE_loss_codeword = calculate_l21_colnorm(codeword)
            _RAE_loss_rae = loss_RAE(RAE_reconstruct,low) *0.1   
            loss = _RAE_loss_codeword +_RAE_loss_rae   #+1*(abs(active_model.w11)+abs(active_model.w12)+abs(active_model.w2)+abs(active_model.w3))
        rae_step+=1
        RAE_gradients = rae_tape.gradient(loss,RAEModel.trainable_variables)
        RAE_optimizer.apply_gradients(zip(RAE_gradients, RAEModel.trainable_variables))#Update parameters of AutoEncoder
        total_RAE_loss+= loss.numpy()

        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('RAE_loss_codeword', _RAE_loss_codeword, step=rae_step)
                tf.summary.scalar('RAE_loss_rae', _RAE_loss_rae, step=rae_step)
                tf.summary.scalar('RAE_loss', loss, step=rae_step)
                tf.summary.scalar('Norm_D1', tf.norm(RAEModel.layers[0].kernel).numpy(), step=rae_step)
                tf.summary.scalar('Norm_D2', tf.norm(RAEModel.layers[1].kernel).numpy(), step=rae_step)
                tf.summary.scalar('Norm_D3', tf.norm(RAEModel.layers[2].kernel).numpy(), step=rae_step)
                tf.summary.scalar('Norm_D4', tf.norm(RAEModel.layers[3].kernel).numpy(), step=rae_step)

    # get the optimal "L"
    index_batch=0
    Low_np=Low.numpy() # Low_np will be upadated 
    total_L_loss =0
    for h_input, labels, low in train_ds:
        index_batch+=1
        with tf.GradientTape() as L_tape:
            L=tf.Variable(low,trainable=True)
            RAE_reconstruct,codeword  = RAEModel(L)
            _L_loss_codeword = calculate_l21_colnorm(codeword) 
            _L_loss_rae=loss_RAE(RAE_reconstruct,L)*0.1
            _L_loss_h = calculate_l21_rownorm(h_input-L) *0.1
            loss = _L_loss_codeword+ _L_loss_rae+ _L_loss_h
        l_step+=1     
        L_gradients = L_tape.gradient(loss,[L])
        L_optimizer.apply_gradients(zip(L_gradients, [L]))#Update L
        total_L_loss += loss.numpy()
        Low_np[(index_batch-1)*BATCH_SIZE:min(index_batch*BATCH_SIZE,Low_np.shape[0]),:]=L.read_value() # replace the value in that batch! 
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('L_loss_codeword', _L_loss_codeword, step=l_step)
                tf.summary.scalar('L_loss_h', _L_loss_h, step=l_step)
                tf.summary.scalar('L_loss_rae', _L_loss_rae, step=l_step)
                tf.summary.scalar('L_loss', loss, step=l_step)

    Low=tf.convert_to_tensor(Low_np, dtype=tf.float32)   # Low is upadated  

    RAE_output,layer_output=RAEModel(Low)
    stage2_loss = calculate_l21_colnorm(layer_output)+0.1*loss_RAE(RAE_output,Low)+0.1*calculate_l21_rownorm(H_input-Low)


    print('Epoch {}, stage2 loss: {} avg RAE Loss: {}, avg L Loss: {}'.format(epoch+1,stage2_loss, total_RAE_loss/index_batch,total_L_loss/index_batch))
    print('time', time.time()-start, "step: ",  rae_step,  l_step )

    RAE_LOSS.append(total_RAE_loss/index_batch)
    L_LOSS.append(total_L_loss/index_batch)
    if IS_VIS:
        with train_summary_writer.as_default():
            tf.summary.scalar('Stage2_loss', stage2_loss, step=epoch)

    # Save the weights
    RAEModel.save_weights(os.path.join(directory, 'checkpoints'))
    # print("save weights", os.path.join(directory, 'checkpoints'))
    with open(os.path.join(directory,'L.npy'), 'wb') as f:
        np.save(f, Low.numpy())

    with open(os.path.join(directory, 'stage2_rae_loss.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in RAE_LOSS))
    with open(os.path.join(directory, 'stage2_l_loss.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in L_LOSS))


    if stage2_loss<best_loss: 
        best_loss= stage2_loss
        last_save_epoch= epoch
        RAEModel.save_weights(os.path.join(directory, 'rae_best_ckpt.tf'))
        with open(os.path.join(directory,'opt_L.npy'), 'wb') as f:
            np.save(f, Low.numpy())
