# import dependencies
import tensorflow as tf
import numpy as np
import argparse



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")
parser.add_argument('--epochs', type=int, default=20, dest="epochs")
parser.add_argument('--wr', default=False, action=argparse.BooleanOptionalAction)

args = parser.parse_args()
model_type = args.model_type
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype
para_str = args.para_str
epochs = args.epochs
with_replacement = args.wr



# network specific parameters
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
learning_rate_now = 0.98**100*initial_learning_rate
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# load the model
model = tf.keras.models.load_model(model_str+"/data/trained_model")



@tf.function
def flatten_tf(params):
  return tf.concat([tf.reshape(_params, [-1]) for _params in params], axis=0)

@tf.function
def add_list_tf(weights_list_0, weights_list_1):
  weights_list_result = [tensor_0+tensor_1 for tensor_0, tensor_1 in zip(weights_list_0, weights_list_1)]
  return weights_list_result

# Defining what to recorde during training: one specific layer or projections onto Hessian eigenvectors
if para_str[:4] == "Htop":
  dim = int(float(para_str[4:]))
  H_eigvec = np.load(model_str+"/data/"+para_str+"_eigvec.npy").astype(dtype)
  def projector_fn(weights_list):
    weights = np.array(flatten_tf(weights_list))
    weights = np.matmul(H_eigvec, weights)
    return weights
elif para_str[:3] == "lyr":
  layer_idx = int(para_str[3:])
  parameters = model.trainable_variables[layer_idx]
  dim = len(np.array(parameters).flatten())
  def projector_fn(weights_list):
    weights = np.array(weights_list[layer_idx]).flatten()
    return weights



optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_now, momentum=momentum)
train_loss = tf.keras.metrics.Mean(name='train_loss')

# train function
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=model.loss(labels, predictions)
    total_loss=pred_loss + regularization_loss
  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  train_loss(total_loss)
  return gradients

# pure gradient function for calculating the total gradient
@tf.function
def gradient_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=model.loss(labels, predictions)
    total_loss=pred_loss + regularization_loss
  gradients = tape.gradient(total_loss, model.trainable_variables)
  return  gradients

# total gradient function
train_ds_cmpl = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
def gradient_total_fn():
  gradients = [tf.zeros(shape=tensor.shape, dtype=dtype) for tensor in model.trainable_variables]
  for images, labels in train_ds_cmpl:
    gradients_temp = gradient_step(images, labels)
    gradients = add_list_tf(gradients, gradients_temp)
  return gradients



# train the network further
num_batches = int(y_train.size/batch_size)
gradient_total = np.zeros((num_batches*epochs, dim))
gradient_batch = np.zeros((num_batches*epochs, dim))
batch_index = np.array([0])
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(y_train.size).batch(batch_size)

for _epoch in range(epochs):
  # Draw the examples with replacement if necessary
  if with_replacement == True:
    rng = np.random.default_rng(seed+_epoch)
    idx_list = rng.integers(low=0, high=y_train.size, size=y_train.size)
    train_ds = tf.data.Dataset.from_tensor_slices((x_train[idx_list], y_train[idx_list])).batch(batch_size)

  for images, labels in train_ds:
    gradient_total[batch_index[0]] = projector_fn(gradient_total_fn())/num_batches
    gradient_batch[batch_index[0]] = projector_fn(train_step(images, labels))
    np.add.at(batch_index, 0, 1)

    output_str = f"Epoch {_epoch+1}/{epochs}"
  print(output_str)

if with_replacement == True:
  np.save(model_str+"/data/grad_batch_timeseries_WR_"+para_str, gradient_batch)
  np.save(model_str+"/data/grad_tot_timeseries_WR_"+para_str, gradient_total)
else:
  np.save(model_str+"/data/grad_batch_timeseries_"+para_str, gradient_batch)
  np.save(model_str+"/data/grad_tot_timeseries_"+para_str, gradient_total)





