{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c0ad37",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6aaa966",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "281ef3ad",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import os\n",
    "import pickle as pkl\n",
    "from collections.abc import MutableMapping\n",
    "from datetime import datetime\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"DDE_BACKEND\"] = \"jax\"\n",
    "\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".XX\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\"\n",
    "\n",
    "from jax import config\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "# config.update(\"jax_debug_nans\", True)\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax\n",
    "from flax import linen as nn\n",
    "import optax\n",
    "\n",
    "try:\n",
    "    print(f'Jax: CPUs={jax.local_device_count(\"cpu\")} - GPUs={jax.local_device_count(\"gpu\")}')\n",
    "except:\n",
    "    pass\n",
    "    \n",
    "import deepxde_al_patch.deepxde as dde\n",
    "\n",
    "from deepxde_al_patch.model_loader import construct_model\n",
    "from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop\n",
    "from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction\n",
    "from deepxde_al_patch.train_set_loader import load_data\n",
    "\n",
    "from deepxde_al_patch.ntk import NTKHelper\n",
    "from deepxde_al_patch.utils import get_pde_residue, print_dict_structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e52987",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6b82e404",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552cfecb",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "inverse_problem = False\n",
    "\n",
    "model, model_aux = construct_model(\n",
    "    \n",
    "    #     # load data - without pdebench\n",
    "    pde_name='conv-1d', \n",
    "    data_seed=40,\n",
    "    pde_const=(1.,), \n",
    "    use_pdebench=True,\n",
    "#     inverse_problem=inverse_problem, \n",
    "#     inverse_problem_guess=(0.8,),\n",
    "    num_domain=2000, \n",
    "    num_boundary=500, \n",
    "    num_initial=500,\n",
    "    include_ic=(not inverse_problem),\n",
    "    data_root='~/pdebench',\n",
    "    test_max_pts=50000,\n",
    "    \n",
    "#     #     # load data - without pdebench\n",
    "#     pde_name='burgers-1d', \n",
    "#     data_seed=20,\n",
    "#     pde_const=(0.02,), \n",
    "#     use_pdebench=True,\n",
    "#     inverse_problem=inverse_problem, \n",
    "#     inverse_problem_guess=(0.01,),\n",
    "#     num_domain=2000, \n",
    "#     num_boundary=500, \n",
    "#     num_initial=500,\n",
    "#     include_ic=True,\n",
    "#     data_root='~/pdebench',\n",
    "#     test_max_pts=50000,\n",
    "    \n",
    "#     # load data - without pdebench\n",
    "#     pde_name='fd-2d', \n",
    "#     pde_const=(1.0, 0.01), \n",
    "#     use_pdebench=False,\n",
    "#     inverse_problem=inverse_problem, \n",
    "#     inverse_problem_guess=(0., 0.),\n",
    "#     num_domain=2000, \n",
    "#     num_boundary=500, \n",
    "#     num_initial=500,\n",
    "#     include_ic=False,\n",
    "#     test_max_pts=50000,\n",
    "    # model params\n",
    "    hidden_layers=6, \n",
    "    hidden_dim=32, \n",
    "    activation='tanh', \n",
    "    initializer='Glorot uniform', \n",
    "\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd033c6d",
   "metadata": {},
   "source": [
    "### Experiments area\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb37bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "method = 'random'\n",
    "optim = 'adam'\n",
    "\n",
    "\n",
    "if method == 'random':\n",
    "    al_args = dict(\n",
    "        method='pseudo',\n",
    "        res_proportion=0.8,\n",
    "    )\n",
    "\n",
    "elif method == 'residue':\n",
    "    al_args = dict(\n",
    "        res_proportion=0.8,\n",
    "        select_icbc_with_residue=False,\n",
    "        select_anc_with_residue=False,\n",
    "    )\n",
    "    \n",
    "elif method.startswith('eig'):\n",
    "    al_args = dict(\n",
    "        num_points_round=200,\n",
    "        weight_method= \"alignment\", # possible options are 'none', 'labels', 'eigvals', 'labels_train\n",
    "        num_candidates_res=800, #300\n",
    "        num_candidates_bcs=200,\n",
    "        num_candidates_init=200,\n",
    "        memory = True, # True to remember old points and add on new ones\n",
    "        use_init_train_pts=False,\n",
    "        sampling = 'pseudo', # uniform, pseudo\n",
    "        min_num_points_bcs=1,\n",
    "        min_num_points_res=1,\n",
    "        use_anc_in_train=False,\n",
    "#         points_pool_size=5,\n",
    "    )\n",
    "    \n",
    "# elif method == 'gd':\n",
    "#     al_args = dict(\n",
    "#         points_pool_size=2000,\n",
    "#         num_points_round=1000,\n",
    "#         eig_min=0.001, #1e-2\n",
    "#         lr=1e-2,\n",
    "#         train_steps=1000, #1000\n",
    "#         indicator='span',\n",
    "#         compare_mode=True,\n",
    "#         crit='fr', \n",
    "#         active_eig=50,\n",
    "#         eps=1e-4,\n",
    "#         dist_reg=0.,\n",
    "#     )\n",
    "\n",
    "else:\n",
    "    raise ValueError\n",
    "    \n",
    "    \n",
    "if optim == 'lbfgs':\n",
    "    optim_args = dict(\n",
    "        train_steps=1000,\n",
    "        al_every=20,\n",
    "        select_anchors_every=20,\n",
    "        snapshot_every=20,\n",
    "        optim_method='lbfgs', \n",
    "        optim_lr=1e-3, \n",
    "        optim_args=dict(),\n",
    "    )\n",
    "\n",
    "elif optim == 'adam':\n",
    "    optim_args = dict(\n",
    "        train_steps=15000,\n",
    "        al_every=1000,\n",
    "        select_anchors_every=1000,\n",
    "        snapshot_every=1000,\n",
    "        optim_method='adam', \n",
    "        optim_lr=1e-3, \n",
    "        optim_args=dict(),\n",
    "    )\n",
    "    \n",
    "\n",
    "train_loop = ModifiedTrainLoop(\n",
    "    model=model, \n",
    "    inverse_problem=inverse_problem,\n",
    "    point_selector_method=method,\n",
    "    point_selector_args=al_args,\n",
    "    mem_pts_total_budget=500,\n",
    "    anchor_budget=1,\n",
    "    autoscale_loss_w_bcs=True,\n",
    "    ntk_ratio_threshold=None,\n",
    "    anc_measurable_idx=0,\n",
    "#     anc_measurable_idx=jnp.array([0, 1]),\n",
    "    tensorboard_plots=False,\n",
    "#     log_dir=f'./scrap/{datetime.now().strftime(\"%Y%m%d%H%M%S\")}',\n",
    "    **optim_args,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49858c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(model.data.test_x, model.data.test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c03e60",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plot_prediction(train_loop, res=200, out_idx=0);\n",
    "# plot_prediction(train_loop, res=200, out_idx=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1606f540",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train_loop.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ff67a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, _ = train_loop.plot_training_data()\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4cfcd1d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# train_loop.pde_residue_fn(train_loop.current_params, train_loop.x_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3aedf9eb",
   "metadata": {},
   "source": [
    "### Visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f6d39c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = (8,6)\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.size': 12,\n",
    "    'text.usetex': False,\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c9582ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualisation\n",
    "\n",
    "# train_loop.plot_training_data(step_idx=0)\n",
    "train_loop.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e32a68f6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.plot(model.data.test_x, train_loop.solution_prediction(model.data.test_x));\n",
    "# train_loop.current_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a77e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = [300]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67209e85",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_prediction(train_loop=train_loop, step_idxs=steps, plot_training_data=False, out_idx=0);\n",
    "plot_prediction(train_loop=train_loop, step_idxs=steps, plot_training_data=False, out_idx=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cee4bf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_error(train_loop=train_loop, step_idxs=steps, plot_training_data=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe7a32e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_residue_loss(train_loop=train_loop, step_idxs=steps, plot_training_data=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67e24008",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
