import matplotlib.pyplot as plt
import numpy as np

from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error
from swimnetworks import (Dense, Linear)

from datafold import TSCDataFrame
from swim_rnn import RNN
import time

# hyperparams
layer_width = 80
regularization_constant = None
delay = 5
pca_components = 2

rng = np.random.default_rng(5)
test=True

dt = 0.1  # delta time

# data-loading (only use first coordinate)
training_data_npz = np.load('datasets/VdP_train.npy')
validation_data_npz = np.load('datasets/VdP_valid.npy')

training_tsc_list = []
num_training_traj = training_data_npz.shape[2]
for i in range(num_training_traj):
    training_tsc_list.append(TSCDataFrame.from_array(training_data_npz[:, 0, i].reshape(-1, 1), feature_names=['0']))
training_data = TSCDataFrame.from_frame_list(training_tsc_list)

validation_tsc_list = []
num_validation_traj = validation_data_npz.shape[2]
for i in range(num_validation_traj):
    validation_tsc_list.append(TSCDataFrame.from_array(validation_data_npz[:, 0, i].reshape(-1, 1), feature_names=['0']))
validation_data = TSCDataFrame.from_frame_list(validation_tsc_list)

# training
steps = [
    ("hidden",
     Dense(layer_width=layer_width, activation='tanh', parameter_sampler='tanh',
           random_seed=rng.integers(100, size=1))),
]
network_dictionary = Pipeline(steps=steps)

swim_rnn = RNN(network_dictionary, regularization_constant, time_delay=delay, pca_components=pca_components)
start_time = time.time()
swim_rnn.fit(training_data)
end_time = time.time()

elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

# evaluation on training set
training_pred = []
training_mse = []
for i in range(5):
    pred = swim_rnn.predict(training_tsc_list[i].iloc[:(delay + 1)],
                            time_values=training_tsc_list[i].time_values()[(delay+1):])
    training_pred.append(pred)
    training_mse.append(mean_squared_error(pred, training_tsc_list[i][delay:]))
print(f'MSE averaged over training trajectories: \t {np.mean(training_mse)}')

# evaluation on validation set
validation_pred = []
validation_mse = []
for i in range(num_validation_traj):
    pred = swim_rnn.predict(validation_tsc_list[i].iloc[:(delay + 1)],
                            time_values=validation_tsc_list[i].time_values()[(delay + 1):])
    validation_pred.append(pred[1:])
    validation_mse.append(mean_squared_error(pred[1:], validation_tsc_list[i][(delay+1):]))
print(f'MSE averaged over validation trajectories: \t {np.mean(validation_mse)}')
print(f'MAX MSE averaged over validation trajectories: \t {np.max(validation_mse)}')

num_plot_traj = 1


if test==True:
    # evaluation on test set
    test_data_npz = np.load('datasets/VdP_test.npy')

    test_tsc_list = []
    num_test_traj = test_data_npz.shape[2]
    for i in range(num_test_traj):
        test_tsc_list.append(
            TSCDataFrame.from_array(test_data_npz[:, 0, i].reshape(-1, 1), feature_names=['0']))
    test_data = TSCDataFrame.from_frame_list(test_tsc_list)
    test_pred = []
    test_mse = []
    for i in range(num_test_traj):
        pred = swim_rnn.predict(test_tsc_list[i].iloc[:(delay + 1)],
                                time_values=test_tsc_list[i].time_values()[(delay + 1):])
        test_pred.append(pred[1:])
        test_mse.append(mean_squared_error(pred[1:], test_tsc_list[i][(delay+1):]))
    print(f'MSE averaged over test trajectories: \t {np.mean(test_mse)}')
    print(f'MAX MSE averaged over test trajectories: \t {np.max(test_mse)}')

# fig, ax = plt.subplots(1, 1, figsize=(8, 4))
# for k in [0]:
#     ax.plot(test_tsc_list[k].time_values()[(delay+1):], test_tsc_list[k].to_numpy()[(delay+1):], '-', label='true')
#     ax.plot(test_tsc_list[k].time_values()[(delay+1):], test_pred[k], '--', label='prediction')
#     # ax.vlines(delay*dt, ymin=-3, ymax=3)
# ax.set_xlabel('$h_1$')
# ax.set_ylabel('$h_2$')
# # ax.set_aspect(1)
# ax.legend()
# ax.grid()
# # plt.savefig(f'VdP_validation_{num_plot_traj}_trajectories.pdf')
# plt.show()
#
#
# np.save('sampled_RNN_VdP_timedelay_TestTraj1.npy', test_pred[0])

# seed=1:
# Fit time:  0.24980592727661133
# MSE averaged over validation trajectories: 	 0.0002224817982690282
# MAX MSE averaged over validation trajectories: 	 0.0022375579207673147
# MSE averaged over test trajectories: 	 0.0006222437619712617
# MAX MSE averaged over test trajectories: 	 0.01025712719039545
#
# seed=2:
# Fit time:  0.2576115131378174
# MSE averaged over validation trajectories: 	 0.0017923451517382857
# MAX MSE averaged over validation trajectories: 	 0.026626532345753177
# MSE averaged over test trajectories: 	 0.004248229960761807
# MAX MSE averaged over test trajectories: 	 0.08942524919219888
#
# seed=3:
# Fit time:  0.3096330165863037
# MSE averaged over validation trajectories: 	 7.078452176844107e-05
# MAX MSE averaged over validation trajectories: 	 0.0009458162218798138
# MSE averaged over test trajectories: 	 0.0001574176659645995
# MAX MSE averaged over test trajectories: 	 0.0038854679757197638
#
# seed=4:
# Fit time:  0.2939300537109375
# MSE averaged over validation trajectories: 	 0.0021430375787849934
# MAX MSE averaged over validation trajectories: 	 0.07834584969983326
# MSE averaged over test trajectories: 	 0.01577204519661697
# MAX MSE averaged over test trajectories: 	 0.6968533670673117
#
# seed=5:
# Fit time:  0.3149123191833496
# MSE averaged over validation trajectories: 	 0.0008142708001634779
# MAX MSE averaged over validation trajectories: 	 0.032186106252485844
# MSE averaged over test trajectories: 	 0.0045219268171852785
# MAX MSE averaged over test trajectories: 	 0.18549489924792606

# mean fit time: 0.28517856597900393
# mean test mse: 0.005064372680499983