import time

import numpy as np
from datafold import TSCDataFrame
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from swimnetworks import Dense

from computational_experiments.metric_utils.EKL import empirical_KL
from computational_experiments.metric_utils.PSE import power_spectrum_error
from kirnn import KIRNN

np.random.seed(0)

# hyperparams
layer_width = 300
regularization_constant = 1e-4

rng = np.random.default_rng(5)

dt = 0.01

# data-loading
training_data_npz = np.load('../simulated_datasets/Rossler_normalized_train.npy')
validation_data_npz = np.load('../simulated_datasets/Rossler_normalized_valid.npy')
test = True

training_tsc_list = []
num_training_traj = training_data_npz.shape[2]
for i in range(num_training_traj):
    training_tsc_list.append(TSCDataFrame.from_array(training_data_npz[:, :, i], feature_names=list(
        str(i) for i in range(training_data_npz.shape[1]))))
training_data = TSCDataFrame.from_frame_list(training_tsc_list)

validation_tsc_list = []
num_validation_traj = validation_data_npz.shape[2]
for i in range(num_validation_traj):
    validation_tsc_list.append(TSCDataFrame.from_array(validation_data_npz[:, :, i], feature_names=list(
        str(i) for i in range(validation_data_npz.shape[1]))))
validation_data = TSCDataFrame.from_frame_list(validation_tsc_list)

steps = [
    ("hidden", Dense(layer_width=layer_width,
                     activation='tanh',
                     parameter_sampler='tanh',
                     random_seed=rng.integers(100, size=1))),
]
network_dictionary = Pipeline(steps=steps)

kirnn_model = KIRNN(dictionary=network_dictionary,
                    n_features_in=3,
                    rcond=regularization_constant)

start_time = time.time()
kirnn_model.fit(training_data)
end_time = time.time()

elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

# evaluation on training set
training_pred = []
training_MSE = []
training_EKL = []
training_PSE = []
for i in range(num_training_traj):
    pred = kirnn_model.predict(training_tsc_list[i].initial_states(),
                               time_values=training_tsc_list[i].time_values())
    training_pred.append(pred)
    training_MSE.append(mean_squared_error(pred, training_tsc_list[i]))
    training_EKL.append(empirical_KL(X_pred=pred, X_true=training_tsc_list[i].to_numpy()))
    training_PSE.append(power_spectrum_error(x_gen=pred, x_true=training_tsc_list[i].to_numpy()))
training_EKL_np = np.array(training_EKL)
training_EKL_np[training_EKL_np < 1e-6] = 0
print(f'MSE averaged over training trajectories: \t {np.mean(training_MSE)}')
print(f'EKL averaged over training trajectories: \t {np.mean(training_EKL_np)}')
print(f'PSE averaged over training trajectories: \t {np.mean(training_PSE)}')

# evaluation on validation set
validation_pred = []
validation_MSE = []
validation_EKL = []
validation_PSE = []
for i in range(num_validation_traj):
    pred = kirnn_model.predict(validation_tsc_list[i].initial_states(),
                               time_values=validation_tsc_list[i].time_values())
    validation_pred.append(pred)
    validation_MSE.append(mean_squared_error(pred, validation_tsc_list[i]))
    validation_EKL.append(empirical_KL(X_pred=pred, X_true=validation_tsc_list[i].to_numpy()))
    validation_PSE.append(power_spectrum_error(x_gen=pred, x_true=validation_tsc_list[i].to_numpy()))
validation_EKL_np = np.array(validation_EKL)
validation_EKL_np[validation_EKL_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(validation_MSE)}')
print(f'EKL averaged over validation trajectories: \t {np.mean(validation_EKL_np)}')
print(f'PSE averaged over validation trajectories: \t {np.mean(validation_PSE)}')

if test:
    test_data_npz = np.load('../simulated_datasets/Rossler_normalized_test.npy')

    test_tsc_list = []
    num_test_traj = test_data_npz.shape[2]
    for i in range(num_test_traj):
        test_tsc_list.append(TSCDataFrame.from_array(test_data_npz[:, :, i],
                                                     feature_names=list(str(i) for i in range(test_data_npz.shape[1]))))
    test_data = TSCDataFrame.from_frame_list(test_tsc_list)

    # evaluation on test set
    test_pred = []
    test_MSE = []
    test_EKL = []
    test_PSE = []
    for i in range(num_test_traj):
        pred = kirnn_model.predict(test_tsc_list[i].initial_states(), time_values=test_tsc_list[i].time_values())
        test_pred.append(pred)
        test_MSE.append(mean_squared_error(pred, test_tsc_list[i]))
        test_EKL.append(empirical_KL(X_pred=pred, X_true=test_tsc_list[i].to_numpy()))
        test_PSE.append(power_spectrum_error(x_gen=pred, x_true=test_tsc_list[i].to_numpy()))
    test_D_stsp_np = np.array(test_EKL)
    test_D_stsp_np[test_D_stsp_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_MSE)}')
    print(f'EKL averaged over test trajectories: \t {np.mean(test_D_stsp_np)}')
    print(f'PSE averaged over test trajectories: \t {np.mean(test_PSE)}')

# seed = 1:
# Fit time:  5.140249967575073
# MSE averaged over validation trajectories: 	 0.02940165592784143
# EKL averaged over validation trajectories: 	 0.00043070953598153063
# PSE averaged over validation trajectories: 	 0.22717388872982058
# MSE averaged over test trajectories: 	 0.03132171539452584
# EKL averaged over test trajectories: 	 0.0003819032044585863
# PSE averaged over test trajectories: 	 0.2250928177344255
#
# seed = 2:
# Fit time:  4.650840520858765
# MSE averaged over validation trajectories: 	 0.020732687055289287
# EKL averaged over validation trajectories: 	 0.00012204299519220334
# PSE averaged over validation trajectories: 	 0.1958096841462998
# MSE averaged over test trajectories: 	 0.01717028125450918
# EKL averaged over test trajectories: 	 0.00011852507068670404
# PSE averaged over test trajectories: 	 0.18399414158402116
#
# seed = 3:
# Fit time:  4.391016960144043
# MSE averaged over validation trajectories: 	 0.026307814573230802
# EKL averaged over validation trajectories: 	 0.0001402812143837675
# PSE averaged over validation trajectories: 	 0.15190443958772723
# MSE averaged over test trajectories: 	 0.030581061540880805
# EKL averaged over test trajectories: 	 0.00013350165817390712
# PSE averaged over test trajectories: 	 0.15270656208549396
#
# seed = 4:
# Fit time:  6.3937599658966064
# MSE averaged over validation trajectories: 	 0.021223911963957152
# EKL averaged over validation trajectories: 	 0.00013320260882066198
# PSE averaged over validation trajectories: 	 0.1337950369782515
# MSE averaged over test trajectories: 	 0.018416397544973498
# EKL averaged over test trajectories: 	 9.314530620796075e-05
# PSE averaged over test trajectories: 	 0.12625251318801883
#
# seed = 5:
# Fit time:  6.206273794174194
# MSE averaged over validation trajectories: 	 0.02092555369935451
# EKL averaged over validation trajectories: 	 8.798172165121078e-05
# PSE averaged over validation trajectories: 	 0.10897360207036628
# MSE averaged over test trajectories: 	 0.01711909140947611
# EKL averaged over test trajectories: 	 5.8553638579771554e-05
# PSE averaged over test trajectories: 	 0.1053761889889552=
#
#
# mean fit time: 5.356428241729736
# mean EKL on test data: 0.00015712577562138598
# mean PSE on test data: 0.16017891537385517
