{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solution Error vs. Advection Coefficient for SWIM and ELM Without Space-Time Separation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "sys.path.append('../../src')\n",
    "from swimpde import Domain\n",
    "from swimpde import BasicAnsatz\n",
    "from swimpde import AdvectionSolver\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import time\n",
    "from scipy.stats import qmc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Problem setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create equally spaced points in a square\n",
    "x_lim = [0, 2 * np.pi] # space domain\n",
    "t_lim = [0, 1] # time domain\n",
    "rng = np.random.default_rng(seed=123)\n",
    "\n",
    "# initial condition\n",
    "def u0(x):\n",
    "    return np.sin(x)\n",
    "\n",
    "# boundary condition: Loss term that tries to satisfy the solution and it's first derivative zero at the boundary\n",
    "boundary_condition = \"periodic strict\" \n",
    "\n",
    "# forcing\n",
    "def forcing(x, t):\n",
    "    return np.zeros(x.shape[0])\n",
    "\n",
    "# Boundary and initial data\n",
    "n_initial = 200 # No. of initial points where solution is known\n",
    "n_boundary = 200 # No. of boundary points where solution is known\n",
    "rng = np.random.default_rng(seed=123)\n",
    "n_edges = 4 # edges of the domain\n",
    "pps = int(n_boundary/n_edges) # points per surface \n",
    "\n",
    "# Initial data\n",
    "x_0 = qmc.LatinHypercube(d=1, seed=221).random(n=n_initial)\n",
    "x_0 = qmc.scale(x_0, 0, 2 * np.pi)\n",
    "X_init = np.hstack((x_0, np.zeros_like(x_0)))\n",
    "u_init = u0(x_0)\n",
    "data_init = np.column_stack((X_init, u_init.reshape(-1, 1)))\n",
    "\n",
    "# Boundary data\n",
    "y_0 = qmc.LatinHypercube(d=1, seed=3).random(n=pps)\n",
    "y_0 = qmc.scale(y_0, 0, 1)\n",
    "X_left = np.hstack((np.zeros_like(y_0), y_0))\n",
    "X_right = np.hstack((np.ones((np.shape(y_0))) * 2 * np.pi, y_0))\n",
    "\n",
    "# Visualiza initial and boundary data\n",
    "fig, ax = plt.subplots(1,1,figsize=(3, 3))\n",
    "ax.scatter(X_left[:, 0], X_left[:, 1], label='left boundary')\n",
    "ax.scatter(X_right[:, 0], X_right[:, 1], label='right boundary')\n",
    "ax.scatter(X_init[:, 0], X_init[:, 1], label='initial condition')\n",
    "ax.set_xlabel(\"x\")\n",
    "ax.set_ylabel(\"t\")\n",
    "ax.set_title(\"Initial and boundary data\")\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "\n",
    "# Test data : \n",
    "n_test_space = 256\n",
    "n_test_time = 100\n",
    "x_test = np.linspace(x_lim[0], x_lim[1], n_test_space)\n",
    "t_test = np.linspace(t_lim[0], t_lim[1], n_test_time)\n",
    "xx_test, yy_test = np.meshgrid(x_test, t_test)\n",
    "X_test = np.hstack((xx_test.reshape(-1, 1), yy_test.reshape(-1, 1))) # Uniform test points\n",
    "print(np.shape(X_test))\n",
    "\n",
    "# coordinates of boundary points (excluding corners)\n",
    "boundary_points = np.row_stack([X_left, X_right])\n",
    "\n",
    "# Analytical solution of the PDE\n",
    "def analytical_sol(x, t, c):\n",
    "    return np.sin(x - c * t)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameter run\n",
    "conv_coeffs = [1e-2, 1e-1, 1, 1e1, 40, 100] # Values of convection coefficients\n",
    "n_s = 1000 # No. of collocation points\n",
    "n_s_2 = 8000 # No. of collocation points for beta = 40, 100\n",
    "ratio = 2 # data points/width\n",
    "reg = 1e-10 # regularization constants\n",
    "seeds = [1, 2, 3] # random seeds\n",
    "\n",
    "experiments = []\n",
    "for conv_coeff in conv_coeffs:\n",
    "    # Ground truth for the selected convection coefficient\n",
    "    u_true_test = analytical_sol(xx_test, yy_test, conv_coeff).reshape(-1, 1)\n",
    "\n",
    "    # Collocation points\n",
    "    if conv_coeff < 5:\n",
    "        n_c = n_s\n",
    "    else:\n",
    "        n_c = n_s_2\n",
    "\n",
    "    # Collocation points\n",
    "    n_train = int(np.sqrt(n_c))\n",
    "    sampler_2d_m = qmc.LatinHypercube(d=2, seed=3)\n",
    "    xy_space_train = sampler_2d_m.random(n=n_train)\n",
    "    xy_space_scaled_train = qmc.scale(xy_space_train, [0, 0], [2 * np.pi, 1])\n",
    "    xx_train, yy_train = np.meshgrid(xy_space_scaled_train[:, 0], xy_space_scaled_train[:, 1])\n",
    "    X_train = np.hstack((xx_train.reshape(-1, 1), yy_train.reshape(-1, 1))) # Uniform test points\n",
    "    \n",
    "    # interior points\n",
    "    interior_points = X_train\n",
    "    \n",
    "    # Domain\n",
    "    domain = Domain(\n",
    "        interior_points=interior_points,\n",
    "        boundary_points=boundary_points,\n",
    "    )\n",
    "    # Loop over hyper-params\n",
    "    info = []\n",
    "    info.append(conv_coeff)\n",
    "    info.append(n_c)\n",
    "    n_b = int(n_c//ratio)\n",
    "    info.append(n_b)\n",
    "    info.append(reg)\n",
    "\n",
    "    # Loop over different seeds\n",
    "    rmse_swim = np.ones((len(seeds), ))\n",
    "    rmse_elm = np.ones((len(seeds), ))\n",
    "    rel_err_swim = np.ones((len(seeds), ))\n",
    "    rel_err_elm = np.ones((len(seeds), ))\n",
    "    time_swim = np.ones((len(seeds), ))\n",
    "    time_elm = np.ones((len(seeds), ))\n",
    "    j = 0\n",
    "    for s in seeds:\n",
    "        ansatz_elm = BasicAnsatz(\n",
    "            activation='tanh',\n",
    "            n_neurons=n_b,\n",
    "            random_state=s,\n",
    "            regularization_scale=reg,\n",
    "            parameter_sampler=\"random\"\n",
    "        )\n",
    "        ansatz_swim = BasicAnsatz(\n",
    "            activation='tanh',\n",
    "            n_neurons=n_b,\n",
    "            random_state=s,\n",
    "            regularization_scale=reg,\n",
    "            parameter_sampler=\"tanh\"\n",
    "        )\n",
    "        adv_solver_elm = AdvectionSolver(\n",
    "            domain=domain, \n",
    "            ansatz=ansatz_elm,\n",
    "            boundary_condition=boundary_condition,\n",
    "            c=conv_coeff,\n",
    "            forcing=forcing,\n",
    "            regularization_scale=reg,\n",
    "            u0=u0\n",
    "        )\n",
    "        adv_solver_swim = AdvectionSolver(\n",
    "            domain=domain, \n",
    "            ansatz=ansatz_swim,\n",
    "            boundary_condition=boundary_condition,\n",
    "            c=conv_coeff,\n",
    "            forcing=forcing,\n",
    "            regularization_scale=reg,\n",
    "            u0=u0\n",
    "        )\n",
    "        # ELM (No-ODE)\n",
    "        t_elm_start = time.time()\n",
    "        adv_solver_elm.fit_no_ode_periodic_bc(data_init=data_init, svd_cutoff = 1e-12); #, num_components=min(1000, int(n_b//2) - 1)\n",
    "        t_elm_stop = time.time()\n",
    "        time_elm[s-1] = t_elm_stop - t_elm_start\n",
    "\n",
    "        # SWIM (No-ODE)\n",
    "        t_swim_start = time.time()\n",
    "        adv_solver_swim.fit_no_ode_periodic_bc(data_init=data_init, svd_cutoff = 1e-12); #, num_components=min(1000, int(n_b//2) - 1)\n",
    "        t_swim_stop = time.time()\n",
    "        time_swim[s-1] = t_swim_stop - t_swim_start\n",
    "        \n",
    "        # Evaluate on test data\n",
    "        u_swim_test = adv_solver_swim.evaluate_no_ode(X_test).T\n",
    "        u_elm_test = adv_solver_elm.evaluate_no_ode(X_test).T\n",
    "        \n",
    "        # Compute metrics\n",
    "        rmse_elm[j] = np.sqrt(mean_squared_error(u_true_test, u_elm_test))\n",
    "        rmse_swim[j] = np.sqrt(mean_squared_error(u_true_test, u_swim_test))  # mean squared error\n",
    "        rel_err_elm[j] = rmse_elm[j]/np.sqrt(mean_squared_error(u_true_test, np.zeros_like(u_true_test)))\n",
    "        rel_err_swim[j] = rmse_swim[j]/np.sqrt(mean_squared_error(u_true_test, np.zeros_like(u_true_test)))\n",
    "        j += 1\n",
    "\n",
    "    info.append(np.mean(time_elm))\n",
    "    info.append(np.mean(time_swim))\n",
    "    info.append(np.mean(rmse_elm))\n",
    "    info.append(np.std(rmse_elm))\n",
    "    info.append(np.mean(rmse_swim))\n",
    "    info.append(np.std(rmse_swim))\n",
    "    info.append(np.mean(rel_err_elm))\n",
    "    info.append(np.std(rel_err_elm))\n",
    "    info.append(np.mean(rel_err_swim))\n",
    "    info.append(np.std(rel_err_swim))\n",
    "\n",
    "    #print(info)\n",
    "    experiments.append(info)            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convergence plots\n",
    "fontsize = 12\n",
    "res = np.vstack(experiments)\n",
    "\n",
    "# For beta = 40 \n",
    "print('Information for beta = 40:')\n",
    "print('time_swim = ', res[-2, -9], 'time_elm = ', res[-2, -10])\n",
    "print('rmse elm = ', res[-2, -8], '+-', res[-2, -7])\n",
    "print('rmse swim = ', res[-2, -6], '+-', res[-2, -5])\n",
    "print('rel error elm = ', res[-2, -4], '+-', res[-2, -3])\n",
    "print('rel error swim = ', res[-2, -2], '+-', res[-2, -1])\n",
    "\n",
    "# Extract errors\n",
    "conv_coeffs = np.reshape(np.asarray(conv_coeffs), (-1, 1))\n",
    "rmse_elm = np.reshape(res[:, -4], (-1, 1))\n",
    "rmse_swim = np.reshape(res[:, -2], (-1, 1))\n",
    "rel_l2_elm = np.reshape(res[:, -8], (-1, 1))\n",
    "rel_l2_swim = np.reshape(res[:, -6], (-1, 1))\n",
    "\n",
    "# Store the errors for different values of beta\n",
    "rmse = np.hstack((conv_coeffs, rmse_elm, rmse_swim))\n",
    "rel_l2 = np.hstack((conv_coeffs, rel_l2_elm, rel_l2_swim))\n",
    "\n",
    "# Store rmse error values for convergence plots\n",
    "with open('adv_swim_elm_rmse.npy', 'wb') as f:\n",
    "    np.save(f, rmse)\n",
    "\n",
    "# Store rel l2 error values for convergence plots\n",
    "with open('adv_swim_elm_rel_l2.npy', 'wb') as f:\n",
    "    np.save(f, rel_l2)\n",
    "\n",
    "# Visualiza errors vs convection coefficient\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "ax.loglog(conv_coeffs, res[:, -8], label='Frozen-PINN-elm')\n",
    "ax.loglog(conv_coeffs, res[:, -6], label='Frozen-PINN-swim')\n",
    "plt.xlabel(r'$\\beta$')\n",
    "plt.ylabel(r'Relative  $\\mathbb{L}_{2}$ error')\n",
    "plt.tick_params(axis='both', labelsize=fontsize)\n",
    "ax.legend()\n",
    "fig.tight_layout()\n",
    "#plt.savefig('advection_swim_elm.pdf', bbox_inches='tight')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
