# import dependencies
import tensorflow as tf
import os
import argparse
import time
import numpy as np
from matplotlib import pyplot as plt



# create directories
if not os.path.exists("comparison_test_accuracy"):
  os.mkdir("comparison_test_accuracy")



# network specific parameters
dtype = "float32"
weight_decay = 1e-4
init_lr = 5e-3
momentum = 0.9
batch_size = 50
lrs_scale = 0.98
epochs_train = 100

seeds_low = 0
seeds_high = 19



# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train_size = 50000



def cfr10_lenet_create(
    dtype = "float32",
    seed = 1,
    weight_decay = 1e-4,
    init_lr = 5e-3,
    momentum = 0.9,
    batch_size = 50,
    ):
  """
  Args:
    dtype: string containing the datatype of the network, float32 or float64.
    seed: seed for weight initialization.
    weight_decay: parameter for L2 regularization.
    init_lr: initial learning rate, might change if lr schedule is used.
    momentum: parameter for heavy ball momentum.
    batch_size: example batch size for update gradient.
    lrs_str: containing information about the learning rate schedule, if no
      schedule is used, the string equals NONExabc, with
      abc = number_of_epochs_trained.

  Returns:
    A LeNet for Cifar-10 as a compiled keras model with the desired parameters.
  """

  # The model name helps to identify saved data later on
  model_type = "CFR10_LeNet"
  model_name = (model_type+"_"+dtype[-2:]+"_"+f"{seed:.0f}_wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}")
  y_train_size = 50000
  tf.keras.backend.set_floatx(dtype)

  # learning rate schedule initialization
  learning_rate = init_lr

  # model initialization
  tf.random.set_seed(seed)
  initializer = tf.keras.initializers.GlorotNormal(seed=seed)
  l2_reg = tf.keras.regularizers.L2(weight_decay)

  # model architecture
  model = tf.keras.models.Sequential(name=model_name)
  model.add(tf.keras.layers.Conv2D(6, 5, activation='relu', input_shape=(32, 32, 3), padding='same', kernel_initializer=initializer, kernel_regularizer=l2_reg))
  model.add(tf.keras.layers.MaxPooling2D((2, 2)))
  model.add(tf.keras.layers.Conv2D(16, 5, activation='relu', padding='same', kernel_initializer=initializer, kernel_regularizer=l2_reg))
  model.add(tf.keras.layers.MaxPooling2D((2, 2)))
  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(120, activation='relu', kernel_initializer=initializer, kernel_regularizer=l2_reg))
  model.add(tf.keras.layers.Dense(84, activation='relu', kernel_initializer=initializer, kernel_regularizer=l2_reg))
  model.add(tf.keras.layers.Dense(10, activation='softmax', kernel_initializer=initializer, kernel_regularizer=l2_reg))

  # compile model
  model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = learning_rate, momentum = momentum),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])

  return model



with_replacement_list = [False,True]

for init_seed in range(seeds_low, seeds_high+1):
  print(f"######### seed: {init_seed}/{seeds_low}-{seeds_high} #########")
  for with_replacement in with_replacement_list:
    print(f"With replacement: {with_replacement}")
    # create the model
    model = cfr10_lenet_create(
      dtype = dtype,
      seed = init_seed,
      weight_decay = weight_decay,
      init_lr = init_lr,
      momentum = momentum,
      batch_size = batch_size)

    # train function
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.SGD(learning_rate=init_lr, momentum=momentum)
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

    #train function
    @tf.function
    def train_step(images, labels):
      with tf.GradientTape() as tape:
        # training=True is only needed if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(images, training=True)
        regularization_loss=tf.math.add_n(model.losses)
        pred_loss=loss_fn(labels, predictions)
        total_loss=pred_loss + regularization_loss
      gradients = tape.gradient(total_loss, model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      train_loss(total_loss)
      train_accuracy(labels, predictions)

    #test function
    @tf.function
    def test_step(images, labels):
      # training=False is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      predictions = model(images, training=False)
      regularization_loss=tf.math.add_n(model.losses)
      pred_loss=loss_fn(labels, predictions)
      t_loss=pred_loss + regularization_loss
      test_loss(t_loss)
      test_accuracy(labels, predictions)

    rng = np.random.default_rng(init_seed)
    seed = rng.integers(low=0, high=1000000, size=1)[0]

    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
    history = np.zeros((0, 5))
    learning_rate_now = init_lr/lrs_scale

    time_stemp = time.monotonic()
    time_stemp_0 = time_stemp

    for _epoch in range(epochs_train):
      learning_rate_now = lrs_scale*learning_rate_now
      optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_now, momentum=momentum)
      #train function
      @tf.function
      def train_step(images, labels):
        with tf.GradientTape() as tape:
          # training=True is only needed if there are layers with different
          # behavior during training versus inference (e.g. Dropout).
          predictions = model(images, training=True)
          regularization_loss=tf.math.add_n(model.losses)
          pred_loss=loss_fn(labels, predictions)
          total_loss=pred_loss + regularization_loss
        gradients = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(total_loss)
        train_accuracy(labels, predictions)

      #ensure that batch shuffling is consitent
      tf.random.set_seed(seed+_epoch)
      train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(y_train.shape[0]).batch(batch_size)
      # Draw the examples with replacement if necessary
      if with_replacement == True:
        rng = np.random.default_rng(seed+_epoch)
        idx_list = rng.integers(low=0, high=y_train.size, size=y_train.size)
        train_ds = tf.data.Dataset.from_tensor_slices((x_train[idx_list], y_train[idx_list])).batch(batch_size)

      # Reset the metrics at the start of the next epoch
      train_loss.reset_state()
      train_accuracy.reset_state()
      test_loss.reset_state()
      test_accuracy.reset_state()

      for images, labels in train_ds:
        train_step(images, labels)
      for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)

      history = np.append(history, np.array([[train_loss.result(), train_accuracy.result(), test_loss.result(), test_accuracy.result(), learning_rate_now]]),  axis = 0)
      time_temp = time.monotonic()
      time_diff = time_temp - time_stemp
      time_stemp = time_temp
      time_total = time_temp - time_stemp_0
      TREM = (epochs_train - _epoch-1)*time_total/(_epoch+1)

      output_str = f'Epoch {_epoch + 1}, 'f'{time_diff:.0f}s, ' + \
        f'Loss: {train_loss.result():.3e}, ' + \
        f'Accuracy: {train_accuracy.result():.4e}, ' + \
        f'Test Loss: {test_loss.result():.2e}, ' + \
        f'Test Accuracy: {test_accuracy.result():.3e}, ' + \
        f'Learning Rate: {learning_rate_now:.2e}, ' + \
        f'TREM: {TREM:.0f}s ({TREM/3600:.0f}h)'
      print(output_str)

    np.save(f"comparison_test_accuracy/histroy_WR_{with_replacement}_epochs_{epochs_train}_lrsscale_{lrs_scale}_{model.name}", history)



history_with_replacement = []
history_without_replacement = []
for seed in range(seeds_low, seeds_high+1):
  model_name = f"CFR10_LeNet_32_{seed}_wd1e-04_lr5e-03_b50"
  history_with_replacement.append(np.load(f"comparison_test_accuracy/histroy_WR_True_epochs_{epochs_train}_lrsscale_{lrs_scale}_{model_name}.npy"))
  history_without_replacement.append(np.load(f"comparison_test_accuracy/histroy_WR_False_epochs_{epochs_train}_lrsscale_{lrs_scale}_{model_name}.npy"))

history_with_replacement = np.array(history_with_replacement)
history_without_replacement = np.array(history_without_replacement)



x_space = np.array(range(1,epochs_train+1))
test_acc_True_mean = np.mean(history_with_replacement[:,:,3], axis=0)
test_acc_True_1sigerr = np.std(history_with_replacement[:,:,3], axis=0)/np.sqrt(history_with_replacement.shape[0])
test_acc_False_mean = np.mean(history_without_replacement[:,:,3], axis=0)
test_acc_False_1sigerr = np.std(history_without_replacement[:,:,3], axis=0)/np.sqrt(history_without_replacement.shape[0])



# idx_start = 0
# plt.errorbar(x_space[idx_start:], test_acc_False_mean[idx_start:], test_acc_False_1sigerr[idx_start:], capsize=2)
# plt.errorbar(x_space[idx_start:], test_acc_True_mean[idx_start:], test_acc_True_1sigerr[idx_start:], capsize=2)



idx_start = 0
test_acc_difference_max = np.max(history_without_replacement[:,:,3], axis=1) - np.max(history_with_replacement[:,:,3], axis=1)
test_acc_difference_max_mean = np.mean(test_acc_difference_max[idx_start:])
test_acc_difference_max_err = np.std(test_acc_difference_max[idx_start:])/np.sqrt(len(test_acc_difference_max[idx_start:]))

print(f"Diff Max:{test_acc_difference_max_mean:.4f}+-{test_acc_difference_max_err:.4f}")



max_acc_True = np.mean(np.max(history_with_replacement[:,:,3], axis=1))
max_acc_True_err = np.std(np.max(history_with_replacement[:,:,3], axis=1))/np.sqrt(history_with_replacement.shape[0])
max_acc_False = np.mean(np.max(history_without_replacement[:,:,3], axis=1))
max_acc_False_err = np.std(np.max(history_without_replacement[:,:,3], axis=1))/np.sqrt(history_with_replacement.shape[0])

print(f"WR_False Max Acc: {max_acc_False:.4f}+-{max_acc_False_err:.4f}")
print(f"WR_True Max Acc: {max_acc_True:.4f}+-{max_acc_True_err:.4f}")





