import time

import matplotlib.pyplot as plt
import numpy as np
import reservoirpy as rpy
from sklearn.metrics import mean_squared_error

from computational_experiments.metric_utils.EKL import empirical_KL
from computational_experiments.metric_utils.PSE import power_spectrum_error

rpy.verbosity(0)

units = 300  # - number of units
leak_rate = 0.3  # - leaking rate
spectral_radius = 1.25  # - spectral radius
input_scaling = 0.1  # - input scaling (also called input gain)
connectivity = 0.1  # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-4  # - L2 regularization coeficient
transient = 0  # - number of warmup steps
seed = 5  # - use for reproducibility

from reservoirpy.nodes import Reservoir, Ridge

reservoir = Reservoir(units,
                      input_scaling=input_scaling,
                      sr=spectral_radius,
                      lr=leak_rate,
                      rc_connectivity=connectivity,
                      input_connectivity=input_connectivity,
                      seed=seed)

readout = Ridge(ridge=regularization)

esn = reservoir >> readout

dt = 0.01

# data-loading
training_data_npz = np.load('../simulated_datasets/Lorenz_normalized_train.npy')
validation_data_npz = np.load('../simulated_datasets/Lorenz_normalized_valid.npy')
training_data = np.transpose(training_data_npz, (2, 0, 1))
validation_data = np.transpose(validation_data_npz, (2, 0, 1))
test = True

training_Xm, training_Xp = training_data[:, :-1, :], training_data[:, 1:, :]

start_time = time.time()
esn = esn.fit(training_Xm, training_Xp, warmup=transient)
end_time = time.time()
elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

nb_preds = validation_data.shape[1] - 1
valid_traj_pred = []
valid_MSE = []
valid_EKL = []
valid_PSE = []
for i in range(validation_data.shape[0]):  # loop over all trajectories
    X_pred = np.zeros((nb_preds, 3))
    y = validation_data[i, 0, :]
    for t in range(nb_preds):
        y = esn(y)
        X_pred[t, :] = y
    valid_traj_pred.append(X_pred)

    valid_MSE.append(mean_squared_error(X_pred, validation_data[i, 1:]))
    valid_EKL.append(empirical_KL(X_pred=X_pred, X_true=validation_data[i, 1:]))
    valid_PSE.append(power_spectrum_error(X_pred, validation_data[i, 1:]))

validation_EKL_np = np.array(valid_EKL)
validation_EKL_np[validation_EKL_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(valid_MSE)}')
print(f'EKL averaged over validation trajectories: \t {np.mean(validation_EKL_np)}')
print(f'PSE averaged over validation trajectories: \t {np.mean(valid_PSE)}')

fig, ax = plt.subplots(3, 1, figsize=(8, 4))
ax[0].plot(np.arange(nb_preds) * dt, validation_data[i, 1:, 0], '-', label='true')
ax[0].plot(np.arange(nb_preds) * dt, X_pred[:, 0], '--', label='prediction')
ax[0].set_ylabel('h1')
ax[1].plot(np.arange(nb_preds) * dt, validation_data[i, 1:, 1], '-', label='true')
ax[1].plot(np.arange(nb_preds) * dt, X_pred[:, 1], '--', label='prediction')
ax[1].set_ylabel('h2')
ax[2].plot(np.arange(nb_preds) * dt, validation_data[i, 1:, 2], '-', label='true')
ax[2].plot(np.arange(nb_preds) * dt, X_pred[:, 2], '--', label='prediction')
ax[2].set_ylabel('h3')
ax[2].set_xlabel('t')
plt.legend()
plt.title('Reservoir Computer')
plt.show()

if test:
    test_data_npz = np.load('../simulated_datasets/Lorenz_normalized_test.npy')
    test_data = np.transpose(test_data_npz, (2, 0, 1))
    nb_preds = test_data.shape[1] - 1
    test_traj_pred = []
    test_MSE = []
    test_EKL = []
    test_PSE = []
    for i in range(test_data.shape[0]):  # loop over all trajectories
        X_pred = np.zeros((nb_preds, 3))
        y = test_data[i, 0, :]
        for t in range(nb_preds):
            y = esn(y)
            X_pred[t, :] = y
        test_traj_pred.append(X_pred)

        test_MSE.append(mean_squared_error(X_pred, test_data[i, 1:]))
        test_EKL.append(empirical_KL(X_pred=X_pred, X_true=test_data[i, 1:]))
        test_PSE.append(power_spectrum_error(X_pred, test_data[i, 1:]))

    test_EKL_np = np.array(test_EKL)
    test_EKL_np[test_EKL_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_MSE)}')
    print(f'EKL averaged over test trajectories: \t {np.mean(test_EKL_np)}')
    print(f'PSE averaged over test trajectories: \t {np.mean(test_PSE)}')

# seed = 1:
# Fit time:  2.873450517654419
# MSE averaged over validation trajectories: 	 0.49014922952098255
# EKL averaged over validation trajectories: 	 0.011368723876668622
# PSE averaged over validation trajectories: 	 0.23960336552215117
# MSE averaged over test trajectories: 	 0.47364420734389534
# EKL averaged over test trajectories: 	 0.007201534877204112
# PSE averaged over test trajectories: 	 0.22322210376314522
#
# seed = 2:
# Fit time:  4.23048996925354
# MSE averaged over validation trajectories: 	 0.496795831397837
# EKL averaged over validation trajectories: 	 0.00965731441665479
# PSE averaged over validation trajectories: 	 0.2187207320339564
# MSE averaged over test trajectories: 	 0.4697687450519723
# EKL averaged over test trajectories: 	 0.007751698168609683
# PSE averaged over test trajectories: 	 0.21165033784313558
#
# seed = 3:
# Fit time:  3.2297987937927246
# MSE averaged over validation trajectories: 	 0.4887167691831239
# EKL averaged over validation trajectories: 	 0.008686872204675091
# PSE averaged over validation trajectories: 	 0.23114223426607416
# MSE averaged over test trajectories: 	 0.4727371134422841
# EKL averaged over test trajectories: 	 0.007880468905669898
# PSE averaged over test trajectories: 	 0.2101463269236138
#
#
# seed = 4:
# Fit time:  4.465778350830078
# MSE averaged over validation trajectories: 	 0.5014856496853026
# EKL averaged over validation trajectories: 	 0.008772619000453128
# PSE averaged over validation trajectories: 	 0.20428726396268257
# MSE averaged over test trajectories: 	 0.46958277724131103
# EKL averaged over test trajectories: 	 0.01019745551613056
# PSE averaged over test trajectories: 	 0.19720676609348384
#
# seed = 5:
# Fit time:  2.8995723724365234
# MSE averaged over validation trajectories: 	 0.4845026842896848
# EKL averaged over validation trajectories: 	 0.010811782177871467
# PSE averaged over validation trajectories: 	 0.2037504957265786
# MSE averaged over test trajectories: 	 0.480585543802613
# EKL averaged over test trajectories: 	 0.010619974843663647
# PSE averaged over test trajectories: 	 0.2044154830816703
#

# mean fit time: 3.539818000793457
# mean EKL on test data: 0.00873022646225558
# mean PSE on test data: 0.20932820354100973
