{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solve viscous Burgers equation with Frozen-pinn-swim:\n",
    "In this script we show how Frozen-pinn-swim uses resampling of collocation points and data-dependant sampling of weights and biases to solve Burgers equation accurately"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../')\n",
    "sys.path.append('../../../../src')\n",
    "from swimpde import Domain\n",
    "from swimpde import BasicAnsatz\n",
    "from swimpde import BurgersSolver\n",
    "import numpy as np\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from matplotlib import ticker\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['image.cmap'] = 'jet'\n",
    "import scipy\n",
    "from scipy.stats import norm\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data and reference solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and visualize data\n",
    "data = scipy.io.loadmat('../../../../data/burgers_shock.mat')\n",
    "t_eval = data['t'].flatten()[:,None]\n",
    "x_eval = data['x'].flatten()[:,None]\n",
    "u_exact = np.real(data['usol']).T\n",
    "X, T = np.meshgrid(x_eval,t_eval)\n",
    "X_ = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))\n",
    "\n",
    "# Set ground truth\n",
    "u_true = u_exact.flatten()[:,None]              "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initial condition\n",
    "def u0(x):\n",
    "    return -1 * np.sin(np.pi * x)\n",
    "\n",
    "# forcing\n",
    "def forcing(x, t):\n",
    "    return np.zeros(x.shape[0])\n",
    "\n",
    "# boundary condition\n",
    "boundary_condition = \"zero dirichlet\" # \n",
    "\n",
    "# Domain information and spacial points for the first time-block\n",
    "n_points_1d = 4000 # No. of points in space\n",
    "x_lim = [-1, 1] # Domain range\n",
    "\n",
    "# Interior points\n",
    "rng = np.random.default_rng(seed=123)\n",
    "x_space = rng.uniform(x_lim[0], x_lim[1], n_points_1d).reshape((-1, 1)) \n",
    "x_space_inner = x_space[1:-1]\n",
    "interior_points = x_space_inner\n",
    "\n",
    "# Boundary points (excluding corners)\n",
    "left = x_lim[0]\n",
    "right = x_lim[1]\n",
    "boundary_points = np.row_stack([left, right])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train and evaluate Frozen-pinn-swim network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameters\n",
    "n_sample = 3000 # No. of sampling points (for computing gradients)\n",
    "n_col = 800 # No. of collocation points (to be re-sampled)\n",
    "width = 450 # Width\n",
    "reg_const = 1e-7 # Regularization constant\n",
    "svd_cutoff = 1e-9 # SVD threshold\n",
    "seeds = [3] # Seeds (to compute mean errors)\n",
    "time_blocks = 3 # Number of time-blocks for smapling\n",
    "info = [] # List to store errors and time measurements\n",
    "\n",
    "# Define prob. distribution for (re)-sampling collocation points\n",
    "def collocation_points_probabilities(df_dx):\n",
    "    gradients = np.abs(df_dx)\n",
    "    gradients = gradients + 0.01 * np.max(gradients)\n",
    "    return gradients/np.sum(gradients)\n",
    "\n",
    "# Points where gradient of the solution at the end of a time-block is computed\n",
    "sample_test_points = np.sort(rng.uniform(x_lim[0] + 1e-4, x_lim[1] - 1e-4, n_sample)).reshape((-1, 1)) # This does not include boundary points\n",
    "\n",
    "# Domain\n",
    "domain = Domain(\n",
    "    interior_points=interior_points,\n",
    "    boundary_points=boundary_points,\n",
    "    sample_points = sample_test_points\n",
    ")\n",
    "\n",
    "# Store errors and time required for each seed\n",
    "rel_err_swim = np.ones((len(seeds), ))\n",
    "time_swim = np.ones((len(seeds), ))\n",
    "rmse_swim = np.ones((len(seeds), ))\n",
    "j = 0\n",
    "for seed in seeds: # Run over 3 seeds                                    \n",
    "    # SWIM network ansatz\n",
    "    ansatz_swim = BasicAnsatz(\n",
    "        n_neurons=width,\n",
    "        activation=\"tanh\",\n",
    "        random_state=seed,\n",
    "        regularization_scale=reg_const,\n",
    "        parameter_sampler = \"tanh\" \n",
    "    )\n",
    "    # Burgers equation solver\n",
    "    burgers_solver_swim = BurgersSolver(\n",
    "        domain=domain, \n",
    "        ansatz=ansatz_swim,\n",
    "        u0=u0,\n",
    "        boundary_condition=boundary_condition,\n",
    "        forcing=forcing,\n",
    "        regularization_scale=reg_const,\n",
    "        c=(0.01/np.pi)\n",
    "    )\n",
    "    # SWIM fit\n",
    "    t_swim_start = time.time()\n",
    "    sol_swim, solver_status_swim = burgers_solver_swim.fit_time_blocks(t_span=[0, np.max(t_eval)], rtol=1e-8, atol=1e-8, svd_cutoff=svd_cutoff, time_blocks=time_blocks, prob_distr_resampling = collocation_points_probabilities, n_col=n_col, outer_basis=False);\n",
    "    t_swim_stop = time.time()\n",
    "    time_swim[j] = t_swim_stop - t_swim_start\n",
    "    # Evaluate SWIM\n",
    "    u_swim = (burgers_solver_swim.evaluate_blocks(x_eval= x_eval, t_eval = t_eval, time_blocks = time_blocks, solver_status = solver_status_swim)).T\n",
    "    # Compute metrics\n",
    "    mse_swim = mean_squared_error(u_true, u_swim.flatten()[:,None])  # mean squared error\n",
    "    rmse_swim[j] = np.sqrt(mse_swim)  # Root Mean Squared Error                  \n",
    "    rel_err_swim[j] = np.linalg.norm(u_true-u_swim.flatten()[:,None],2)/np.linalg.norm(u_true,2) # Relative l-2 error\n",
    "    print(\"rmse_swim, re_swim\")\n",
    "    print(rmse_swim[j], rel_err_swim[j])\n",
    "    j += 1\n",
    "\n",
    "# RMSE and relative l-2 errors (mean and std) and time measurements\n",
    "info.append(np.mean(time_swim))\n",
    "info.append(np.mean(rmse_swim))\n",
    "info.append(np.std(rmse_swim))\n",
    "info.append(np.mean(rel_err_swim))\n",
    "info.append(np.std(rel_err_swim))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_block_id = 2\n",
    "lw = 2\n",
    "t_span=[0, np.max(t_eval)]\n",
    "t_block_size = (t_span[-1] - t_span[0])/time_blocks\n",
    "plot_t_block = [plot_block_id * t_block_size, (plot_block_id+1) * t_block_size]\n",
    "\n",
    "plot_matrix_tanh = burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(x_eval)\n",
    "id_from = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_from\n",
    "id_to = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_to\n",
    "sample_points = burgers_solver_swim.domain.interior_points\n",
    "\n",
    "\n",
    "def evalutate_plot_value(point, plot_block_id):\n",
    "    return burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(point).ravel()\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1,2,figsize=(6,4))\n",
    "\n",
    "ax[0].set_title(\"SWIM basis functions\")\n",
    "ax[0].plot(x_eval, u_swim[(t_eval==plot_t_block[0]).ravel(),:].ravel(), '-', \n",
    "           color=\"black\", linewidth=3, label=\"Burgers solution\")\n",
    "\n",
    "id_list = [\n",
    "    4,\n",
    "    121,\n",
    "    41,\n",
    "    29\n",
    "]\n",
    "for id in id_list:\n",
    "    sol_diff = evalutate_plot_value(sample_points[id_from[id]], plot_block_id)[id] - evalutate_plot_value(sample_points[id_to[id]], plot_block_id)[id]\n",
    "    point_diff = sample_points[id_to[id]] - sample_points[id_from[id]]\n",
    "    print(sol_diff/point_diff)\n",
    "    ax[0].plot(x_eval, plot_matrix_tanh[:, id], \n",
    "               label=r\"Gradient $\\approx$\" + f\"{np.float64(sol_diff/point_diff):.2f}\",\n",
    "               linestyle='--', linewidth=lw)\n",
    "    \n",
    "    #ax[0].scatter([sample_points[id_to[id]], sample_points[id_from[id]]], [0, 0])\n",
    "    #middle_point = (sample_points[id_from[id]]+sample_points[id_to[id]])/2\n",
    "    #print(\"middle point:\", middle_point)\n",
    "    #print(\"middle value\", burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(middle_point).ravel()[id])\n",
    "\n",
    "for i in range(0, plot_matrix_tanh.shape[1]):\n",
    "     point_value = (plot_matrix_tanh[(np.abs(x_eval-0.1)<2e-3).ravel(),i]).astype(np.float64)\n",
    "     if (1-np.abs(point_value) > 0.01) and (1-np.abs(point_value) < 0.1):\n",
    "         print(\"value =\", point_value, \", id =\", i)\n",
    "\n",
    "\n",
    "ax[1].set_title(\"ELM basis functions\")\n",
    "ax[1].plot(x_eval, u_swim[(t_eval==plot_t_block[0]).ravel(),:].ravel(), '-', \n",
    "           color=\"black\", linewidth=3, label=\"Burgers solution\")\n",
    "\n",
    "plot_block_id = 0\n",
    "plot_matrix_tanh = burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(x_eval)\n",
    "id_from = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_from\n",
    "id_to = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_to\n",
    "sample_points = x_space[1:-1]\n",
    "\n",
    "id_list = [\n",
    "    0,\n",
    "    2,\n",
    "    4,\n",
    "    5\n",
    "]\n",
    "for id in id_list:\n",
    "    sol_diff = evalutate_plot_value(sample_points[id_from[id]], plot_block_id)[id] - evalutate_plot_value(sample_points[id_to[id]], plot_block_id)[id]\n",
    "    point_diff = sample_points[id_to[id]] - sample_points[id_from[id]]\n",
    "    print(sol_diff/point_diff)\n",
    "    ax[1].plot(x_eval, plot_matrix_tanh[:, id], \n",
    "               label=r\"Gradient $\\approx$\" + f\"{np.float64(sol_diff/point_diff):.2f}\",\n",
    "               linestyle='--', linewidth=lw)\n",
    "    \n",
    "    # ax[1].scatter([sample_points[id_to[id]], sample_points[id_from[id]]], [0, 0])\n",
    "    # middle_point = (sample_points[id_from[id]]+sample_points[id_to[id]])/2\n",
    "    # print(\"middle point:\", middle_point)\n",
    "    # print(\"middle value\", burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(middle_point).ravel()[id])\n",
    "\n",
    "ax[0].legend(loc=\"upper center\", bbox_to_anchor=(0.5,-0.15))\n",
    "ax[1].legend(loc=\"upper center\", bbox_to_anchor=(0.5,-0.15))\n",
    "fig.tight_layout()\n",
    "plt.savefig('burgers_basis_functions.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_block_id = 2\n",
    "lw = 2\n",
    "fontsize = 14\n",
    "t_span = [0, np.max(t_eval)]\n",
    "t_block_size = (t_span[-1] - t_span[0]) / time_blocks\n",
    "plot_t_block = [plot_block_id * t_block_size, (plot_block_id + 1) * t_block_size]\n",
    "\n",
    "plot_matrix_tanh = burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(x_eval)\n",
    "id_from = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_from\n",
    "id_to = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_to\n",
    "sample_points = burgers_solver_swim.domain.interior_points\n",
    "\n",
    "\n",
    "def evalutate_plot_value(point, plot_block_id):\n",
    "    return burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(point).ravel()\n",
    "\n",
    "\n",
    "# --- First Figure: SWIM basis functions ---\n",
    "fig1, ax1 = plt.subplots(figsize=(3,2))\n",
    "\n",
    "ax1.set_title(\"SWIM basis functions\")\n",
    "ax1.plot(x_eval, u_swim[(t_eval == plot_t_block[0]).ravel(), :].ravel(), '-', \n",
    "         color=\"black\", linewidth=3, label=\"Burgers solution\")\n",
    "\n",
    "id_list = [4, 121, 41, 29]\n",
    "for id in id_list:\n",
    "    sol_diff = evalutate_plot_value(sample_points[id_from[id]], plot_block_id)[id] - \\\n",
    "               evalutate_plot_value(sample_points[id_to[id]], plot_block_id)[id]\n",
    "    point_diff = sample_points[id_to[id]] - sample_points[id_from[id]]\n",
    "    ax1.plot(x_eval, plot_matrix_tanh[:, id],\n",
    "             label=rf\"Gradient ≈ {np.float64(sol_diff / point_diff):.2f}\",\n",
    "             linestyle='--', linewidth=lw)\n",
    "\n",
    "for i in range(0, plot_matrix_tanh.shape[1]):\n",
    "    point_value = (plot_matrix_tanh[(np.abs(x_eval - 0.1) < 2e-3).ravel(), i]).astype(np.float64)\n",
    "    if (1 - np.abs(point_value) > 0.01) and (1 - np.abs(point_value) < 0.1):\n",
    "        print(\"value =\", point_value, \", id =\", i)\n",
    "ax1.tick_params(axis='both', labelsize=fontsize)\n",
    "ax1.set_xlabel('x', fontsize=fontsize)\n",
    "ax1.set_ylabel('y', fontsize=fontsize)\n",
    "#ax1.legend(loc=\"upper right\")\n",
    "fig1.tight_layout()\n",
    "plt.savefig('swim_basis_functions.pdf')\n",
    "\n",
    "# --- Second Figure: ELM basis functions ---\n",
    "fig2, ax2 = plt.subplots(figsize=(3, 2))\n",
    "\n",
    "ax2.set_title(\"ELM basis functions\")\n",
    "ax2.plot(x_eval, u_swim[(t_eval == plot_t_block[0]).ravel(), :].ravel(), '-', \n",
    "         color=\"black\", linewidth=3, label=\"Burgers solution\")\n",
    "\n",
    "plot_block_id = 0\n",
    "plot_matrix_tanh = burgers_solver_swim.ansatz_collection[plot_block_id].evaluate_model(x_eval)\n",
    "id_from = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_from\n",
    "id_to = burgers_solver_swim.ansatz_collection[plot_block_id]._model[0].idx_to\n",
    "sample_points = x_space[1:-1]\n",
    "\n",
    "id_list = [0, 2, 4, 5]\n",
    "for id in id_list:\n",
    "    sol_diff = evalutate_plot_value(sample_points[id_from[id]], plot_block_id)[id] - \\\n",
    "               evalutate_plot_value(sample_points[id_to[id]], plot_block_id)[id]\n",
    "    point_diff = sample_points[id_to[id]] - sample_points[id_from[id]]\n",
    "    ax2.plot(x_eval, plot_matrix_tanh[:, id],\n",
    "             label=rf\"Gradient ≈ {np.float64(sol_diff / point_diff):.2f}\",\n",
    "             linestyle='--', linewidth=lw)\n",
    "ax2.tick_params(axis='both', labelsize=fontsize)\n",
    "ax2.set_xlabel('x', fontsize=fontsize)\n",
    "ax2.set_ylabel('y', fontsize=fontsize)\n",
    "#ax2.legend(loc=\"upper right\")\n",
    "fig2.tight_layout()\n",
    "plt.savefig('elm_basis_functions.pdf')\n",
    "\n",
    "########################################################\n",
    "# --- 3rd Figure: Data Only ---\n",
    "fig2, ax2 = plt.subplots(figsize=(3, 2))\n",
    "\n",
    "ax2.set_title(\"Target function\")\n",
    "ax2.plot(x_eval, u_swim[(t_eval == plot_t_block[0]).ravel(), :].ravel(), '-', \n",
    "         color=\"black\", linewidth=3, label=\"Burgers solution\")\n",
    "\n",
    "ax2.tick_params(axis='both', labelsize=fontsize)\n",
    "ax2.set_xlabel('x', fontsize=fontsize)\n",
    "ax2.set_ylabel('y', fontsize=fontsize)\n",
    "#ax2.legend(loc=\"upper right\")\n",
    "fig2.tight_layout()\n",
    "plt.savefig('data.pdf')\n",
    "\n",
    "##############################################################\n",
    "# Standard normal distribution (mean=0, std=1)\n",
    "x = np.linspace(-4, 4, 500)\n",
    "pdf = norm.pdf(x, loc=0, scale=1)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(3, 2))\n",
    "plt.plot(x, pdf, label='Standard Normal Distribution', \n",
    "         color='gray', linewidth=3,)\n",
    "plt.title('Gaussian (μ=0, σ=1)', fontsize=fontsize) # \n",
    "plt.xlabel('w', fontsize=fontsize)\n",
    "plt.tick_params(axis='both', labelsize=fontsize)\n",
    "plt.ylabel('P(w)', fontsize=fontsize)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
