# import dependencies
import tensorflow as tf
import numpy as np
import argparse
import scipy
import matplotlib.pyplot as plt



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")
parser.add_argument('--epochs', type=int, default=20, dest="epochs")
parser.add_argument('--wr', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--basis_str', type=str, default="hessian", dest="basis_str")
parser.add_argument('--subtract_mean', default=False, action=argparse.BooleanOptionalAction)

args = parser.parse_args()
model_type = args.model_type
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype
para_str = args.para_str
epochs = args.epochs
with_replacement = args.wr
basis_str = args.basis_str
subtract_mean = args.subtract_mean



# network specific parameters
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
learning_rate_now = 0.98**100*initial_learning_rate
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# load the model
model = tf.keras.models.load_model(model_str+"/data/trained_model")



# Loading evaluted data
with_replacement_str = ""
subtract_mean_str = ""
if with_replacement == True:
  with_replacement_str = "WR_"
if subtract_mean == True:
  subtract_mean_str = "SubMean"
data_array = np.load(model_str+"/data/variance_data_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str+".npy")
theory_array = np.load(model_str+"/data/tau_theory_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str+".npy")



# Fit variances
def func_pwr(x, y_0, pwr):
  return y_0*x**pwr
def func_lin(x, y_0, pwr):
  return y_0 + x*pwr
def fit_var(data_array, theory_array, lambda_cross):
  idx_start_fit_vel = np.argmax(data_array[0])
  if np.min(data_array[0]) <  lambda_cross/10:
    idx_start_fit_theta = idx_start_fit_vel + next(x[0] for x in enumerate(data_array[0][idx_start_fit_vel:]) if x[1] < lambda_cross/10)
    fit_variance_theta, pcov = scipy.optimize.curve_fit(func_lin, np.log10(data_array[0][idx_start_fit_theta:]), np.log10(data_array[1][idx_start_fit_theta:]))
  else:
    idx_start_fit_theta = len(data_array[0])-1
    fit_variance_theta = [np.mean(np.log10(data_array[1])), 0]
    pcov = np.zeros((2,2))
  print(f"Fit Variance Theta exponent: {fit_variance_theta[1]}")
  print(f"Fit Variance Theta 1sigma error: {np.sqrt(np.diag(pcov))[1]}")
  fit_variance_theta = np.array([10**fit_variance_theta[0], fit_variance_theta[1]])
  fit_variance_vel, pcov = scipy.optimize.curve_fit(func_lin, np.log10(data_array[0]), np.log10(data_array[2]))
  print(f"Fit Variance Velocity exponent: {fit_variance_vel[1]}")
  print(f"Fit Variance Velocity 1sigma error: {np.sqrt(np.diag(pcov))[1]}")
  fit_variance_vel = np.array([10**fit_variance_vel[0], fit_variance_vel[1]])
  return (fit_variance_theta, fit_variance_vel, idx_start_fit_theta, idx_start_fit_vel)



# Plotting the cross-over value
num_batches = y_train.size/batch_size
lambda_cross = 3*(1-momentum)/(learning_rate_now*num_batches)
def add_lambda_cross(ax):
  xlim_l, xlim_u = ax.get_xlim()
  ax.axvline(x=lambda_cross, ls="--", color = "grey", zorder = 0, lw = 0.5)
  ticklabels = ax.get_xticklabels()
  ax.set_xticks(np.append(ax.get_xticks(), [lambda_cross]))
  ax.set_xticklabels(ticklabels + [r"$\lambda_{cross}$"])
  ax.set_xlim(xlim_l, xlim_u)



def plot_var(data_array, theory_array):
  lw = 0.7
  dot_size = 1
  fit_variance_theta, fit_variance_vel, idx_start_fit_theta, idx_start_fit_vel  = fit_var(data_array, theory_array, lambda_cross)
  F_plot_fit_total = np.array([data_array[0][idx_start_fit_vel], data_array[0][-1]])
  F_plot_fit_theta = np.array([data_array[0][idx_start_fit_theta], data_array[0][-1]])

  fig, ax = plt.subplots()
  ax.scatter(data_array[0], data_array[1],  s = dot_size, lw=0, label = r"$\sigma_{\theta, i}^2$")
  ax.plot(F_plot_fit_total, func_pwr(F_plot_fit_total, *fit_variance_theta), color = colors2[1], ls = "--", lw = lw)
  ax.plot(F_plot_fit_theta, func_pwr(F_plot_fit_theta, *fit_variance_theta), color = colors2[1], lw = lw)
  ax.scatter(data_array[0], data_array[2],  s = dot_size, lw=0, label = r"$\sigma_{v, i}^2$")
  ax.plot(F_plot_fit_total, func_pwr(F_plot_fit_total, *fit_variance_vel), color = colors2[3], lw = lw)

  ax.set_yscale("log")
  ax.set_xscale("log")
  ax.set_xlabel(r"Hessian eigenvalue $\lambda_i$")
  ax.set_ylabel(r"fluctuations")
  ax.legend(loc=0)
  return ax



def plot_tau(data_array, theory_array):
  fig, ax = plt.subplots()
  ax.scatter(data_array[0], data_array[3], s = 0.6, lw=0, label = "data")
  ax.plot(theory_array[0], theory_array[1], label = "theory", color = colors[2], lw = 0.7)

  ax.set_xlabel(r"Hessian eigenvalue $\lambda_i$")
  ax.set_ylabel(r"correlation time $\tau_i$")
  ax.legend(loc=3)
  ax.set_yscale("log")
  ax.set_xscale("log")
  return ax



plt.style.use('plt_style.pstyle')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
colors2 =  ["#0C5DA5", "#70b7f5", "#00B945", "#33ff7e", "#FF9500", "#ffb54d", "#FF2C00", "#ff6a4d"]



ax = plot_var(data_array, theory_array)
add_lambda_cross(ax)
plt.savefig(model_str+"/plots/variance_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str+".jpeg")



ax = plot_tau(data_array, theory_array)
add_lambda_cross(ax)
plt.savefig(model_str+"/plots/tau_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str+".jpeg")





