# import dependencies
import numpy as np
import matplotlib.pyplot as plt
import os



# SGD parameters
seed = 1
dim = 2500
total_steps = 12000
eta = 0.1

# Initialize SGD
rng = np.random.default_rng(seed)
theta = np.zeros(dim)
velocity = np.zeros(dim)
theta_list = [theta]
velocity_list = [velocity]

# Train for "total_steps" steps
for step in range(total_steps):
    velocity = -eta*theta + rng.normal(0, np.full(dim, 1))
    theta = theta + velocity
    theta_list.append(theta)
    velocity_list.append(velocity)
theta_list = np.array(theta_list)
velocity_list = np.array(velocity_list)



# calculate the variances in the original basis
variance_theta_original = np.var(theta_list, axis = 0)
variance_vel_original = np.var(velocity_list, axis = 0)

# calculate the variances in the eigenbasis of the weight covariance matrix sigma
variance_theta_sigma, eigenvectors_sigma = np.linalg.eigh(np.cov(np.transpose(theta_list)))
theta_sigma = np.tensordot(eigenvectors_sigma, theta_list, axes=(0, 1))
vel_sigma = np.diff(theta_sigma, axis=1)
variance_theta_sigma = np.var(theta_sigma, axis = 1)
variance_vel_sigma = np.var(vel_sigma , axis = 1)



def plot_var(variance_theta, variance_vel):
  fig, ax = plt.subplots()
  ax.scatter(range(1, dim+1), np.flip(variance_theta), label = r"$\sigma_{\theta, i}^2$",  s = 2)
  ax.scatter(range(1, dim+1), np.flip(variance_vel), label = r"$\sigma_{v, i}^2$",  s = 2)
  ax.set_xlabel("basis index $i$")
  ax.set_ylabel("fluctuations")
  ax.legend(loc=1)
  ax.set_yscale("log")
  return ax



def plot_tau(variance_theta, variance_vel):
  fig, ax = plt.subplots()
  ax.scatter(range(1, dim+1), 2*np.flip(variance_theta)/np.flip(variance_vel),  s = 2)
  ax.set_xlabel("basis index $i$")
  ax.set_ylabel(r"correlation time $\tau_i$")
  ax.set_yscale("log")
  return ax



# Creat directory
if not os.path.exists("Artificial_SGD"):
  os.mkdir("Artificial_SGD")



# Plot style
plt.style.use('plt_style.pstyle')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
colors2 =  ["0C5DA5", "a0cff8", "00B945", "99ffbe", "FF9500", "ffd599", "FF2C00", "ffbfb3"]



# Plot the variances in the two different bases
ax = plot_var(variance_theta_sigma, variance_vel_sigma)
ylim_var = ax.get_ylim()
ax.set_title(r"$\bf \Sigma$ Eigenbasis")
plt.savefig("Artificial_SGD/var_sigma.jpeg")

ax = plot_var(variance_theta_original, variance_vel_original)
ax.set_ylim(ylim_var)
ax.set_title("Original basis")
plt.savefig("Artificial_SGD/var_original.jpeg")



# Plot the correlation time in the two different bases
ax = plot_tau(variance_theta_sigma, variance_vel_sigma)
ylim_tau = ax.get_ylim()
ax.set_title(r"$\bf \Sigma$ Eigenbasis")
plt.savefig("Artificial_SGD/tau_sigma.jpeg")

ax = plot_tau(variance_theta_original, variance_vel_original)
ax.set_ylim(ylim_tau)
ax.set_title("Original basis")
plt.savefig("Artificial_SGD/tau_original.jpeg")

