import matplotlib.pyplot as plt
import reservoirpy as rpy
from sklearn.metrics import mean_squared_error
import time
from utilities.pse import *
from utils import D_stsp

rpy.verbosity(0)

units = 300  # - number of units
leak_rate = 0.3  # - leaking rate
spectral_radius = 1.25  # - spectral radius
input_scaling = 0.1  # - input scaling (also called input gain)
connectivity = 0.1  # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-4  # - L2 regularization coeficient
transient = 0  # - number of warmup steps
seed = 5  # - use for reproducibility

from reservoirpy.nodes import Reservoir, Ridge

reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
                      lr=leak_rate, rc_connectivity=connectivity,
                      input_connectivity=input_connectivity, seed=seed)

readout = Ridge(ridge=regularization)

esn = reservoir >> readout

dt = 0.01

# data-loading
training_data_npz = np.load('datasets/Lorenz_normalized_train.npy')
validation_data_npz = np.load('datasets/Lorenz_normalized_valid.npy')
training_data = np.transpose(training_data_npz, (2, 0, 1))
validation_data = np.transpose(validation_data_npz, (2, 0, 1))
test = True

training_Xm, training_Xp = training_data[:, :-1, :], training_data[:, 1:, :]

start_time = time.time()
esn = esn.fit(training_Xm, training_Xp, warmup=transient)
end_time = time.time()
elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

nb_generations = validation_data.shape[1] - 1
valid_traj_pred = []
valid_mse= []
valid_dstsp= []
valid_dstsp= []
valid_dh= []
for i in range(validation_data.shape[0]):  # loop over all trajectories
    X_gen = np.zeros((nb_generations, 3))
    y = validation_data[i, 0, :]
    for t in range(nb_generations):
        y = esn(y)
        X_gen[t, :] = y
    valid_traj_pred.append(X_gen)

    valid_mse.append(mean_squared_error(X_gen, validation_data[i, 1:]))
    valid_dstsp.append(D_stsp(X_pred=X_gen, X_true=validation_data[i, 1:]))
    valid_dh.append(power_spectrum_error(X_gen, validation_data[i, 1:]))

validation_D_stsp_np = np.array(valid_dstsp)
validation_D_stsp_np[validation_D_stsp_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(valid_mse)}')
print(f'D_stsp averaged over validation trajectories: \t {np.mean(validation_D_stsp_np)}')
print(f'D_h averaged over validation trajectories: \t {np.mean(valid_dh)}')

fig, ax = plt.subplots(3, 1, figsize=(8, 4))
ax[0].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 0], '-', label='true')
ax[0].plot(np.arange(nb_generations) * dt, X_gen[:, 0], '--', label='prediction')
ax[0].set_ylabel('h1')
ax[1].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 1], '-', label='true')
ax[1].plot(np.arange(nb_generations) * dt, X_gen[:, 1], '--', label='prediction')
ax[1].set_ylabel('h2')
ax[2].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 2], '-', label='true')
ax[2].plot(np.arange(nb_generations) * dt, X_gen[:, 2], '--', label='prediction')
ax[2].set_ylabel('h3')
ax[2].set_xlabel('t')
plt.legend()
plt.title('Reservoir Computer')
plt.show()

if test:
    test_data_npz = np.load('datasets/Lorenz_normalized_test.npy')
    test_data = np.transpose(test_data_npz, (2, 0, 1))
    nb_generations = test_data.shape[1] - 1
    test_traj_pred = []
    test_mse= []
    test_dstsp= []
    test_dstsp= []
    test_dh= []
    for i in range(test_data.shape[0]):  # loop over all trajectories
        X_gen = np.zeros((nb_generations, 3))
        y = test_data[i, 0, :]
        for t in range(nb_generations):
            y = esn(y)
            X_gen[t, :] = y
        test_traj_pred.append(X_gen)

        test_mse.append(mean_squared_error(X_gen, test_data[i, 1:]))
        test_dstsp.append(D_stsp(X_pred=X_gen, X_true=test_data[i, 1:]))
        test_dh.append(power_spectrum_error(X_gen, test_data[i, 1:]))

    test_D_stsp_np = np.array(test_dstsp)
    test_D_stsp_np[test_D_stsp_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_mse)}')
    print(f'D_stsp averaged over test trajectories: \t {np.mean(test_D_stsp_np)}')
    print(f'D_h averaged over test trajectories: \t {np.mean(test_dh)}')

# seed = 1:
# Fit time:  2.873450517654419
# MSE averaged over validation trajectories: 	 0.49014922952098255
# D_stsp averaged over validation trajectories: 	 0.011368723876668622
# D_h averaged over validation trajectories: 	 0.23960336552215117
# MSE averaged over test trajectories: 	 0.47364420734389534
# D_stsp averaged over test trajectories: 	 0.007201534877204112
# D_h averaged over test trajectories: 	 0.22322210376314522
#
# seed = 2:
# Fit time:  4.23048996925354
# MSE averaged over validation trajectories: 	 0.496795831397837
# D_stsp averaged over validation trajectories: 	 0.00965731441665479
# D_h averaged over validation trajectories: 	 0.2187207320339564
# MSE averaged over test trajectories: 	 0.4697687450519723
# D_stsp averaged over test trajectories: 	 0.007751698168609683
# D_h averaged over test trajectories: 	 0.21165033784313558
#
# seed = 3:
# Fit time:  3.2297987937927246
# MSE averaged over validation trajectories: 	 0.4887167691831239
# D_stsp averaged over validation trajectories: 	 0.008686872204675091
# D_h averaged over validation trajectories: 	 0.23114223426607416
# MSE averaged over test trajectories: 	 0.4727371134422841
# D_stsp averaged over test trajectories: 	 0.007880468905669898
# D_h averaged over test trajectories: 	 0.2101463269236138
#
#
# seed = 4:
# Fit time:  4.465778350830078
# MSE averaged over validation trajectories: 	 0.5014856496853026
# D_stsp averaged over validation trajectories: 	 0.008772619000453128
# D_h averaged over validation trajectories: 	 0.20428726396268257
# MSE averaged over test trajectories: 	 0.46958277724131103
# D_stsp averaged over test trajectories: 	 0.01019745551613056
# D_h averaged over test trajectories: 	 0.19720676609348384
#
# seed = 5:
# Fit time:  2.8995723724365234
# MSE averaged over validation trajectories: 	 0.4845026842896848
# D_stsp averaged over validation trajectories: 	 0.010811782177871467
# D_h averaged over validation trajectories: 	 0.2037504957265786
# MSE averaged over test trajectories: 	 0.480585543802613
# D_stsp averaged over test trajectories: 	 0.010619974843663647
# D_h averaged over test trajectories: 	 0.2044154830816703
#

# mean fit time: 3.539818000793457
# mean d_stsp on test data: 0.00873022646225558
# mean d_h on test data: 0.20932820354100973