# import dependencies
import tensorflow as tf
import os
import numpy as np
import argparse
import scipy
import matplotlib.pyplot as plt



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")
parser.add_argument('--epochs', type=int, default=20, dest="epochs")
parser.add_argument('--wr', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--basis_str', type=str, default="hessian", dest="basis_str")
parser.add_argument('--subtract_mean', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--seed_list', action='append', type=int)

args = parser.parse_args()
model_type = args.model_type
weight_decay = args.weight_decay
dtype = args.dtype
para_str = args.para_str
epochs = args.epochs
with_replacement = args.wr
basis_str = args.basis_str
subtract_mean = args.subtract_mean
seed_list = args.seed_list



# network specific parameters
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0



# parameters
init_lr_list = [4e-2, 3e-2, 2e-2,1e-2,5e-3]
batch_size_list = [32, 50, 64, 100, 128]
momentum_list = [0, 0.5, 0.7, 0.9, 0.95]



def func_lin(x, y_0):
  return y_0 - x

def tau_lambda_fit_fun(H_eigval, tau, lambda_cross, tau_fit_points=20):
  tau_average_mask = (H_eigval<1*lambda_cross)
  tau_average_fit = np.mean(tau[tau_average_mask])

  curvature_temp = H_eigval[:tau_fit_points]
  tau_temp = tau[:tau_fit_points]
  fit_temp, pcov = scipy.optimize.curve_fit(func_lin, np.log10(curvature_temp), np.log10(tau_temp))
  tau_at_unit_curvature_from_fit = 10**fit_temp
  lambda_cross_fit = tau_at_unit_curvature_from_fit/tau_average_fit

  return (tau_average_fit, lambda_cross_fit)

def tau_lambda_theory_fn(parameter_theory, parameter_str, init_lr, batch_size, momentum, y_train_size):
  if parameter_str == "lr":
    init_lr = parameter_theory
  if parameter_str == "b":
    batch_size = parameter_theory
  if parameter_str == "m":
    momentum = parameter_theory
  num_batches = np.ceil(y_train_size/batch_size)
  learning_rate = init_lr*(0.98**100)
  lambda_cross = 3*(1-momentum)/(learning_rate*num_batches)
  tau_sgd = (num_batches/3)*(1+momentum)/(1-momentum)*np.ones(len(parameter_theory))
  return (tau_sgd, lambda_cross)



# Fit tau_sgd and lambda_cross for all hyperparameter combinations and seeds
x_label_multi = []
parameter_list_multi = []
tau_sgd_fit_mean_multi = []
tau_sgd_fit_error_multi = []
lambda_cross_fit_mean_multi = []
lambda_cross_fit_error_multi = []
parameter_theory_multi = []
tau_sgd_theory_multi = []
lambda_cross_theory_multi = []

for parameter_str, parameter_list in zip(["lr","b","m"],[init_lr_list, batch_size_list, momentum_list]):
  tau_sgd_mean_list = []
  tau_sgd_error_list = []
  lambda_cross_mean_list = []
  lambda_cross_error_list = []
  for parameter in parameter_list:
    if parameter_str == "lr":
      init_lr = parameter
      batch_size = 50
      momentum = 0.9
    if parameter_str == "b":
      init_lr = 5e-3
      batch_size = parameter
      momentum = 0.9
    if parameter_str == "m":
      init_lr = 2e-2
      batch_size = 50
      momentum = parameter

    tau_sgd_fit_list = []
    lambda_cross_fit_list = []
    for seed in seed_list:

      # initialize variables
      model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
      initial_learning_rate = init_lr
      learning_rate_now = 0.98**100*initial_learning_rate
      num_batches = int(np.ceil(y_train.size/batch_size))
      lambda_cross = 3*(1-momentum)/(learning_rate_now*num_batches)

      # Loading evaluted data
      with_replacement_str = ""
      subtract_mean_str = ""
      if with_replacement == True:
        with_replacement_str = "WR_"
      if subtract_mean == True:
        subtract_mean_str = "SubMean"
      data_array = np.load(model_str+"/data/variance_data_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str+".npy")

      # fit tau_SGD and lambda_cross
      tau_sgd_fit, lambda_cross_fit = tau_lambda_fit_fun(data_array[0], data_array[3], lambda_cross)
      tau_sgd_fit_list.append(tau_sgd_fit)
      lambda_cross_fit_list.append(lambda_cross_fit)

    # calculate mean and 2\sigma error of fitted variables
    tau_sgd_mean_list.append(np.mean(tau_sgd_fit_list))
    tau_sgd_error_list.append(2*np.std(tau_sgd_fit_list)/np.sqrt(len(tau_sgd_fit_list)))
    lambda_cross_mean_list.append(np.mean(lambda_cross_fit_list))
    lambda_cross_error_list.append(2*np.std(lambda_cross_fit_list)/np.sqrt(len(lambda_cross_fit_list)))

  # calculate theory curve
  parameter_theory = np.linspace(np.min(parameter_list), np.max(parameter_list), num=100)
  tau_sgd_theory, lambda_cross_theory = tau_lambda_theory_fn(parameter_theory, parameter_str, init_lr, batch_size, momentum, y_train.size)

  if parameter_str == "lr":
    x_label = "Initial Learning Rate"
  if parameter_str == "b":
    x_label = "Batch Size"
  if parameter_str == "m":
    x_label = "Momentum"

  x_label_multi.append(x_label)
  parameter_list_multi.append(parameter_list)
  tau_sgd_fit_mean_multi.append(np.array(tau_sgd_mean_list))
  tau_sgd_fit_error_multi.append(np.array(tau_sgd_error_list))
  lambda_cross_fit_mean_multi.append(np.array(lambda_cross_mean_list))
  lambda_cross_fit_error_multi.append(np.array(lambda_cross_error_list))
  parameter_theory_multi.append(parameter_theory)
  tau_sgd_theory_multi.append(tau_sgd_theory)
  lambda_cross_theory_multi.append(lambda_cross_theory)



def triple_plotter(x_label_multi,
                   parameter_list_multi,
                   tau_sgd_fit_mean_multi,
                   tau_sgd_fit_error_multi,
                   lambda_cross_fit_mean_multi,
                   lambda_cross_fit_error_multi,
                   parameter_theory_multi,
                   tau_sgd_theory_multi,
                   lambda_cross_theory_multi):

  plt.style.use('plt_style.pstyle')
  prop_cycle = plt.rcParams['axes.prop_cycle']
  colors = prop_cycle.by_key()['color']
  colors2 =  ["#0C5DA5", "#70b7f5", "#00B945", "#33ff7e", "#FF9500", "#ffb54d", "#FF2C00", "#ff6a4d"]

  plt.rcParams['figure.figsize'] = [2.5*2.7, 1.9*2.0250]


  fig, ax = plt.subplots()
  for parameter_list, tau_sgd_fit_mean, tau_sgd_fit_error, parameter_theory, tau_sgd_theory in zip(parameter_list_multi, tau_sgd_fit_mean_multi, tau_sgd_fit_error_multi, parameter_theory_multi, tau_sgd_theory_multi):
    #ax.errorbar(parameter_list, tau_sgd_fit_mean, tau_sgd_fit_error, fmt="none", capsize = 1.5, lw = 0.7, color = colors[0])
    ax.scatter(parameter_list, tau_sgd_fit_mean, marker ="x", s = 1.5, color = colors[0])
    ax.plot(parameter_theory, tau_sgd_theory, lw = 0.7, color = colors2[1], label = "theory")
  ylim_tau = ax.get_ylim()
  plt.clf()
  fig, ax = plt.subplots()
  for parameter_list, lambda_cross_fit_mean, lambda_cross_fit_error, parameter_theory, lambda_cross_theory in zip(parameter_list_multi, lambda_cross_fit_mean_multi, lambda_cross_fit_error_multi, parameter_theory_multi, lambda_cross_theory_multi):
    ax.errorbar(parameter_list, lambda_cross_fit_mean, lambda_cross_fit_error, fmt="none", capsize = 1.5, lw = 0.7, color = colors[0])
    ax.plot(parameter_theory, lambda_cross_theory, lw = 0.7, color = colors2[1], label = "theory")
  ylim_lambda = ax.get_ylim()
  plt.clf()


  fig, axes = plt.subplots(2,3)
  for idx, x_label in zip(range(3), x_label_multi):
    parameter_list = parameter_list_multi[idx]
    tau_sgd_fit_mean = tau_sgd_fit_mean_multi[idx]
    tau_sgd_fit_error = tau_sgd_fit_error_multi[idx]
    lambda_cross_fit_mean = lambda_cross_fit_mean_multi[idx]
    lambda_cross_fit_error = lambda_cross_fit_error_multi[idx]
    parameter_theory = parameter_theory_multi[idx]
    tau_sgd_theory = tau_sgd_theory_multi[idx]
    lambda_cross_theory = lambda_cross_theory_multi[idx]

    axes[0, idx].plot(parameter_theory, tau_sgd_theory, lw = 1.5, color = colors2[1], label = "theory", zorder = 0)
    #axes[0, idx].errorbar(parameter_list, tau_sgd_fit_mean, tau_sgd_fit_error, fmt="none", capsize = 1.5, lw = 0.5, capthick=0.5, color = colors[0], label="data")
    axes[0, idx].scatter(parameter_list, tau_sgd_fit_mean, marker ="x", s = 14, lw=0.8, color = colors[0], label="data")
    axes[0, idx].set_ylim(ylim_tau)
    axes[0, idx].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
    axes[1, idx].plot(parameter_theory, lambda_cross_theory, lw = 1.5, color = colors2[3], label = "theory", zorder = 0)
    #axes[1, idx].errorbar(parameter_list, lambda_cross_fit_mean, lambda_cross_fit_error, fmt="none", capsize = 1.5, lw = 0.5, capthick=0.5, color = colors[1], label="data")
    axes[1, idx].scatter(parameter_list, lambda_cross_fit_mean, marker ="x", s = 14, lw=0.8, color = colors[1], label="data")
    axes[1, idx].set_ylim(ylim_lambda)
    axes[1, idx].set_xlabel(x_label)



  axes[0, 0].set_ylabel(r"$\tau_{SGD}$")
  axes[1, 0].set_ylabel(r"$\lambda_{\rm cross}$")
  axes[0, 2].legend()
  axes[1, 2].legend()



# Creat directory
if not os.path.exists("Appendix_Plots"):
  os.mkdir("Appendix_Plots")

triple_plotter(x_label_multi,
               parameter_list_multi,
               tau_sgd_fit_mean_multi,
               tau_sgd_fit_error_multi,
               lambda_cross_fit_mean_multi,
               lambda_cross_fit_error_multi,
               parameter_theory_multi,
               tau_sgd_theory_multi,
               lambda_cross_theory_multi)

plt.savefig("Appendix_Plots/hyperparameter_dependence.jpeg")





