# This script contains the ablation study for the SVD layer of frozen-PINN-swim. 
# In this script, we do USE the SVD layer to study it's effect in comparison to the case when we use an SVD layer. 

# Imports
import sys
sys.path.append('../../')
sys.path.append('../../src')
from sklearn.preprocessing import MinMaxScaler
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Nonlinear_Diffusion_Solver
import numpy as np
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
np.random.seed(2)
import time
cmap = cm.jet

# Read and scale data
scaler = MinMaxScaler(feature_range=(0.65, 0.9))

# Load and preprocess training data
data_boundary = np.load('../../../data/nature_inspired_dataset/nature_inspired_boundary_1000.npy')
data_interior = np.load('../../../data/nature_inspired_dataset/nature_inspired_interior_1000.npy')
data_full = np.vstack((data_boundary, data_interior))
data_new = scaler.fit_transform(data_full) 
data_boundary = data_new[:np.shape(data_boundary)[0], :]
data_interior = data_new[np.shape(data_boundary)[0]::, :]

# Load and preprocess test data
data_eval = np.load('../../../data/nature_inspired_dataset/nature_inspired_eval.npy')
x_eval = scaler.fit_transform(data_eval)

# initial condition
def u0(x):
    return np.sin(np.pi * x[:, 0]) * x[:, 1]**-3  

# forcing
def forcing(x, t):    
    return np.exp(-t) * np.sin(np.pi * x[:, 0]) * x[:, 1]**(-3) * (-1. - (-1) * np.exp(-t) * x[:, 1]**(-5) * np.sin(np.pi*x[:, 0]) * (-12 + np.pi**2 * x[:, 1]**2) )

# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    return np.exp(-t) * np.sin(np.pi * x[:, 0]) * x[:, 1]**-3  
    
# Test data
t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain
x_train = data_interior # space domain
u_true =  analytical_sol(x_eval, t_eval)
u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))

# Visualize the true solution
timesteps = [0, 30, 60, 99]

seeds = [1, 2, 3]
experiments = []
width = 500 # Width
reg_const = 1e-15 # Regularization constant
svd_cutoff = 1e-15
rtol = 1e-6
atol = 1e-6
# Loop over different seeds
rmse_swim = np.ones((len(seeds), ))
rel_err_swim = np.ones((len(seeds), ))
time_swim = np.ones((len(seeds), ))
j = 0
info = []
for seed in seeds:
    svd_on= True                         

    ansatz_swim = BasicAnsatz(
        n_neurons=width,
        activation="tanh",
        random_state=seed,
        regularization_scale=reg_const,
        parameter_sampler = "tanh" 
    )   
    # Interior points
    normal_vectors = data_boundary.copy()
    
    # Domain
    domain = Domain(
        interior_points=data_interior,
        boundary_points=data_boundary,
        normal_vectors=normal_vectors
    )
    
    nonlinear_diffusion_solver_swim = Nonlinear_Diffusion_Solver(
        domain=domain, 
        ansatz=ansatz_swim,
        u0=u0,
        boundary_condition=boundary_condition,
        forcing=forcing,
        regularization_scale=reg_const,
        scale_boundary_correction=1.,
        boundary_condition_true=analytical_sol
    )
    # Compute weights and biases of the SWIM network
    time_blocks = 1
    ic_eval = u0(domain.interior_points)
    t_swim_start = time.time()
    sol_swim, solver_status_swim = nonlinear_diffusion_solver_swim.fit(t_span=[0, np.max(t_eval)], 
                                            rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                            outer_basis=False,
                                            init_cond=ic_eval,
                                            svd_on=svd_on);
    t_swim_stop = time.time()
    time_swim[j] = t_swim_stop - t_swim_start

    # Evaluate on test data
    u_swim_test = nonlinear_diffusion_solver_swim.evaluate(x_eval=x_eval, t_eval = t_eval,svd_on=svd_on).T #, solver_status=solver_status
                
    # Compute metrics
    rmse_swim[j] = np.sqrt(mean_squared_error(u_true, u_swim_test))  # mean squared error
    rel_err_swim[j] = rmse_swim[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

    # Compute metrics
    rmse_swim[j] = np.sqrt(mean_squared_error(u_true, u_swim_test))
    rel_err_swim[j] = rmse_swim[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
    info.append(time_swim[j])
    info.append(rmse_swim[j])
    print('time=', time_swim[j], 'rmse_elm=', rmse_swim[j], 'rel_err_elm=',rel_err_swim[j])
    j += 1
    if svd_on:
        print('Last layer width (after SVD): ', sol_swim._V_a.shape[0])

print('swim-ode time = ', np.mean(time_swim))
print('rmse swim-ode = ',np.mean(rmse_swim), '+-', np.std(rmse_swim))
print('rel l-2 error swim-ode = ',np.mean(rel_err_swim), '+-', np.std(rel_err_swim))
experiments.append(info)
