# import dependencies
import tensorflow as tf
import numpy as np
import argparse
import matplotlib.pyplot as plt



parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_type', type=str, default="lenet", dest="model_type")
parser.add_argument('-s', '--seed', type=int, default=1, dest="seed")
parser.add_argument('-w', '--weight_decay', type=float, default=1e-4, dest="weight_decay")
parser.add_argument('-l', '--init_lr', type=float, default=5e-3, dest="init_lr")
parser.add_argument('-b', '--momentum', type=float, default=0.9, dest="momentum")
parser.add_argument('--batch_size', type=int, default=50, dest="batch_size")
parser.add_argument('--dtype', type=str, default="float32", dest="dtype")
parser.add_argument('--para_str', type=str, default="Htop5e+03", dest="para_str")
parser.add_argument('--epochs', type=int, default=20, dest="epochs")
parser.add_argument('--wr', default=False, action=argparse.BooleanOptionalAction)

args = parser.parse_args()
model_type = args.model_type
seed = args.seed
weight_decay = args.weight_decay
init_lr = args.init_lr
momentum = args.momentum
batch_size = args.batch_size
dtype = args.dtype
para_str = args.para_str
epochs = args.epochs
with_replacement = args.wr



# network specific parameters
model_str = model_type+"_"+f"wd{weight_decay:.0e}_lr{init_lr:.0e}_b{batch_size:.0f}_m{momentum:.2f}_"+dtype[-2:]+"_"+str(seed)
initial_learning_rate = init_lr
learning_rate_now = 0.98**100*initial_learning_rate
tf.keras.backend.set_floatx(dtype)

# load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(dtype)
x_test = x_test.astype(dtype)
x_train, x_test = x_train / 255.0, x_test / 255.0

# load the model
model = tf.keras.models.load_model(model_str+"/data/trained_model")



# loading the gradient timeseries
with_replacement_str = ""
if with_replacement == True:
  with_replacement_str = "WR_"
grad_batch_list = np.load(model_str+"/data/grad_batch_timeseries_"+with_replacement_str+para_str+".npy")
grad_tot_list = np.load(model_str+"/data/grad_tot_timeseries_"+with_replacement_str+para_str+".npy")
dim = grad_batch_list.shape[1]
total_time = grad_batch_list.shape[0]

# Calculate the noise term
grad_noise_list = grad_batch_list - grad_tot_list



# Calculate the empirical autocorrelation
num_batches = y_train.size/batch_size
corr_sum = np.zeros(total_time)
for idx in range(dim):
  corr_temp = np.correlate(grad_noise_list[:,idx], grad_noise_list[:,idx], mode="same")
  corr_sum += corr_temp/corr_temp[int(total_time/2)]
corr_sum = corr_sum/dim

# Avoid plotting same-time correlation of one
numerics_x = np.concatenate((np.arange(-2000, 0), np.arange(1, 2000)))/num_batches
numerics_y = np.concatenate((corr_sum[int(total_time/2)-2000:int(total_time/2)], corr_sum[int(total_time/2)+1:int(total_time/2)+2000]))



# Theoretical autocorrelation
theory_x = [-2, -1, 0, 1, 2]
if with_replacement == True:
  theory_y = [0,0,0,0,0]
else:
  theory_y = [0,0,-1/num_batches,0,0]
theory_error = 1/np.sqrt(total_time*dim)



plt.style.use('plt_style.pstyle')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
colors2 =  ["#0C5DA5", "#70b7f5", "#00B945", "#33ff7e", "#FF9500", "#ffb54d", "#FF2C00", "#ff6a4d"]
plt.rcParams['figure.figsize'] = [2*2.7, 1*2.0250]



fig, ax = plt.subplots()
ax.scatter(numerics_x, numerics_y, label = "data", s = 1.5, lw=0)
ax.plot(theory_x, theory_y, label = "theory", color = colors[2], lw = 1)
ax.plot(theory_x, theory_y+2*theory_error, color = colors2[5], lw = 1, ls="--")
ax.plot(theory_x, theory_y-2*theory_error, label = r"$2\sigma$-interval", color = colors2[5], lw = 1, ls="--")
ax.set_xlabel("Update step difference: n [epochs]")
ax.set_ylabel(r"corr$_{\delta g, avg}$(n)")
ax.ticklabel_format(axis="y", style="scientific", scilimits = [-3,-3])
ax.legend(loc=3)
plt.savefig(model_str+"/plots/autocorrelation_"+with_replacement_str+para_str+".jpeg")





