import tensorflow as tf
from tensorflow.keras import Model, datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
from sklearn import metrics
from model import  VFLPassiveModel,VFLPassiveModelCIFAR
import os 

DATASET="cifar"
BATCH_SIZE = 128
if DATASET=="mnist":
    # feature_div=  [0,7,14,21,28]
    feature_div=  [0,7]
    Server_Trainable= False
    # feature_div=  [0,14,28]
else:
    # feature_div=  [0,16]
    # feature_div=  [0,9]
    # feature_div=  [9,19]
    feature_div=  [19,32]
    
    # feature_div=  [16,32]
    Server_Trainable= True # cifar 
if Server_Trainable: 
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.concatenated = tf.keras.layers.Concatenate()
            self.d1 = layers.Dense(32, name="dense1", activation='relu')
            self.out = layers.Dense(class_num, name="out", activation='softmax')

        def call(self, x):
            x = self.concatenated(x)
            x = self.d1(x)
            return self.out(x)

else:
#  Active party not trainable
    class VFLActiveModel(Model):
        def __init__(self):
            super(VFLActiveModel, self).__init__()
            self.added = tf.keras.layers.Add()
            self.activation = layers.Activation('softmax')

        def call(self, x):
            x = self.added(x)
            return self.activation(x)


directory= './img_results/{}_qtrain_{}_{}'.format(DATASET,feature_div[0],feature_div[1])
print('save into: '+ directory)

if not os.path.exists(directory):
    os.makedirs(directory)

if DATASET=="mnist":
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
elif DATASET=="cifar":
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images.astype('float32') / 255.0, test_images.astype('float32') / 255.0
if DATASET=="mnist":
    train_images = train_images[:,:,:,np.newaxis]
    test_images = test_images[:,:,:,np.newaxis]
print(train_images.shape)
print(test_images.shape)
print(train_labels.shape)

class_num = 10



loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
test_label_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_label_accuracy')


EPOCHS = 100



if DATASET=="mnist":
    local_model = VFLPassiveModel()
else:
    local_model = VFLPassiveModelCIFAR()
# input_shape= (BATCH_SIZE, 28,  feature_div[1]-feature_div[0],1)
# print(input_shape)
# local_model.build(input_shape)
# local_model.load_weights(os.path.join('./nontrain_results/','mnist_local_7_14/','checkpoints'))

server_model = VFLActiveModel()
acc_train = []
acc_test = []
acc_test_label = [[], [], [], [], [], [], [], [], [], [], []]
loss_train = []
loss_test = []


# test_images_div = test_images[:,feature_div[0]:feature_div[1]]
# test_output = server_model([local_model(test_images_div)])
# t_loss = loss_object(test_labels, test_output)
# test_loss(t_loss)
# test_accuracy(test_labels, test_output)
# template = ' Test Loss: {}, Test Accuracy: {}'
# print(template.format(
#                     test_loss.result(),
#                     test_accuracy.result()*100
#                         ))

best_acc = -1
for epoch in range(EPOCHS):
    # Batch and shuffle the data
    train_ds = tf.data.Dataset.from_tensor_slices(
        (train_images, train_labels)).batch(BATCH_SIZE)
    # For each batch of images and labels
    for images, labels in train_ds:
        with tf.GradientTape() as passive_tape:
            images_div = images[:,feature_div[0]:feature_div[1]].numpy()
            local_embedding= local_model(images_div)
            
            with tf.GradientTape() as server_tape:
                server_tape.watch(local_embedding)
                output = server_model([local_embedding])
                loss = loss_object(labels, output)

            [emb_gradients, server_model_gradients] = \
            server_tape.gradient(loss, [local_embedding, server_model.trainable_variables])
            optimizer.apply_gradients(zip(server_model_gradients, server_model.trainable_variables)) # update server model
            
            emd_loss = tf.multiply(local_embedding, emb_gradients.numpy())
            
        [local_model_gradients] = \
        passive_tape.gradient([emd_loss], \
                                [local_model.trainable_variables])
        optimizer.apply_gradients(zip(local_model_gradients, local_model.trainable_variables))
        train_loss(loss)
        train_accuracy(labels, output)

    test_images_div = test_images[:,feature_div[0]:feature_div[1]]
    
    test_output = server_model([local_model(test_images_div)])
    t_loss = loss_object(test_labels, test_output)
    test_loss(t_loss)
    test_accuracy(test_labels, test_output)

    # for label_val in range(class_num):
    #     image_val = test_images[test_labels==label_val]
    #     y_val = test_labels[test_labels==label_val]
    #     image_val_div = image_val[:,feature_div[0]:feature_div[1]]
        
    #     val_output = server_model([local_model(image_val_div)])
    #     test_label_accuracy.reset_states()
    #     tl_acc = test_label_accuracy(y_val, val_output)
    #     acc_test_label[label_val].append(tl_acc.numpy())


    template = 'Epoch {}, Clean, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,  
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100
                            ))
    acc_train.append(train_accuracy.result())
    acc_test.append(test_accuracy.result())
    loss_train.append(train_loss.result())
    loss_test.append(test_loss.result())
    

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    result_list = [acc_train, acc_test, acc_test_label, loss_train, loss_test]


    with open(os.path.join(directory, 'acc_test.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_test))
    with open(os.path.join(directory, 'acc_train.txt'), "w") as outfile:
        outfile.write("\n".join(str(float(item))[:6] for item in acc_train))

    # Save the weights
    local_model.save_weights(os.path.join(directory, 'checkpoints'))
    if best_acc < test_accuracy.result():
        best_acc = test_accuracy.result()
        local_model.save_weights(os.path.join(directory, 'best_ckpt'))
        

