import matplotlib.pyplot as plt
import reservoirpy as rpy
from sklearn.metrics import mean_squared_error
import time
from utilities.pse import *
from utils import D_stsp
from reservoirpy.nodes import Reservoir, Ridge

rpy.verbosity(0)
rpy.set_seed(1)

units = 500  # - number of units
leak_rate = 0.3  # - leaking rate
spectral_radius = 0.5  # - spectral radius
input_scaling = 0.1  # - input scaling (also called input gain)
connectivity = 0.1  # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-8  # - L2 regularization coeficient
transient = 1  # - number of warmup steps
seed = 5  # - use for reproducibility


reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
                      lr=leak_rate, rc_connectivity=connectivity,
                      input_connectivity=input_connectivity, seed=seed)

readout = Ridge(ridge=regularization)

esn = reservoir >> readout

dt = 0.01

# data-loading
training_data_npz = np.load('datasets/Rossler/Rossler_Koop_Dataset_training.npy')
validation_data_npz = np.load('datasets/Rossler/Rossler_Koop_Dataset_validation.npy')
training_data = np.transpose(training_data_npz, (2, 0, 1))
validation_data = np.transpose(validation_data_npz, (2, 0, 1))
test = True

training_Xm, training_Xp = training_data[:, :-1, :], training_data[:, 1:, :]

start_time = time.time()
esn = esn.fit(training_Xm, training_Xp, warmup=transient)
end_time = time.time()
elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

nb_generations = validation_data.shape[1] - 1 - transient
valid_traj_pred = []
valid_mse= []
valid_dstsp= []
valid_dh= []
for i in range(validation_data.shape[0]):
# for i in range(5):
    with reservoir.with_state():# loop over all trajectories
        seed_timesteps = transient
        warming_inputs = validation_data[i,:seed_timesteps,:]
        warming_out = esn.run(warming_inputs)
        X_gen = np.zeros((nb_generations, 3))
        y = warming_out[-1]
        for t in range(nb_generations):
            y = esn(y)
            X_gen[t, :] = y
    valid_traj_pred.append(X_gen)
    valid_mse.append(mean_squared_error(X_gen, validation_data[i, (1+transient):]))
    dstsp_temp = D_stsp(X_pred=X_gen, X_true=validation_data[i, (1+transient):])
    if dstsp_temp>1:
        print("diverged trajectory with mse=",valid_mse[-1])
    else:
        valid_dstsp.append(dstsp_temp)
    valid_dh.append(power_spectrum_error(X_gen, validation_data[i, (1+transient):]))
    # fig, ax = plt.subplots(3, 1, figsize=(8, 4))
    # ax[0].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 0], '-', label='true')
    # ax[0].plot(np.arange(nb_generations) * dt, X_gen[:, 0], '--', label='prediction')
    # ax[0].set_ylabel('h1')
    # ax[1].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 1], '-', label='true')
    # ax[1].plot(np.arange(nb_generations) * dt, X_gen[:, 1], '--', label='prediction')
    # ax[1].set_ylabel('h2')
    # ax[2].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 2], '-', label='true')
    # ax[2].plot(np.arange(nb_generations) * dt, X_gen[:, 2], '--', label='prediction')
    # ax[2].set_ylabel('h3')
    # ax[2].set_xlabel('t')
    # plt.legend()
    # plt.show()

validation_D_stsp_np = np.array(valid_dstsp)
validation_D_stsp_np[validation_D_stsp_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(valid_mse)}')
print(f'D_stsp averaged over validation trajectories: \t {np.mean(validation_D_stsp_np)}')
print(f'D_h averaged over validation trajectories: \t {np.mean(valid_dh)}')



fig, ax = plt.subplots(3, 1, figsize=(8, 4))
ax[0].plot(np.arange(nb_generations) * dt, validation_data[i, (1+transient):, 0], '-', label='true')
ax[0].plot(np.arange(nb_generations) * dt, X_gen[:, 0], '--', label='prediction')
ax[0].set_ylabel('h1')
ax[1].plot(np.arange(nb_generations) * dt, validation_data[i, (1+transient):, 1], '-', label='true')
ax[1].plot(np.arange(nb_generations) * dt, X_gen[:, 1], '--', label='prediction')
ax[1].set_ylabel('h2')
ax[2].plot(np.arange(nb_generations) * dt, validation_data[i, (1+transient):, 2], '-', label='true')
ax[2].plot(np.arange(nb_generations) * dt, X_gen[:, 2], '--', label='prediction')
ax[2].set_ylabel('h3')
ax[2].set_xlabel('t')
plt.legend()
plt.show()

print('\n Evaluating test data:')
if test:
    test_data_npz = np.load('datasets/Rossler/Rossler_Koop_Dataset_test.npy')
    test_data = np.transpose(test_data_npz, (2, 0, 1))
    nb_generations = test_data.shape[1] - 1 - transient
    test_traj_pred = []
    test_mse= []
    test_dstsp= []
    test_dstsp= []
    test_dh= []
    for i in range(test_data.shape[0]):  # loop over all trajectories
    # for i in range(5):  # loop over all trajectories
        with reservoir.with_state():  # loop over all trajectories
            X_gen = np.zeros((nb_generations, 3))
            y = test_data[i, 0, :]
            for t in range(nb_generations):
                y = esn(y)
                X_gen[t, :] = y
        test_traj_pred.append(X_gen)
        test_mse.append(mean_squared_error(X_gen, test_data[i, (1+transient):]))
        dstsp_temp = D_stsp(X_pred=X_gen, X_true=test_data[i, (1+transient):])
        if dstsp_temp > 1:
            print("diverged trajectory with mse=", test_mse[-1])
        else:
            test_dstsp.append(dstsp_temp)
        test_dh.append(power_spectrum_error(X_gen, test_data[i, (1+transient):]))

    test_D_stsp_np = np.array(test_dstsp)
    test_D_stsp_np[test_D_stsp_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_mse)}')
    print(f'D_stsp averaged over test trajectories: \t {np.mean(test_D_stsp_np)}')
    print(f'D_h averaged over test trajectories: \t {np.mean(test_dh)}')

# seed = 1:
# Fit time:  9.575299263000488
# MSE averaged over validation trajectories: 	 0.01690784580180543
# D_stsp averaged over validation trajectories: 	 4.125272045316702e-05
# D_h averaged over validation trajectories: 	 0.10435934482793031
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 0.01635176507607742
# D_stsp averaged over test trajectories: 	 5.2186264833564554e-05
# D_h averaged over test trajectories: 	 0.09509333192149148
#
#
# seed = 2:
# Fit time:  9.570014953613281
# MSE averaged over validation trajectories: 	 0.017824433682696796
# D_stsp averaged over validation trajectories: 	 3.716396947263004e-05
# D_h averaged over validation trajectories: 	 0.09372364152818362
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 0.014988685091612703
# D_stsp averaged over test trajectories: 	 4.6032898592754286e-05
# D_h averaged over test trajectories: 	 0.08613544740119788
#
#
# seed = 3:
# Fit time:  9.241371870040894
# MSE averaged over validation trajectories: 	 0.017169103485345644
# D_stsp averaged over validation trajectories: 	 3.883224627267755e-05
# D_h averaged over validation trajectories: 	 0.10142009008155554
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 0.015581846713137192
# D_stsp averaged over test trajectories: 	 4.700289609834171e-05
# D_h averaged over test trajectories: 	 0.09057043142927235
#
# seed = 4:
# Fit time:  9.65462040901184
# MSE averaged over validation trajectories: 	 0.017954354310113784
# D_stsp averaged over validation trajectories: 	 4.0459726867877004e-05
# D_h averaged over validation trajectories: 	 0.09657758218108156
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 12.959660319375665
# D_stsp averaged over test trajectories: 	 4.2389640772355135e-05
# D_h averaged over test trajectories: 	 0.09987795684275257
#
# seed = 5:
# Fit time:  9.54546856880188
# MSE averaged over validation trajectories: 	 0.017748871051270418
# D_stsp averaged over validation trajectories: 	 6.301981874147753e-05
# D_h averaged over validation trajectories: 	 0.10237002217578073
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 0.016078049222405896
# D_stsp averaged over test trajectories: 	 4.3077354882461153e-05
# D_h averaged over test trajectories: 	 0.09755461748496218
#
#
# mean fit time: 9.517355012893677
# mean dstsp test: 4.613781103589537e-05