{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "726eae1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9452c693",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import os\n",
    "import pickle as pkl\n",
    "from collections.abc import MutableMapping\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as tri\n",
    "import matplotlib.ticker as ticker\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"DDE_BACKEND\"] = \"jax\"\n",
    "\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".XX\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\"\n",
    "\n",
    "from jax import config\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "# config.update(\"jax_debug_nans\", True)\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax\n",
    "from flax import linen as nn\n",
    "import optax\n",
    "\n",
    "try:\n",
    "    print(f'Jax: CPUs={jax.local_device_count(\"cpu\")} - GPUs={jax.local_device_count(\"gpu\")}')\n",
    "except:\n",
    "    pass\n",
    "    \n",
    "import deepxde_al_patch.deepxde as dde\n",
    "\n",
    "from deepxde_al_patch.model_loader import construct_model, construct_net\n",
    "from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop\n",
    "from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction\n",
    "from deepxde_al_patch.train_set_loader import load_data\n",
    "\n",
    "from deepxde_al_patch.ntk import NTKHelper\n",
    "from deepxde_al_patch.utils import get_pde_residue, print_dict_structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3eec5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.dpi'] = 80\n",
    "plt.rcParams['font.size'] = 18\n",
    "plt.rcParams[\"figure.titlesize\"] = 24\n",
    "plt.rcParams['text.usetex'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5213bab1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b25c20",
   "metadata": {},
   "outputs": [],
   "source": [
    "main_graph = 'al_pinn_graphs_final/main'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3a7825",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_root = 'al_pinn_graphs_final/main'\n",
    "max_runs = 10\n",
    "algs = {\n",
    "    'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls=':', marker='p')),\n",
    "    'random_Hammersley_prop-0.8': ('Hammersley', dict(c='grey', ls=':', marker='h')),\n",
    "    'residue_prop-0.8': ('RAD', dict(c='red', ls='--', marker='v')),\n",
    "    'residue_prop-0.8_alltype': ('RAD-All', dict(c='orange', ls='--', marker='^')),\n",
    "    'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (ours)', dict(c='green', ls='-', marker='s')),\n",
    "    'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (ours)', dict(c='blue', ls='-', marker='o')),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95b9ac1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_folder = '../../'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa39778",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb4cccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def contour_on_ax(ax, xs, zs, levels, res=200, rm_axis=False):\n",
    "    xi, yi = [np.linspace(np.min(xs[:,i]), np.max(xs[:,i]), res) for i in range(2)]\n",
    "    grid = np.meshgrid(xi, yi)\n",
    "    triang = tri.Triangulation(xs[:,0], xs[:,1])\n",
    "    interpolator = tri.LinearTriInterpolator(triang, zs)\n",
    "    Xi, Yi = np.meshgrid(xi, yi)\n",
    "    zi = interpolator(Xi, Yi)\n",
    "    cb = ax.contourf(xi, yi, zi, levels=levels, cmap=\"RdBu_r\")\n",
    "    if rm_axis:\n",
    "        ax.set_xticklabels([])\n",
    "        ax.set_yticklabels([])\n",
    "    else:\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,0]) - np.min(xs[:,0]) > 1.0) else 0.5))\n",
    "        ax.yaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,1]) - np.min(xs[:,1]) > 1.0) else 0.5))\n",
    "    return cb\n",
    "\n",
    "\n",
    "def plot_contours(xs, ys_list, titles, res=200, sym_colour=False, ptile=False, cbar=True):\n",
    "    \n",
    "    nrows = ys_list[0].shape[1]\n",
    "    fig, axs = plt.subplots(\n",
    "        nrows=nrows, \n",
    "        ncols=len(ys_list), \n",
    "        sharex=True, \n",
    "        sharey=True, \n",
    "        figsize=(3 * (len(ys_list) + (1 if cbar else -1)), 3 * nrows + 2),\n",
    "        constrained_layout=True\n",
    "    )\n",
    "    \n",
    "    p_d = 1\n",
    "    if nrows == 1:\n",
    "        if ptile:\n",
    "            min_ = np.percentile(ys_list, p_d)\n",
    "            max_ = np.percentile(ys_list, 100-p_d)\n",
    "        else:\n",
    "            min_ = np.min(ys_list)\n",
    "            max_ = np.max(ys_list)\n",
    "        if sym_colour and (min_ < 0 < max_):\n",
    "            m = max(-min_, max_)\n",
    "            min_ = -m\n",
    "            max_ = m\n",
    "        if ptile:\n",
    "            ys_list = [np.clip(y, min_, max_) for y in ys_list]\n",
    "        levels = np.linspace(min_, max_, num=res)\n",
    "        if not hasattr(axs, '__iter__'):\n",
    "            axs = np.array([axs])\n",
    "        for ax, zs, title in zip(axs, ys_list, titles):\n",
    "            cb = contour_on_ax(ax, xs, zs[:,0], levels, res, rm_axis=not cbar)\n",
    "            ax.set_title(title)\n",
    "        axs = axs.ravel().tolist()\n",
    "        if cbar:\n",
    "            fig.colorbar(cb, ax=axs)\n",
    "    \n",
    "    else:\n",
    "        for i in range(nrows):\n",
    "            ys_list_reduced = [y[:,i] for y in ys_list]\n",
    "            if ptile:\n",
    "                min_ = np.percentile(ys_list_reduced, p_d)\n",
    "                max_ = np.percentile(ys_list_reduced, 100-p_d)\n",
    "            else:\n",
    "                min_ = np.min(ys_list_reduced)\n",
    "                max_ = np.max(ys_list_reduced)\n",
    "            if sym_colour and (min_ < 0 < max_):\n",
    "                m = max(-min_, max_)\n",
    "                min_ = -m\n",
    "                max_ = m\n",
    "            if ptile:\n",
    "                ys_list_reduced = [np.clip(y, min_+1e-9, max_-1e-9) for y in ys_list_reduced]\n",
    "            levels = np.linspace(min_, max_, num=res)\n",
    "            for ax, zs in zip(axs[i], ys_list_reduced):\n",
    "                cb = contour_on_ax(ax, xs, zs, levels, res, rm_axis=not cbar)\n",
    "            if cbar:\n",
    "                fig.colorbar(cb, ax=axs[i])\n",
    "        for ax, title in zip(axs[0], titles):\n",
    "            ax.set_title(title)\n",
    "    \n",
    "    return fig, axs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7216fdf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_training_data(ax, samples):\n",
    "    ms = 4.\n",
    "    ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False)\n",
    "    if 'anc' in samples.keys():\n",
    "        ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False)\n",
    "    for i, bc_pts in enumerate(samples['bcs']):\n",
    "        ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6b4419f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574e5e78",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8407376",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_folder = '../../al_pinn_results/conv-1d{1.0}_pb-80_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/kmeans_alignment_scale-none_mem_autoal/20230914101511'\n",
    "eigplot_folder = '../../al_pinn_graphs_final/eigplots/conv-80'\n",
    "\n",
    "os.makedirs(eigplot_folder, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e81f6f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for s in [10000, 100000]:\n",
    "\n",
    "    steps_range = [s]\n",
    "    step_idx = 0\n",
    "\n",
    "    print('plots number', step_idx)\n",
    "\n",
    "    [x.delete() for x in jax.devices()[0].client.live_buffers()];\n",
    "\n",
    "    model, model_aux = construct_model(\n",
    "\n",
    "        pde_name='conv-1d', \n",
    "        data_seed=40,\n",
    "        pde_const=(1.0,), \n",
    "        use_pdebench=True,\n",
    "        test_max_pts=50000,\n",
    "        include_ic=True,\n",
    "        data_root='~/pdebench',\n",
    "\n",
    "        # model params\n",
    "        hidden_layers=8, \n",
    "        hidden_dim=128, \n",
    "        activation='tanh', \n",
    "        initializer='Glorot uniform', \n",
    "        arch=None, \n",
    "\n",
    "    )\n",
    "\n",
    "    d = dict()\n",
    "\n",
    "\n",
    "    for file in os.listdir(example_folder):\n",
    "\n",
    "        if file.startswith('snapshot_data'):\n",
    "\n",
    "            fname = f'{example_folder}/{file}'\n",
    "\n",
    "            with open(fname, 'rb') as f:\n",
    "                d_update = pkl.load(f)\n",
    "\n",
    "            d.update(d_update)\n",
    "\n",
    "    x_test = d[None]['x_test']\n",
    "    y_test = d[None]['y_test']\n",
    "\n",
    "    d_modified = {\n",
    "        'x_test': x_test,\n",
    "        'y_test': d[None]['y_test'],\n",
    "        'steps': steps_range,\n",
    "        'res_mean': [d[k]['residue_test_mean'] for k in steps_range],\n",
    "        'err_mean': [d[k]['error_test_mean'] for k in steps_range],\n",
    "        'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],\n",
    "        'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],\n",
    "        'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],\n",
    "        'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],\n",
    "        'res': [d[k]['residue_test'] for k in steps_range],\n",
    "        'err': [d[k]['error_test'] for k in steps_range],\n",
    "        'pred': [d[k]['pred_test'] for k in steps_range],\n",
    "        'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],\n",
    "        'inv': [d[k]['params'][1] for k in steps_range],\n",
    "        'params': [d[k]['params'][0] for k in steps_range],\n",
    "    }\n",
    "\n",
    "    ntk = NTKHelper(model)\n",
    "\n",
    "    res = 80\n",
    "    from scipy.spatial.distance import cdist\n",
    "    xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]\n",
    "    grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T\n",
    "    grid_idxs = np.argmin(cdist(grid, x_test), axis=1)\n",
    "    grid = x_test[grid_idxs]\n",
    "    grid_ans = y_test[grid_idxs]\n",
    "\n",
    "    jac_I = ntk.get_jac(grid, code=-2, params=d_modified['params'][step_idx])\n",
    "    jac_N = ntk.get_jac(grid, code=-1, params=d_modified['params'][step_idx])\n",
    "\n",
    "    T_ii = ntk.get_ntk(jac1=jac_I, jac2=jac_I)\n",
    "    T_in = ntk.get_ntk(jac1=jac_I, jac2=jac_N)\n",
    "    T_nn = ntk.get_ntk(jac1=jac_N, jac2=jac_N)\n",
    "\n",
    "    T = np.block([[T_ii, T_in], [T_in.T, T_nn]])\n",
    "    T = T + 1e-9 * np.eye(T.shape[0])\n",
    "\n",
    "    eigvals, eigvects = np.linalg.eigh(T)\n",
    "    eigvals = eigvals[::-1] / (res**2)\n",
    "    eigvects = eigvects.T[::-1]    \n",
    "\n",
    "    ans_flat = grid_ans.reshape(-1)\n",
    "    ys_true = np.concatenate([ans_flat, jnp.zeros_like(ans_flat)])\n",
    "\n",
    "    ys_ = model.net.apply(d_modified['params'][step_idx], grid)\n",
    "    ys_res = model.data.pde(grid, (ys_, lambda x: model.net.apply(d_modified['params'][step_idx], x)))[0]\n",
    "    ys_pred = np.concatenate([ys_.reshape(-1), ys_res.reshape(-1)])\n",
    "\n",
    "    ys_diff = ys_pred - ys_true\n",
    "\n",
    "    \n",
    "    for k, eigvals_rank in enumerate([[10, 20, 50, 100, 200, 500, 1000], [10, 50, 500]]):\n",
    "\n",
    "        for ys, name in [\n",
    "            (ys_true, 'ys_true'),\n",
    "            (ys_pred, 'ys_pred'),\n",
    "            (ys_diff, 'ys_res'),\n",
    "        ]:\n",
    "\n",
    "            coeffs = np.sum(ys * eigvects, axis=1)\n",
    "            scaled_vects = coeffs[:,None] * eigvects\n",
    "\n",
    "            fig, axs = plot_contours(\n",
    "                xs=grid, \n",
    "                ys_list=[ys.reshape(2, -1).T] + [\n",
    "                    np.sum(scaled_vects[:i], axis=0).reshape(2, -1).T\n",
    "                    for i in eigvals_rank\n",
    "                ], \n",
    "                titles=['True solution'] + [\n",
    "                    f'Top {i} eig.fn.'\n",
    "                    for i in eigvals_rank\n",
    "                ], \n",
    "                res=200, sym_colour=False, ptile=False, cbar=True)\n",
    "\n",
    "            axs[0,0].set_ylabel('Experimental pts.')\n",
    "            axs[1,0].set_ylabel('PDE Collocation pts.')\n",
    "            fig.savefig(os.path.join(eigplot_folder, f'eigdecomp-s{steps_range[step_idx]}-{name}-{k}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "            plt.close('all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d0eb69f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_contours_eigval(xs, ys_list, titles, res=200, cbar=False):\n",
    "    \n",
    "    nrows = ys_list[0].shape[1]\n",
    "    fig, axs = plt.subplots(\n",
    "        nrows=nrows, \n",
    "        ncols=len(ys_list), \n",
    "        sharex=True, \n",
    "        sharey=True, \n",
    "        figsize=(4 * len(ys_list) - 3, 4 * nrows),\n",
    "        constrained_layout=True\n",
    "    )\n",
    "    \n",
    "    p_d = 1\n",
    "    if nrows == 1:\n",
    "        axs = [axs]\n",
    "    for i in range(nrows):\n",
    "        ys_list_reduced = [y[:,i] for y in ys_list]\n",
    "        cb = contour_on_ax(\n",
    "            axs[i][0], xs, ys_list_reduced[0], \n",
    "            np.linspace(np.min(ys_list_reduced[0]), np.max(ys_list_reduced[0]), num=res), \n",
    "            res, rm_axis=False)\n",
    "        min_ = np.percentile(ys_list_reduced[1:], p_d)\n",
    "        max_ = np.percentile(ys_list_reduced[1:], 100-p_d)\n",
    "        m = max(-min_, max_)\n",
    "        min_ = -m\n",
    "        max_ = m\n",
    "        ys_list_reduced = [np.clip(y, min_+1e-9, max_-1e-9) for y in ys_list_reduced]\n",
    "        levels = np.linspace(min_, max_, num=res)\n",
    "        for ax, zs in zip(axs[i][1:], ys_list_reduced[1:]):\n",
    "            cb = contour_on_ax(ax, xs, zs, levels, res, rm_axis=True)\n",
    "        if cbar:\n",
    "            fig.colorbar(cb, ax=axs[i])\n",
    "    for ax, title in zip(axs[0], titles):\n",
    "        ax.set_title(title)\n",
    "    \n",
    "    return fig, axs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4c32e08",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "eigvals_rank = [1, 2, 3, 10, 20, 100, 1000]\n",
    "steps_range = [10000, 100000]\n",
    "\n",
    "for step_idx in range(len(steps_range)):\n",
    "    \n",
    "    print('plots number', step_idx)\n",
    "    \n",
    "    [x.delete() for x in jax.devices()[0].client.live_buffers()];\n",
    "    \n",
    "    model, model_aux = construct_model(\n",
    "\n",
    "        pde_name='conv-1d', \n",
    "        data_seed=40,\n",
    "        pde_const=(1.0,), \n",
    "        use_pdebench=True,\n",
    "        test_max_pts=50000,\n",
    "        include_ic=True,\n",
    "        data_root='~/pdebench',\n",
    "\n",
    "        # model params\n",
    "        hidden_layers=8, \n",
    "        hidden_dim=128, \n",
    "        activation='tanh', \n",
    "        initializer='Glorot uniform', \n",
    "        arch=None, \n",
    "\n",
    "    )\n",
    "\n",
    "    d = dict()\n",
    "\n",
    "\n",
    "    for file in os.listdir(example_folder):\n",
    "\n",
    "        if file.startswith('snapshot_data'):\n",
    "\n",
    "            fname = f'{example_folder}/{file}'\n",
    "\n",
    "            with open(fname, 'rb') as f:\n",
    "                d_update = pkl.load(f)\n",
    "\n",
    "            d.update(d_update)\n",
    "\n",
    "    x_test = d[None]['x_test']\n",
    "\n",
    "    d_modified = {\n",
    "        'x_test': x_test,\n",
    "        'y_test': d[None]['y_test'],\n",
    "        'steps': steps_range,\n",
    "        'res_mean': [d[k]['residue_test_mean'] for k in steps_range],\n",
    "        'err_mean': [d[k]['error_test_mean'] for k in steps_range],\n",
    "        'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],\n",
    "        'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],\n",
    "        'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],\n",
    "        'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],\n",
    "        'res': [d[k]['residue_test'] for k in steps_range],\n",
    "        'err': [d[k]['error_test'] for k in steps_range],\n",
    "        'pred': [d[k]['pred_test'] for k in steps_range],\n",
    "        'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],\n",
    "        'inv': [d[k]['params'][1] for k in steps_range],\n",
    "        'params': [d[k]['params'][0] for k in steps_range],\n",
    "    }\n",
    "\n",
    "    ntk = NTKHelper(model)\n",
    "\n",
    "    res = 90\n",
    "    xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]\n",
    "    grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T\n",
    "    \n",
    "    ys = model.net.apply(d_modified['params'][step_idx], grid)\n",
    "    ys_res = model.data.pde(grid, (ys, lambda x: model.net.apply(d_modified['params'][step_idx], x)))[0]\n",
    "    ys_pred_grid = jnp.concatenate([ys, ys_res], axis=1)\n",
    "\n",
    "    jac_I = ntk.get_jac(grid, code=-2, params=d_modified['params'][step_idx])\n",
    "    jac_N = ntk.get_jac(grid, code=-1, params=d_modified['params'][step_idx])\n",
    "\n",
    "    T_ii = ntk.get_ntk(jac1=jac_I, jac2=jac_I)\n",
    "    T_in = ntk.get_ntk(jac1=jac_I, jac2=jac_N)\n",
    "    T_nn = ntk.get_ntk(jac1=jac_N, jac2=jac_N)\n",
    "\n",
    "    T = jnp.block([[T_ii, T_in], [T_in.T, T_nn]])\n",
    "    T = T + 1e-9 * jnp.eye(T.shape[0])\n",
    "\n",
    "    eigvals, eigvects = jnp.linalg.eigh(T)\n",
    "    eigvals = eigvals[::-1] / (res**2)\n",
    "    eigvects = eigvects.T[::-1]    \n",
    "    \n",
    "    fig, ax = plt.subplots()\n",
    "    ax.semilogy(eigvals[:1000])\n",
    "    ax.set_xlabel('Eigenvalue rank')\n",
    "    ax.set_ylabel('Eigenvalue')\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(os.path.join(eigplot_folder, f's{steps_range[step_idx]}-eigval.pdf'))\n",
    "    plt.close('all')\n",
    "    \n",
    "    eigvects_modifies = [jnp.array([eigvects[idx-1, :res**2], eigvects[idx-1, res**2:]]).T for idx in eigvals_rank]\n",
    "    fig, axs = plot_contours_eigval(\n",
    "        grid, \n",
    "        [ys_pred_grid] + eigvects_modifies, \n",
    "        ['NN output'] + [f'$\\lambda_{{{i}}}$ = {eigvals[i]:.1E}' for i in eigvals_rank], \n",
    "    )\n",
    "    axs[0,0].set_ylabel('Experimental pts.')\n",
    "    axs[1,0].set_ylabel('PDE Collocation pts.')\n",
    "    fig.suptitle(f'Step {steps_range[step_idx]}', fontsize=1.8*plt.rcParams['font.size'])\n",
    "    fig.savefig(os.path.join(eigplot_folder, f's{steps_range[step_idx]}-eigvect.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "    plt.close('all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22ab4c95",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47387d92",
   "metadata": {},
   "outputs": [],
   "source": [
    "eigvals_rank = [0, 1, 10, 100]\n",
    "\n",
    "# example_folder = '../../al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/kmeans_alignment-none_mem_autoal/'\n",
    "# eigplot_folder = '../../al_pinn_graphs_eigplots/conv-40'\n",
    "# alg_name = 'PINNACLE-K'\n",
    "# step_num = 50000\n",
    "\n",
    "example_folder = '../../al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0/sampling_alignment_scale-none_mem_autoal/20230901072031'\n",
    "eigplot_folder = '../../al_pinn_graphs_eigplots/conv-40'\n",
    "alg_name = 'PINNACLE-S'\n",
    "step_num = 50000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "125910aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "[x.delete() for x in jax.devices()[0].client.live_buffers()];\n",
    "\n",
    "model, model_aux = construct_model(\n",
    "\n",
    "    pde_name='conv-1d', \n",
    "    data_seed=40,\n",
    "    pde_const=(1.0,), \n",
    "    use_pdebench=True,\n",
    "    test_max_pts=50000,\n",
    "    include_ic=True,\n",
    "    data_root='~/pdebench',\n",
    "\n",
    "    # model params\n",
    "    hidden_layers=8, \n",
    "    hidden_dim=128, \n",
    "    activation='tanh', \n",
    "    initializer='Glorot uniform', \n",
    "    arch=None, \n",
    "\n",
    ")\n",
    "\n",
    "d = dict()\n",
    "\n",
    "\n",
    "for file in os.listdir(example_folder):\n",
    "\n",
    "    if file.startswith('snapshot_data'):\n",
    "\n",
    "        fname = f'{example_folder}/{file}'\n",
    "\n",
    "        with open(fname, 'rb') as f:\n",
    "            d_update = pkl.load(f)\n",
    "\n",
    "        d.update(d_update)\n",
    "\n",
    "x_test = d[None]['x_test']\n",
    "\n",
    "params = d[step_num]['params'][0]\n",
    "train_pts_series = {s: d[s]['al_intermediate']['chosen_pts'] for s in d.keys() if s is not None}\n",
    "chosen_pts = d[step_num]['al_intermediate']['new_points']\n",
    "\n",
    "print(d[step_num]['error_test_mean'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e816a70b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pts_prop = []\n",
    "for s in sorted(train_pts_series.keys()):\n",
    "    tr = train_pts_series[s]\n",
    "    n_res = tr['res'].shape[0]\n",
    "    n_ic = tr['bcs'][0].shape[0]\n",
    "    n_bc = tr['bcs'][1].shape[0]\n",
    "    pts_prop.append((n_res, n_ic, n_bc))\n",
    "    \n",
    "pts_prop = np.array(pts_prop)\n",
    "plt.stackplot(sorted(train_pts_series.keys()), *(pts_prop / np.sum(pts_prop, axis=1)[:,None]).T, alpha=0.8,\n",
    "             labels=['Residual', 'IC', 'BC'])\n",
    "plt.legend(loc='lower right')\n",
    "plt.xlabel('Steps')\n",
    "plt.ylabel('Proportion of training set')\n",
    "plt.xticks(range(0, 150001, 50000))\n",
    "plt.savefig(os.path.join(eigplot_folder, f'all-pointsel_{alg_name}.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d9f1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntk = NTKHelper(model)\n",
    "\n",
    "res = 90\n",
    "xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]\n",
    "grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T\n",
    "\n",
    "ys = model.net.apply(params, grid)\n",
    "ys_res = model.data.pde(grid, (ys, lambda x: model.net.apply(params, x)))[0]\n",
    "ys_pred_grid = jnp.concatenate([ys, ys_res], axis=1)\n",
    "\n",
    "jac_I = ntk.get_jac(grid, code=-2, params=params)\n",
    "jac_N = ntk.get_jac(grid, code=-1, params=params)\n",
    "\n",
    "T_ii = ntk.get_ntk(jac1=jac_I, jac2=jac_I)\n",
    "T_in = ntk.get_ntk(jac1=jac_I, jac2=jac_N)\n",
    "T_nn = ntk.get_ntk(jac1=jac_N, jac2=jac_N)\n",
    "\n",
    "T = jnp.block([[T_ii, T_in], [T_in.T, T_nn]])\n",
    "T = T + 1e-9 * jnp.eye(T.shape[0])\n",
    "\n",
    "eigvals, eigvects = jnp.linalg.eigh(T)\n",
    "eigvals = eigvals[::-1] / (res**2)\n",
    "eigvects = eigvects.T[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6283a411",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = 200\n",
    "xi, yi = [jnp.linspace(jnp.min(x_test[:,i]), jnp.max(x_test[:,i]), res) for i in range(2)]\n",
    "grid = jnp.array([y.flatten() for y in jnp.meshgrid(xi, yi)]).T\n",
    "\n",
    "ys = model.net.apply(params, grid)\n",
    "ys_res = model.data.pde(grid, (ys, lambda x: model.net.apply(params, x)))[0]\n",
    "ys_pred_grid = jnp.concatenate([ys, ys_res], axis=1)\n",
    "\n",
    "jac_Ip = ntk.get_jac(grid, code=-2, params=params)\n",
    "jac_Np = ntk.get_jac(grid, code=-1, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f74cf87",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = jnp.block([\n",
    "    [ntk.get_ntk(jac1=jac_I, jac2=jac_Ip), ntk.get_ntk(jac1=jac_I, jac2=jac_Np)], \n",
    "    [ntk.get_ntk(jac1=jac_N, jac2=jac_Ip), ntk.get_ntk(jac1=jac_N, jac2=jac_Np)]\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3377c006",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c76beb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "eigvects_modifies = [jnp.array([eigvects[idx, :res**2], eigvects[idx, res**2:]]).T for idx in eigvals_rank]\n",
    "fig, axs = plot_contours_eigval(\n",
    "    grid, \n",
    "    [ys_pred_grid] + eigvects_modifies, \n",
    "    ['NN output'] + [f'$\\lambda_{{{i}}}$ = {eigvals[i]:.1E}' for i in eigvals_rank], \n",
    "    cbar=True,\n",
    ")\n",
    "for ax_row in axs:\n",
    "    for ax in ax_row[1:]:\n",
    "        plot_training_data(ax, chosen_pts)\n",
    "axs[0,0].set_ylabel('Prediction')\n",
    "axs[1,0].set_ylabel('PDE Residual')\n",
    "fig.suptitle(alg_name)\n",
    "fig.savefig(os.path.join(eigplot_folder, f's{step_num}-pointsel_{alg_name}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "plt.close('all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45d66af1",
   "metadata": {},
   "outputs": [],
   "source": [
    "eigvects_modifies = [jnp.array([eigvects[idx, res**2:]]).T for idx in eigvals_rank]\n",
    "fig, axs = plot_contours_eigval(\n",
    "    grid, \n",
    "    [ys_pred_grid[:,1:2]] + eigvects_modifies, \n",
    "    ['NN output'] + [f'$\\lambda_{{{i}}}$ = {eigvals[i]:.1E}' for i in eigvals_rank], \n",
    "    cbar=True,\n",
    ")\n",
    "for ax_row in axs:\n",
    "    plot_training_data(ax, chosen_pts)\n",
    "axs[0][0].set_ylabel('PDE Residual')\n",
    "fig.suptitle(alg_name)\n",
    "fig.savefig(os.path.join(eigplot_folder, f's{step_num}-pointsel_{alg_name}-resonly.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "plt.close('all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e79a8cd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
