# %% [markdown]
# # Solving nonlinear diffusion equation using ELM-ODE and SWIM-ODE

# %%
# Imports
import sys
sys.path.append('../')
sys.path.append('../../../')
sys.path.append('../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from examples.utils import *
from mpl_toolkits.mplot3d import Axes3D
#from exautils import *

from contextlib import redirect_stdout



# %% [markdown]
# ### Problem Setup

# %%
# List of all hyper-parameters
ratio = 5 # no. of data points/width 
#reg_const = 1e-12 
#reg_svd_constants = [1e-6, 1e-8, 1e-10] 
#dimensions = [5] # Spatial dimensions
#widths = [500, 1000, 2000, 4000] # Hidden layer widths
#rms = [0.05]
reg_svd_constants = [1e-12] #, 1e-8, 1e-10, 1e-12
dimensions = [7] # Spatial dimensions
widths = [4000] # Hidden layer widths
rms = [0.05, 0.005, 0.1] #
rtol = 1e-7 # Relative tolerance for the ODE solver
atol = 1e-7 # Absolute tolerance for the ODE solver
filename = f"data_svd/dim_7_4_svd.npy"
filename_text = f"data_svd/dim_7_4_svd.txt"

# initial condition
def u0(x):
    space_dim = np.shape(x)[1]
    return np.cos(np.sum(x, axis=1)/space_dim)

# forcing
def forcing(x, t):    
    space_dim_inv = 1./np.shape(x)[1]
    return np.exp(-t) * np.cos(np.sum(x, axis=1) * space_dim_inv) * (space_dim_inv - 1)

# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    return np.exp(-t) * np.cos(np.sum(x, axis=1)/np.shape(x)[1]) 

# Test data
n_int_test = 8000  # Internal points
n_b_test = 2000  # Boundary points

# %% [markdown]
# ### Problem setup

# %%
with open(filename_text, 'w') as f:
    with redirect_stdout(f):
        experiments = []
        for d in dimensions:   
            for w in widths:
                for reg_svd in reg_svd_constants:
                    for r_m in rms:
                        # Set seeds
                        np.random.seed(2)
                        rng = np.random.default_rng(seed=123)
                        svd_on = True

                        seed = 1
                        info = []
                        reg_const = reg_svd
                        svd_cutoff = reg_svd
                        
                        # Train: interior and boundary points
                        n_training = ratio * w
                        n_b_train = w # 20 % of total training points 
                        n_int_train = 4 * w   # 80% of total training points 
                        X_b_train, boundary_labels = sample_boundary_lhs(d, n_b_train, bounds=(-1,1))
                        X_int_train = sample_interior_lhs(d, n_int_train, bounds=(-1,1))

                        # Test data
                        X_b_test, boundary_labels_test = sample_boundary_lhs(d, n_b_test, bounds=(-1,1))
                        X_int_test = sample_interior_lhs(d, n_int_test, bounds=(-1,1))
                        X_test = np.vstack((X_int_test, X_b_test)) # Inernal + boundary points
                        X_train = np.vstack((X_int_train, X_b_train))
                        t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain

                        # Ground truth test data
                        u_true =  analytical_sol(X_test, t_eval)
                        u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))

                        # Ground truth test data
                        u_true_train =  analytical_sol(X_train, t_eval)
                        u_true_train = np.reshape(u_true_train, (np.shape(u_true_train)[0], np.shape(u_true_train)[2]))

                        # Parameter sampler for ELM: Sample weights from a normal distribution and biases uniformly from [-4, 4]
                        def sample_parameters_randomly(x, _, rng):
                            #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))
                            weights = rng.uniform(low=-1.*r_m, high=r_m, size=(x.shape[1], w)) # low=-np.pi, high=np.pi,  2 * np.pi
                            biases = rng.uniform(low=-1.*r_m, high=r_m, size=(1, w)) # low=-np.pi, high=np.pi,  2 * np.pi
                            idx0 = None
                            idx1 = None
                            return weights, biases, idx0, idx1
                        
                        ansatz_swim = BasicAnsatz(
                            n_neurons=w,
                            activation="tanh",
                            random_state=seed,
                            regularization_scale=reg_const,
                            parameter_sampler = "tanh" 
                        )   
                        
                        ansatz_elm = BasicAnsatz(
                            n_neurons=w,
                            activation="tanh",
                            random_state=seed,
                            regularization_scale=reg_const,
                            parameter_sampler = sample_parameters_randomly
                        )  

                        # Interior points
                        normal_vectors = X_b_train.copy()
                        
                        # Domain
                        domain = Domain(
                            interior_points=X_int_train,
                            boundary_points=X_b_train,
                            normal_vectors=normal_vectors
                        )
                        diffusion_solver_swim = Diffusion_Solver(
                        domain=domain, 
                        ansatz=ansatz_swim,
                        u0=u0,
                        boundary_condition=boundary_condition,
                        forcing=forcing,
                        regularization_scale=reg_const,
                        scale_boundary_correction=1.,
                        boundary_condition_true=analytical_sol
                    )
                        
                        diffusion_solver_elm = Diffusion_Solver(
                            domain=domain, 
                            ansatz=ansatz_elm,
                            u0=u0,
                            boundary_condition=boundary_condition,
                            forcing=forcing,
                            regularization_scale=reg_const,
                            scale_boundary_correction=1.,
                            boundary_condition_true=analytical_sol
                        )
                        # Compute weights and biases of the elm network
                        time_blocks = 1
                        ic_eval = u0(domain.interior_points)

                        # SWIM 
                        t_swim_start = time.time()
                        sol_swim, solver_status_swim = diffusion_solver_swim.fit(t_span=[0, np.max(t_eval)], 
                                                                rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                                outer_basis=False,
                                                                init_cond=ic_eval,
                                                                svd_on=svd_on);
                        t_swim_stop = time.time()
                        time_swim = t_swim_stop - t_swim_start

                        # Evaluate on test data
                        u_swim_test = diffusion_solver_swim.evaluate(x_eval=X_test, t_eval = t_eval, svd_on=svd_on).T #, solver_status=solver_status
                        u_swim_train = diffusion_solver_swim.evaluate(x_eval=X_train, t_eval = t_eval, svd_on=svd_on).T #, solver_status=solver_status
                                    
                        # Compute metrics
                        rmse_swim = np.sqrt(mean_squared_error(u_true, u_swim_test))  # mean squared error
                        rel_err_swim = rmse_swim/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

                        # Compute metrics (Train)
                        rmse_swim_train = np.sqrt(mean_squared_error(u_true_train, u_swim_train))  # mean squared error
                        rel_err_swim_train = rmse_swim_train/np.sqrt(mean_squared_error(u_true_train, np.zeros_like(u_true_train)))

                        info.append(d)
                        info.append(w)
                        info.append(time_swim)
                        info.append(rmse_swim)
                        info.append(rel_err_swim)

                        # ELM
                        t_elm_start = time.time()
                        sol_elm, solver_status_elm = diffusion_solver_elm.fit(t_span=[0, np.max(t_eval)], 
                                                                rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                                outer_basis=False,
                                                                init_cond=ic_eval,
                                                                svd_on=svd_on
                                                                );
                        t_elm_stop = time.time()
                        time_elm = t_elm_stop - t_elm_start

                        # Evaluate on test data
                        u_elm_test = diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval, svd_on=svd_on).T #, solver_status=solver_status
                        u_elm_train = diffusion_solver_elm.evaluate(x_eval=X_train, t_eval = t_eval, svd_on=svd_on).T #, solver_status=solver_status
                                    
                        # Test: Compute metrics
                        rmse_elm = np.sqrt(mean_squared_error(u_true, u_elm_test))  # mean squared error
                        rel_err_elm = rmse_elm/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

                        # Train: Compute metrics
                        rmse_elm_train = np.sqrt(mean_squared_error(u_true_train, u_elm_train))  # mean squared error
                        rel_err_elm_train = rmse_elm_train/np.sqrt(mean_squared_error(u_true_train, np.zeros_like(u_true_train)))


                        # Compute metrics
                        rmse_elm = np.sqrt(mean_squared_error(u_true, u_elm_test))
                        rel_err_elm = rmse_elm/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
                        info.append(time_elm)
                        info.append(rmse_elm)
                        info.append(rel_err_elm)
                        info.append(reg_svd)
                        info.append(r_m)
                        if svd_on:
                            info.append(sol_elm._V_a.shape[0])
                            info.append(sol_swim._V_a.shape[0])
                        print('-----------------------------------------------------------')
                        print('dim = ',d,'width = ',w,  'reg_svd = ', reg_svd, 'r_m', r_m)
                        if svd_on:
                            print('w_outer_swim = ',sol_swim._V_a.shape[0],  'w_outer_elm = ', sol_elm._V_a.shape[0])
                        print('rmse_swim =', rmse_swim, 'rel_err_swim =',rel_err_swim,'time =', time_swim)
                        print('rmse_swim (train) =', rmse_swim_train, 'rel_err_swim (train)= ',rel_err_swim_train, '\n')                        
                        print('rmse_elm =', rmse_elm, 'rel_err_elm = ',rel_err_elm, 'time =', time_elm)
                        print('rmse_elm (train) =', rmse_elm_train, 'rel_err_elm (train)= ',rel_err_elm_train)
                        print('-----------------------------------------------------------\n')
                        experiments.append(info)

res = np.vstack(experiments)
np.save(filename, res) # save



res_reshape = res.reshape((len(dimensions), len(widths), len(reg_svd_constants), len(rms), -1))
res_reshape.shape

cmap = plt.cm.get_cmap("jet", len(widths))  # Use 'viridis' with a color for each row
fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(3, 3), sharey=True)

# For each dimension, for each width, I want to store the minimum of swim and elm losses
for d in range(len(dimensions)):
    re_swim_w = np.ones((len(widths), ))
    re_elm_w = np.ones((len(widths), ))
    w_o_swim_w = np.ones((len(widths), ))
    w_o_elm_w = np.ones((len(widths), ))
    filename_widths_errors = f'data_svd/{dimensions[d]}d_svd_3.npy'
    info = []
    for w in range(len(widths)):
        rel_err_swim = res_reshape[d, w, :, :, 4].min()
        rel_err_elm = res_reshape[d, w, :, :, 7].min()
        #print(dimensions[d], widths[w], rel_err_swim, rel_err_elm)

        res_reshape_swim = res_reshape[d, w, :, :, 4]
        rel_err_swim_ind = np.unravel_index(np.argmin(res_reshape_swim, axis=None), res_reshape_swim.shape) #res_reshape_swim.argmin()
        
        res_reshape_elm = res_reshape[d, w, :, :, 7]
        rel_err_elm_ind = np.unravel_index(np.argmin(res_reshape_elm, axis=None), res_reshape_elm.shape) #res_reshape_swim.argmin()

        w_o_swim = res_reshape[d, w, rel_err_swim_ind[0], rel_err_swim_ind[1], -2]#[rel_err_swim_ind][0]
        w_o_elm = res_reshape[d, w, rel_err_elm_ind[0], rel_err_elm_ind[1], -1]#[rel_err_elm_ind][0]

        re_swim_w[w] = rel_err_swim
        re_elm_w[w] = rel_err_elm
        w_o_swim_w[w] = w_o_swim        
        w_o_elm_w[w] = w_o_elm
    
    info.append(widths)
    info.append(re_swim_w)
    info.append(re_elm_w)
    info.append(w_o_swim_w)
    info.append(w_o_elm_w)
    
    # Plot convergence plots
    ax.semilogy(widths, re_swim_w, linestyle='--', color=cmap(d),label=f"d = {dimensions[d]}")
    ax.semilogy(widths, re_elm_w, color=cmap(d))
    ax.set_xlabel('Width of hidden layer', fontsize=fontsize)
    ax.set_ylabel(r'Relative  ' + r'$\mathbb{L}_{2}$ ' + r'error', fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.legend()
    fig.tight_layout()


    np.save(filename_widths_errors, np.vstack(info))

    print(info)




