import time

import numpy as np
import reservoirpy as rpy
from reservoirpy.nodes import Reservoir, Ridge
from sklearn.metrics import mean_squared_error

from computational_experiments.metric_utils.EKL import empirical_KL
from computational_experiments.metric_utils.PSE import power_spectrum_error

rpy.verbosity(0)

units = 500  # - number of units
leak_rate = 0.3  # - leaking rate
spectral_radius = 0.5  # - spectral radius
input_scaling = 0.1  # - input scaling (also called input gain)
connectivity = 0.1  # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-8  # - L2 regularization coeficient
transient = 0  # - number of warmup steps
seed = 5  # - use for reproducibility

reservoir = Reservoir(units,
                      input_scaling=input_scaling,
                      sr=spectral_radius,
                      lr=leak_rate,
                      rc_connectivity=connectivity,
                      input_connectivity=input_connectivity,
                      seed=seed)

readout = Ridge(ridge=regularization)

esn = reservoir >> readout

dt = 0.01

# data-loading
training_data_npz = np.load('../simulated_datasets/Rossler_normalized_train.npy')
validation_data_npz = np.load('../simulated_datasets/Rossler_normalized_valid.npy')
training_data = np.transpose(training_data_npz, (2, 0, 1))
validation_data = np.transpose(validation_data_npz, (2, 0, 1))
test = True

training_Xm, training_Xp = training_data[:, :-1, :], training_data[:, 1:, :]

start_time = time.time()
esn = esn.fit(training_Xm, training_Xp, warmup=transient)
end_time = time.time()
elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

nb_generations = validation_data.shape[1] - 1
valid_traj_pred = []
valid_MSE = []
valid_EKL = []
valid_PSE = []
for i in range(validation_data.shape[0]):
    with reservoir.with_state():
        X_pred = np.zeros((nb_generations, 3))
        y = validation_data[i, 0, :]
        for t in range(nb_generations):
            y = esn(y)
            X_pred[t, :] = y
    valid_traj_pred.append(X_pred)
    valid_MSE.append(mean_squared_error(X_pred, validation_data[i, 1:]))
    EKL_temp = empirical_KL(X_pred=X_pred, X_true=validation_data[i, 1:])
    # if EKL_temp>1:
    #     print(f"diverged trajectory with mse={valid_MSE[-1]}, EKL={EKL_temp}")
    valid_EKL.append(EKL_temp)
    valid_PSE.append(power_spectrum_error(X_pred, validation_data[i, 1:]))

validation_EKL_np = np.array(valid_EKL)
validation_EKL_np[validation_EKL_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(valid_MSE)}')
print(f'EKL averaged over validation trajectories: \t {np.mean(validation_EKL_np)}')
print(f'PSE averaged over validation trajectories: \t {np.mean(valid_PSE)}')

if test:
    test_data_npz = np.load('../simulated_datasets/Rossler_normalized_test.npy')
    test_data = np.transpose(test_data_npz, (2, 0, 1))
    nb_generations = test_data.shape[1] - 1
    test_traj_pred = []
    test_MSE = []
    test_EKL = []
    test_PSE = []
    for i in range(test_data.shape[0]):
        with reservoir.with_state():
            X_pred = np.zeros((nb_generations, 3))
            y = test_data[i, 0, :]
            for t in range(nb_generations):
                y = esn(y)
                X_pred[t, :] = y
        test_traj_pred.append(X_pred)
        test_MSE.append(mean_squared_error(X_pred, test_data[i, 1:]))
        EKL_temp = empirical_KL(X_pred=X_pred, X_true=test_data[i, 1:])
        # if EKL_temp > 1:
        #     print(f"diverged trajectory with mse={test_MSE[-1]}, EKL={EKL_temp}")
        test_EKL.append(EKL_temp)
        test_PSE.append(power_spectrum_error(X_pred, test_data[i, 1:]))

    test_EKL_np = np.array(test_EKL)
    test_EKL_np[test_EKL_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_MSE)}')
    print(f'EKL averaged over test trajectories: \t {np.mean(test_EKL_np)}')
    print(f'PSE averaged over test trajectories: \t {np.mean(test_PSE)}')

# seed = 1:
# Fit time:  8.056436538696289
# MSE averaged over validation trajectories: 	 654.3199314852455
# EKL averaged over validation trajectories: 	 0.00016210154310590877
# PSE averaged over validation trajectories: 	 0.12017678981350692
# MSE averaged over test trajectories: 	 3879.7903713068845
# EKL averaged over test trajectories: 	 3.7885417905032276e-05
# PSE averaged over test trajectories: 	 0.1928823798609178
#
#
# seed = 2:
# Fit time:  8.309618473052979
# MSE averaged over validation trajectories: 	 0.03049757414739526
# EKL averaged over validation trajectories: 	 0.00033600091166895553
# PSE averaged over validation trajectories: 	 0.13133252818977273
# MSE averaged over test trajectories: 	 0.02845912862859082
# EKL averaged over test trajectories: 	 0.00022541200710735025
# PSE averaged over test trajectories: 	 0.12918676237117802
#
#
# seed = 3:
# Fit time:  8.272890090942383
# MSE averaged over validation trajectories: 	 0.01663168442628681
# EKL averaged over validation trajectories: 	 7.194669216741149e-05
# PSE averaged over validation trajectories: 	 0.09720479107668957
# MSE averaged over test trajectories: 	 296.03752516548997
# EKL averaged over test trajectories: 	 3.9061767861130803e-05
# PSE averaged over test trajectories: 	 0.10816137476767393
#
#
# seed = 4:
# Fit time:  7.948967933654785
# MSE averaged over validation trajectories: 	 0.016708791178992894
# EKL averaged over validation trajectories: 	 0.00012114594653108677
# PSE averaged over validation trajectories: 	 0.08596944981262801
# MSE averaged over test trajectories: 	 0.008766918938158324
# EKL averaged over test trajectories: 	 6.251434515167985e-05
# PSE averaged over test trajectories: 	 0.07594553536795373
#
#
# seed = 5:
# Fit time:  7.96442723274231
# MSE averaged over validation trajectories: 	 188.57324325230678
# EKL averaged over validation trajectories: 	 9.11550103495426e-05
# PSE averaged over validation trajectories: 	 0.09831216136327903
# MSE averaged over test trajectories: 	 1322.367376261182
# EKL averaged over test trajectories: 	 5.154190367966549e-05
# PSE averaged over test trajectories: 	 0.1727504321395439
#
#
# mean fit time: 8.11046805381775
# mean EKL test: 8.328308834097174e-05
# mean PSE test: 0.13578529690145347
