# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import numpy as np
import csv
import os
from io import BytesIO
import time
import numpy as np
import pandas as pd
from numpy import linalg
import requests
from PIL import Image
from sklearn.utils import shuffle
import copy
from nus_wide_data_util import *
from io import BytesIO
import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.utils import shuffle
import datetime
def get_local_outputs(num_img_clients, images, img_feature_div, 
        num_text_clients, texts, text_feature_div, LocalModels):
    LocalOutputs =[] 
    for i in range(num_img_clients):
        images_div = images[:,img_feature_div[i]:img_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i](images_div))
    for i in range(num_text_clients):
        text_div = texts[:,text_feature_div[i]:text_feature_div[i+1]]
        LocalOutputs.append(LocalModels[i+num_img_clients](text_div))

    return LocalOutputs

USE_RAE= True
IS_VIS=False
exp_type= 2
# exp_type= 4
trial=5
EPOCHS = 100#  todo 200


is_clean_model= True

if exp_type==2: 
    img_feature_div= [0,634]
    text_feature_div=[0,1000]
    if is_clean_model:
        SavedModels=[
            'nus_results/local_image_0_634',
            'nus_results/local_text_0_1000'
        ]
        RAE_SavedModel= 'nus_results/clean_2cli_rae_pretrain100_try{}/rae_ckpt.tf'.format(trial)
        OptimalL_SavedModel='nus_results/clean_2cli_rae_pretrain100_try{}/L.npy'.format(trial)
    else:
        SavedModels=[
            'nus_results/local_image_0_634',
            'nus_results/poison_text_0_1000_trigger-1'
        ]
        RAE_SavedModel= 'nus_results/2cli_rae_pretrain100_try{}/rae_ckpt.tf'.format(trial)
        OptimalL_SavedModel='nus_results/2cli_rae_pretrain100_try{}/L.npy'.format(trial)
        
       
    # RAE_SavedModel='nus_results/2cli_rae_pretrain30/rae_ckpt.tf'
    # OptimalL_SavedModel='nus_results/2cli_rae_pretrain30/L.npy'

elif exp_type==4: 
    img_feature_div= [0,225,634]
    text_feature_div=[0,500,1000]
    if is_clean_model:
        SavedModels=[
            'nus_results/local_image_0_225',
            'nus_results/local_image_225_634',
            'nus_results/local_text_0_500',
            'nus_results/local_text_500_1000'
        ]
        RAE_SavedModel= 'nus_results/clean_4cli_rae_225_pretrain200_try{}/rae_ckpt.tf'.format(trial)
        OptimalL_SavedModel='nus_results/clean_4cli_rae_225_pretrain200_try{}/L.npy'.format(trial)
    else:
        SavedModels=[
            'nus_results/local_image_0_225',
            'nus_results/local_image_225_634',
            'nus_results/local_text_0_500',
            'nus_results/poison_text_500_1000_trigger-1'
        ]
        RAE_SavedModel= 'nus_results/4cli_rae_225_pretrain200_try{}/rae_ckpt.tf'.format(trial)
        OptimalL_SavedModel='nus_results/4cli_rae_225_pretrain200_try{}/L.npy'.format(trial)
        # RAE_SavedModel= 'nus_results/4cli_rae_225_pretrain200/rae_best_ckpt.tf'
        # OptimalL_SavedModel='nus_results/4cli_rae_225_pretrain200/opt_L.npy'


num_img_clients= len(img_feature_div)-1
num_text_clients= len(text_feature_div)-1
num_clients = num_img_clients+ num_text_clients



RAE_out_dim= 32* num_clients
Batch_size= 64
if is_clean_model:
    directory= './nus_results/clean_new_{}cli_server_rae{}_try{}'.format(num_clients,USE_RAE,trial)
else:
    directory= './nus_results/new_{}cli_server_rae{}_try{}'.format(num_clients,USE_RAE,trial)
if not os.path.exists(directory):
    os.makedirs(directory)
print("save to", directory)
Purify_learning_rate =0.5

if IS_VIS:
    from tensorboardX import SummaryWriter
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(directory,current_time,'train')
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)


class_num = 5

top_k = ['buildings', 'grass', 'animal', 'water', 'person']
print(top_k)

train_X_image, train_X_text, train_Y = get_labeled_data('', top_k, 60000, 'Train')
print(type(train_X_image), type(train_X_text), type(train_Y))
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 40000, 'Test')
print(type(test_X_image), type(test_X_text), type(test_Y))

x_train, x_test, y_train, y_test = (np.array(train_X_image).astype('float32'), np.array(train_X_text).astype('float32')), \
                                    (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32')), \
                                    np.array(train_Y).astype('float32'), np.array(test_Y).astype('float32')

# Batch and shuffle the data

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(1024).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(32)

np.sum(y_test, axis=0)
print(sum(y_test[:,3]==1))

class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(32, name="dense1", activation='relu')

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        return x


class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):
        x = self.d1(x)
        return self.out(x)

(image_test, text_test) = x_test
image_backdoor = image_test[text_test[:,-1]==1]
text_backdoor = text_test[text_test[:,-1]==1]
y_backdoor = copy.deepcopy(y_test[text_test[:,-1]==1])

np.sum(y_backdoor, axis=0)
print(np.sum(x_train[1][:,-1])) # 152.0

print(np.sum(x_test[1][:,-1])) # 102.0



LocalModels= []
for i in range(num_clients):
    local_model = VFLPassiveModel()
    local_model.load_weights(
    os.path.join(SavedModels[i],'checkpoints')) 
    LocalModels.append(local_model)

TRAIN_image=np.array(train_X_image).astype('float32')
TRAIN_text=np.array(train_X_text).astype('float32')

LocalOutputs =[] 
for i in range(num_img_clients):
    images_div = TRAIN_image[:,img_feature_div[i]:img_feature_div[i+1]]
    LocalOutputs.append(LocalModels[i](images_div))
for i in range(num_text_clients):
    text_div = TRAIN_text[:,text_feature_div[i]:text_feature_div[i+1]]
    LocalOutputs.append(LocalModels[i+num_img_clients](text_div))

print("LocalOutputs", len(LocalOutputs), LocalOutputs[0].shape)
H_input=tf.concat(LocalOutputs,1)
loss_RAE=tf.keras.losses.MeanSquaredError()

def RAE_split(LocalOutputs,rae, epochs=20):
    h=np.concatenate(tuple(LocalOutputs), axis=1)
    low2=h
    L=tf.Variable(low2,trainable=True)
    L_optimizer = tf.keras.optimizers.SGD(learning_rate = Purify_learning_rate )
    for epoch in range(epochs):
      with tf.GradientTape() as passive_tape:              
        RAE_output,layer_output=rae(L)
        rec_split=  tf.split(RAE_output, num_clients, 1 )
        loss=0
        for i in range(num_clients):
            loss += tf.sqrt(loss_RAE(rec_split[i], LocalOutputs[i]))

        passive_RAE_L_gradients = passive_tape.gradient(loss,[L])
        L_optimizer.apply_gradients(zip(passive_RAE_L_gradients, [L]))
    return RAE_output

def RAE_split_ini(LocalOutputs,rae,outliers,  epochs =20,  is_vis= False ):

    h=np.concatenate(tuple(LocalOutputs), axis=1)
    low2=h-outliers
    L=tf.Variable(low2,trainable=True)
    L_optimizer = tf.keras.optimizers.SGD(learning_rate = Purify_learning_rate )
 
    for epoch in range(epochs):
      with tf.GradientTape() as passive_tape:                      
        RAE_output,layer_output=rae(L)
        rec_split=  tf.split(RAE_output, num_clients, 1 )
        loss=0
        for i in range(num_clients):
            loss += tf.sqrt(loss_RAE(rec_split[i], LocalOutputs[i]))

        passive_RAE_L_gradients = passive_tape.gradient(loss,[L])
        L_optimizer.apply_gradients(zip(passive_RAE_L_gradients, [L]))
        if is_vis:
                print('Epoch {}, purify L Loss: {}'.format(epoch+1,loss.numpy()))
 
    return RAE_output



class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()

        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_out_dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
       

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2

def calculate_l21_rownorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[i,:]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(1))).sum()

def calculate_l21_colnorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[:,j]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(0))).sum()



Low=tf.zeros(H_input.shape)
Low=H_input


lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)




rae=RAE()
rae.load_weights(RAE_SavedModel)

Low=tf.zeros(H_input.shape)
with open(OptimalL_SavedModel, 'rb') as f:
    Low_np =  np.load(f)
Low = tf.convert_to_tensor(Low_np, dtype=tf.float32)  
S= H_input- Low

new_active_model = newVFLActiveModelWithOneLayer()
print(new_active_model)


loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer3 = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.CategoricalAccuracy(name='backdoor_accuracy')


sample_id_need_copy = 369
text_feat_need_copy = copy.deepcopy(x_train[1][sample_id_need_copy])
y_backdoor[:] = y_train[sample_id_need_copy]


acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []
server_step=0
best_acc= -1
for epoch in range(EPOCHS):
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train, S)).shuffle(y_train.shape[0]).batch(Batch_size)
    # For each batch of images and labels
    number_of_poison = 0

    index=0
    start = time.time()
    for (images, texts), labels, outliers in train_ds:
        index+=1
        server_step+=1

        LocalOutputs= get_local_outputs(num_img_clients, images, img_feature_div, 
                    num_text_clients, texts, text_feature_div, LocalModels)
        if USE_RAE:
            if index==5:
                RAE_output=RAE_split_ini(LocalOutputs,rae,outliers, epochs = 2, is_vis=True)
            else:
                RAE_output=RAE_split_ini(LocalOutputs,rae,outliers, epochs = 2)
        else:
            RAE_output= np.concatenate(tuple(LocalOutputs), axis=1) 
 
        with tf.GradientTape() as active_tape:
        
            active_output = new_active_model(RAE_output)
            loss = loss_object(labels, active_output)
       
        [active_model_gradients] = active_tape.gradient(loss, [new_active_model.trainable_variables])
        optimizer3.apply_gradients(zip(active_model_gradients, new_active_model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, active_output)
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('Server_loss', loss, step=server_step)
    new_active_model.save_weights(os.path.join(directory, 'new_server_ckpt.tf'))

    if epoch%1==0:

        LocalOutputs= get_local_outputs(num_img_clients, image_backdoor, img_feature_div, 
                    num_text_clients, text_backdoor, text_feature_div, LocalModels)
        if USE_RAE:
            H_rec=RAE_split(LocalOutputs,rae)
        else:
            H_rec= np.concatenate(tuple(LocalOutputs), axis=1) 

        active_output=new_active_model(H_rec)
        
        backdoor_loss.reset_states()
        backdoor_accuracy.reset_states()

        backdoor_loss(loss_object(y_backdoor, active_output))
        backdoor_acc = backdoor_accuracy(y_backdoor, active_output)
        acc_backdoor.append(backdoor_accuracy.result())

        LocalOutputs= get_local_outputs(num_img_clients, image_test, img_feature_div, 
                    num_text_clients, text_test, text_feature_div, LocalModels)
        if USE_RAE:
            H_rec2= RAE_split(LocalOutputs,rae)
        else:
            H_rec2= np.concatenate(tuple(LocalOutputs), axis=1) 
            
        active_output = new_active_model(H_rec2)#!!
        



        t_loss = loss_object(y_test, active_output)
        test_loss(t_loss)
        test_accuracy(y_test, active_output)
        if best_acc < test_accuracy.result():
            best_acc= test_accuracy.result()
            new_active_model.save_weights(os.path.join(directory, 'best_server_ckpt.tf'))
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('Server_Test_clean', test_accuracy.result()*100, step=epoch)
                tf.summary.scalar('Server_Test_bkd', backdoor_accuracy.result()*100, step=epoch)
            
        loss_train.append(train_loss.result())
        loss_test.append(test_loss.result())
        loss_backdoor.append(backdoor_loss.result())

        template = 'Epoch {}, Poisoned {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}, Backdoor Accuracy: {}'
        print(template.format(epoch+1,
                            number_of_poison,
                            train_loss.result(),
                            train_accuracy.result()*100,
                            test_loss.result(),
                            test_accuracy.result()*100,
                            backdoor_accuracy.result()*100))

        acc_train.append(train_accuracy.result())
        acc_test.append(test_accuracy.result())

        with open(os.path.join(directory, 'stage4_acc.txt'), "w") as outfile:
            outfile.write("\n".join(str(float(item))[:6] for item in acc_test))
        with open(os.path.join(directory, 'stage4_bkd.txt'), "w") as outfile:
            outfile.write("\n".join(str(float(item))[:6] for item in acc_backdoor))

        # Reset the metrics for the next epoch
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
    
    end = time.time()
    print(end - start)
