# import dependencies
import tensorflow as tf
import os
import argparse



parser = argparse.ArgumentParser()
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")

args = parser.parse_args()
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype



# network specific parameters
model_type = "lenet"
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# create directories
if not os.path.exists(model_str):
  os.mkdir(model_str)
  os.mkdir(model_str+"/data")
  os.mkdir(model_str+"/plots")



# initialize the model
tf.random.set_seed(seed)
initializer = tf.keras.initializers.GlorotNormal(seed=seed)
l2_reg = tf.keras.regularizers.L2(weight_decay)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(6, 5, activation='relu', input_shape=(32, 32, 3), padding='same', kernel_initializer=initializer, kernel_regularizer=l2_reg))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(16, 5, activation='relu', padding='same', kernel_initializer=initializer, kernel_regularizer=l2_reg))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(120, activation='relu', kernel_initializer=initializer, kernel_regularizer=l2_reg))
model.add(tf.keras.layers.Dense(84, activation='relu', kernel_initializer=initializer, kernel_regularizer=l2_reg))
model.add(tf.keras.layers.Dense(10, activation='softmax', kernel_initializer=initializer, kernel_regularizer=l2_reg))

# learning rate schedule: after 100 epochs 13% of initial learning rate
initial_learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = initial_learning_rate,
    decay_steps= y_train.size / batch_size,
    decay_rate=0.98,
    staircase=True)

#compile model
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = lr_schedule, momentum = momentum),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])



# fit data
history = model.fit(x_train, y_train, epochs=100, batch_size= batch_size , verbose=2, validation_data=(x_test, y_test))

#save the model
model.save(model_str+"/data/trained_model")

