{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyper-parameter search for Frozen-PINN-swim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../')\n",
    "sys.path.append('../../../../src')\n",
    "from swimpde import Domain\n",
    "from swimpde import BasicAnsatz\n",
    "from swimpde import BurgersSolver\n",
    "import numpy as np\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import scipy.io\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['image.cmap'] = 'jet'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data and reference solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and visualize data\n",
    "data = scipy.io.loadmat('../../../../data/burgers_shock.mat')\n",
    "t_eval = data['t'].flatten()[:,None]\n",
    "x_eval = data['x'].flatten()[:,None]\n",
    "u_exact = np.real(data['usol']).T\n",
    "X, T = np.meshgrid(x_eval,t_eval)\n",
    "X_ = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))\n",
    "\n",
    "# Set ground truth\n",
    "u_true = u_exact.flatten()[:,None]              "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initial condition\n",
    "def u0(x):\n",
    "    return -1 * np.sin(np.pi * x)\n",
    "\n",
    "# forcing\n",
    "def forcing(x, t):\n",
    "    return np.zeros(x.shape[0])\n",
    "\n",
    "# boundary condition\n",
    "boundary_condition = \"zero dirichlet\" # \n",
    "\n",
    "# Domain information and spacial points for the first time-block\n",
    "n_points_1d = 4000 #320 # 258 #square root of number of points\n",
    "x_lim = [-1, 1]\n",
    "rng = np.random.default_rng(seed=123)\n",
    "x_space = rng.uniform(x_lim[0], x_lim[1], n_points_1d).reshape((-1, 1))\n",
    "x_space_inner = x_space[1:-1]\n",
    "interior_points = x_space_inner\n",
    "\n",
    "# coordinates of boundary points (excluding corners)\n",
    "left = x_lim[0]\n",
    "right = x_lim[1]\n",
    "boundary_points = np.row_stack([left, right])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train and evaluate ODE-swim for different hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Probability distribution \n",
    "n_sample = [3000, 1000]\n",
    "n_col = [2400, 1600, 800, 500]\n",
    "ratios = [10, 4, 3, 2, 1.5]\n",
    "grad_percent = [0.01]\n",
    "reg_consts = [1e-6, 1e-8, 1e-10, 1e-12, 1e-14]\n",
    "seeds = [1, 2, 3]\n",
    "time_blocks = [1, 3]\n",
    "experiments = []\n",
    "tols = [1e-4, 1e-6, 1e-8]\n",
    "# Define prob. distribution for (re)-sampling collocation points\n",
    "def collocation_points_probabilities(df_dx):\n",
    "    gradients = np.abs(df_dx)\n",
    "    gradients = gradients + 0.01 * np.max(gradients)\n",
    "    return gradients/np.sum(gradients)\n",
    "\n",
    "for t_b in time_blocks:\n",
    "    for n_s in n_sample: # Number of training data points\n",
    "        for n_c in n_col:\n",
    "            for r in ratios: \n",
    "                for g_p in grad_percent:\n",
    "                    for r_c in reg_consts:\n",
    "                        for tol in tols:\n",
    "                            svd_co = r_c\n",
    "                            info = []\n",
    "                            info.append(n_s) # No. of sampling points (for computing gradients)\n",
    "                            info.append(n_c) # No. of collocation points (re-sampled)\n",
    "                            width = int(n_c/r)\n",
    "                            info.append(width) # Width\n",
    "                            info.append(g_p) # Percent of gradient added to the prob distribution for sampling collocation points\n",
    "                            info.append(r_c) # Regularization const\n",
    "                            info.append(svd_co) # Scaling factor for boundary conditions\n",
    "                            \n",
    "                            # Compute prob. distribution for (re)-sampling collocation points\n",
    "                            def collocation_points_probabilities(df_dx):\n",
    "                                gradients = np.abs(df_dx)\n",
    "                                gradients = gradients + g_p * np.max(gradients)\n",
    "                                return gradients/np.sum(gradients)\n",
    "\n",
    "                            sample_test_points = np.sort(rng.uniform(x_lim[0] + 1e-4, x_lim[1] - 1e-4, n_s)).reshape((-1, 1)) # This does not include boundary points\n",
    "                            domain = Domain(\n",
    "                                interior_points=interior_points,\n",
    "                                boundary_points=boundary_points,\n",
    "                                sample_points = sample_test_points\n",
    "                            )\n",
    "                            \n",
    "                            # Loop over different seeds\n",
    "                            rel_err_swim = np.ones((len(seeds), ))\n",
    "                            time_swim = np.ones((len(seeds), ))\n",
    "                            rmse_swim = np.ones((len(seeds), ))\n",
    "                            j = 0\n",
    "                            \n",
    "                            # Loop over different seeds\n",
    "                            for seed in seeds: # Run for 3 seeds                                    \n",
    "                                ansatz_swim = BasicAnsatz(\n",
    "                                    n_neurons=width,\n",
    "                                    activation=\"tanh\",\n",
    "                                    random_state=seed,\n",
    "                                    regularization_scale=r_c,\n",
    "                                    parameter_sampler = 'tanh',\n",
    "                                )\n",
    "                                burgers_solver_swim = BurgersSolver(\n",
    "                                    domain=domain, \n",
    "                                    ansatz=ansatz_swim,\n",
    "                                    u0=u0,\n",
    "                                    boundary_condition=boundary_condition,\n",
    "                                    forcing=forcing,\n",
    "                                    regularization_scale=r_c,\n",
    "                                    c=(0.01/np.pi),\n",
    "                                )\n",
    "                                \n",
    "                                # swim fit\n",
    "                                t_swim_start = time.time()\n",
    "                                sol_swim, solver_status_swim = burgers_solver_swim.fit_time_blocks(t_span=[0, np.max(t_eval)], rtol=tol, atol=tol, svd_cutoff=svd_co, time_blocks=t_b, prob_distr_resampling = collocation_points_probabilities, n_col=n_c, outer_basis=False);\n",
    "                                t_swim_stop = time.time()\n",
    "                                time_swim[j] = t_swim_stop - t_swim_start\n",
    "\n",
    "                                # Evaluate on swim and SWIM\n",
    "                                pred_swim = burgers_solver_swim.evaluate_blocks(x_eval= x_eval, t_eval = t_eval, time_blocks = t_b, solver_status = solver_status_swim)\n",
    "\n",
    "                                # True and model solutions\n",
    "                                u_swim = pred_swim.T\n",
    "\n",
    "                                mse_swim = mean_squared_error(u_true, u_swim.flatten()[:,None])  # mean squared error\n",
    "                                rmse_swim[j] = np.sqrt(mse_swim)  # Root Mean Squared Error\n",
    "                                rel_err_swim[j] = np.linalg.norm(u_true-u_swim.flatten()[:,None], 2)/np.linalg.norm(u_true,2)\n",
    "                                print(\"rmse_swim, re_swim\")\n",
    "                                print(rmse_swim[j], rel_err_swim[j])\n",
    "                                j += 1\n",
    "                            \n",
    "                            info.append(np.mean(time_swim))\n",
    "                            info.append(np.mean(rmse_swim))\n",
    "                            info.append(np.mean(rel_err_swim))\n",
    "\n",
    "                            print(\"d_m, n_s, n_c, width, g_p, r_c, r_m, svd_cutoff, time_swim, rmse_swim, re_swim\")\n",
    "                            print(info)\n",
    "                            experiments.append(info)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
