import time

import matplotlib.pyplot as plt
import numpy as np
from datafold import TSCDataFrame
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from swimnetworks import Dense

from kirnn import KIRNN

# ablation parameters
use_koopman = False
use_SWIM = False

# hyperparams
layer_width = 80
regularization_constant = 1e-8

rng = np.random.default_rng(1)

dt = 0.1  # delta time

# data-loading
training_data_npz = np.load('../simulated_datasets/VdP_train.npy')
validation_data_npz = np.load('../simulated_datasets/VdP_valid.npy')
test = True

training_tsc_list = []
num_training_traj = training_data_npz.shape[2]
for i in range(num_training_traj):
    training_tsc_list.append(TSCDataFrame.from_array(training_data_npz[:, :, i], feature_names=list(
        str(i) for i in range(training_data_npz.shape[1]))))
training_data = TSCDataFrame.from_frame_list(training_tsc_list)

validation_tsc_list = []
num_validation_traj = validation_data_npz.shape[2]
for i in range(num_validation_traj):
    validation_tsc_list.append(TSCDataFrame.from_array(validation_data_npz[:, :, i], feature_names=list(
        str(i) for i in range(validation_data_npz.shape[1]))))
validation_data = TSCDataFrame.from_frame_list(validation_tsc_list)

# training
steps = [
    ("hidden",
     Dense(layer_width=layer_width,
           activation='tanh',
           parameter_sampler='tanh' * (use_SWIM == True) + 'random' * (use_SWIM != True),
           random_seed=rng.integers(100, size=1))),
]
network_dictionary = Pipeline(steps=steps)

kirnn_model = KIRNN(dictionary=network_dictionary,
                    n_features_in=2,
                    rcond=regularization_constant,
                    use_koopman=use_koopman)
start_time = time.time()
kirnn_model.fit(training_data)
end_time = time.time()

elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

# evaluation on training set
training_pred = []
training_MSE = []
for i in range(num_training_traj):
    pred = kirnn_model.predict(training_tsc_list[i].initial_states(), time_values=training_tsc_list[i].time_values())
    training_pred.append(pred)
    training_MSE.append(mean_squared_error(pred[1:], training_tsc_list[i][1:]))
print(f'MSE averaged over training trajectories: \t {np.mean(training_MSE)}')

# evaluation on validation set
validation_pred = []
validation_MSE = []
for i in range(num_validation_traj):
    pred = kirnn_model.predict(validation_tsc_list[i].initial_states(), time_values=validation_tsc_list[i].time_values())
    validation_pred.append(pred)
    validation_MSE.append(mean_squared_error(pred[1:], validation_tsc_list[i][1:]))
print(f'MSE averaged over validation trajectories: \t {np.mean(validation_MSE)}')

num_plot_traj = 1

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
for k in rng.integers(0, num_validation_traj, size=num_plot_traj):
    ax.plot(validation_tsc_list[k].time_values(), validation_tsc_list[k].to_numpy(), '-', label='true')
    ax.plot(validation_tsc_list[k].time_values(), validation_pred[k], '--', label='prediction')
ax.set_xlabel('$h_1$')
ax.set_ylabel('$h_2$')
ax.legend()
plt.tight_layout()
plt.show()

fig, ax = plt.subplots()
for k in rng.integers(0, num_validation_traj, size=num_plot_traj):
    ax.plot(*validation_tsc_list[k].to_numpy().T, '-', label='true')
    ax.plot(*validation_pred[k].T, '--', label='prediction')
ax.set_xlabel('$h_1$')
ax.set_ylabel('$h_2$')
ax.set_aspect(1)
ax.legend()
plt.show()

if test:
    test_data_npz = np.load('../simulated_datasets/VdP_test.npy')

    test_tsc_list = []
    num_test_traj = test_data_npz.shape[2]
    for i in range(num_test_traj):
        test_tsc_list.append(TSCDataFrame.from_array(test_data_npz[:, :, i],
                                                     feature_names=list(str(i) for i in range(test_data_npz.shape[1]))))
    test_data = TSCDataFrame.from_frame_list(test_tsc_list)

    # evaluation on test set
    test_pred = []
    test_MSE = []
    for i in range(num_test_traj):
        pred = kirnn_model.predict(test_tsc_list[i].initial_states(), time_values=test_tsc_list[i].time_values())
        test_pred.append(pred)
        test_MSE.append(mean_squared_error(pred[1:], test_tsc_list[i][1:]))
    print(f'MSE averaged over test trajectories: \t {np.mean(test_MSE)}')

#
# use_koopman = False
# use_SWIM = False
# MSE averaged over test trajectories: 	 0.02978032122526424
# MSE averaged over test trajectories: 	 0.03271619039450429
# MSE averaged over test trajectories: 	 0.09190086639770868
# MSE averaged over test trajectories: 	 0.02933212206327753
# MSE averaged over test trajectories: 	 0.05134604971610702
# avg: 0.04701510995937235; max: 0.09190086639770868 ; min: 0.02933212206327753
#
# use_koopman = True
# use_SWIM = False
# MSE averaged over test trajectories: 	 0.0032093557131553947
# MSE averaged over test trajectories: 	 0.0021412823937094036
# MSE averaged over test trajectories: 	 19.546946155709623  # seed 3 ommitted due to value likely being an outlier
# MSE averaged over test trajectories: 	 0.02208688293428485
# MSE averaged over test trajectories: 	 0.019098118467823486
# MSE averaged over test trajectories: 	 0.6817881645883418
# avg: 0.14566476081946297; max: 0.6817881645883418, min: 0.0021412823937094036
#
# use_koopman = False
# use_SWIM = True
# MSE averaged over test trajectories: 	 0.03329407351992163
# MSE averaged over test trajectories: 	 0.03329809339789158
# MSE averaged over test trajectories: 	 0.03248613489717295
# MSE averaged over test trajectories: 	 0.03325770458804691
# MSE averaged over test trajectories: 	 0.03420085898500914
# avg: 0.033307373077608435, max: 0.03420085898500914, min: 0.03248613489717295
#
# use_koopman = True
# use_SWIM = True
# MSE averaged over test trajectories: 	 0.0009198011534011048
# MSE averaged over test trajectories: 	 0.0012817655676518555
# MSE averaged over test trajectories: 	 0.0008009498814094368
# MSE averaged over test trajectories: 	 0.0007089266434631788
# MSE averaged over test trajectories: 	 0.0010640042193456067
# avg: 0.0009550894930542367, max: 0.0012817655676518555, min: 0.0007089266434631788
