# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import numpy as np
import csv
import os
from io import BytesIO
import time
import numpy as np
import pandas as pd
from numpy import linalg
import requests
from PIL import Image
from sklearn.utils import shuffle

import datetime
import copy
from nus_wide_data_util import *
from io import BytesIO
import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.utils import shuffle


trial=5
exp_type= 4
is_clean_model= True

if exp_type==2: 
    img_feature_div= [0,634]
    text_feature_div=[0,1000]
    
    Pretrain_EPOCH= 100  # 100
    STAGE2_EPOCH = 100
    directory= './nus_results/2cli_rae_pretrain{}_try{}'.format(Pretrain_EPOCH,trial)
    if is_clean_model:
        SavedModels=[
            'nus_results/local_image_0_634',
            'nus_results/local_text_0_1000'
            ]
        directory= './nus_results/clean_2cli_rae_pretrain{}_try{}'.format(Pretrain_EPOCH,trial)
    else:
        SavedModels=[
            'nus_results/local_image_0_634',
            'nus_results/poison_text_0_1000_trigger-1'
            ]
        directory= './nus_results/2cli_rae_pretrain{}_try{}'.format(Pretrain_EPOCH,trial)
elif  exp_type==4:
    img_feature_div= [0,225,634]
    text_feature_div=[0,500,1000]
    Pretrain_EPOCH= 200  # 100
    STAGE2_EPOCH = 100
    
    if is_clean_model:
        SavedModels=[
            'nus_results/local_image_0_225',
            'nus_results/local_image_225_634',
            'nus_results/local_text_0_500',
            'nus_results/local_text_500_1000'
            ]   
        directory= './nus_results/clean_4cli_rae_225_pretrain{}_try{}'.format(Pretrain_EPOCH,trial)
    else:
        SavedModels=[
            'nus_results/local_image_0_225',
            'nus_results/local_image_225_634',
            'nus_results/local_text_0_500',
            'nus_results/poison_text_500_1000_trigger-1'
        ]   
        directory= './nus_results/4cli_rae_225_pretrain{}_try{}'.format(Pretrain_EPOCH,trial)
    


num_img_clients= len(img_feature_div)-1
num_text_clients= len(text_feature_div)-1
num_clients = num_img_clients+ num_text_clients




if not os.path.exists(directory):
    os.makedirs(directory)
IS_VIS=False
print("save to", directory)

Batch_size= 64
RAE_out_dim= 32* num_clients




if IS_VIS:
    from tensorboardX import SummaryWriter
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(directory,current_time,'train')
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)


class_num = 5

top_k = ['buildings', 'grass', 'animal', 'water', 'person']
print(top_k)

train_X_image, train_X_text, train_Y = get_labeled_data('', top_k, 60000, 'Train')
print(type(train_X_image), type(train_X_text), type(train_Y))
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 10000, 'Test')
print(type(test_X_image), type(test_X_text), type(test_Y))

x_train, x_test, y_train, y_test = (np.array(train_X_image).astype('float32'), np.array(train_X_text).astype('float32')), \
                                    (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32')), \
                                    np.array(train_Y).astype('float32'), np.array(test_Y).astype('float32')

# Batch and shuffle the data

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(1024).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(32)

np.sum(y_test, axis=0)
print(sum(y_test[:,3]==1))

class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(32, name="dense1", activation='relu')

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        return x


class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):
        x = self.d1(x)
        return self.out(x)

(image_test, text_test) = x_test



TRAIN_image=np.array(train_X_image).astype('float32')
TRAIN_text=np.array(train_X_text).astype('float32')



LocalModels= []
for i in range(num_clients):
    local_model = VFLPassiveModel()
    local_model.load_weights(
    os.path.join(SavedModels[i],'checkpoints')) 
    LocalModels.append(local_model)


LocalOutputs =[] 
for i in range(num_img_clients):
    images_div = TRAIN_image[:,img_feature_div[i]:img_feature_div[i+1]]
    LocalOutputs.append(LocalModels[i](images_div))
for i in range(num_text_clients):
    text_div = TRAIN_text[:,text_feature_div[i]:text_feature_div[i+1]]
    LocalOutputs.append(LocalModels[i+num_img_clients](text_div))

print("LocalOutputs", len(LocalOutputs), LocalOutputs[0].shape)
H_input=tf.concat(LocalOutputs,1)
loss_RAE=tf.keras.losses.MeanSquaredError()



class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()

        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(RAE_out_dim, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
       

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2

def calculate_l21_rownorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[i,:]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(1))).sum()

def calculate_l21_colnorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[:,j]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(0))).sum()



Low=tf.zeros(H_input.shape)
Low=H_input


lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.99)

rae=RAE()
rae_pretrain_step= 0
rae_step=0
l_step=0

RAE_LOSS =[]
L_LOSS= []

RAE_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
for epoch in range(Pretrain_EPOCH): # todo:100

    train_ds = tf.data.Dataset.from_tensor_slices(
        (H_input, y_train, Low)).batch(Batch_size)


    start = time.time()

    for h_input, labels, low in train_ds:
        rae_pretrain_step +=1
        with tf.GradientTape() as passive_tape:
            
            RAE_output,layer_output=rae(low)
            loss = loss_RAE(RAE_output,low)
    
        passive_RAE_L_gradients = passive_tape.gradient(loss,rae.trainable_variables)
        RAE_optimizer.apply_gradients(zip(passive_RAE_L_gradients, rae.trainable_variables))  
        
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('pretrain_rae', loss, step=rae_pretrain_step)   

    print("epoch", epoch, "loss", loss , "time: ",time.time()- start, "step: ", rae_pretrain_step ) 
    rae.save_weights(os.path.join(directory, 'pre_rae_ckpt.tf'))       

best_loss = 10000000
RAE_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule) #  default  learning_rate=0.01
L_optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule) #  default  learning_rat

for epoch in range(STAGE2_EPOCH):
    train_ds = tf.data.Dataset.from_tensor_slices(
        (H_input, y_train, Low)).batch(Batch_size)
 
    index=0
    start = time.time()
    total_RAE_loss =0
    for h_input, labels, low in train_ds:
        rae_step+=1
        with tf.GradientTape() as passive_tape:
            RAE_output,layer_output=rae(low)
            loss = calculate_l21_colnorm(layer_output)+0.1*loss_RAE(RAE_output,low)
        passive_RAE_L_gradients = passive_tape.gradient(loss,rae.trainable_variables)
        RAE_optimizer.apply_gradients(zip(passive_RAE_L_gradients, rae.trainable_variables))
        total_RAE_loss+= loss.numpy()
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('RAE_loss', loss, step=rae_step)

    index_batch=0
    Low_np=Low.numpy()
    total_L_loss =0
    for h_input, labels, low in train_ds:
        l_step+=1 
        index_batch+=1
        with tf.GradientTape() as passive_tape:
            L=tf.Variable(low,trainable=True)
            RAE_output,layer_output=rae(L)
            loss = calculate_l21_colnorm(layer_output)+0.1*loss_RAE(RAE_output,L)+0.1*calculate_l21_rownorm(h_input-L)
        passive_RAE_L_gradients = passive_tape.gradient(loss,[L])
        L_optimizer.apply_gradients(zip(passive_RAE_L_gradients, [L]))
        total_L_loss += loss.numpy()
        if IS_VIS:
            with train_summary_writer.as_default():
                tf.summary.scalar('L_loss', loss, step=l_step)

        Low_np[(index_batch-1)*Batch_size:min(index_batch*Batch_size,Low_np.shape[0]),:]=L.read_value()
    Low=tf.convert_to_tensor(Low_np, dtype=tf.float32)   
    RAE_output,layer_output=rae(Low)
    stage2_loss = calculate_l21_colnorm(layer_output)+0.1*loss_RAE(RAE_output,Low)+0.1*calculate_l21_rownorm(H_input-Low)
    print('Epoch {}, stage2 loss: {} avg RAE Loss: {}, avg L Loss: {}'.format(epoch+1,stage2_loss, total_RAE_loss/index_batch,total_L_loss/index_batch))
    print('time', time.time()-start, "step: ",  rae_step,  l_step )
    RAE_LOSS.append(total_RAE_loss/index_batch)
    L_LOSS.append(total_L_loss/index_batch)

    if IS_VIS:
        with train_summary_writer.as_default():
            tf.summary.scalar('Stage2_loss', stage2_loss, step=epoch)
    
    S=H_input-Low    
    rae.save_weights(os.path.join(directory, 'rae_ckpt.tf'))
    with open(os.path.join(directory,'L.npy'), 'wb') as f:
            np.save(f, Low.numpy())

    with open(os.path.join(directory, 'stage2_rae_loss.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in RAE_LOSS))
    with open(os.path.join(directory, 'stage2_l_loss.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in L_LOSS))

    if stage2_loss<best_loss: 
        best_loss= stage2_loss
        last_save_epoch= epoch
        rae.save_weights(os.path.join(directory, 'rae_best_ckpt.tf'))
        with open(os.path.join(directory,'opt_L.npy'), 'wb') as f:
            np.save(f, Low.numpy())

