{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solving nonlinear diffusion equation using ELM-ODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import sys\n",
    "from utils import *\n",
    "sys.path.append('../../')\n",
    "sys.path.append('../../src')\n",
    "# from swimpde import Domain\n",
    "# from swimpde import BasicAnsatz\n",
    "# from swimpde import Reaction_Diffusion_Solver\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.cm as cm\n",
    "import time\n",
    "cmap = cm.jet\n",
    "\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "# Set seeds\n",
    "np.random.seed(2)\n",
    "rng = np.random.default_rng(seed=123)\n",
    "print(sys.path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and test boundary points\n",
    "d = 5  # Dimensions\n",
    "n_b_train = 1000  # Number of samples\n",
    "n_b_test = 2000  # Number of samples\n",
    "n_int_train = 1000 #20000 #int(4000 * d) #16000  # Number of samples\n",
    "\n",
    "# Train and test boundary points\n",
    "X_b_train, boundary_labels = sample_boundary_lhs(d, n_b_train, bounds=(-1,1))\n",
    "X_b_test, boundary_labels_test = sample_boundary_lhs(d, n_b_test, bounds=(-1,1))\n",
    "\n",
    "# Train interior points\n",
    "X_int_train = sample_interior_lhs(d, n_int_train, bounds=(-1,1)) \n",
    "                                \n",
    "# Test interioir points\n",
    "x_0 = np.linspace(-1, 1, 100, endpoint=True)\n",
    "x_1 = np.linspace(-1, 1, 100, endpoint=True)\n",
    "xx, yy = np.meshgrid(x_0, x_1)\n",
    "np.random.seed(2)\n",
    "rng = np.random.default_rng(seed=123)\n",
    "X_test = rng.uniform(low=-1, high=1, size=(100 * 100, d)) \n",
    "X_test[:, 0] = xx.reshape(-1)\n",
    "X_test[:, 1] = yy.reshape(-1)\n",
    "\n",
    "print(X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"X_test.npy\", X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_b_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_int_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and test boundary points\n",
    "d_time = 1  # Dimensions\n",
    "\n",
    "# Train and test boundary points\n",
    "t_b_train, boundary_labels = sample_boundary_lhs(d_time, n_b_train, bounds=(0,1))\n",
    "t_b_test, boundary_labels_test = sample_boundary_lhs(d_time, n_b_test, bounds=(0,1))\n",
    "\n",
    "# Train interior points\n",
    "t_int_train = sample_interior_lhs(d_time, n_int_train, bounds=(0,1)) \n",
    "                                \n",
    "# Test interioir points\n",
    "x_0 = np.linspace(-1, 1, 100, endpoint=True)\n",
    "x_1 = np.linspace(-1, 1, 100, endpoint=True)\n",
    "xx, yy = np.meshgrid(x_0, x_1)\n",
    "tt = xx*0 + 1.0\n",
    "tt = tt.reshape(-1,1)\n",
    "np.random.seed(2)\n",
    "rng = np.random.default_rng(seed=123)\n",
    "X_test = rng.uniform(low=-1, high=1, size=(100 * 100, d)) \n",
    "X_test[:, 0] = xx.reshape(-1)\n",
    "X_test[:, 1] = yy.reshape(-1)\n",
    "\n",
    "X_test = np.concatenate((tt, X_test), axis=1)\n",
    "\n",
    "print(X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_int_train.shape)\n",
    "print(t_int_train.shape)\n",
    "X_int_train = np.concatenate((t_int_train, X_int_train), axis=1)\n",
    "print(X_int_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_b_train.shape)\n",
    "print(t_b_train.shape)\n",
    "X_b_train = np.concatenate((t_b_train, X_b_train), axis=1)\n",
    "print(X_b_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_b_test.shape)\n",
    "print(t_b_test.shape)\n",
    "X_b_test = np.concatenate((t_b_test, X_b_test), axis=1)\n",
    "print(X_b_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"X_int_train.npy\", X_int_train)\n",
    "np.save(\"X_b_train.npy\", X_b_train)\n",
    "np.save(\"X_b_test.npy\", X_b_test)\n",
    "np.save(\"X_test.npy\", X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(X_int_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print all the below values\n",
    "\n",
    "train_interior\n",
    "train_boundary\n",
    "\n",
    "test_interior\n",
    "test_boundary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "# initial condition\n",
    "def u0(x):\n",
    "    return 2. * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1])\n",
    "\n",
    "# forcing\n",
    "def forcing(x, t):    \n",
    "    return (np.pi**2. - 2.) * np.exp(-t) * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1]) - 4. * np.exp(-2. * t) * ((np.sin(0.5 * np.pi * x[:, 0]))**2.) * (np.cos(0.5 * np.pi* x[:, 1])**2.)\n",
    "\n",
    "# boundary condition\n",
    "boundary_condition = \"dirichlet\"\n",
    "\n",
    "# Analytical solution\n",
    "def analytical_sol(x, t):\n",
    "    return 2. * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1]) * np.exp(-t)\n",
    "\n",
    "t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # Entire time domain\n",
    "t_eval_test = np.array([1]).reshape(-1, 1, 1) # Only final time point\n",
    "x_train = X_int_train # Space domain\n",
    "\n",
    "# Interior and boundary train points\n",
    "u_true_int_train =  analytical_sol(x_train, t_eval) # Interior points\n",
    "u_true_int_train = np.reshape(u_true_int_train, (np.shape(u_true_int_train)[0], np.shape(u_true_int_train)[2])) \n",
    "u_true_bdry_train =  analytical_sol(X_b_train, t_eval_test) # Boundary train data for final time point\n",
    "u_true_bdry_train = np.reshape(u_true_bdry_train, (np.shape(u_true_bdry_train)[0], np.shape(u_true_bdry_train)[2]))\n",
    "\n",
    "# Test data\n",
    "u_true_test_final_time_point =  analytical_sol(X_test, t_eval_test) # Interioir test data for final time point\n",
    "u_true_test_final_time_point = np.reshape(u_true_test_final_time_point, (np.shape(u_true_test_final_time_point)[0], np.shape(u_true_test_final_time_point)[2])) \n",
    "\n",
    "u_true_b_test =  analytical_sol(X_b_test, t_eval_test) \n",
    "u_true_b_test = np.reshape(u_true_b_test, (np.shape(u_true_b_test)[0], np.shape(u_true_b_test)[2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the true solution\n",
    "timesteps = [0, 30, 60, 99]\n",
    "u_true_full =  analytical_sol(X_test, t_eval)\n",
    "u_true_full = np.reshape(u_true_full, (np.shape(u_true_full)[0], np.shape(u_true_full)[2]))\n",
    "\n",
    "plot(x=X_test, u=u_true_full, timesteps=timesteps, \n",
    "           title='Ground truth (test)', cmap_offset=0,\n",
    "           savefig=True,figname='nrd_ground_truth.png',marker_size=5.0,\n",
    "           extent=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit and evaluate ELM-ODE for good hyper-parameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def gram_schmidt(B):\n",
    "    \"\"\"\n",
    "    Orthonormalizes two basis vectors using the Gram-Schmidt process.\n",
    "    \n",
    "    Parameters:\n",
    "    B : ndarray of shape (N, 2, 5) - Two basis vectors for each sample\n",
    "    \n",
    "    Returns:\n",
    "    B_orth : ndarray of shape (N, 2, 5) - Orthonormalized basis vectors\n",
    "    \"\"\"\n",
    "    u1 = B[:, 0]  # First vector (N, 5)\n",
    "    u1 /= np.linalg.norm(u1, axis=1, keepdims=True)  # Normalize\n",
    "\n",
    "    u2 = B[:, 1] - np.einsum('ij,ij->i', B[:, 1], u1)[:, np.newaxis] * u1  # Remove projection on u1\n",
    "    u2 /= np.linalg.norm(u2, axis=1, keepdims=True)  # Normalize\n",
    "\n",
    "    return np.stack((u1, u2), axis=1)  # Stack orthonormalized vectors\n",
    "\n",
    "def project_on_gradient_plane(X1, X2, grad_f):\n",
    "    \"\"\"\n",
    "    Projects displacement vectors onto the 2D plane spanned by the function's gradients at X1 and X2.\n",
    "\n",
    "    Parameters:\n",
    "    X1 : ndarray of shape (N, 5)  - First set of points (each row is a 5D point)\n",
    "    X2 : ndarray of shape (N, 5)  - Second set of points\n",
    "    grad_f : function handle      - Function that returns the gradient at a given point (supports batch input)\n",
    "\n",
    "    Returns:\n",
    "    v_proj : ndarray of shape (N, 5) - Projected vectors onto the 2D gradient plane\n",
    "    \"\"\"\n",
    "    G1 = grad_f(X1)  # Gradient at first point (N, 5)\n",
    "    G2 = grad_f(X2)  # Gradient at second point (N, 5)\n",
    "\n",
    "    # Use the two gradients as the basis for the 2D plane\n",
    "    B = np.stack((G1, G2), axis=1)  # Shape: (N, 2, 5)\n",
    "\n",
    "    # Orthonormalize the basis vectors\n",
    "    B_orth = gram_schmidt(B)  # Shape: (N, 2, 5)\n",
    "\n",
    "    # Compute displacement vectors\n",
    "    V_ij = X2 - X1  # Shape: (N, 5)\n",
    "\n",
    "    # Project displacement vectors onto the 2D plane\n",
    "    coeffs = np.einsum('nij,nj->ni', B_orth, V_ij)  # Projection coefficients (N, 2)\n",
    "    v_proj = np.einsum('ni,nij->nj', coeffs, B_orth)  # Reconstruct projected vectors (N, 5)\n",
    "\n",
    "    return v_proj\n",
    "\n",
    "# Example batch gradient function\n",
    "def gradient_u0(x):\n",
    "    #d = np.shape(x)[1]\n",
    "    grad = np.zeros_like(x)\n",
    "    grad[:, 0] =  np.pi * np.cos(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1])\n",
    "    grad[:, 1] =  - np.pi * np.sin(0.5 * np.pi * x[:, 0]) * np.sin(0.5 * np.pi * x[:, 1])\n",
    "    return grad #2. * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1])\n",
    "    #\"\"\"Example gradient function: only depends on first two dimensions.\"\"\"\n",
    "    #G = np.zeros_like(X)\n",
    "    #G[:, :2] = X[:, :2]  # Nonzero gradient only in the first two dimensions\n",
    "    #return G\n",
    "\n",
    "# Generate random sample points\n",
    "N = 3  # Number of pairs\n",
    "# np.random.seed(42)\n",
    "X1 = np.random.randn(N, 5)  # First set of points\n",
    "X2 = np.random.randn(N, 5)  # Second set of points\n",
    "\n",
    "# Compute projected vectors\n",
    "V_proj = project_on_gradient_plane(X1, X2, gradient_u0)\n",
    "\n",
    "print(\"Projected Vectors:\\n\", V_proj)\n",
    "print(\"Original Vectors:\\n\", X2 - X1)\n",
    "print(\"X2:\\n\", X2)\n",
    "print(\"X1:\\n\", X1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_parameters(x, y, rng):\n",
    "        \"\"\"\n",
    "        Sample directions from points to other points in the given dataset (x, y).\n",
    "        \"\"\"\n",
    "\n",
    "        # n_repetitions repeats the sampling procedure to find better directions.\n",
    "        # If we require more samples than data points, the repetitions will cause more pairs to be drawn.\n",
    "        n_repetitions = max(1, int(np.ceil(width/ x.shape[0]))) * 1\n",
    "\n",
    "        # This guarantees that:\n",
    "        # (a) we draw from all the N(N-1)/2 - N possible pairs (minus the exact idx_from=idx_to case)\n",
    "        # (b) no indices appear twice at the same position (never idx0[k]==idx1[k] for all k)\n",
    "        candidates_idx_from = rng.integers(low=0, high=x.shape[0], size=x.shape[0] * n_repetitions)\n",
    "        delta = rng.integers(low=1, high=x.shape[0], size=candidates_idx_from.shape[0])\n",
    "        candidates_idx_to = (candidates_idx_from + delta) % x.shape[0]\n",
    "        #directions = x[candidates_idx_to, ...] - x[candidates_idx_from, ...]\n",
    "        #directions[:, 2:] = np.zeros((np.shape(directions)[0], d-2))\n",
    "\n",
    "        ##########################################################\n",
    "\n",
    "        # directions = x[candidates_idx_to, ...] - x[candidates_idx_from, ...]\n",
    "        X1 = x[candidates_idx_from, ...]\n",
    "        X2 = x[candidates_idx_to, ...]\n",
    "        #directions = X2 - X1\n",
    "        #print('directions 1\\n: ', X2 - X1)\n",
    "        directions = project_on_gradient_plane(X1, X2, gradient_u0)\n",
    "        #print('directions 2\\n: ', directions)\n",
    "        ##########################################################\n",
    "\n",
    "        # Uncomment the following line to project the difference vectors on the 2-d space (knowing that the underlying function is 2-d)\n",
    "        # directions[:, 2:] = np.zeros((np.shape(directions)[0], d-2))\n",
    "        \n",
    "        dists = np.linalg.norm(directions, axis=1, keepdims=True)\n",
    "        dists = np.clip(dists, a_min=1e-10, a_max=None)\n",
    "        directions = directions / dists\n",
    "        # print('directions 2\\n: ', directions)\n",
    "\n",
    "        # TODO: Project the direction onto the gradient of the function.\n",
    "        dy = y[candidates_idx_to, :] - y[candidates_idx_from, :]\n",
    "\n",
    "        # We always sample with replacement to avoid forcing to sample low densities\n",
    "        probabilities = weight_probabilities(dy, dists)\n",
    "        selected_idx = rng.choice(dists.shape[0], size=width, replace=True, p=probabilities)\n",
    "\n",
    "        directions = directions[selected_idx]\n",
    "        dists = dists[selected_idx]\n",
    "        idx_from = candidates_idx_from[selected_idx]\n",
    "        idx_to = candidates_idx_to[selected_idx]\n",
    "        \n",
    "        return directions, dists, idx_from, idx_to\n",
    "\n",
    "def weight_probabilities(dy, dists, sample_uniformly=False):\n",
    "        \"\"\"Compute probability that a certain weight should be chosen as part of the network.\n",
    "        This method computes all probabilities at once, without removing the new weights one by one.\n",
    "\n",
    "        Args:\n",
    "            dy: function difference\n",
    "            dists: distance between the base points\n",
    "            rng: random number generator\n",
    "\n",
    "        Returns:\n",
    "            probabilities: probabilities for the weights.\n",
    "        \"\"\"\n",
    "        # compute the maximum over all changes in all y directions to sample good gradients for all outputs\n",
    "        gradients = (np.max(np.abs(dy), axis=1, keepdims=True) / dists).ravel()\n",
    "\n",
    "        if sample_uniformly or np.sum(gradients) < 1e-10:\n",
    "            # When all gradients are small, avoind dividing by a small number\n",
    "            # and default to uniform distribution.\n",
    "            probabilities = np.ones_like(gradients) / len(gradients)\n",
    "        else:\n",
    "            probabilities = gradients / np.sum(gradients)\n",
    "\n",
    "        return probabilities\n",
    "\n",
    "\n",
    "def sample_parameters_tanh(x, y, rng):\n",
    "        scale = 0.5 * (np.log(1 + 1/2) - np.log(1 - 1/2))\n",
    "\n",
    "        directions, dists, idx_from, idx_to = sample_parameters(x, y, rng)\n",
    "        weights = (2 * scale * directions / dists).T\n",
    "        biases = -np.sum(x[idx_from, :] * weights.T, axis=-1).reshape(1, -1) - scale\n",
    "\n",
    "        return weights, biases, idx_from, idx_to"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_parameters_randomly(x, _, rng):\n",
    "    #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))\n",
    "    weights = rng.normal(loc=0, scale=1, size=(x.shape[1], width))\n",
    "    biases = rng.uniform(low=-2 * np.pi, high=2 * np.pi, size=(1, width))\n",
    "\n",
    "    #r_m = 0.05\n",
    "    #weights = rng.uniform(low=-2, high=2, size=(x.shape[1], width)) # low=-1.*r_m, high=r_m\n",
    "    #biases = rng.uniform(low=-1, high=1, size=(1, width)) # low=-np.pi, high=np.pi,  2 * np.pi\n",
    "    idx0 = None\n",
    "    idx1 = None\n",
    "    return weights, biases, idx0, idx1\n",
    "seeds = [1]\n",
    "experiments = []\n",
    "widths = [400] #[2000, 3000, 4000, 5000, 8000] #\n",
    "reg_consts = [1e-10] #[1e-8, 1e-10]\n",
    "svd_on = True\n",
    "\n",
    "info = []\n",
    "# Define prob. distribution for (re)-sampling collocation points\n",
    "def collocation_points_probabilities(df_dx):\n",
    "    gradients = np.abs(df_dx)\n",
    "    gradients = gradients + 0.01 * np.max(gradients)\n",
    "    return gradients/np.sum(gradients)\n",
    "\n",
    "def sample_parameters_randomly(x, _, rng):\n",
    "        #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))\n",
    "        r_m = 0.05\n",
    "        #weights = rng.uniform(low=-np.pi, high=np.pi, size=(x.shape[1], width)) # ,  2 * np.pi\n",
    "        weights = rng.uniform(low=0, high=1, size=(x.shape[1], width)) # ,  2 * np.pi\n",
    "        biases = rng.uniform(low=0, high=1, size=(1, width)) # low=-np.pi, high=np.pi,  2 * np.pi\n",
    "        idx0 = None\n",
    "        idx1 = None\n",
    "        return weights, biases, idx0, idx1\n",
    "\n",
    "param_samplers = [sample_parameters_tanh] #, sample_parameters_randomly, 'random',  sample_parameters_randomly\n",
    "                    #\"tanh\",\n",
    "for width in widths:\n",
    "    for reg_const in reg_consts:\n",
    "        svd_cutoff = reg_const\n",
    "        rtol = 1e6 * reg_const\n",
    "        atol = 1e6 * reg_const\n",
    "        for param_sampler in param_samplers:\n",
    "            j = 0\n",
    "            rmse_elm = np.ones((len(seeds), ))\n",
    "            rel_err_elm = np.ones((len(seeds)))\n",
    "            rmse_elm_train = np.ones((len(seeds)))\n",
    "            rel_err_elm_train = np.ones((len(seeds)))\n",
    "            rmse_elm_train_b = np.ones((len(seeds)))\n",
    "            rel_err_elm_train_b= np.ones((len(seeds)))\n",
    "            rmse_elm_test_b = np.ones((len(seeds)))\n",
    "            rel_err_elm_test_b= np.ones((len(seeds)))\n",
    "            time_elm = np.ones((len(seeds)))\n",
    "            for seed in seeds:\n",
    "                # Set seeds\n",
    "                np.random.seed(2)\n",
    "                rng = np.random.default_rng(seed=123)\n",
    "                # Parameter sampler for ELM: Sample weights from a normal distribution and biases uniformly from [-4, 4]\n",
    "\n",
    "\n",
    "                ansatz_elm = BasicAnsatz(\n",
    "                    n_neurons=width,\n",
    "                    activation=\"tanh\",\n",
    "                    random_state=seed,\n",
    "                    regularization_scale=reg_const,\n",
    "                    parameter_sampler = param_sampler # sample_parameters_randomly # 'random' # sample_parameters_randomly\n",
    "                )  \n",
    "                # Interior points\n",
    "                normal_vectors = X_b_train.copy()\n",
    "                \n",
    "                # Domain\n",
    "                domain = Domain(\n",
    "                    interior_points=X_int_train,\n",
    "                    boundary_points=X_b_train,\n",
    "                    normal_vectors=normal_vectors,\n",
    "                    sample_points = X_int_train\n",
    "                )\n",
    "                \n",
    "                reaction_diffusion_solver_elm = Reaction_Diffusion_Solver(\n",
    "                    domain=domain, \n",
    "                    ansatz=ansatz_elm,\n",
    "                    u0=u0,\n",
    "                    boundary_condition=boundary_condition,\n",
    "                    forcing=forcing,\n",
    "                    regularization_scale=reg_const,\n",
    "                    scale_boundary_correction=10000.,\n",
    "                    boundary_condition_true=analytical_sol\n",
    "                )\n",
    "                # Compute weights and biases of the elm network\n",
    "                time_blocks = 1\n",
    "                # ic_eval = u0(domain.interior_points)\n",
    "                ic_eval = u0(domain.all_points)\n",
    "                t_elm_start = time.time()\n",
    "                \n",
    "                sol_elm, solver_status_elm = reaction_diffusion_solver_elm.fit(t_span=[0, np.max(t_eval)], \n",
    "                                                        rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,\n",
    "                                                        outer_basis=False,\n",
    "                                                        init_cond=ic_eval,\n",
    "                                                        svd_on=svd_on);\n",
    "                \"\"\"\n",
    "                sol_swim, solver_status_swim = reaction_diffusion_solver_elm.fit_time_blocks(t_span=[0, np.max(t_eval)], \n",
    "                                                    rtol=rtol, atol=atol, \n",
    "                                                    svd_cutoff=svd_cutoff, \n",
    "                                                    time_blocks=time_blocks, \n",
    "                                                    prob_distr_resampling = collocation_points_probabilities, \n",
    "                                                    init_cond=ic_eval,\n",
    "                                                    n_col=n_int_train, outer_basis=False,\n",
    "                                                    #svd_on=svd_on\n",
    "                                                    );\n",
    "                \"\"\"\n",
    "                t_elm_stop = time.time()\n",
    "                time_elm[j] = t_elm_stop - t_elm_start\n",
    "\n",
    "                # Evaluate on test data\n",
    "                \"\"\"\n",
    "                u_elm_test = (reaction_diffusion_solver_elm.evaluate_blocks(x_eval= X_test, t_eval = t_eval, \n",
    "                                                            time_blocks = time_blocks, \n",
    "                                                            solver_status = solver_status_swim,\n",
    "                                                            #svd_on=svd_on\n",
    "                                                            )).T\n",
    "                u_elm_train = (reaction_diffusion_solver_elm.evaluate_blocks(x_eval= x_train, t_eval = t_eval, \n",
    "                                                            time_blocks = time_blocks, \n",
    "                                                            solver_status = solver_status_swim,\n",
    "                                                            #svd_on=svd_on\n",
    "                                                            )).T\n",
    "                \"\"\"\n",
    "                u_elm_test = reaction_diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval_test, svd_on = svd_on).T #, solver_status=solver_status\n",
    "                u_elm_train = reaction_diffusion_solver_elm.evaluate(x_eval=x_train, t_eval = t_eval, svd_on = svd_on).T #, solver_status=solver_status\n",
    "                u_elm_boundary_train = reaction_diffusion_solver_elm.evaluate(x_eval=X_b_train, t_eval = t_eval_test, svd_on = svd_on).T #, solver_status=solver_status\n",
    "                u_elm_boundary_test = reaction_diffusion_solver_elm.evaluate(x_eval=X_b_test, t_eval = t_eval_test, svd_on = svd_on).T #, solver_status=solver_status\n",
    "                            \n",
    "                # Compute metrics\n",
    "                rmse_elm[j] = np.sqrt(mean_squared_error(u_true_test_final_time_point, u_elm_test))  # mean squared error\n",
    "                rel_err_elm[j] = rmse_elm[j]/np.sqrt(mean_squared_error(u_true_test_final_time_point, np.zeros_like(u_true_test_final_time_point)))\n",
    "                \n",
    "                rmse_elm_train[j] = np.sqrt(mean_squared_error(u_true_int_train, u_elm_train))  # mean squared error\n",
    "                rel_err_elm_train[j] = rmse_elm_train[j]/np.sqrt(mean_squared_error(u_true_int_train, np.zeros_like(u_true_int_train)))\n",
    "\n",
    "                rmse_elm_train_b[j] = np.sqrt(mean_squared_error(u_true_bdry_train, u_elm_boundary_train))  # mean squared error\n",
    "                rel_err_elm_train_b[j] = rmse_elm_train_b[j]/np.sqrt(mean_squared_error(u_true_bdry_train, np.zeros_like(u_true_bdry_train)))\n",
    "\n",
    "                rmse_elm_test_b[j] = np.sqrt(mean_squared_error(u_true_b_test, u_elm_boundary_test))  # mean squared error\n",
    "                rel_err_elm_test_b[j] = rmse_elm_test_b[j]/np.sqrt(mean_squared_error(u_true_b_test, np.zeros_like(u_true_b_test)))\n",
    "\n",
    "                # Compute metrics\n",
    "                info.append(time_elm[j])\n",
    "                info.append(rmse_elm[j])\n",
    "                print('time=', time_elm[j], 'rmse_elm=', rmse_elm[j], 'rel_err_elm=',rel_err_elm[j])\n",
    "                j += 1\n",
    "\n",
    "            # Train \n",
    "            print('-------------------------------------------------------------------------')\n",
    "            print('Width: ', width, 'param_sampler: ', param_sampler, 'reg_const', reg_const, 'atol', atol)\n",
    "            print('-------------------------------------------------------------------------')\n",
    "            print('Train: elm-ode time = ', np.mean(time_elm))\n",
    "            print('Train: rmse elm-ode = ',np.mean(rmse_elm_train), '+-', np.std(rmse_elm_train))\n",
    "            print('Train: rel l-2 error elm-ode = ',np.mean(rel_err_elm_train), '+-', np.std(rel_err_elm_train))\n",
    "            print('Train: rel l-2 error elm-ode (boundary) = ',np.mean(rel_err_elm_train_b), '+-', np.std(rel_err_elm_train_b))\n",
    "\n",
    "            # Test\n",
    "            print('Test: rmse elm-ode = ',np.mean(rmse_elm), '+-', np.std(rmse_elm))\n",
    "            print('Test: rel l-2 error elm-ode = ',np.mean(rel_err_elm), '+-', np.std(rel_err_elm))\n",
    "            print('Test: rel l-2 error elm-ode (boundary) = ',np.mean(rel_err_elm_test_b), '+-', np.std(rel_err_elm_test_b))\n",
    "            print('-------------------------------------------------------------------------\\n')\n",
    "            experiments.append(info)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('width of the output layer = ', np.shape(reaction_diffusion_solver_elm._get_c0(outer_basis=False).reshape(-1))[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the ELM-ODE solution\n",
    "plot(x=x_train, u=u_elm_train, timesteps=timesteps, \n",
    "           title='SWIM-ODE solution',cmap_offset=0.,marker_size=5.0,extent=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the absolute error with SWIM-ODE\n",
    "timesteps = [0, 29, 59, 99]\n",
    "plot_error(x=x_train, u_true=u_true_int_train, u_nn=u_elm_train, timesteps=timesteps, \n",
    "           figsize=(7,3), fontsize=14, \n",
    "           title='Absolute error: SWIM-ODE',savefig=True, \n",
    "           figname='nrd_swim_ode_error.png',marker_size=5.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the absolute error with SWIM-ODE\n",
    "timesteps = [0, 29, 59, 99]\n",
    "u_true =  analytical_sol(X_test, t_eval)\n",
    "u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))\n",
    "u_elm_test = reaction_diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval, svd_on = svd_on).T #, solver_status=solver_status\n",
    "\n",
    "plot_error(x=X_test, u_true=u_true, u_nn=u_elm_test, timesteps=timesteps, \n",
    "           figsize=(7,3), fontsize=14, \n",
    "           title='Absolute error: SWIM-ODE',savefig=True, \n",
    "           figname='nrd_swim_ode_error.png',marker_size=5.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
