import tensorflow as tf
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
from sklearn import metrics

import os 
from nus_utils import *
from nus_wide_data_util import get_labeled_data
import time
        
DATASET="nus"
BATCH_SIZE = 64

data_type='image'
# image (60000, 634)
# text (60000, 1000)

# feature_div=[0,634]
# feature_div=[0,225]
# feature_div=[225,634]

data_type='text'
# feature_div=[0,500]
# feature_div=[500,1000]
feature_div=[0,1000]

EPOCHS = 100


# image (60000, 634)
# text (60000, 1000)

print(feature_div)
directory= './nus_results/'+'local'+'_{}_{}_{}'.format(data_type,feature_div[0],feature_div[1])
print('save into: '+ directory)
if not os.path.exists(directory):
    os.makedirs(directory)


class_num = 5


top_k = ['buildings', 'grass', 'animal', 'water', 'person']
print(top_k)

train_X_image, train_X_text, train_Y = get_labeled_data('', top_k, 60000, 'Train')
test_X_image, test_X_text, test_Y = get_labeled_data('', top_k, 10000, 'Test')

x_train, x_test, y_train, y_test = (np.array(train_X_image).astype('float32'), np.array(train_X_text).astype('float32')), \
                                    (np.array(test_X_image).astype('float32'), np.array(test_X_text).astype('float32')), \
                                    np.array(train_Y).astype('float32'), np.array(test_Y).astype('float32')


test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(BATCH_SIZE)

(image_test, text_test) = x_test


loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_label_accuracy')

backdoor_loss = tf.keras.metrics.Mean(name='backdoor_loss')
backdoor_accuracy = tf.keras.metrics.CategoricalAccuracy(name='backdoor_accuracy')



local_model = VFLPassiveModel()
# input_shape= (BATCH_SIZE, 28,  feature_div[1]-feature_div[0],1)
# print(input_shape)
# local_model.build(input_shape)
# local_model.load_weights(os.path.join('./nontrain_results/','mnist_local_7_14/','checkpoints'))
server_model = VFLActiveModelWithOneLayer()


acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], []]
acc_backdoor = []
loss_train = []
loss_test = []
loss_backdoor = []


# test_images_div = test_images[:,feature_div[0]:feature_div[1]]
# test_output = server_model([local_model(test_images_div)])
# t_loss = loss_object(test_labels, test_output)
# test_loss(t_loss)
# test_accuracy(test_labels, test_output)
# template = ' Test Loss: {}, Test Accuracy: {}'
# print(template.format(
#                     test_loss.result(),
#                     test_accuracy.result()*100
#                         ))

best_acc = -1

for epoch in range(EPOCHS):
    # Batch and shuffle the data
    start= time.time()
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(y_train.shape[0]).batch(BATCH_SIZE)
    # For each batch of images and labels

    for (images, texts), labels in train_ds:
        with tf.GradientTape() as passive_tape:
            if data_type=='image':
                data_div = images[:,feature_div[0]:feature_div[1]].numpy()
            elif data_type=='text':
                data_div = texts[:,feature_div[0]:feature_div[1]].numpy()
            
            local_embedding= local_model(data_div)
            
            with tf.GradientTape() as server_tape:
                server_tape.watch(local_embedding)
                output = server_model([local_embedding])
                loss = loss_object(labels, output)

            [emb_gradients, server_model_gradients] = \
            server_tape.gradient(loss, [local_embedding, server_model.trainable_variables])
            optimizer.apply_gradients(zip(server_model_gradients, server_model.trainable_variables)) # update server model
            
            emd_loss = tf.multiply(local_embedding, emb_gradients.numpy())
            
        [local_model_gradients] = \
        passive_tape.gradient([emd_loss], \
                                [local_model.trainable_variables])
        optimizer.apply_gradients(zip(local_model_gradients, local_model.trainable_variables))
        train_loss(loss)
        train_accuracy(labels, output)


    
    if data_type=='image':
        test_data_div = image_test[:,feature_div[0]:feature_div[1]]
    elif data_type=='text':
        test_data_div = text_test[:,feature_div[0]:feature_div[1]]
    
    test_output = server_model([local_model(test_data_div)])
    t_loss = loss_object(y_test, test_output)
    test_loss(t_loss)
    test_accuracy(y_test, test_output)

    # for label_val in range(class_num):
    #     y_val = y_test[y_test==label_val]
    #     if data_type=='image':
    #         val_data_div = image_test[y_test==label_val][:,feature_div[0]:feature_div[1]]
    #     elif data_type=='text':
    #         val_data_div = text_test[y_test==label_val][:,feature_div[0]:feature_div[1]]
       
    #     val_output = server_model([local_model(val_data_div)])
    #     test_label_accuracy.reset_states()
    #     tl_acc = test_label_accuracy(y_val, val_output)
    #     acc_test_label[label_val].append(tl_acc.numpy())


    template = 'Epoch {}, Clean, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,  
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100
                            ))
    acc_train.append(train_accuracy.result())
    acc_test.append(test_accuracy.result())
    loss_train.append(train_loss.result())
    loss_test.append(test_loss.result())
    

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    result_list = [acc_train, acc_test, acc_test_label, loss_train, loss_test]


    with open(os.path.join(directory, 'acc_test.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_test))
    with open(os.path.join(directory, 'acc_train.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_train))

    # Save the weights
    local_model.save_weights(os.path.join(directory, 'checkpoints'))
    if best_acc < test_accuracy.result():
        best_acc = test_accuracy.result()
        local_model.save_weights(os.path.join(directory, 'best_ckpt'))
        

