# import dependencies
import tensorflow as tf
import numpy as np
import argparse
import scipy
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")

args = parser.parse_args()
model_type = args.model_type
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype
para_str = args.para_str



# network specific parameters
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
learning_rate_now = 0.98**100*initial_learning_rate
hessian_batch_size = 200 #batch size for calculating Hessian vector products
eig_number = int(float(para_str[4:])) # number of wanted eigenvectors
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

#load the model
model = tf.keras.models.load_model(model_str+"/data/trained_model")



parameters = model.trainable_variables

@tf.function
def flatten_tf(params):
  return tf.concat([tf.reshape(_params, [-1]) for _params in params], axis=0)

# Hessian vector product for one batch of examples via the Pearlmutter trick
@tf.function
def HVP_partial_tf(vec, images, labels):
  with tf.GradientTape(persistent=True) as tape1:
    with tf.GradientTape() as tape2:
      regularization_loss = tf.math.add_n(model.losses)
      predictions = model(images, training=True)
      pred_loss = model.loss(labels, predictions)
      total_loss = pred_loss + regularization_loss
    gradient_vec = tape2.gradient(total_loss, parameters)
    gradient_flat =  flatten_tf(gradient_vec)
    gradient_prod = tf.math.multiply(gradient_flat, tf.stop_gradient(vec))
  gradient_vec_2 = tape1.gradient(gradient_prod, parameters)
  return flatten_tf(gradient_vec_2)

# Hessian vector product for all examples
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(hessian_batch_size)
def HVP_fun(vec):
  vec = vec.astype(dtype)
  output = np.zeros(shape = vec.shape, dtype=dtype)
  for images, labels in train_ds:
    output = output + HVP_partial_tf(vec, images, labels)*labels.shape[0]
  return np.array(output)/y_train.shape[0]



# Calculate the eigenvalues and eigenvectors of the Hessian via scipy algorithm
dim = np.array(flatten_tf(parameters)).size
hessian_operator = scipy.sparse.linalg.LinearOperator(shape = (dim, dim), dtype = dtype, matvec = HVP_fun)
H_eigval, H_eigvec = scipy.sparse.linalg.eigsh(hessian_operator, k = eig_number, which = "LA")
H_eigval = np.flip(H_eigval, axis = 0)
H_eigvec = np.flip(np.transpose(H_eigvec), axis = 0)

np.save(model_str+"/data/"+para_str+"_eigval", H_eigval)
np.save(model_str+"/data/"+para_str+"_eigvec", H_eigvec)



# Calculate the Hessian eigenvalue density
def eigenvalue_density(H_eigval, dim, sigma=1e-2, min_density_factor=5, num_points=10000, xlim=None):
  prefactor = 1/(np.sqrt(2*np.pi)*sigma)
  def gauss_fn(x,x_0):
    return prefactor*np.exp(-0.5*(x-x_0)**2/(sigma**2))

  num = num_points # number of points plotted
  if xlim == None:
    density_x = np.linspace(np.min(H_eigval)-30*sigma,np.max(H_eigval)+6*sigma, num = num)
  else:
    density_x = np.linspace(xlim[0], xlim[1], num = num)
  density_y = np.zeros(num)
  for idx in range(num):
    x = density_x[idx]
    density_y[idx] = np.sum(gauss_fn(x, H_eigval))
  density_y += prefactor/min_density_factor
  density_y = density_y/dim
  return(density_x, density_y)



# Plotting the cross-over value
learning_rate_now = 0.98**100*initial_learning_rate
num_batches = y_train.size/batch_size
lambda_cross = 3*(1-momentum)/(learning_rate_now*num_batches)
def add_lambda_cross(ax):
  xlim_l, xlim_u = ax.get_xlim()
  ax.axvline(x=lambda_cross, ls="--", color = "grey", zorder = 0, lw = 0.5)
  ticklabels = ax.get_xticklabels()
  ax.set_xticks(np.append(ax.get_xticks(), [lambda_cross]))
  ax.set_xticklabels(ticklabels + [r"$\lambda_{cross}$"])
  ax.set_xlim(xlim_l, xlim_u)



# Plot the Hessian eigenvalue density
plt.style.use('plt_style.pstyle')
plt.rcParams['figure.figsize'] = [2*2.7, 1*2.0250]
fig, ax = plt.subplots()
axins = inset_axes(ax, width="50%", height="50%", loc=1, borderpad=1.4)

# Main Plot
(density_x, density_y) = eigenvalue_density(H_eigval, dim, sigma=1e-2)
ax.plot(density_x, density_y)
ax.set_yscale("log")
ylim_l, ylim_u = ax.get_ylim()
ax.set_ylim(np.min(density_y), ylim_u)
ax.set_ylabel(r"Hessian Eigenvalue Density")
ax.set_xlabel(r"Eigenvalue $\lambda$")

# Inset
xlim=(-2.1e-3,2.1e-2)
(density_x, density_y) = eigenvalue_density(H_eigval, dim, sigma=5e-5, xlim=xlim)
axins.plot(density_x, density_y)
axins.set_yscale("log")
ylim_l, ylim_u = axins.get_ylim()
axins.set_ylim(np.min(density_y), ylim_u)
xlim_l,xlim_u = axins.get_xlim()
axins.set_xlim(xlim_l, xlim[1])

add_lambda_cross(ax)
plt.savefig(model_str+"/plots/Hessian_density_"+para_str+".jpeg")





