# import dependencies
import tensorflow as tf
import numpy as np
import argparse



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")
parser.add_argument('--epochs', type=int, default=20, dest="epochs")
parser.add_argument('--wr', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--basis_str', type=str, default="hessian", dest="basis_str")
parser.add_argument('--subtract_mean', default=False, action=argparse.BooleanOptionalAction)

args = parser.parse_args()
model_type = args.model_type
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype
para_str = args.para_str
epochs = args.epochs
with_replacement = args.wr
basis_str = args.basis_str
subtract_mean = args.subtract_mean



# network specific parameters
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
learning_rate_now = 0.98**100*initial_learning_rate
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# load the model
model = tf.keras.models.load_model(model_str+"/data/trained_model")



# Loading the weight timeseries
if with_replacement == True:
  weights_list = np.load(model_str+"/data/weights_timeseries_WR_"+para_str+".npy")
else:
  weights_list = np.load(model_str+"/data/weights_timeseries_"+para_str+".npy")
dim = weights_list.shape[1]
total_time = weights_list.shape[0]
time_vec = np.array(range(total_time))



# Subtract mean velocity if necessary
if subtract_mean == True:
  vel_list = np.diff(weights_list, axis=0)
  vel_mean = np.mean(vel_list, axis=0)
  for i in range(total_time):
    weights_list[i] = weights_list[i] - i*vel_mean



# Load Hessian (Hessian eigenvector projection or layer wise recording)
if para_str[:4] == "Htop":
  H_eigvec = np.identity(dim)
  H_eigval = np.load(model_str+"/data/"+para_str+"_eigval.npy").astype(dtype)
  hessian = np.diag(H_eigval)
elif para_str[:3] == "lyr":
  hessian = np.load(model_str+"/data/hessian_"+para_str+".npy")
  H_eigval, H_eigvec = np.linalg.eigh(hessian)
  H_eigval = np.flip(H_eigval, axis = 0)
  H_eigvec = np.flip(np.transpose(H_eigvec), axis = 0)


# Project the weights onto the Hessian eigenbasis
if basis_str == "hessian":
  theta_i = np.tensordot(H_eigvec, weights_list - np.mean(weights_list, axis = 0), axes=(1, 1))

# OR onto Weight Covariance (Sigma) eigenbasis
if basis_str == "sigma":
  variance_theta, Sigma_eigvec = np.linalg.eigh(np.cov(np.transpose(weights_list)))
  Sigma_eigvec = np.transpose(Sigma_eigvec)
  Sigma_eigvec = np.flip(Sigma_eigvec, axis = 0)
  theta_i = np.tensordot(Sigma_eigvec, weights_list - np.mean(weights_list, axis = 0), axes=(1, 1))
  # Calculate curvature of the corresponding eigendirections
  H_eigval = np.array([np.matmul(np.transpose(Sigma_eigvec[idx]), np.matmul(hessian, Sigma_eigvec[idx])) for idx in range(dim)])

# Calculate variances
vel_i = np.diff(theta_i, axis=1)
variance_theta = np.var(theta_i, axis = 1)
variance_vel = np.var(vel_i, axis = 1)
tau = 2*variance_theta/variance_vel



# Theory functions variables
num_batches = int(np.ceil(y_train.size/batch_size))
M = num_batches
S = batch_size
eta = learning_rate_now
beta = momentum

# Variance theory
def corr(h):
  if h==0:
    return 1
  if np.abs(h) < M+1:
    return -(M-np.abs(h))/(M*(M-1))
  if np.abs(h) > M:
    return 0

def variance_vec_fn(lambda_theory):
  variance_noise_theory = lambda_theory/S
  D = np.array([[1 + beta -eta*lambda_theory, -beta],[1, 0]])
  F = np.array([[(1 + beta)/(eta*lambda_theory), 2*beta*(eta*lambda_theory - 1 - beta)/(eta*lambda_theory)], [2, 2*(eta*lambda_theory - 2)]])
  F = F/((1-beta)*(2*(1+beta) - eta*lambda_theory))
  e_1 = np.array([1,0])
  # The matrix E is calculated not via a matrix inverse but via the sum to ensure numerical stability
  E = np.array([[0,0],[0,0]])
  for h in range(1,M):
    E = E + corr(h)*np.linalg.matrix_power(D,h)
  E = -np.matmul(E, np.array([[1,0],[0,0]]))
  variance_vec = e_1 - np.matmul((E+np.transpose(E)), e_1)
  variance_vec = eta**2*variance_noise_theory*np.matmul(F, variance_vec)
  return variance_vec



# Calculate Tau Theory
steps_theory = 100
lambda_max = np.max(H_eigval)
lambda_min = np.min(H_eigval)
lambda_theory = np.logspace(np.log10(lambda_min), np.log10(lambda_max), num=steps_theory)
variance_vec_list = np.array([variance_vec_fn(lambda_temp) for lambda_temp in lambda_theory])
variance_theta_theory = variance_vec_list[:,0]
variance_vel_theory = variance_vec_list[:,1]
tau_theory = 2*variance_theta_theory/variance_vel_theory



# Save analysed data
data_array = np.array([H_eigval, variance_theta, variance_vel, tau])
theory_array = np.array([lambda_theory, tau_theory])
with_replacement_str = ""
subtract_mean_str = ""
if with_replacement == True:
  with_replacement_str = "WR_"
if subtract_mean == True:
  subtract_mean_str = "SubMean"
np.save(model_str+"/data/variance_data_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str, data_array)
np.save(model_str+"/data/tau_theory_"+with_replacement_str+para_str+"_"+basis_str+subtract_mean_str, theory_array)

