import time

import matplotlib.pyplot as plt
import numpy as np
from datafold import TSCDataFrame
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from swimnetworks import Dense

from kirnn import KIRNN

# hyperparams
layer_width = 80
regularization_constant = 1e-8

rng = np.random.default_rng(5)

dt = 0.1  # delta time

# data-loading
training_data_npz = np.load('../simulated_datasets/VdP_train.npy')
validation_data_npz = np.load('../simulated_datasets/VdP_valid.npy')
test = True

training_tsc_list = []
num_training_traj = training_data_npz.shape[2]
for i in range(num_training_traj):
    training_tsc_list.append(TSCDataFrame.from_array(training_data_npz[:, :, i], feature_names=list(
        str(i) for i in range(training_data_npz.shape[1]))))
training_data = TSCDataFrame.from_frame_list(training_tsc_list)

validation_tsc_list = []
num_validation_traj = validation_data_npz.shape[2]
for i in range(num_validation_traj):
    validation_tsc_list.append(TSCDataFrame.from_array(validation_data_npz[:, :, i], feature_names=list(
        str(i) for i in range(validation_data_npz.shape[1]))))
validation_data = TSCDataFrame.from_frame_list(validation_tsc_list)

# training
steps = [
    ("hidden",
     Dense(layer_width=layer_width,
           activation='tanh',
           parameter_sampler='tanh',
           random_seed=rng.integers(100, size=1))),
]
network_dictionary = Pipeline(steps=steps)

kirnn_model = KIRNN(dictionary=network_dictionary,
                    n_features_in=2,
                    rcond=regularization_constant)
start_time = time.time()
kirnn_model.fit(training_data)
end_time = time.time()

elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

# evaluation on training set
training_pred = []
training_MSE = []
for i in range(num_training_traj):
    pred = kirnn_model.predict(training_tsc_list[i].initial_states(), time_values=training_tsc_list[i].time_values())
    training_pred.append(pred)
    training_MSE.append(mean_squared_error(pred[1:], training_tsc_list[i][1:]))
print(f'MSE averaged over training trajectories: \t {np.mean(training_MSE)}')

# evaluation on validation set
validation_pred = []
validation_MSE = []
for i in range(num_validation_traj):
    pred = kirnn_model.predict(validation_tsc_list[i].initial_states(), time_values=validation_tsc_list[i].time_values())
    validation_pred.append(pred)
    validation_MSE.append(mean_squared_error(pred[1:], validation_tsc_list[i][1:]))
print(f'MSE averaged over validation trajectories: \t {np.mean(validation_MSE)}')

num_plot_traj = 1

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
for k in rng.integers(0, num_validation_traj, size=num_plot_traj):
    ax.plot(validation_tsc_list[k].time_values(), validation_tsc_list[k].to_numpy(), '-', label='true')
    ax.plot(validation_tsc_list[k].time_values(), validation_pred[k], '--', label='prediction')
ax.set_xlabel('$h_1$')
ax.set_ylabel('$h_2$')
ax.legend()
plt.tight_layout()
plt.show()

fig, ax = plt.subplots()
for k in rng.integers(0, num_validation_traj, size=num_plot_traj):
    ax.plot(*validation_tsc_list[k].to_numpy().T, '-', label='true')
    ax.plot(*validation_pred[k].T, '--', label='prediction')
ax.set_xlabel('$h_1$')
ax.set_ylabel('$h_2$')
ax.set_aspect(1)
ax.legend()
plt.show()

if test:
    test_data_npz = np.load('../simulated_datasets/VdP_test.npy')

    test_tsc_list = []
    num_test_traj = test_data_npz.shape[2]
    for i in range(num_test_traj):
        test_tsc_list.append(TSCDataFrame.from_array(test_data_npz[:, :, i],
                                                     feature_names=list(str(i) for i in range(test_data_npz.shape[1]))))
    test_data = TSCDataFrame.from_frame_list(test_tsc_list)

    # evaluation on test set
    test_pred = []
    test_MSE = []
    for i in range(num_test_traj):
        pred = kirnn_model.predict(test_tsc_list[i].initial_states(), time_values=test_tsc_list[i].time_values())
        test_pred.append(pred)
        test_MSE.append(mean_squared_error(pred[1:], test_tsc_list[i][1:]))
    print(f'MSE averaged over test trajectories: \t {np.mean(test_MSE)}')

# seed = 1:
# Fit time:  0.2781996726989746
# MSE averaged over training trajectories: 	 2.856164974540874e-06
# MSE averaged over validation trajectories: 	 0.0003023686718317407
# MSE averaged over test trajectories: 	 0.0009198011534011048
#
# seed = 2:
# Fit time:  0.21685338020324707
# MSE averaged over training trajectories: 	 1.985967817849042e-06
# MSE averaged over validation trajectories: 	 0.00021106607235258878
# MSE averaged over test trajectories: 	 0.0012817655676518555
#
# seed = 3:
# Fit time:  0.17678594589233398
# MSE averaged over training trajectories: 	 2.2811149625708387e-06
# MSE averaged over validation trajectories: 	 6.716843262428966e-05
# MSE averaged over test trajectories: 	 0.0008009498814094368
#
# seed = 4:
# Fit time:  0.34900903701782227
# MSE averaged over training trajectories: 	 1.8718101717462866e-06
# MSE averaged over validation trajectories: 	 0.00022223771728883523
# MSE averaged over test trajectories: 	 0.0007089266434631788
#
# seed = 5:
# Fit time:  0.20610928535461426
# MSE averaged over training trajectories: 	 1.8972165661707506e-06
# MSE averaged over validation trajectories: 	 0.0002877674573554258
# MSE averaged over test trajectories: 	 0.0010640042193456067
