import matplotlib.pyplot as plt
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error
from swimnetworks import (Dense, Linear)

from swim_rnn import RNN
from utils import D_stsp
from datafold import InitialCondition
from datafold import TSCDataFrame
import time
from utilities.pse import *

# torch.manual_seed(0)
np.random.seed(0)

# # hyperparams
layer_width = 200
regularization_constant = 1e-7

rng = np.random.default_rng(5)

dt = 0.01

# data-loading
training_data_npz = np.load('datasets/Lorenz_normalized_train.npy')
validation_data_npz = np.load('datasets/Lorenz_normalized_valid.npy')

test = True

training_tsc_list = []
num_training_traj = training_data_npz.shape[2]
for i in range(num_training_traj):
    training_tsc_list.append(TSCDataFrame.from_array(training_data_npz[:, :, i], feature_names=list(str(i) for i in range(training_data_npz.shape[1]))))
training_data = TSCDataFrame.from_frame_list(training_tsc_list)

validation_tsc_list = []
num_validation_traj = validation_data_npz.shape[2]
for i in range(num_validation_traj):
    validation_tsc_list.append(TSCDataFrame.from_array(validation_data_npz[:, :, i], feature_names=list(str(i) for i in range(training_data_npz.shape[1]))))
validation_data = TSCDataFrame.from_frame_list(validation_tsc_list)

steps = [
    ("hidden", Dense(layer_width=layer_width, activation='tanh', parameter_sampler='tanh',
                     random_seed=rng.integers(100, size=1))),
]
network_dictionary = Pipeline(steps=steps)

swim_rnn = RNN(network_dictionary, regularization_constant)

start_time = time.time()
swim_rnn.fit(training_data)
end_time = time.time()

elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

# evaluation on training set
training_pred = []
training_mse = []
training_D_stsp = []
training_D_h = []
# for i in range(3):
#     pred = swim_rnn.predict(training_tsc_list[i].initial_states(),
#                             time_values=training_tsc_list[i].time_values())
#     training_pred.append(pred)
#     training_mse.append(mean_squared_error(pred[1:], training_tsc_list[i][1:]))
#     training_D_stsp.append(D_stsp(X_pred=pred[1:], X_true=training_tsc_list[i][1:].to_numpy()))
#     training_D_h.append(power_spectrum_error(x_gen=pred[1:], x_true=training_tsc_list[i][1:].to_numpy()))
# training_D_stsp_np = np.array(training_D_stsp)
# training_D_stsp_np[training_D_stsp_np < 1e-6] = 0 # setting values to 0 if they are small, because due to numerical rounding they could also be negative
# print(f'MSE averaged over training trajectories: \t {np.mean(training_mse)}')
# print(f'D_stsp averaged over training trajectories: \t {np.mean(training_D_stsp_np)}')
# print(f'D_h averaged over training trajectories: \t {np.mean(training_D_h)}')

# evaluation on validation set
validation_pred = []
validation_mse = []
validation_D_stsp = []
validation_D_h = []
for i in range(num_validation_traj):
    pred = swim_rnn.predict(validation_tsc_list[i].initial_states(),
                            time_values=validation_tsc_list[i].time_values())
    validation_pred.append(pred)
    validation_mse.append(mean_squared_error(pred[1:], validation_tsc_list[i][1:]))
    validation_D_stsp.append(D_stsp(X_pred=pred[1:], X_true=validation_tsc_list[i][1:].to_numpy()))
    validation_D_h.append(power_spectrum_error(x_gen=pred[1:], x_true=validation_tsc_list[i][1:].to_numpy()))

validation_D_stsp_np = np.array(validation_D_stsp)
validation_D_stsp_np[validation_D_stsp_np < 1e-8] = 0
print(f'MSE averaged over validation trajectories: \t {np.mean(validation_mse)}')
print(f'D_stsp averaged over validation trajectories: \t {np.mean(validation_D_stsp_np)}')
print(f'D_h averaged over validation trajectories: \t {np.mean(validation_D_h)}')
print(f'Max Dstsp={np.max(validation_D_stsp_np)}')
# num_plot_traj = 1
#
# for k in rng.integers(0, 3, size=num_plot_traj):
#     fig, ax = plt.subplots(3, 1, figsize=(8, 4))
#     ax[0].plot(np.arange(validation_tsc_list[i].shape[0]) * dt, validation_tsc_list[k].to_numpy()[:,0], '-', label='true')
#     ax[0].plot(np.arange(validation_tsc_list[i].shape[0]) * dt, validation_pred[k][:,0], '--', label='prediction')
#     ax[1].plot(np.arange(validation_tsc_list[i].shape[0]) * dt, validation_tsc_list[k].to_numpy()[:,1], '-', label='true')
#     ax[1].plot(np.arange(validation_tsc_list[i].shape[0]) * dt, validation_pred[k][:,1], '--', label='prediction')
#     ax[2].plot(np.arange(validation_tsc_list[i].shape[0]) * dt, validation_tsc_list[k].to_numpy()[:,2], '-', label='true')
#     ax[2].plot(np.arange(validation_tsc_list[i].shape[0]) * dt, validation_pred[k][:,2], '--', label='prediction')
#     plt.legend()
#     plt.show()

# fig = plt.figure(figsize=(4, 4))
# ax = fig.add_subplot(1, 1, 1, projection='3d')
# for k in rng.integers(0, num_validation_traj, size=num_plot_traj):
#     ax.plot(*validation_tsc_list[k].to_numpy().T, '-', label='true')
#     ax.plot(*validation_pred[k].T, '--', label='prediction')
# ax.set_xlabel('$h_1$')
# ax.set_ylabel('$h_2$')
# ax.set_zlabel('$h_3$')
# ax.set_aspect("equal")
# ax.legend()
# plt.savefig(f'Lorenz_validation_{num_plot_traj}_phase_space.pdf')
# plt.show()

if test:
    test_data_npz = np.load('datasets/Lorenz_normalized_test.npy')

    test_tsc_list = []
    num_test_traj = test_data_npz.shape[2]
    for i in range(num_test_traj):
        test_tsc_list.append(TSCDataFrame.from_array(test_data_npz[:, :, i], feature_names=list(str(i) for i in range(training_data_npz.shape[1]))))
    test_data = TSCDataFrame.from_frame_list(test_tsc_list)

    # evaluation on test set
    test_pred = []
    test_mse = []
    test_D_stsp = []
    test_D_h = []
    for i in range(num_test_traj):
        pred = swim_rnn.predict(test_tsc_list[i].initial_states(), time_values=test_tsc_list[i].time_values())
        test_pred.append(pred)
        test_mse.append(mean_squared_error(pred[1:], test_tsc_list[i][1:]))
        test_D_stsp.append(D_stsp(X_pred=pred[1:], X_true=test_tsc_list[i][1:].to_numpy()))
        test_D_h.append(power_spectrum_error(x_gen=pred[1:], x_true=test_tsc_list[i][1:].to_numpy()))
    test_D_stsp_np = np.array(test_D_stsp)
    test_D_stsp_np[test_D_stsp_np < 1e-8] = 0
    print(f'MSE averaged over test trajectories: \t {np.mean(test_mse)}')
    print(f'D_stsp averaged over test trajectories: \t {np.mean(test_D_stsp_np)}')
    print(f'D_h averaged over test trajectories: \t {np.mean(test_D_h)}')
    print(f'Max Dstsp={np.max(test_D_stsp_np)}')

# seed = 1:
# Fit time:  1.905015468597412
# MSE averaged over validation trajectories: 	 0.45808561466266207
# D_stsp averaged over validation trajectories: 	 0.007810825524116718
# D_h averaged over validation trajectories: 	 0.2090123408254966
# Max Dstsp=0.04696804355321585
# MSE averaged over test trajectories: 	 0.42333721642968636
# D_stsp averaged over test trajectories: 	 0.004458073765154723
# D_h averaged over test trajectories: 	 0.2012964234555135
# Max Dstsp=0.028316024820638563
#
# seed = 2:
# Fit time:  1.3379480838775635
# MSE averaged over validation trajectories: 	 0.465284502574145
# D_stsp averaged over validation trajectories: 	 0.007912617703803127
# D_h averaged over validation trajectories: 	 0.20219245328441823
# Max Dstsp=0.0513857856287726
# MSE averaged over test trajectories: 	 0.4077575949123488
# D_stsp averaged over test trajectories: 	 0.0046171649783463995
# D_h averaged over test trajectories: 	 0.20177087647812086
# Max Dstsp=0.04334327101015657
#
# seed = 3:
# Fit time:  1.9211618900299072
# MSE averaged over validation trajectories: 	 0.46545866689444293
# D_stsp averaged over validation trajectories: 	 0.010970091377190816
# D_h averaged over validation trajectories: 	 0.20466282511105827
# Max Dstsp=0.08807884595319339
# MSE averaged over test trajectories: 	 0.4173356612605709
# D_stsp averaged over test trajectories: 	 0.0053629192093051625
# D_h averaged over test trajectories: 	 0.20662826486465488
# Max Dstsp=0.038197985588259244
#
# seed = 4:
# Fit time:  1.7766201496124268
# MSE averaged over validation trajectories: 	 0.4746695500299186
# D_stsp averaged over validation trajectories: 	 0.006376543706586626
# D_h averaged over validation trajectories: 	 0.21479329153110518
# Max Dstsp=0.030189508008428213
# MSE averaged over test trajectories: 	 0.41935384792323144
# D_stsp averaged over test trajectories: 	 0.00366493563439397
# D_h averaged over test trajectories: 	 0.20338815511522473
# Max Dstsp=0.02070454394094152
#
# seed = 5:
# Fit time:  1.422229528427124
# MSE averaged over validation trajectories: 	 0.45719821385694187
# D_stsp averaged over validation trajectories: 	 0.006184012667483223
# D_h averaged over validation trajectories: 	 0.2061499794089979
# Max Dstsp=0.02259059894444154
# MSE averaged over test trajectories: 	 0.41996861742300756
# D_stsp averaged over test trajectories: 	 0.00370650146861692
# D_h averaged over test trajectories: 	 0.2045480607786683
# Max Dstsp=0.02414919523550347
#
# mean fit time:
# mean D_stsp on test data: 0.004361919011163435
# mean D_h on test data: 0.20352635613843645