{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b9803f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a0e74b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import os\n",
    "import pickle as pkl\n",
    "from collections.abc import MutableMapping\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as tri\n",
    "import matplotlib.ticker as ticker\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"DDE_BACKEND\"] = \"jax\"\n",
    "\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".XX\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\"\n",
    "\n",
    "from jax import config\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "# config.update(\"jax_debug_nans\", True)\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax\n",
    "from flax import linen as nn\n",
    "import optax\n",
    "\n",
    "try:\n",
    "    print(f'Jax: CPUs={jax.local_device_count(\"cpu\")} - GPUs={jax.local_device_count(\"gpu\")}')\n",
    "except:\n",
    "    pass\n",
    "    \n",
    "import deepxde_al_patch.deepxde as dde\n",
    "\n",
    "from deepxde_al_patch.model_loader import construct_model, construct_net\n",
    "from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop\n",
    "from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction\n",
    "from deepxde_al_patch.train_set_loader import load_data\n",
    "\n",
    "from deepxde_al_patch.ntk import NTKHelper\n",
    "from deepxde_al_patch.utils import get_pde_residue, print_dict_structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "873bd66c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.dpi'] = 150\n",
    "plt.rcParams['font.size'] = 14\n",
    "plt.rcParams[\"figure.titlesize\"] = 24\n",
    "plt.rcParams['text.usetex'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43207528",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b64d35f",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_root = 'al_pinn_graphs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ee1f343",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for data_folder, algs, cutoff in [\n",
    "    \n",
    "    (\n",
    "        '../../al_pinn_results_timing/conv-1d{5.0}_ic/nn-laaf-6-64_adam_bcsloss-1.0_budget-',\n",
    "        {\n",
    "            '10000-50-0-random_Hammersley_prop-0.8': ('Hamm (10k)', dict(c='darkgrey', ls='--', marker='v')),\n",
    "            '100000-50-0-random_Hammersley_prop-0.8': ('Hamm (100k)', dict(c='black', ls='--', marker='^')),\n",
    "            '10000-50-0-residue_prop-0.8_alltype': ('RAD-All (10k)', dict(c='red', ls=':', marker='p')),\n",
    "            '1000-500-0-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (1k)', dict(c='blue', ls='-', marker='o')),\n",
    "        },\n",
    "        0.01\n",
    "    ),\n",
    "    \n",
    "    (\n",
    "        '../../al_pinn_results_timing/fd-2d{1.0-0.01}_inv_anc[0,1]/nn-laaf-6-64_adam_bcsloss-1.0_budget-',\n",
    "        {\n",
    "            '10000-50-0-random_Hammersley_prop-0.8': ('Hamm (10k)', dict(c='darkgrey', ls='--', marker='v')),\n",
    "            '10000-50-0-residue_prop-0.8_alltype': ('RAD-All (10k)', dict(c='red', ls=':', marker='p')),\n",
    "            '1000-200-0-kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (1k)', dict(c='blue', ls='-', marker='o')),\n",
    "        },\n",
    "        0.05\n",
    "    ),\n",
    "]:\n",
    "    \n",
    "    print(data_folder)\n",
    "\n",
    "    cases = {x: os.listdir(f'{data_folder}{x}') for x in algs.keys() if os.path.exists(f'{data_folder}{x}')}\n",
    "\n",
    "    data = dict()\n",
    "\n",
    "    for c in cases.keys():\n",
    "\n",
    "        data[c] = []\n",
    "\n",
    "        for k in cases[c]:\n",
    "\n",
    "            try:\n",
    "                with open(f'{data_folder}{c}/{k}/timing.pkl', 'rb') as f:\n",
    "                    d = pkl.load(f)\n",
    "            except FileNotFoundError:\n",
    "                continue\n",
    "            d = np.array(d)\n",
    "            for i in range(d.shape[0]):\n",
    "                if d[i,2] < cutoff:\n",
    "                    break\n",
    "            d[:,1] = d[:,1] / 60.\n",
    "            d = d[:i+1]\n",
    "#             if (d[-1,1] <= 180.) and (d[-1,2] >= cutoff):\n",
    "#                 continue\n",
    "            data[c].append(d)\n",
    "            if len(data[c]) == 5:\n",
    "                break\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    for a in data.keys():\n",
    "        if len(data[a]) > 0:\n",
    "            print(a, len(data[a]))\n",
    "            best_idx = 0\n",
    "            best_t = float('inf')\n",
    "            vals = {i: d[-1,1] if d[-1,2] < cutoff else 200. + d[-1,2]\n",
    "                    for i, d in enumerate(data[a])}\n",
    "            best_idx = sorted(vals.keys(), key=lambda k: vals[k])[0]\n",
    "            for i, d in enumerate(data[a]):\n",
    "                if i == best_idx:\n",
    "                    ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', \n",
    "                                color=algs[a][1]['c'], label=algs[a][0], \n",
    "                                alpha=0.9, markerfacecolor='none', lw=1, ms=4)\n",
    "                else:\n",
    "                    ax.semilogy(d[:,1],d[:,2], algs[a][1]['marker'] + '-', \n",
    "                                color=algs[a][1]['c'], \n",
    "                                alpha=0.1, markerfacecolor='none', lw=1, ms=4)\n",
    "\n",
    "    ax.axhline(cutoff, linestyle='--', color='darkgrey', zorder=1)\n",
    "    ax.set_xlabel('Time (min)')\n",
    "    ax.set_ylabel('Mean error')\n",
    "#     ax.set_xticks(range(0, 181, 30))\n",
    "    ax.set_xlim(0, None)\n",
    "#     ax.set_ylim(0.1 * cutoff, None)\n",
    "    ax.grid(alpha=0.2)\n",
    "    ax.legend(ncols=2)\n",
    "    plt.show(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a15fe4ff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc335b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../al_pinn_results_timing/conv-1d{5.0}_ic/nn-laaf-6-64_adam_bcsloss-1.0_budget-1000-500-0-kmeans_alignment_scale-none_mem_autoal/20230928122817/last_snapshot_0.010000000000000002.pkl', 'rb') as f:\n",
    "    d = pkl.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "448c0e1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(d['pred_test'].reshape(200, 200))\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3a0628",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a3b4da",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
