import random
import time

import matplotlib.pyplot as plt
import numpy as np
import reservoirpy as rpy
from reservoirpy.nodes import Reservoir, Ridge
from sklearn.metrics import mean_squared_error

rpy.verbosity(0)

seed = 5  # - use for reproducibility
random.seed(seed)

units_list = [100, 200, 500]  # - number of units
leak_rate_list = [0.1, 0.3, 0.5, 0.7, 0.9]  # - leaking rate
spectral_radius_list = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 2, 3, 5]  # - spectral radius
input_scaling_list = [0.05, 0.1, 0.5, 1, 1.5, 2]  # - input scaling (also called input gain)
connectivity_list = [0.2, 0.4, 0.6, 0.8, 1]  # - recurrent weights connectivity probability
input_connectivity_list = [0.2, 0.4, 0.6, 0.8, 1]  # - input weights connectivity probability
regularization_list = [1e-4, 1e-6, 1e-8, 1e-10]  # - L2 regularization coefficient

transient = 0  # - number of warmup steps <- fixed to keep method comparable

# data-loading
training_data_npz = np.load('../simulated_datasets/VdP_train.npy')
validation_data_npz = np.load('../simulated_datasets/VdP_valid.npy')
training_data = np.transpose(training_data_npz, (2, 0, 1))
validation_data = np.transpose(validation_data_npz, (2, 0, 1))
test = True

training_Xm, training_Xp = training_data[:, :-1, :], training_data[:, 1:, :]

# f = open("gridsearch_RC_VdP.txt", "a")

dt = 0.1
# for j in range(1000): # for random search of hyperparameters

# best outcomes from hypeparameter search:
# validation MSE=0.003488735658322748, u=500, 	 lr=0.9, 	 sr=0.5, 	 is=0.05, 	 connect=0.8, 	 ic=0.2, 	 r=1e-10 (seed=5)

# best hyperparamter choice:
units = 500
leak_rate = 0.9
spectral_radius = 0.5
input_scaling = 0.05
connectivity = 0.8
input_connectivity = 0.2
regularization = 1e-10

reservoir = Reservoir(units,
                      input_scaling=input_scaling,
                      sr=spectral_radius,
                      lr=leak_rate,
                      rc_connectivity=connectivity,
                      input_connectivity=input_connectivity,
                      seed=seed)
readout = Ridge(ridge=regularization)

esn = reservoir >> readout

start_time = time.time()
esn = esn.fit(training_Xm, training_Xp, warmup=transient)
end_time = time.time()
elapsed_time = end_time - start_time

nb_pred = validation_data.shape[1] - 1
valid_traj_pred = []
valid_MSE = []
for i in range(validation_data.shape[0]):
    with reservoir.with_state():
        X_pred = np.zeros((nb_pred, 2))
        y = validation_data[i, 0, :]
        for t in range(nb_pred):
            y = esn(y)
            X_pred[t, :] = y
    valid_traj_pred.append(X_pred)
    valid_MSE.append(mean_squared_error(X_pred, validation_data[i, 1:, :]))
# print(f'validation MSE={np.mean(valid_mse)}, \t u={units}, \t lr={leak_rate}, \t sr={spectral_radius}, \t is={input_scaling}, \t connect={connectivity}, \t ic={input_connectivity}, \t r={regularization}')
# f.write(f'validation MSE={np.mean(valid_mse)}, \t u={units}, \t lr={leak_rate}, \t sr={spectral_radius}, \t is={input_scaling}, \t connect={connectivity}, \t ic={input_connectivity}, \t r={regularization}')
# f.close()


test_traj_pred = []
test_MSE = []
if test:

    test_data_npz = np.load('../simulated_datasets/VdP_test.npy')
    test_data = np.transpose(test_data_npz, (2, 0, 1))
    nb_pred = test_data.shape[1] - 1
    for i in range(test_data.shape[0]):  # loop over all trajectories
        with reservoir.with_state():
            X_pred = np.zeros((nb_pred, 2))
            y = test_data[i, 0, :]
            for t in range(nb_pred):
                y = esn(y)
                X_pred[t, :] = y
        test_traj_pred.append(X_pred)
        test_MSE.append(mean_squared_error(X_pred, test_data[i, 1:, :]))

print(f'validation MSE={np.mean(valid_MSE)}, \t test MSE={np.mean(test_MSE)}, \t fit_time={elapsed_time}')

fig, ax = plt.subplots(2, 1, figsize=(8, 4))
if test:
    ax[0].plot(np.arange(nb_pred) * dt, test_data[i, 1:, 0], '-', label='true')
    ax[1].plot(np.arange(nb_pred) * dt, test_data[i, 1:, 1], '-', label='true')
else:
    ax[0].plot(np.arange(nb_pred) * dt, validation_data[i, 1:, 0], '-', label='true')
    ax[1].plot(np.arange(nb_pred) * dt, validation_data[i, 1:, 1], '-', label='true')
ax[0].plot(np.arange(nb_pred) * dt, X_pred[:, 0], '--', label='prediction')
ax[0].set_ylabel('h1')
ax[1].plot(np.arange(nb_pred) * dt, X_pred[:, 1], '--', label='prediction')
ax[1].set_ylabel('h2')
plt.legend()
plt.show()

print(" ")

# seed = 1:
# validation MSE=0.008013423271643293, 	 test MSE=0.020731463754683772, 	 fit_time=2.2906792163848877
#
# seed = 2:
# validation MSE=0.005426942086164147, 	 test MSE=0.011549697581332388, 	 fit_time=3.7403950691223145
#
# seed = 3:
# validation MSE=0.014629154985815813, 	 test MSE=0.014558099440724979, 	 fit_time=5.40882134437561
#
# seed = 4:
# validation MSE=0.005775710759731528, 	 test MSE=0.01254082514193595, 	 fit_time=3.3799593448638916
#
# seed = 5:
# validation MSE=0.003488735658322748, 	 test MSE=0.01954929665362315, 	 fit_time=4.002091646194458
#
# mean fit time: 3.7643893241882322
# mean test MSE: 0.015785876514460045
