import os,sys

# which GPU is used? 
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# variables
NE = True
enh_noise=1.5
batch_size=500


#------------------------------
enh_int=round(enh_noise*10)//10
enh_point=round(enh_noise*10) % 10

# output file
sys.stdout=open("cifar100_NE"+str(enh_int)+"_"+str(enh_point)+"_Batch"+str(batch_size)+".txt","w")


import tensorflow as tf
import numpy as np
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist,cifar10,cifar100
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization, Activation, Conv2D, MaxPooling2D
from tensorflow.keras import utils as np_utils
from tensorflow.keras import regularizers


# gpu memory
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    for k in range(len(physical_devices)):
        tf.config.experimental.set_memory_growth(physical_devices[k], True)
else:
    print("Not enough GPU hardware devices available")





np.random.seed()

# CIFAR100
cifar100 = tf.keras.datasets.cifar100
(train_data,train_label),(test_data,test_label) = cifar100.load_data()
train_data = train_data.astype('float32')/255
test_data = test_data.astype('float32')/255

ntrain=len(train_label)
ntest=len(test_label)
nclass=100
    
# one-hot representation of labels
train_label=np_utils.to_categorical(train_label,nclass)
test_label=np_utils.to_categorical(test_label,nclass)

# network
def build_model():
    ghost_batch=100
    input_shape = (32,32,3)
    model = Sequential()
    model.add(Conv2D(64, 3, padding="same", input_shape=input_shape))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(64, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(64, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(128, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(128, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(128, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(256, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(256, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(256, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(512, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    model.add(Conv2D(512, 3, padding="same"))
    model.add(BatchNormalization(virtual_batch_size=ghost_batch))
    model.add(Activation("relu"))
    
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Flatten())

    model.add(Dense(units=1024, activation='relu'))
    model.add(Dense(units=nclass, activation='softmax'))
    return model

# training
strategy = tf.distribute.MirroredStrategy()

BUFFER_SIZE=ntrain
train_steps_per_epoch = ntrain // batch_size
test_steps_per_epoch = ntest // batch_size

train_dataset = tf.data.Dataset.from_tensor_slices( (train_data, train_label)).shuffle(BUFFER_SIZE).batch(batch_size)
train_dataset2 = tf.data.Dataset.from_tensor_slices( (train_data, train_label)).shuffle(BUFFER_SIZE).batch(batch_size)
train_iterator = strategy.make_dataset_iterator(train_dataset)
train_iterator2 = strategy.make_dataset_iterator(train_dataset2)
test_dataset = tf.data.Dataset.from_tensor_slices( (test_data, test_label)).batch(batch_size)
test_iterator = strategy.make_dataset_iterator(test_dataset)

lr=0.001
with strategy.scope():
    loss_object = tf.keras.losses.categorical_crossentropy
    optimizer = tf.keras.optimizers.Adam(lr=lr)
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

    def compute_loss(labels,predictions):
        regularization_loss = tf.reduce_sum(model.losses)
        per_example_loss = loss_object(labels, predictions) + regularization_loss 
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=batch_size)

    def train_step(images, labels, images2, labels2):
        with tf.GradientTape() as tape:
            predictions = model(images,training=True)
            loss = compute_loss(labels, predictions)
            predictions2 = model(images2,training=True)
            loss2 = compute_loss(labels2,predictions2)
            loss_scaled = enh_noise*loss+(1.0-enh_noise)*loss2
        gradients = tape.gradient(loss_scaled, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_accuracy(labels, predictions)
        return loss
    
    def test_step(images, labels):
        predictions = model(images,training=False)
        t_loss = loss_object(labels, predictions)
        test_loss(t_loss)
        test_accuracy(labels, predictions)
    
    def train_step_normal(inputs):
        images, labels = inputs
        with tf.GradientTape() as tape:
            predictions = model(images,training=True)
            loss = compute_loss(labels, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_accuracy(labels, predictions)
        return loss


with strategy.scope():
    if NE == True:
        EPOCHS = 2000
        model=build_model()
        @tf.function
        def distributed_train():
            images,labels=train_iterator.get_next()
            images2,labels2=train_iterator2.get_next()
            per_replica_losses = strategy.experimental_run_v2(train_step,args=(images,labels,images2,labels2))
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
    
        @tf.function
        def distributed_test():
            images,labels=test_iterator.get_next()
            return strategy.experimental_run_v2(test_step, args=(images,labels))

        itr=0
        for epoch in range(EPOCHS):
            train_accuracy.reset_states()
            test_loss.reset_states()
            test_accuracy.reset_states()
            # TRAIN LOOP
            total_loss = 0.0
            num_batches = 0
            # Initialize the iterator
            train_iterator.initialize()
            train_iterator2.initialize()
            for _ in range(train_steps_per_epoch):
                total_loss += distributed_train()
                num_batches += 1
                itr += 1
            train_loss = total_loss / num_batches

            test_iterator.initialize()
            for _ in range(test_steps_per_epoch):
                distributed_test()

            if train_loss<0.02:
                optimizer.lr=lr/2.0
            if train_loss<0.001:
                acc_data = test_accuracy.result().numpy()*100
                conv_time = itr            
                break

    else:
        EPOCHS = 2000
        model=build_model()
        @tf.function
        def distributed_train():
            per_replica_losses = strategy.experimental_run(train_step_normal,train_iterator)
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
   
        @tf.function
        def distributed_test():
            images,labels=test_iterator.get_next()
            return strategy.experimental_run_v2(test_step, args=(images,labels))

        itr=0
        for epoch in range(EPOCHS):
            train_accuracy.reset_states()
            test_loss.reset_states()
            test_accuracy.reset_states()
            # TRAIN LOOP
            total_loss = 0.0
            num_batches = 0
            # Initialize the iterator
            train_iterator.initialize()
            for _ in range(train_steps_per_epoch):
                total_loss += distributed_train()
                num_batches += 1
                itr += 1
            train_loss = total_loss / num_batches

            test_iterator.initialize()
            for _ in range(test_steps_per_epoch):
                distributed_test()

            if train_loss<0.02:
                optimizer.lr=lr/2.0
            if train_loss<0.001:
                acc_data = test_accuracy.result().numpy()*100
                conv_time = itr            
                break


# output the result
print("training loss = ", train_loss.numpy())
print("training accuracy = ", train_accuracy.result().numpy()*100)
print("convergence time = ", conv_time)
print("test accuracy = ", acc_data)
