import torch
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
sys.path.append("../src")
from networks import AutoGenNet

plt.style.use('../plot_params.dms')
units_convert = {'cm': 1 / 2.54, 'mm': 1 / 2.54 / 10}

# TODO: save optimizer and import its infos
result_folder = f"../results/sine-generation"
data_folder = f"../data/sine-generation-hiddensize200-lr0.1-biaslearningTrue-biasinituniform-seed1-before-bifurcation-epoch500"

if not os.path.exists(result_folder):
    os.makedirs(result_folder)

model = AutoGenNet(network_size=(1, 200, 1))
model.load_state_dict(torch.load(os.path.join(data_folder, 'model_weights.pth')))

delta_bias = np.load(os.path.join(data_folder, 'delta_bias.npy'))
bias_post_training = np.load(os.path.join(data_folder, 'bias_post_training.npy'))
recurrent_weights = np.load(os.path.join(data_folder, 'recurrent_weights.npy'))
r = np.load(os.path.join(data_folder, 'activity.npy'))

# Average activity over a period vs bias
T = 25
relax_time = 10
r = r[relax_time:relax_time+T]
plt.figure(figsize=(45*units_convert['mm'], 45/1.25*units_convert['mm']))
plt.plot(bias_post_training, np.mean(r, axis=0),
         marker='o', markersize=2, markeredgecolor='white', mew=0.1, color='k', lw=0)
plt.xlabel('Bias after training')
plt.ylabel('Average activity over period')
plt.tight_layout()
plt.savefig(os.path.join(result_folder, 'AverageRate_vs_Bias.png'))

# Heatmap of activity over a period
plt.figure(figsize=(45/1.25*units_convert['mm'], 45/1.25*units_convert['mm']))
plt.imshow(r.T)
plt.xlabel('Time')
plt.ylabel('Unit')
plt.tight_layout()
plt.savefig(os.path.join(result_folder, 'Activity_vs_Time_Heatmap.png'))

# Activity vs time
readout_weights = model.readout_layer.weight
plt.figure(figsize=(45*units_convert['mm'], 45/1.25*units_convert['mm']))
wr = r @ np.sign(np.diag(readout_weights.squeeze()))
#max_ = np.max(wr, axis=0).reshape((1, -1))
#min_ = np.min(wr, axis=0).reshape((1, -1))
#wr = 2 * wr/ (max_ - min_) - (max_ + min_) / (max_ - min_) # transform to [-1, 1] interval
#wr = np.reshape(wr.ravel()[~np.isnan(wr.ravel())], (T, -1))
plt.plot(wr, lw=0.3, alpha=0.3, color='grey')
plt.plot(np.sum(r[relax_time:relax_time+T,:] @ np.diag(readout_weights.squeeze()), axis=1), color='k')
plt.xlabel('Time')
plt.ylabel(r'Sign of readout weight $\times$' + '\nactivity')
plt.tight_layout()
plt.savefig(os.path.join(result_folder, 'Activity_vs_Time.png'))


# Weight matrix vs Jacobian
# import data before bifurcation
data_folder = f"../data/sine-generation-hiddensize200-lr0.1-biaslearningTrue-biasinituniform-seed1-before-bifurcation-epoch100"
bias_post_training = np.load(os.path.join(data_folder, 'bias_post_training.npy'))
recurrent_weights = np.load(os.path.join(data_folder, 'recurrent_weights.npy'))
r = np.load(os.path.join(data_folder, 'activity.npy'))
fixed_point = r[-1, :]
u = recurrent_weights @ fixed_point + bias_post_training
J = np.diag(u>0) @ recurrent_weights

# import data at start of training
data_folder = f"../data/sine-generation-hiddensize200-lr0.1-biaslearningTrue-biasinituniform-seed1-before-bifurcation-epoch0"
bias_post_training = np.load(os.path.join(data_folder, 'bias_post_training.npy'))
recurrent_weights = np.load(os.path.join(data_folder, 'recurrent_weights.npy'))
r = np.load(os.path.join(data_folder, 'activity.npy'))
fixed_point = r[-1, :]
u = recurrent_weights @ fixed_point + bias_post_training
J_start = np.diag(u>0) @ recurrent_weights


fig, axes = plt.subplots(ncols=2, figsize=(2*45*units_convert['mm'], 45/1.25*units_convert['mm']), sharex=True, sharey=True)
axes[0].set_aspect('equal')
evs = np.linalg.eigvals(recurrent_weights)
axes[0].plot(np.real(evs), np.imag(evs),
             marker='o', markersize=2.5, markeredgecolor='white', mew=0.2, color='k', lw=0)
axes[0].set_xlabel('Re')
axes[0].set_ylabel('Im')
axes[0].set_xticks([-1, 0, 1])
axes[0].set_yticks([-1, 0, 1])
axes[0].set_title('Recurrent weights')

# jacobian
evs = np.linalg.eigvals(J)
evs_start = np.linalg.eigvals(J_start)
# 0.9479738 +0.25675613j,  0.9479738 -0.25675613j
axes[1].plot(np.real(evs), np.imag(evs),
             marker='o', markersize=2.5, markeredgecolor='white', mew=0.2, color='orange', lw=0)
axes[1].plot(np.real(evs_start), np.imag(evs_start),
             marker='s', markersize=1.5, markeredgecolor='white', mew=0.2, color='grey', lw=0, alpha=0.4)
axes[1].set_aspect('equal')
axes[1].set_xlabel('Re')
axes[1].set_xticks([-1,0,1])
axes[1].set_yticks([-1,0,1])
axes[1].set_title('Jacobian')
plt.tight_layout()
plt.savefig(os.path.join(result_folder, 'EigenSpectrum.png'))
