# import dependencies
import tensorflow as tf
import os
import argparse



parser = argparse.ArgumentParser()
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")

args = parser.parse_args()
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype



# network specific parameters
model_type = "resnet20"
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# create directories
if not os.path.exists(model_str):
  os.mkdir(model_str)
  os.mkdir(model_str+"/data")
  os.mkdir(model_str+"/plots")



# initialize the model
tf.random.set_seed(seed)
initializer = tf.keras.initializers.GlorotNormal(seed=seed)
l2_reg = tf.keras.regularizers.L2(weight_decay)

num_res_blocks = 3

# model architecture
def residual_block(x, initializer, regularizer, filters=16, num_res_blocks=3, downsampeling=False):
  if downsampeling == True:
    y = tf.keras.layers.Conv2D(filters, 3, strides=(2,2), padding='same', kernel_initializer=initializer, kernel_regularizer=regularizer)(x)
    x = tf.keras.layers.Conv2D(filters, 1, padding='same', strides=(2,2), kernel_initializer=initializer, kernel_regularizer=regularizer)(x)
  else:
    y = tf.keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer=initializer, kernel_regularizer=regularizer)(x)
  y = tf.keras.layers.Activation('relu')(y)
  y = tf.keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer=initializer, kernel_regularizer=regularizer)(y)
  x = tf.keras.layers.Add()([x,y])
  x = tf.keras.layers.Activation('relu')(x)
  for idx in range(1, num_res_blocks):
    y = tf.keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer=initializer, kernel_regularizer=regularizer)(x)
    y = tf.keras.layers.Activation('relu')(y)
    y = tf.keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer=initializer, kernel_regularizer=regularizer)(y)
    x = tf.keras.layers.Add()([x,y])
    x = tf.keras.layers.Activation('relu')(x)
  return x

inputs = tf.keras.layers.Input(shape=(32, 32, 3))
x = tf.keras.layers.Conv2D(16, 3, padding='same', kernel_initializer=initializer, kernel_regularizer=l2_reg)(inputs)
x = tf.keras.layers.Activation('relu')(x)
x = residual_block(x, initializer=initializer, regularizer=l2_reg, filters=16, num_res_blocks=num_res_blocks, downsampeling=False)
x = residual_block(x, initializer=initializer, regularizer=l2_reg, filters=32, num_res_blocks=num_res_blocks, downsampeling=True)
x = residual_block(x, initializer=initializer, regularizer=l2_reg, filters=64, num_res_blocks=num_res_blocks, downsampeling=True)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(10, activation='softmax', kernel_initializer=initializer, kernel_regularizer=l2_reg)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


# learning rate schedule: after 100 epochs 13% of initial learning rate
initial_learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = initial_learning_rate,
    decay_steps= y_train.size / batch_size,
    decay_rate=0.98,
    staircase=True)

#compile model
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = lr_schedule, momentum = momentum),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])



# fit data
history = model.fit(x_train, y_train, epochs=100, batch_size= batch_size , verbose=2, validation_data=(x_test, y_test))

#save the model
model.save(model_str+"/data/trained_model")

