
# Imports
import sys
sys.path.append('../../../')
sys.path.append('../../../src')
sys.path.append('../../')
sys.path.append('../../../../')
sys.path.append('../../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Reaction_Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from examples.utils import *
from mpl_toolkits.mplot3d import Axes3D
from contextlib import redirect_stdout

# Set seeds
np.random.seed(2)
rng = np.random.default_rng(seed=123)

# ### Problem Setup

filename = f"dim_10.npy"
filename_text = f"dim_10_2.txt"


# Train and test boundary points
d = 10 # Dimensions
n_b_train = 15000 
n_b_test = 5000  
n_int_train = 15000 
n_int_test = 5000

# Train and test boundary points
X_b_train, boundary_labels = sample_boundary_lhs(d, n_b_train, bounds=(-1,1))

# Train interior points
X_int_train = sample_interior_lhs(d, n_int_train, bounds=(-1,1)) 
                                
# Test interioir points
x_0 = np.linspace(-1, 1, 100, endpoint=True)
x_1 = np.linspace(-1, 1, 100, endpoint=True)
xx, yy = np.meshgrid(x_0, x_1)
np.random.seed(2)
rng = np.random.default_rng(seed=123)
X_test = rng.uniform(low=-1, high=1, size=(100 * 100, d)) 
X_test[:, 0] = xx.reshape(-1)
X_test[:, 1] = yy.reshape(-1)

#### Problem setup

# initial condition
def u0(x):
    return 2 * np.sin(np.pi/2 * x[:, 0]) * np.cos(np.pi/2 * x[:, 1])

# forcing
def forcing(x, t):    
    return (np.pi**2 - 2) * np.exp(-t) * np.sin(np.pi/2 * x[:, 0]) * np.cos(np.pi/2 * x[:, 1]) - 4. * np.exp(-2*t) * ((np.sin(np.pi/2 * x[:, 0]))*(np.sin(np.pi/2 * x[:, 0]))) * np.cos(np.pi/2 * x[:, 1])


# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    return 2 * np.sin(np.pi/2 * x[:, 0]) * np.cos(np.pi/2 * x[:, 1]) * np.exp(-t)

# Test data
t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain
t_eval_test = np.array([1]).reshape(-1, 1, 1) # time domain
x_train = X_int_train # space domain

u_true =  analytical_sol(X_test, t_eval_test)
u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))

u_true_train =  analytical_sol(x_train, t_eval)
u_true_train = np.reshape(u_true_train, (np.shape(u_true_train)[0], np.shape(u_true_train)[2]))

# Visualize the true solution
timesteps = [0, 30, 60, 99]

# %%
def sample_parameters_randomly(x, _, rng):
    weights = rng.normal(loc=0, scale=1, size=(x.shape[1], width))
    biases = rng.uniform(low=-2 * np.pi, high=2 * np.pi, size=(1, width))
    idx0 = None
    idx1 = None
    return weights, biases, idx0, idx1
seeds = [2]
experiments = []
widths = [2000, 3000, 4000, 5000, 6000, 8000] #
reg_consts = [1e-6, 1e-8, 1e-10]
param_samplers = ['random', 'tanh'] #, sample_parameters_randomly

info = []

with open(filename_text, 'w') as f:
    with redirect_stdout(f):
        for width in widths:
            for reg_const in reg_consts:
                svd_cutoff = reg_const
                rtol = 100 * reg_const
                atol = 100 * reg_const
                for param_sampler in param_samplers:
                    j = 0
                    rmse_elm = np.ones((len(seeds), ))
                    rel_err_elm = np.ones((len(seeds)))
                    rmse_elm_train = np.ones((len(seeds)))
                    rel_err_elm_train = np.ones((len(seeds)))
                    time_elm = np.ones((len(seeds)))
                    for seed in seeds:
                        # Set seeds
                        np.random.seed(2)
                        rng = np.random.default_rng(seed=123)
                        # Parameter sampler for ELM: Sample weights from a normal distribution and biases uniformly from [-4, 4]
                        ansatz_elm = BasicAnsatz(
                            n_neurons=width,
                            activation="tanh",
                            random_state=seed,
                            regularization_scale=reg_const,
                            parameter_sampler = param_sampler
                        )  
                        # Interior points
                        normal_vectors = X_b_train.copy()
                        
                        # Domain
                        domain = Domain(
                            interior_points=X_int_train,
                            boundary_points=X_b_train,
                            normal_vectors=normal_vectors,
                            sample_points = X_int_train
                        )
                        
                        reaction_diffusion_solver_elm = Reaction_Diffusion_Solver(
                            domain=domain, 
                            ansatz=ansatz_elm,
                            u0=u0,
                            boundary_condition=boundary_condition,
                            forcing=forcing,
                            regularization_scale=reg_const,
                            scale_boundary_correction=1.,
                            boundary_condition_true=analytical_sol
                        )
                        # Compute weights and biases of the elm network
                        time_blocks = 3
                        ic_eval = u0(domain.interior_points)
                        t_elm_start = time.time()
                        
                        sol_elm, solver_status_elm = reaction_diffusion_solver_elm.fit(t_span=[0, np.max(t_eval)], 
                                                                rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                                outer_basis=False,
                                                                init_cond=ic_eval);
                        t_elm_stop = time.time()
                        time_elm[j] = t_elm_stop - t_elm_start

                        # Evaluate on test data
                        u_elm_test = reaction_diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval_test).T #, solver_status=solver_status
                        u_elm_train = reaction_diffusion_solver_elm.evaluate(x_eval=x_train, t_eval = t_eval).T #, solver_status=solver_status
                                    
                        # Compute metrics
                        rmse_elm[j] = np.sqrt(mean_squared_error(u_true, u_elm_test))  # mean squared error
                        rel_err_elm[j] = rmse_elm[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
                        
                        rmse_elm_train[j] = np.sqrt(mean_squared_error(u_true_train, u_elm_train))  # mean squared error
                        rel_err_elm_train[j] = rmse_elm_train[j]/np.sqrt(mean_squared_error(u_true_train, np.zeros_like(u_true_train)))

                        # Compute metrics
                        info.append(time_elm[j])
                        info.append(rmse_elm[j])
                        j += 1

                    # Train 
                    print('-------------------------------------------------------------------------')
                    print('Width: ', width, 'param_sampler: ', param_sampler, 'reg_const', reg_const, 'atol', atol)
                    print('-------------------------------------------------------------------------')
                    print('Train: Frozen-pinn-swim time = ', np.mean(time_elm))
                    print('Train: rmse Frozen-pinn-swim = ',np.mean(rmse_elm_train), '+-', np.std(rmse_elm_train))
                    print('Train: rel l-2 error Frozen-pinn-swim = ',np.mean(rel_err_elm_train), '+-', np.std(rel_err_elm_train))

                    # Test
                    print('Test: rmse Frozen-pinn-swim = ',np.mean(rmse_elm), '+-', np.std(rmse_elm))
                    print('Test: rel l-2 error Frozen-pinn-swim = ',np.mean(rel_err_elm), '+-', np.std(rel_err_elm))
                    print('-------------------------------------------------------------------------\n')

                    experiments.append(info)

res = np.vstack(experiments)
np.save(filename, res) # save
