import matplotlib.pyplot as plt
import reservoirpy as rpy
from sklearn.metrics import mean_squared_error
import time
from utilities.pse import *
from utils import D_stsp
from reservoirpy.nodes import Reservoir, Ridge

rpy.verbosity(0)
# rpy.set_seed(1)

units = 500  # - number of units
leak_rate = 0.3  # - leaking rate
spectral_radius = 0.5  # - spectral radius
input_scaling = 0.1  # - input scaling (also called input gain)
connectivity = 0.1  # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-8  # - L2 regularization coeficient
transient = 0  # - number of warmup steps
seed = 5  # - use for reproducibility


reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
                      lr=leak_rate, rc_connectivity=connectivity,
                      input_connectivity=input_connectivity, seed=seed)

readout = Ridge(ridge=regularization)

esn = reservoir >> readout

dt = 0.01

# data-loading
training_data_npz = np.load('datasets/Rossler_normalized_train.npy')
validation_data_npz = np.load('datasets/Rossler_normalized_valid.npy')
training_data = np.transpose(training_data_npz, (2, 0, 1))
validation_data = np.transpose(validation_data_npz, (2, 0, 1))
test = True

training_Xm, training_Xp = training_data[:, :-1, :], training_data[:, 1:, :]

start_time = time.time()
esn = esn.fit(training_Xm, training_Xp, warmup=transient)
end_time = time.time()
elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

nb_generations = validation_data.shape[1] - 1
valid_traj_pred = []
valid_mse= []
valid_dstsp= []
valid_dh= []
for i in range(validation_data.shape[0]):
    with reservoir.with_state():
        X_gen = np.zeros((nb_generations, 3))
        y = validation_data[i, 0, :]
        for t in range(nb_generations):
            y = esn(y)
            X_gen[t, :] = y
    valid_traj_pred.append(X_gen)
    valid_mse.append(mean_squared_error(X_gen, validation_data[i, 1:]))
    dstsp_temp = D_stsp(X_pred=X_gen, X_true=validation_data[i, 1:])
    if dstsp_temp>1:
        print(f"diverged trajectory with mse={valid_mse[-1]}, dstsp={dstsp_temp}")
    valid_dstsp.append(dstsp_temp)
    valid_dh.append(power_spectrum_error(X_gen, validation_data[i, 1:]))
    # fig, ax = plt.subplots(3, 1, figsize=(8, 4))
    # ax[0].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 0], '-', label='true')
    # ax[0].plot(np.arange(nb_generations) * dt, X_gen[:, 0], '--', label='prediction')
    # ax[0].set_ylabel('h1')
    # ax[1].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 1], '-', label='true')
    # ax[1].plot(np.arange(nb_generations) * dt, X_gen[:, 1], '--', label='prediction')
    # ax[1].set_ylabel('h2')
    # ax[2].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 2], '-', label='true')
    # ax[2].plot(np.arange(nb_generations) * dt, X_gen[:, 2], '--', label='prediction')
    # ax[2].set_ylabel('h3')
    # ax[2].set_xlabel('t')
    # plt.legend()
    # plt.show()

validation_D_stsp_np = np.array(valid_dstsp)
validation_D_stsp_np[validation_D_stsp_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(valid_mse)}')
print(f'D_stsp averaged over validation trajectories: \t {np.mean(validation_D_stsp_np)}')
print(f'D_h averaged over validation trajectories: \t {np.mean(valid_dh)}')



fig, ax = plt.subplots(3, 1, figsize=(8, 4))
ax[0].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 0], '-', label='true')
ax[0].plot(np.arange(nb_generations) * dt, X_gen[:, 0], '--', label='prediction')
ax[0].set_ylabel('h1')
ax[1].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 1], '-', label='true')
ax[1].plot(np.arange(nb_generations) * dt, X_gen[:, 1], '--', label='prediction')
ax[1].set_ylabel('h2')
ax[2].plot(np.arange(nb_generations) * dt, validation_data[i, 1:, 2], '-', label='true')
ax[2].plot(np.arange(nb_generations) * dt, X_gen[:, 2], '--', label='prediction')
ax[2].set_ylabel('h3')
ax[2].set_xlabel('t')
plt.legend()
plt.show()

print('\n Evaluating test data:')
if test:
    test_data_npz = np.load('datasets/Rossler_normalized_test.npy')
    test_data = np.transpose(test_data_npz, (2, 0, 1))
    nb_generations = test_data.shape[1] - 1
    test_traj_pred = []
    test_mse= []
    test_dstsp= []
    test_dstsp= []
    test_dh= []
    for i in range(test_data.shape[0]):  # loop over all trajectories
        with reservoir.with_state():  # loop over all trajectories
            X_gen = np.zeros((nb_generations, 3))
            y = test_data[i, 0, :]
            for t in range(nb_generations):
                y = esn(y)
                X_gen[t, :] = y
        test_traj_pred.append(X_gen)
        test_mse.append(mean_squared_error(X_gen, test_data[i, 1:]))
        dstsp_temp = D_stsp(X_pred=X_gen, X_true=test_data[i, 1:])
        if dstsp_temp > 1:
            print(f"diverged trajectory with mse={test_mse[-1]}, dstsp={dstsp_temp}")
        test_dstsp.append(dstsp_temp)
        test_dh.append(power_spectrum_error(X_gen, test_data[i, 1:]))

    test_D_stsp_np = np.array(test_dstsp)
    test_D_stsp_np[test_D_stsp_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_mse)}')
    print(f'D_stsp averaged over test trajectories: \t {np.mean(test_D_stsp_np)}')
    print(f'D_h averaged over test trajectories: \t {np.mean(test_dh)}')

# seed = 1:
# Fit time:  8.056436538696289
# MSE averaged over validation trajectories: 	 654.3199314852455
# D_stsp averaged over validation trajectories: 	 0.00016210154310590877
# D_h averaged over validation trajectories: 	 0.12017678981350692
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 3879.7903713068845
# D_stsp averaged over test trajectories: 	 3.7885417905032276e-05
# D_h averaged over test trajectories: 	 0.1928823798609178
#
#
#
# seed = 2:
# Fit time:  8.309618473052979
# MSE averaged over validation trajectories: 	 0.03049757414739526
# D_stsp averaged over validation trajectories: 	 0.00033600091166895553
# D_h averaged over validation trajectories: 	 0.13133252818977273
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 0.02845912862859082
# D_stsp averaged over test trajectories: 	 0.00022541200710735025
# D_h averaged over test trajectories: 	 0.12918676237117802
#
#
#
# seed = 3:
# Fit time:  8.272890090942383
# MSE averaged over validation trajectories: 	 0.01663168442628681
# D_stsp averaged over validation trajectories: 	 7.194669216741149e-05
# D_h averaged over validation trajectories: 	 0.09720479107668957
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 296.03752516548997
# D_stsp averaged over test trajectories: 	 3.9061767861130803e-05
# D_h averaged over test trajectories: 	 0.10816137476767393
#
#
#
# seed = 4:
# Fit time:  7.948967933654785
# MSE averaged over validation trajectories: 	 0.016708791178992894
# D_stsp averaged over validation trajectories: 	 0.00012114594653108677
# D_h averaged over validation trajectories: 	 0.08596944981262801
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 0.008766918938158324
# D_stsp averaged over test trajectories: 	 6.251434515167985e-05
# D_h averaged over test trajectories: 	 0.07594553536795373
#
#
# seed = 5:
# Fit time:  7.96442723274231
# MSE averaged over validation trajectories: 	 188.57324325230678
# D_stsp averaged over validation trajectories: 	 9.11550103495426e-05
# D_h averaged over validation trajectories: 	 0.09831216136327903
#
#  Evaluating test data:
# MSE averaged over test trajectories: 	 1322.367376261182
# D_stsp averaged over test trajectories: 	 5.154190367966549e-05
# D_h averaged over test trajectories: 	 0.1727504321395439
#
#
# mean fit time: 8.11046805381775
# mean dstsp test: 8.328308834097174e-05
# mean dh test: 0.13578529690145347
