# import dependencies
import tensorflow as tf
import numpy as np
import argparse
import scipy
import matplotlib.pyplot as plt



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")
parser.add_argument('--epochs', type=int, default=20, dest="epochs")
parser.add_argument('--wr', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--basis_str', type=str, default="hessian", dest="basis_str")
parser.add_argument('--subtract_mean', default=False, action=argparse.BooleanOptionalAction)

args = parser.parse_args()
model_type = args.model_type
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype
para_str = args.para_str
epochs = args.epochs
with_replacement = args.wr
basis_str = args.basis_str
subtract_mean = args.subtract_mean



# network specific parameters
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
learning_rate_now = 0.98**100*initial_learning_rate
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# load the model
#model = tf.keras.models.load_model(model_str+"/data/trained_model")



# loading the gradient timeseries
with_replacement_str = ""
if with_replacement == True:
  with_replacement_str = "WR_"
grad_batch_list = np.load(model_str+"/data/grad_batch_timeseries_"+with_replacement_str+para_str+".npy")
grad_tot_list = np.load(model_str+"/data/grad_tot_timeseries_"+with_replacement_str+para_str+".npy")
dim = grad_batch_list.shape[1]
total_time = grad_batch_list.shape[0]

# Calculate the noise term
grad_noise_list = grad_batch_list - grad_tot_list



# Load Hessian
H_eigval = np.load(model_str+"/data/"+para_str+"_eigval.npy").astype(dtype)
hessian = np.diag(H_eigval)



# Calculate the gradiant noise covariance
noise_cov = np.cov(np.transpose(grad_noise_list))



# Fit variances
def func_pwr(x, y_0, pwr):
  return y_0*x**pwr
def func_lin(x, y_0, pwr):
  return y_0 + x*pwr
def fit_var(H_eigval, noise_var):
  fit_variance_noise, pcov = scipy.optimize.curve_fit(func_lin, np.log10(H_eigval), np.log10(noise_var))
  print(f"Fit Variance Noise exponent: {fit_variance_noise[1]}")
  print(f"Fit Variance Noise 1sigma error: {np.sqrt(np.diag(pcov))[1]}")
  fit_variance_noise = np.array([10**fit_variance_noise[0], fit_variance_noise[1]])
  return (fit_variance_noise)



def plot_var(H_eigval, noise_var):
  lw = 0.7
  dot_size = 0.7
  dot_size_legend=2
  fit_variance_noise  = fit_var(H_eigval, noise_var)
  F_plot_fit_total = np.array([H_eigval[0], H_eigval[-1]])

  fig, ax = plt.subplots()
  ax.scatter(H_eigval, noise_var,  s = dot_size, lw=0)
  ax.plot(F_plot_fit_total, func_pwr(F_plot_fit_total, *fit_variance_noise), color = colors2[1], lw = lw)
  ax.scatter(-1, -1,  s = dot_size_legend, lw=0, label = r"$\sigma_{\delta g, i}^2$", color = colors[0])

  ax.set_yscale("log")
  ax.set_xscale("log")
  ax.set_xlabel(r"Hessian eigenvalue $\lambda_i$")
  ax.set_ylabel(r"noise variance")
  ax.legend(loc=0)
  return ax



plt.style.use('plt_style.pstyle')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
colors2 =  ["#0C5DA5", "#70b7f5", "#00B945", "#33ff7e", "#FF9500", "#ffb54d", "#FF2C00", "#ff6a4d"]



cos_sim = np.dot(np.ndarray.flatten(hessian), np.ndarray.flatten(noise_cov))/(np.linalg.norm(np.ndarray.flatten(hessian))*np.linalg.norm(np.ndarray.flatten(noise_cov)))
print("Cosine similarity between noise covariance and Hessian matrix: "+str(cos_sim))



ax = plot_var(H_eigval, np.diagonal(noise_cov))
plt.savefig(model_str+"/plots/noise_variance_"+with_replacement_str+para_str+"_"+basis_str+".jpeg")



eigval, eigvec = np.linalg.eigh(noise_cov)
eigval = np.flip(eigval)
eigvec = np.transpose(eigvec)
eigvec = np.flip(eigvec, axis=0)



def plot_noise_commute(eigval, eigvec, diagonal, idx2=5000):
  fig2, ax2 = plt.subplots()
  ax2.scatter(range(1,1+idx2), eigval[:idx2], s=1, label=r"along ${\bf p}_i^{\bf C}$")
  ax2.scatter(range(1,1+idx2), diagonal[:idx2], s=1, label=r"along ${\bf p}_i^{\bf H}$")
  ax2.legend()

  ax2.set_ylabel(r"noise variance")
  ax2.set_xlabel(r"index $i$")
  ax2.set_yscale("log")

  return (fig2, ax2)



(fig2, ax2) = plot_noise_commute(eigval, eigvec, np.diagonal(noise_cov), idx2=200)
fig2.savefig(model_str+"/plots/noise_variance_eigval_comparison_"+with_replacement_str+para_str+"_"+basis_str+".jpeg")





