"""
Training script for the finite difference parameter fitting. It relies on the pysindy library.
"""
from collections import defaultdict
import datetime
import os
import pickle

import numpy as np

import pysindy as ps
import poisson_solutions

# We need to have a custom optimizer and metric to trim off the edges of the boundary conditions. 
# The finite difference stencils for borders introduce extra error that you cannot escape from. 
# So, we train and evaluate only on the "interior" of the domain.
# To actually solve a system, you would need a way to handle the BCs.
def trim_border(X, n):
    return ps.AxesArray(np.array(X)[n:-n,...], X.axes)

def reshape_for_pysindy(x):
    return x.T.reshape(x.shape[-1], -1, 1)


def my_metric(model, data):
    f_pred = model.predict(reshape_for_pysindy(data['u']))
    r = f_pred[2:-2,...] - reshape_for_pysindy(data['f'])[2:-2,...]
    return (r**2).mean()

class MySTLSQ(ps.STLSQ):
    def __init__(self, *args, n_grid: int = 22, **kwargs):
        self.n_grid = n_grid
        super().__init__(*args, **kwargs)
    def fit(self, _X, _y):
        X = np.array(_X)
        X = X.reshape(self.n_grid, -1, X.shape[-1])
        X = X[2:-2, :, :]
        X = X.reshape(-1, X.shape[-1])
        y = np.array(_y)
        y = y.reshape(self.n_grid, -1, y.shape[-1])
        y = y[2:-2, :, :]
        y = y.reshape(-1, y.shape[-1])
        # X, y = ps.utils.base.drop_nan_samples(X, y)
        return super().fit(ps.AxesArray(X, _X.axes), ps.AxesArray(y, _y.axes))


train_data = poisson_solutions.create_dataset_dict()

sweep = defaultdict(dict)
N_grid = 22
N_seeds = 5
sparse_opt = MySTLSQ()
opt = sparse_opt
all_data = train_data[N_grid]
for p_train, dataset in all_data.items():
    for seed in range(N_seeds):
        print(f"Training FD for p={p_train}")
        flib = ps.PDELibrary(
            function_library=ps.PolynomialLibrary(
                degree=1, include_bias=False),
            derivative_order=2,
            spatial_grid=dataset['x'][0,:],
            diff_kwargs={"is_uniform": True,
                        "periodic": False,
                        "drop_endpoints":False,
                        "order":2})
        model = ps.SINDy(feature_library=flib,
                        optimizer=opt)
        model.fit(reshape_for_pysindy(dataset['u']),
                x_dot=reshape_for_pysindy(dataset['f']))
        eval_error = {}
        for p, test_data in all_data.items():
            eval_error[p] = my_metric(model, test_data)
        sweep['sparse'][p_train] = eval_error, model.coefficients()


save_path = '.'
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(save_path, f'pysindy_poisson_multiple_seed_runs_{timestamp}.pkl')
with open(filename, 'wb') as f:
    pickle.dump(sweep, f)
print(f"Saved {len(sweep)} runs to {filename}")
