# This script contains the ablation study for the SVD layer for Frozen-pinn-elm.

import sys
sys.path.append('../../../../')
sys.path.append('../../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import BurgersSolver
import numpy as np
from sklearn.metrics import mean_squared_error
import scipy.io
from matplotlib import ticker
import time
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'jet'

# Load and visualize data
data = scipy.io.loadmat('../../../../data/burgers_shock.mat')
t_eval = data['t'].flatten()[:,None]
x_eval = data['x'].flatten()[:,None]
u_exact = np.real(data['usol']).T
X, T = np.meshgrid(x_eval,t_eval)
X_ = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))

# Set ground truth
u_true = u_exact.flatten()[:,None]              

# initial condition
def u0(x):
    return -1 * np.sin(np.pi * x)

# forcing
def forcing(x, t):
    return np.zeros(x.shape[0])

# boundary condition
boundary_condition = "zero dirichlet" # 

# Domain information and spacial points for the first time-block
n_points_1d = 4000 # No. of points in space
x_lim = [-1, 1] # Domain range

# Interior points
rng = np.random.default_rng(seed=123)
x_space = rng.uniform(x_lim[0], x_lim[1], n_points_1d).reshape((-1, 1)) 
x_space_inner = x_space[1:-1]
interior_points = x_space_inner

# Boundary points (excluding corners)
left = x_lim[0]
right = x_lim[1]
boundary_points = np.row_stack([left, right])

# Hyper-parameters
n_sample = 6000 # No. of sampling points (for computing gradients)
n_col = 3000 # No. of collocation points (to be re-sampled)
width = 2000 # Width
reg_const = 1e-7 # Regularization constant
svd_cutoff = 1e-10 # SVD threshold
seeds = [1, 2, 3] # Seeds (to compute mean errors)
time_blocks = 1 # Number of time-blocks for smapling
info = [] # List to store errors and time measurements

# Compute prob. distribution for (re)-sampling collocation points
def collocation_points_probabilities(df_dx):
    gradients = np.abs(df_dx)
    gradients = gradients + 0.05 * np.max(gradients)
    return gradients/np.sum(gradients)

# Points where gradient of the solution at the end of a time-block is computed
sample_test_points = np.sort(rng.uniform(x_lim[0] + 1e-4, x_lim[1] - 1e-4, n_sample)).reshape((-1, 1)) # This does not include boundary points

# Domain
domain = Domain(
    interior_points=interior_points,
    boundary_points=boundary_points,
    sample_points = sample_test_points
)

# Parameter sampler for sampling of weights and biases using data-agnostic distribution
def parameter_sampler_uniform(x, y, rng):
    """
        returns: weights, biases, idx_from, idx_to
    """
    n_dim = 1
    x_left = -3
    x_right = 3
    # Sample weights from a normal distribution
    weights = rng.normal(size=(n_dim, width)) #rng.uniform(x_left, x_right, size=(n_dim, n_OBF))
    # Sample biases from a uniform distribution
    biases = rng.uniform(x_left, x_right, size=(1, width))
    idx_from = np.arange(width)
    idx_to = np.arange(width)
    return weights, biases, idx_from, idx_to

# Loop over different seeds
rel_err_elm = np.ones((len(seeds), ))
time_elm = np.ones((len(seeds), ))
rmse_elm = np.ones((len(seeds), ))
j = 0

# Loop over different seeds
for seed in seeds: # Run for 3 seeds 
    svd_on= False                         

    # frozen-pinn-elm network ansatz                     
    ansatz_elm = BasicAnsatz(
        n_neurons=width,
        activation="tanh",
        random_state=seed,
        regularization_scale=reg_const,
        parameter_sampler = parameter_sampler_uniform  
    )
    # Burger solver
    burgers_solver_elm = BurgersSolver(
        domain=domain, 
        ansatz=ansatz_elm,
        u0=u0,
        boundary_condition=boundary_condition,
        forcing=forcing,
        regularization_scale=reg_const,
        c=(0.01/np.pi)
    )
    # ELM fit
    t_elm_start = time.time()
    sol_elm, solver_status_elm = burgers_solver_elm.fit_time_blocks(t_span=[0, np.max(t_eval)], rtol=1e-8, atol=1e-8, svd_cutoff=svd_cutoff, time_blocks=time_blocks, prob_distr_resampling = collocation_points_probabilities, n_col=n_col, outer_basis=False, svd_on=svd_on);
    t_elm_stop = time.time()
    time_elm[j] = t_elm_stop - t_elm_start
    # Evaluate ELM-ODE predictions
    u_elm = (burgers_solver_elm.evaluate_blocks(x_eval= x_eval, t_eval = t_eval, time_blocks = time_blocks, solver_status = solver_status_elm, svd_on=svd_on)).T
    # Compute metrics
    mse_elm = mean_squared_error(u_true, u_elm.flatten()[:,None])  # mean squared error
    rmse_elm[j] = np.sqrt(mse_elm)  # Root Mean Squared Error
    rel_err_elm[j] = np.linalg.norm(u_true-u_elm.flatten()[:,None], 2)/np.linalg.norm(u_true,2)
    print("rmse_elm, re_elm")
    print(rmse_elm[j], rmse_elm[j], rel_err_elm[j], rel_err_elm[j])
    if svd_on:
        print('Last layer width (after SVD): ', sol_elm._V_a.shape[0])

    j += 1

info.append(np.mean(time_elm))
info.append(np.mean(rmse_elm))
info.append(np.std(rmse_elm))
info.append(np.mean(rel_err_elm))
info.append(np.std(rel_err_elm))
print(info)


# Print errors and time measurements for Burgers with frozen-pinn-elm
res = np.vstack(info).reshape(-1)

# Burgers time measurements
print('Burgers equation: Errors and time measurements using frozen-pinn-elm')
print('training time for frozen-pinn-elm = ', res[-5])
print('rmse frozen-pinn-elm = ', res[-4], '+-', res[-3])
print('rel error frozen-pinn-elm = ', res[-2], '+-', res[-1])


# True and model solutions
error_u = np.linalg.norm(u_true-u_elm.flatten()[:,None]  ,2)/np.linalg.norm(u_true,2)
print('Relative L2 error on the entire spatio-temporal domain: %e' % (error_u))

# visualize the solution
fig, ax = plt.subplots(1, 3, figsize=(5, 4), constrained_layout=True)
fontsize = 14
extent = [0 , 1, np.min(x_space), np.max(x_space)]
aspect = 0.3
sol_img1 = ax[0].imshow(u_exact.T, extent=extent, origin='lower', aspect=aspect)#, vmin=vlim[0], vmax=vlim[1]
sol_img2 = ax[1].imshow(u_elm.T, extent=extent, origin='lower', aspect=aspect)#, vmin=vlim[0], vmax=vlim[1]
error_img = ax[2].imshow(abs(u_elm - u_exact).T,  extent=extent, origin='lower', aspect=aspect)
ax[0].axvline(x=0.25, color='k', linestyle='--', linewidth=2)
ax[0].axvline(x=0.5, color='k', linestyle='--', linewidth=2)
ax[0].axvline(x=0.75, color='k', linestyle='--', linewidth=2)
ax[1].axvline(x=0.33, color='gray', linestyle='dotted', linewidth=3)
ax[1].axvline(x=0.66, color='gray', linestyle='dotted', linewidth=3)
cbar_true = fig.colorbar(sol_img1, ax=ax[0], location='bottom')
cbar_elm = fig.colorbar(sol_img2, ax=ax[1], location='bottom')
cbar_err = fig.colorbar(error_img, ax=ax[2], location='bottom', format='%.0e', fraction=0.046)

tick_locator_err = ticker.MaxNLocator(nbins=2)
cbar_err.locator = tick_locator_err
cbar_err.update_ticks()

tick_locator_elm = ticker.MaxNLocator(nbins=2)
cbar_elm.locator = tick_locator_elm
cbar_elm.update_ticks()

tick_locator_true = ticker.MaxNLocator(nbins=2)
cbar_true.locator = tick_locator_true
cbar_true.update_ticks()

ax[0].set_title('Ground truth')
ax[1].set_title('ELM-ODE')
ax[2].set_title('Absolute error')
fig.savefig("burgers_elm_ode_1.pdf")

fig, ax = plt.subplots(1, 3, figsize=(6,3), constrained_layout=True)
ax[0].plot(x_eval,u_exact[25,:], 'b-', linewidth = 2, label = 'Ground truth')       
ax[0].plot(x_eval,u_elm[25,:], 'r--', linewidth = 2, label = 'ELM-ODE (resampling)')
ax[0].set_xlabel('$x$', fontsize = 14)
ax[0].set_ylabel('$u(t,x)$', fontsize = 14)    
ax[0].set_title('$t = 0.25$', fontsize = 14)
ax[0].axis('square')
ax[0].set_xlim([-1.1,1.1])
ax[0].set_ylim([-1.1,1.1])
    
ax[1].plot(x_eval,u_exact[50,:], 'b-', linewidth = 2, label = 'Ground truth')       
ax[1].plot(x_eval,u_elm[50,:], 'r--', linewidth = 2, label = 'ELM-ODE (resampling)')
ax[1].set_xlabel('$x$', fontsize = 14)
ax[1].set_ylabel('$u(t,x)$', fontsize = 14)
ax[1].axis('square')
ax[1].set_xlim([-1.1,1.1])
ax[1].set_ylim([-1.1,1.1])
ax[1].set_title('$t = 0.50$', fontsize = 14)
    
ax[2].plot(x_eval,u_exact[75,:], 'b-', linewidth = 2, label = 'Ground truth')       
ax[2].plot(x_eval,u_elm[75,:], 'r--', linewidth = 2, label = 'ELM-ODE (resampling)')
ax[2].set_xlabel('$x$', fontsize = 14)
ax[2].set_ylabel('$u(t,x)$', fontsize = 14)
ax[2].axis('square')
ax[2].set_xlim([-1.1,1.1])
ax[2].set_ylim([-1.1,1.1])    
ax[2].set_title('$t = 0.75$', fontsize = 14)

# Create a single legend for all plots
fig.legend(*ax[1].get_legend_handles_labels(),loc='upper center', ncol=2, fontsize = 12, frameon=False)
fig.tight_layout()
fig.savefig("burgers_elm_ode_2.pdf", bbox_inches='tight')
    


