{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Solving nonlinear diffusion equation using Frozen-PINN-swim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import sys\n",
    "sys.path.append('../../../')\n",
    "sys.path.append('../../')\n",
    "sys.path.append('../../../src/')\n",
    "from swimpde import Domain\n",
    "from swimpde import BasicAnsatz\n",
    "from swimpde import Diffusion_Solver\n",
    "import numpy as np\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.cm as cm\n",
    "import time\n",
    "cmap = cm.jet\n",
    "from examples.utils import *\n",
    "from utils import *\n",
    "from high_dim_diffusion.utils import *\n",
    "\n",
    "# Set seeds\n",
    "np.random.seed(2)\n",
    "rng = np.random.default_rng(seed=123)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and test boundary points\n",
    "d = 100  # Dimensions\n",
    "n_b_train = 5000    # Number of samples\n",
    "n_b_test = 2000     # Number of samples\n",
    "n_int_train = 5000  # Number of samples\n",
    "n_int_test = 8000   # 16000 Number of samples\n",
    "\n",
    "# Train and test boundary points\n",
    "X_b_train = sample_boundary_lhs_ball(d, n_b_train)\n",
    "X_b_test = sample_boundary_lhs_ball(d, n_b_test)\n",
    "\n",
    "# Train and test interior points\n",
    "X_int_train = sample_interior_lhs_ball(d, n_int_train)\n",
    "X_int_test = sample_interior_lhs_ball(d, n_int_test)\n",
    "X_test = np.vstack((X_int_test, X_b_test))\n",
    "\n",
    "# Uncomment to plot interior points\n",
    "plot_interior_points(X_int_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initial condition\n",
    "def u0(x):\n",
    "    space_dim = np.shape(x)[1]\n",
    "    return np.sum(x**2, axis=1, keepdims=True)/(2. * space_dim)\n",
    "\n",
    "# forcing\n",
    "def forcing(x, t):    \n",
    "    return 0.\n",
    "\n",
    "# boundary condition\n",
    "boundary_condition = \"dirichlet\"\n",
    "\n",
    "# Analytical solution\n",
    "def analytical_sol(x, t):\n",
    "    space_dim = np.shape(x)[1]\n",
    "    return t + np.sum(x**2, axis=1, keepdims=True)/(2. * space_dim)\n",
    "    \n",
    "# Test data\n",
    "t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain\n",
    "x_train = X_int_train # space domain\n",
    "\n",
    "u_true =  analytical_sol(X_test, t_eval)\n",
    "u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the true solution\n",
    "timesteps = [0, 30, 60, 99]\n",
    "plot(x=X_test, u=u_true, timesteps=timesteps, \n",
    "           title='Ground truth', cmap_offset=0,\n",
    "           savefig=True,figname='ground_truth.png',marker_size=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit and evaluate Frozen-PINN-swim for good hyper-parameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [1] #,2,3\n",
    "experiments = []\n",
    "width = 100 #600 # Width\n",
    "reg_const = 1e-8 # Regularization constant\n",
    "svd_cutoff = 1e-8\n",
    "rtol = 1e-3 #1e-6\n",
    "atol = 1e-3 #1e-6\n",
    "# Loop over different seeds\n",
    "rmse_elm = np.ones((len(seeds), ))\n",
    "rel_err_elm = np.ones((len(seeds), ))\n",
    "time_elm = np.ones((len(seeds), ))\n",
    "j = 0\n",
    "svd_on = True\n",
    "info = []\n",
    "for seed in seeds:\n",
    "    # Set seeds\n",
    "    np.random.seed(2)\n",
    "    rng = np.random.default_rng(seed=123)\n",
    "    # Parameter sampler for ELM: Sample weights from a normal distribution and biases uniformly from [-4, 4]\n",
    "    def sample_parameters_randomly(x, _, rng):\n",
    "        #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))\n",
    "        r_m = 0.05\n",
    "        weights = rng.uniform(low=-1.*r_m, high=r_m, size=(x.shape[1], width)) # low=-np.pi, high=np.pi,  2 * np.pi\n",
    "        biases = rng.uniform(low=-1.*r_m, high=r_m, size=(1, width)) # low=-np.pi, high=np.pi,  2 * np.pi\n",
    "        idx0 = None\n",
    "        idx1 = None\n",
    "        return weights, biases, idx0, idx1\n",
    "\n",
    "    ansatz_elm = BasicAnsatz(\n",
    "        n_neurons=width,\n",
    "        activation=\"tanh\",\n",
    "        random_state=seed,\n",
    "        regularization_scale=reg_const,\n",
    "        parameter_sampler = 'tanh'\n",
    "    )  \n",
    "    # Interior points\n",
    "    normal_vectors = X_b_train.copy()\n",
    "    \n",
    "    # Domain\n",
    "    domain = Domain(\n",
    "        interior_points=X_int_train,\n",
    "        boundary_points=X_b_train,\n",
    "        normal_vectors=normal_vectors\n",
    "    )\n",
    "    \n",
    "    diffusion_solver_elm = Diffusion_Solver(\n",
    "        domain=domain, \n",
    "        ansatz=ansatz_elm,\n",
    "        u0=u0,\n",
    "        boundary_condition=boundary_condition,\n",
    "        forcing=forcing,\n",
    "        regularization_scale=reg_const,\n",
    "        scale_boundary_correction=10000.,\n",
    "        boundary_condition_true=analytical_sol,\n",
    "        ode_solver='LSODA'\n",
    "    )\n",
    "    # Compute weights and biases of the elm network\n",
    "    time_blocks = 1\n",
    "    ic_eval = u0(domain.interior_points)\n",
    "    t_elm_start = time.time()\n",
    "    sol_elm, solver_status_elm = diffusion_solver_elm.fit(t_span=[0, np.max(t_eval)], \n",
    "                                            rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,\n",
    "                                            outer_basis=False,\n",
    "                                            init_cond=ic_eval,\n",
    "                                            svd_on=svd_on);\n",
    "    t_elm_stop = time.time()\n",
    "    time_elm[j] = t_elm_stop - t_elm_start\n",
    "\n",
    "    # Evaluate on test data\n",
    "    u_elm_test = diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval,svd_on=svd_on).T #, solver_status=solver_status\n",
    "                \n",
    "    # Compute metrics\n",
    "    rmse_elm[j] = np.sqrt(mean_squared_error(u_true, u_elm_test))  # mean squared error\n",
    "    rel_err_elm[j] = rmse_elm[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))\n",
    "\n",
    "    # Compute metrics\n",
    "    rmse_elm[j] = np.sqrt(mean_squared_error(u_true, u_elm_test))\n",
    "    rel_err_elm[j] = rmse_elm[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))\n",
    "    info.append(time_elm[j])\n",
    "    info.append(rmse_elm[j])\n",
    "    print('time=', time_elm[j], 'rmse_elm=', rmse_elm[j], 'rel_err_elm=',rel_err_elm[j])\n",
    "    j += 1\n",
    "\n",
    "print('Frozen-PINN-swim time = ', np.mean(time_elm))\n",
    "print('rmse Frozen-PINN-swim = ',np.mean(rmse_elm), '+-', np.std(rmse_elm))\n",
    "print('rel l-2 error Frozen-PINN-swim = ',np.mean(rel_err_elm), '+-', np.std(rel_err_elm))\n",
    "experiments.append(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('width of the output layer = ', np.shape(diffusion_solver_elm._get_c0(outer_basis=False).reshape(-1))[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the Frozen-PINN-swim solution\n",
    "plot(x=X_test, u=u_elm_test, timesteps=timesteps, \n",
    "           title='Frozen-PINN-swim solution',cmap_offset=0.,marker_size=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the absolute error with SWIM-ODE\n",
    "plot_error(x=X_test, u_true=u_true, u_nn=u_elm_test, timesteps=timesteps, \n",
    "           figsize=(7,3), fontsize=14, \n",
    "           title='Absolute error: Frozen-PINN-swim',savefig=True, \n",
    "           figname='swim_ode_error.png',marker_size=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solution comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_2d_planes(X_test, analytical_sol, diffusion_solver_elm)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_true_solution(X_test, analytical_sol, diffusion_solver_elm, planes=[0, 99])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define required inputs\n",
    "# coordinates = [1, 3, 4]  # Indices to vary\n",
    "# plot_coordinate_comparisons(X_test, analytical_sol, diffusion_solver_elm, coordinates)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
