import time

import numpy as np
import pandas as pd
from datafold import TSCDataFrame
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.pipeline import Pipeline
from swimnetworks import Dense

from kirnn import KIRNN
from visualization_utils import plot_predictions_per_cell, plot_eigenvalues, plot_modes, plot_snapshot

# hyperparams
layer_width = 100
regularization_constant = 1e-4
# 100, 1e-4

# randomness with repeatability
rng = np.random.default_rng(5)

# configuration
dt = 0.2  # delta time
measurement_cells = ['(30_60)', '(40_60)']
visualization_timesteps = [0, 54, 104, 179, 229]
system_observables = ['density']
timeseries_data_dict = {}
data_splits = ['train', 'test']
for param in data_splits:
    for observable in system_observables:
        try:
            timeseries_data_dict[param] = TSCDataFrame.from_csv(f'./data/{observable}_{param}.csv')
        except FileNotFoundError as e:
            print(
                f"Could not load timeseries data at ./data/{observable}_{param}.csv.")

for idx in visualization_timesteps:
    grid_npy = timeseries_data_dict["train"].iloc[[idx]].to_numpy().reshape(110,80)
    plot_snapshot(grid_npy, title=str(idx))
# training
steps = [
    ("hidden",
     Dense(layer_width=layer_width,
           activation='relu',
           parameter_sampler='relu',
           random_seed=rng.integers(100, size=1))),
]
network_dictionary = Pipeline(steps=steps)

kirnn_model = KIRNN(dictionary=network_dictionary,
                    n_features_in=timeseries_data_dict['train'].shape[1],
                    rcond=regularization_constant,
                    compute_pseudospectrum=True)
start_time = time.time()
kirnn_model.fit(timeseries_data_dict['train'])
end_time = time.time()

elapsed_time = end_time - start_time
print("Fit time: ", elapsed_time)

# evaluation on train set
pred = kirnn_model.predict(timeseries_data_dict['train'].initial_states(),
                           time_values=timeseries_data_dict['train'].time_values())
mse_error = mean_squared_error(timeseries_data_dict['train'].to_numpy(), pred)
mae_error = mean_absolute_error(timeseries_data_dict['train'].to_numpy(), pred)
print(f'MSE averaged over TRAIN trajectories: \t {mse_error}')
print(f'MAE averaged over TRAIN trajectories: \t {mae_error}')

plot_predictions_per_cell(timeseries_data_dict['train'].loc[100],
                          pd.DataFrame(pred, columns=timeseries_data_dict['train'].columns), series_id=100,
                          observable='density',
                          visualization_timesteps=visualization_timesteps,
                          measurement_cells=measurement_cells)

# evaluation on test set
pred = kirnn_model.predict(timeseries_data_dict['test'].initial_states(),
                           time_values=timeseries_data_dict['test'].time_values())
mse_error = mean_squared_error(timeseries_data_dict['test'].to_numpy(), pred)
mae_error = mean_absolute_error(timeseries_data_dict['test'].to_numpy(), pred)
print(f'MSE averaged over TEST trajectories: \t {mse_error}')
print(f'MAE averaged over TEST trajectories: \t {mae_error}')

plot_predictions_per_cell(timeseries_data_dict['test'].loc[200],
                          pd.DataFrame(pred, columns=timeseries_data_dict['test'].columns), series_id=100,
                          observable='density',
                          visualization_timesteps=visualization_timesteps,
                          measurement_cells=measurement_cells,
                          is_test=True)

print(f"Visualizing predictions per cell for {observable} in {measurement_cells}")

plot_eigenvalues(model=kirnn_model,
                 compute_pseudospectrum=True)

plot_modes(model=kirnn_model,
           tsc=timeseries_data_dict['train'],
           observable_type='density',
           prediction_timeshift=0,
           grid_shape=[110, 80],
           selection=list(range(1, 6)))

# seed=1:
# Fit time:  1.615339994430542
# MSE averaged over TEST trajectories: 	 0.00022947441731416263
# MAE averaged over TEST trajectories: 	 0.008171400693440545
# seed=2:
# Fit time:  1.0019643306732178
# MSE averaged over TEST trajectories: 	 0.00012572218330904496
# MAE averaged over TEST trajectories: 	 0.006034654066810712
# seed=3:
# Fit time:  0.7973501682281494
# MSE averaged over TEST trajectories: 	 0.000367444530573591
# MAE averaged over TEST trajectories: 	 0.010435799492979717
# seed=4:
# Fit time:  1.4452214241027832
# MSE averaged over TEST trajectories: 	 0.00025102934008082336
# MAE averaged over TEST trajectories: 	 0.008441646958970405
# seed=5:
# Fit time:  1.5389597415924072
# MSE averaged over TEST trajectories: 	 7.561189583707855e-05
# MAE averaged over TEST trajectories: 	 0.0047072652258822265
#
# Average fit time: 1.28
# Average MSE (TEST): 0.0002098564734229401, max:0.000367444530573591, min: 7.561189583707855e-05
# Average MAE (TEST): 0.007558153287616722, max: 0.010435799492979717, min: 0.0047072652258822265

