{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "726eae1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9452c693",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import os\n",
    "import pickle as pkl\n",
    "from collections.abc import MutableMapping\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as tri\n",
    "import matplotlib.ticker as ticker\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"DDE_BACKEND\"] = \"jax\"\n",
    "\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".XX\"\n",
    "# os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\"\n",
    "\n",
    "from jax import config\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "# config.update(\"jax_debug_nans\", True)\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax\n",
    "from flax import linen as nn\n",
    "import optax\n",
    "\n",
    "try:\n",
    "    print(f'Jax: CPUs={jax.local_device_count(\"cpu\")} - GPUs={jax.local_device_count(\"gpu\")}')\n",
    "except:\n",
    "    pass\n",
    "    \n",
    "import deepxde_al_patch.deepxde as dde\n",
    "\n",
    "from deepxde_al_patch.model_loader import construct_model, construct_net\n",
    "from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop\n",
    "from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction\n",
    "from deepxde_al_patch.train_set_loader import load_data\n",
    "\n",
    "from deepxde_al_patch.ntk import NTKHelper\n",
    "from deepxde_al_patch.utils import get_pde_residue, print_dict_structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3eec5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.dpi'] = 150\n",
    "plt.rcParams['font.size'] = 18\n",
    "plt.rcParams[\"figure.titlesize\"] = 24\n",
    "plt.rcParams['text.usetex'] = False\n",
    "\n",
    "main_graph = 'al_pinn_graphs_final/main'\n",
    "data_folder = '../../'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5213bab1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e91efbb8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7216fdf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def contour_on_ax(ax, xs, zs, levels, res=200, rm_axis=False):\n",
    "    xi, yi = [np.linspace(np.min(xs[:,i]), np.max(xs[:,i]), res) for i in range(2)]\n",
    "    grid = np.meshgrid(xi, yi)\n",
    "    triang = tri.Triangulation(xs[:,0], xs[:,1])\n",
    "    interpolator = tri.LinearTriInterpolator(triang, zs)\n",
    "    Xi, Yi = np.meshgrid(xi, yi)\n",
    "    zi = interpolator(Xi, Yi)\n",
    "    cb = ax.contourf(xi, yi, zi, levels=levels, cmap=\"RdBu_r\")\n",
    "    if rm_axis:\n",
    "        ax.set_xticklabels([])\n",
    "        ax.set_yticklabels([])\n",
    "    else:\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,0]) - np.min(xs[:,0]) > 1.0) else 0.5))\n",
    "        ax.yaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,1]) - np.min(xs[:,1]) > 1.0) else 0.5))\n",
    "    return cb\n",
    "\n",
    "\n",
    "def plot_contours(xs, ys_list, titles, res=200, sym_colour=False, ptile=False, cbar=True, axislabels=None):\n",
    "    \n",
    "    nrows = ys_list[0].shape[1]\n",
    "    fig, axs = plt.subplots(\n",
    "        nrows=nrows, \n",
    "        ncols=len(ys_list), \n",
    "        sharex=True, \n",
    "        sharey=True, \n",
    "        figsize=(3 * (len(ys_list) + (1 if cbar else -1)), 3 * nrows + 2),\n",
    "        constrained_layout=True\n",
    "    )\n",
    "    \n",
    "    p_d = 1\n",
    "    if nrows == 1:\n",
    "        if ptile:\n",
    "            min_ = np.percentile(ys_list, p_d)\n",
    "            max_ = np.percentile(ys_list, 100-p_d)\n",
    "        else:\n",
    "            min_ = np.min(ys_list)\n",
    "            max_ = np.max(ys_list)\n",
    "        if sym_colour and (min_ < 0 < max_):\n",
    "            m = max(-min_, max_)\n",
    "            min_ = -m\n",
    "            max_ = m\n",
    "        if ptile:\n",
    "            ys_list = [np.clip(y, min_, max_) for y in ys_list]\n",
    "        levels = np.linspace(min_, max_, num=res)\n",
    "        if not hasattr(axs, '__iter__'):\n",
    "            axs = np.array([axs])\n",
    "        for ax, zs, title in zip(axs, ys_list, titles):\n",
    "            cb = contour_on_ax(ax, xs, zs[:,0], levels, res, rm_axis=not cbar)\n",
    "            ax.set_title(title)\n",
    "            ax.set_xlabel(axislabels[0])\n",
    "        axs[0].set_ylabel(axislabels[1])\n",
    "        axs = axs.ravel().tolist()\n",
    "        if cbar:\n",
    "            fig.colorbar(cb, ax=axs)\n",
    "    \n",
    "    else:\n",
    "        for i in range(nrows):\n",
    "            ys_list_reduced = [y[:,i] for y in ys_list]\n",
    "            if ptile:\n",
    "                min_ = np.percentile(ys_list_reduced, p_d)\n",
    "                max_ = np.percentile(ys_list_reduced, 100-p_d)\n",
    "            else:\n",
    "                min_ = np.min(ys_list_reduced)\n",
    "                max_ = np.max(ys_list_reduced)\n",
    "            if sym_colour and (min_ < 0 < max_):\n",
    "                m = max(-min_, max_)\n",
    "                min_ = -m\n",
    "                max_ = m\n",
    "            if ptile:\n",
    "                ys_list_reduced = [np.clip(y, min_+1e-9, max_-1e-9) for y in ys_list_reduced]\n",
    "            levels = np.linspace(min_, max_, num=res)\n",
    "            for ax, zs in zip(axs[i], ys_list_reduced):\n",
    "                cb = contour_on_ax(ax, xs, zs, levels, res, rm_axis=not cbar)\n",
    "                ax.set_xlabel(axislabels[0])\n",
    "            axs[i,0].set_ylabel(axislabels[1])\n",
    "            if cbar:\n",
    "                fig.colorbar(cb, ax=axs[i])\n",
    "        for ax, title in zip(axs[0], titles):\n",
    "            ax.set_title(title)\n",
    "    \n",
    "    return fig, axs\n",
    "\n",
    "\n",
    "def plot_training_data(ax, samples):\n",
    "    ms = 4.\n",
    "    ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False)\n",
    "    if 'anc' in samples.keys():\n",
    "        ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False)\n",
    "    for i, bc_pts in enumerate(samples['bcs']):\n",
    "        ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3201d005",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "235d37a7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for run_idx in [0, 1, 2, 3, 4]:\n",
    "\n",
    "    if run_idx == 0:\n",
    "        graph_root = 'al_pinn_graphs_final/main'\n",
    "        max_runs = 10\n",
    "        algs = {\n",
    "            'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls=':', marker='p')),\n",
    "            'random_Hammersley_prop-0.8': ('Hammersley', dict(c='grey', ls=':', marker='h')),\n",
    "            'residue_prop-0.8': ('RAD', dict(c='red', ls='--', marker='v')),\n",
    "            'residue_prop-0.8_alltype': ('RAD-All', dict(c='orange', ls='--', marker='^')),\n",
    "            'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (ours)', dict(c='green', ls='-', marker='s')),\n",
    "            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (ours)', dict(c='blue', ls='-', marker='o')),\n",
    "        }\n",
    "\n",
    "    elif run_idx == 1:\n",
    "        graph_root = 'al_pinn_graphs_final/mult_nonad'\n",
    "        max_runs = 5\n",
    "        algs = {\n",
    "            'random_pseudo_prop-0.5': ('Unif-0.5', dict(c='black', ls='-', marker='s')),\n",
    "            'random_pseudo_prop-0.8': ('Unif-0.8', dict(c='black', ls='--', marker='p')),\n",
    "            'random_pseudo_prop-0.95': ('Unif-0.95', dict(c='black', ls='-.', marker='h')),\n",
    "            'random_Hammersley_prop-0.5': ('Hamm-0.5', dict(c='m', ls='-', marker='s')),\n",
    "            'random_Hammersley_prop-0.8': ('Hamm-0.8', dict(c='m', ls='--', marker='p')),\n",
    "            'random_Hammersley_prop-0.95': ('Hamm-0.95', dict(c='m', ls='-.', marker='h')),\n",
    "            'random_Sobol_prop-0.5': ('Sobol-0.5', dict(c='red', ls='-', marker='s')),\n",
    "            'random_Sobol_prop-0.8': ('Sobol-0.8', dict(c='red', ls='--', marker='p')),\n",
    "            'random_Sobol_prop-0.95': ('Sobol-0.95', dict(c='red', ls='-.', marker='h')),\n",
    "        }\n",
    "\n",
    "    elif run_idx == 2:\n",
    "        graph_root = 'al_pinn_graphs_final/mult_adapt'\n",
    "        max_runs = 5\n",
    "        algs = {\n",
    "            'residue_prop-0.5': ('RAD-0.5', dict(c='red', ls='-', marker='^')),\n",
    "            'residue_prop-0.8': ('RAD-0.8', dict(c='red', ls='--', marker='v')),\n",
    "            'residue_prop-0.95': ('RAD-0.95', dict(c='red', ls='-.', marker='>')),\n",
    "            'residue_prop-0.5_alltype': ('RAD-All-0.5', dict(c='lightblue', ls='-', marker='^')),\n",
    "            'residue_prop-0.8_alltype': ('RAD-All-0.8', dict(c='lightblue', ls='--', marker='v')),\n",
    "            'residue_prop-0.95_alltype': ('RAD-All-0.95', dict(c='lightblue', ls='-.', marker='>')),\n",
    "            'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S', dict(c='green', ls='-', marker='o')),\n",
    "            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K', dict(c='blue', ls='-', marker='s')),\n",
    "        }\n",
    "\n",
    "    elif run_idx == 3:\n",
    "        graph_root = 'al_pinn_graphs_final/pinnacle-k'\n",
    "        max_runs = 5\n",
    "        algs = {\n",
    "            'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls='--', marker='p')),\n",
    "            'kmeans_alignment_scale-none_autoal': ('No Memory', dict(c='orange', ls='-', marker='^')),\n",
    "            'kmeans_alignment_scale-none_mem': ('No Auto Trigger', dict(c='m', ls='-', marker='v')),\n",
    "            'kmeans_alignment_scale-none_mem_autoal_fb': ('No Dyn Alloc', dict(c='lightblue', ls='-', marker='s')),\n",
    "            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K', dict(c='blue', ls='-', marker='o')),\n",
    "        }\n",
    "\n",
    "    elif run_idx == 4:\n",
    "        graph_root = 'al_pinn_graphs_final/pinnacle-s'\n",
    "        max_runs = 5\n",
    "        algs = {\n",
    "            'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls='--', marker='p')),\n",
    "            'sampling_alignment_scale-none_autoal': ('No Memory', dict(c='orange', ls='-', marker='^')),\n",
    "            'sampling_alignment_scale-none_mem': ('No Auto Trigger', dict(c='m', ls='-', marker='v')),\n",
    "            'sampling_alignment_scale-none_mem_autoal_fb': ('No Dyn Alloc', dict(c='lightblue', ls='-', marker='s')),\n",
    "            'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S', dict(c='green', ls='-', marker='o')),\n",
    "        }\n",
    "\n",
    "    ######## RUN NORMAL PLOTS\n",
    "\n",
    "    case_list = [\n",
    "\n",
    "        (\n",
    "            'al_pinn_results/fd-2d{1.0-0.01}_inv_anc[0,1]/nn-laaf-6-64_adam_bcsloss-1.0_budget-1000-200-30',\n",
    "            [0, 10000, 20000, 100000],\n",
    "            '2D Fluid Dynamics (Inv)',\n",
    "        ),\n",
    "\n",
    "        (\n",
    "            'al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0',\n",
    "            [0, 10000, 50000, 200000],\n",
    "            '1D Advection',\n",
    "        ),\n",
    "        \n",
    "        (\n",
    "            'al_pinn_results/conv-1d{1.0}_pb-40_anc/nn-None-8-128_adam_bcsloss-1.0_budget-500-200-1',\n",
    "            [0, 10000, 50000, 200000],\n",
    "            '1D Advection (w/ Exp Pts)',\n",
    "        ),\n",
    "\n",
    "        (\n",
    "            'al_pinn_results/conv-1d{1.0}_pb-80_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0',\n",
    "            [0, 10000, 50000, 200000],\n",
    "            '1D Advection',\n",
    "        ),\n",
    "\n",
    "        (\n",
    "            'al_pinn_results/burgers-1d{0.02}_pb-20_ic/nn-None-4-128_adam_bcsloss-1.0_budget-300-100-0', \n",
    "            [0, 10000, 50000, 200000],\n",
    "            '1D Burger\\'s',\n",
    "        ),\n",
    "\n",
    "    ]\n",
    "\n",
    "    if graph_root != main_graph:\n",
    "        case_list = [c for c in case_list if ('conv' in c[0]) or ('burger' in c[0])]\n",
    "\n",
    "\n",
    "    for case_folder, steps_plot, suptit in case_list:\n",
    "\n",
    "        if steps_plot[-1] > 150000:\n",
    "            tick_spacing = 100000\n",
    "        elif steps_plot[-1] > 50000:\n",
    "            tick_spacing = 50000\n",
    "        else:\n",
    "            tick_spacing = 10000\n",
    "\n",
    "        throwout = []\n",
    "\n",
    "        print('---------------------------------------------------\\n\\n')\n",
    "        print('PROCESSING:', graph_root, case_folder)\n",
    "\n",
    "        max_steps = steps_plot[-1]\n",
    "\n",
    "        root_folder = os.path.join(data_folder, case_folder)\n",
    "\n",
    "        _, arch, depth, width = root_folder.split('/')[-1].split('_')[0].split('-')\n",
    "\n",
    "    #     net, _ = construct_net(\n",
    "    #         input_dim=2, \n",
    "    #         output_dim=1, \n",
    "    #         hidden_layers=int(depth), \n",
    "    #         hidden_dim=int(width), \n",
    "    #         arch=(None if arch == 'None' else arch)\n",
    "    #     )\n",
    "\n",
    "        cases = {x: os.listdir(f'{root_folder}/{x}') for x in algs.keys() if os.path.exists(f'{root_folder}/{x}')}\n",
    "        print('Exist:', list(cases.keys()))\n",
    "\n",
    "        if len(list(cases.keys())) < 1:\n",
    "            continue\n",
    "\n",
    "        data = dict()\n",
    "        steps_min = dict()\n",
    "        plotted_cases = dict()\n",
    "\n",
    "        for c in cases.keys():\n",
    "\n",
    "            s_min = float('inf')\n",
    "\n",
    "            runs = []\n",
    "            runs_cases = []\n",
    "\n",
    "            for r in cases[c]:\n",
    "\n",
    "                try:\n",
    "\n",
    "                    d = dict()\n",
    "\n",
    "                    if len(runs) < max_runs:\n",
    "\n",
    "                        for file in os.listdir(f'{root_folder}/{c}/{r}'):\n",
    "\n",
    "                            if file.startswith('snapshot_data'):\n",
    "\n",
    "                                fname = f'{root_folder}/{c}/{r}/{file}'\n",
    "\n",
    "                                with open(fname, 'rb') as f:\n",
    "                                    d_update = pkl.load(f)\n",
    "\n",
    "                                d.update(d_update)\n",
    "\n",
    "                    steps_range = sorted([x for x in d.keys() if (x is not None) and (max_steps >= x)])\n",
    "                    if 'eik' in case_folder:\n",
    "                        steps_range.pop(0)\n",
    "\n",
    "                    if (len(steps_range) > 0) and (max_steps == steps_range[-1]) and (None in d.keys()):\n",
    "\n",
    "                        print(c, r, sorted([x for x in d.keys() if (x is not None)])[-1])\n",
    "\n",
    "                        s_min = min(s_min, steps_range[-1])\n",
    "\n",
    "                        x_test = d[None]['x_test']\n",
    "\n",
    "                        d_modified = {\n",
    "                            'x_test': x_test,\n",
    "                            'y_test': d[None]['y_test'],\n",
    "                            'steps': steps_range,\n",
    "                            'res_mean': [d[k]['residue_test_mean'] for k in steps_range],\n",
    "                            'err_mean': [d[k]['error_test_mean'] for k in steps_range],\n",
    "                            'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],\n",
    "                            'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],\n",
    "                            'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],\n",
    "                            'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],\n",
    "                            'res': [d[k]['residue_test'] for k in steps_range],\n",
    "                            'err': [d[k]['error_test'] for k in steps_range],\n",
    "                            'pred': [d[k]['pred_test'] #if 'pred_test' in d[k].keys() \n",
    "    #                                  else net.apply(d[k]['params'][0], x_test)\n",
    "                                     for k in steps_range],\n",
    "                            'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],\n",
    "                            'inv': [d[k]['params'][1] for k in steps_range],\n",
    "                        }\n",
    "\n",
    "                        if x_test.shape[1] == 2:\n",
    "\n",
    "                            arr_shape = [d_modified['y_test'].shape[1]] + [np.unique(x).shape[0] for x in d_modified['x_test'].T]\n",
    "                            d_modified['y_test_fft'] = np.fft.fftn(\n",
    "                                d_modified['y_test'].reshape(*arr_shape), \n",
    "                                axes=[1, 2]\n",
    "                            )\n",
    "                            d_modified['pred_fft'] = [np.fft.fftn(\n",
    "                                y.reshape(*arr_shape), axes=[1, 2]) \n",
    "                                for y in d_modified['pred']]\n",
    "\n",
    "                            d_modified['fft_err'] = [np.abs(yf - d_modified['y_test_fft'])\n",
    "                                for yf in d_modified['pred_fft']]\n",
    "\n",
    "                #             idxs = np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))[0].T\n",
    "                #             idxs = np.array([idxs, idxs])\n",
    "\n",
    "                            idxs = np.array(np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))).swapaxes(1, 2)\n",
    "\n",
    "                            klow = (idxs <= 4).all(axis=0).astype(float)\n",
    "                            kmid = (idxs <= 12).all(axis=0).astype(float) - klow\n",
    "                            khigh = (idxs <= np.inf).all(axis=0).astype(float) - kmid - klow\n",
    "\n",
    "                            for s, k in [('low', klow), ('mid', kmid), ('high', khigh)]:\n",
    "                                d_modified[f'fft_mean_{s}'] = [np.sum(yf * k[None, :]) / (np.sum(k) * yf.shape[0])\n",
    "                                    for yf in d_modified['fft_err']]\n",
    "\n",
    "                        if 'darcy' in case_folder:\n",
    "                            a_pred = [y[:,0] for y in d_modified['pred']]\n",
    "                            a_true = d_modified['y_test'][:,0]\n",
    "                            d_modified['a_err_mean'] = [np.mean((a_true - y)**2) for y in a_pred]\n",
    "                            d_modified['a_err_q90'] = [np.percentile((a_true - y)**2, 90) for y in a_pred]\n",
    "                            f_true = np.array(a_true > 0.5, dtype=float)\n",
    "                            f_pred = [np.array(y > 0.5, dtype=float) for y in a_pred]\n",
    "                            d_modified['bool_err_mean'] = [np.mean(np.abs(f_true - y)) for y in f_pred]\n",
    "                            multidim = True\n",
    "                        elif 'reacdiff' in case_folder:\n",
    "                            multidim = True\n",
    "                        elif 'fd-2d' in case_folder:\n",
    "                            a_pred = [y[:,2] for y in d_modified['pred']]\n",
    "                            a_true = d_modified['y_test'][:,2]\n",
    "                            d_modified['a_err_mean'] = [np.mean((a_true - y)**2) for y in a_pred]\n",
    "                            d_modified['a_err_q90'] = [np.percentile((a_true - y)**2, 90) for y in a_pred]\n",
    "                            multidim = True\n",
    "                        elif 'eik' in case_folder:\n",
    "                            a_pred = [y[:,1] for y in d_modified['pred']]\n",
    "                            a_true = d_modified['y_test'][:,1]\n",
    "                            d_modified['a_err_mean'] = [np.mean((a_true - y)**2) for y in a_pred]\n",
    "                            d_modified['a_err_q90'] = [np.percentile((a_true - y)**2, 90) for y in a_pred]\n",
    "                            multidim = True\n",
    "                        else:\n",
    "                            multidim = False\n",
    "\n",
    "                        if 'inv' in case_folder:\n",
    "                            d_modified['inv_param_true'] = [float(x) for x in case_folder.split('{')[1].split('}')[0].split('-')]\n",
    "                            d_modified['inv_param_pred'] = tuple([float(d[k]['params'][1][i]) for k in steps_range] for i in range(len(d_modified['inv_param_true'])))\n",
    "                            inv_params = d_modified['inv_param_true']\n",
    "                            plot_inv_param = True\n",
    "                        else:\n",
    "                            plot_inv_param = False\n",
    "\n",
    "                        runs.append(d_modified)\n",
    "                        runs_cases.append(r)\n",
    "\n",
    "                    else:\n",
    "                        throwout.append(f'{root_folder}/{c}/{r}')\n",
    "\n",
    "                except Exception as e:\n",
    "    #                 raise e\n",
    "                    pass\n",
    "\n",
    "            if len(runs) > 0:\n",
    "                data[c] = runs\n",
    "                steps_min[c] = s_min\n",
    "                plotted_cases[c] = runs_cases\n",
    "\n",
    "        print('To plot algorithms =', {k: len(data[k]) for k in data.keys()})\n",
    "        print()\n",
    "\n",
    "        print('Throwout:')\n",
    "        [print(th) for th in throwout]\n",
    "\n",
    "        if len([k for k in data.keys()]) < 1:\n",
    "            continue\n",
    "\n",
    "        graph_folder = os.path.join(data_folder, graph_root, case_folder)\n",
    "        os.makedirs(graph_folder, exist_ok=True)\n",
    "\n",
    "        with open(os.path.join(graph_folder, 'cases_plotted'), 'w+') as f:\n",
    "            f.write(str(plotted_cases))\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        for c in data.keys():\n",
    "            ys = [y['err_mean'] for y in data[c]]\n",
    "            mean = np.mean(ys, axis=0)\n",
    "            err = np.std(ys, axis=0)\n",
    "            label, marker = algs[c]\n",
    "            ax.set_yscale(\"log\")\n",
    "            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "        ax.set_xlabel('Steps')\n",
    "        ax.set_ylabel('Mean error')\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "    #     fig.suptitle(suptit)\n",
    "        fig.savefig(os.path.join(graph_folder, f'err_mean.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "        figl = plt.figure(figsize=(6, 1.5))\n",
    "        axl = figl.add_subplot(111)\n",
    "        axl.legend(*ax.get_legend_handles_labels() , loc=\"center\", ncol=3)\n",
    "        axl.axis('off')\n",
    "        figl.savefig(os.path.join(graph_folder, f'labels.pdf'), bbox_inches='tight', pad_inches=0.05)\n",
    "        plt.close('all')\n",
    "\n",
    "        figl = plt.figure(figsize=(6, 1))\n",
    "        axl = figl.add_subplot(111)\n",
    "        axl.legend(*ax.get_legend_handles_labels() , loc=\"center\", ncol=len(ax.get_legend_handles_labels()[0]))\n",
    "        axl.axis('off')\n",
    "        figl.savefig(os.path.join(graph_folder, f'labels_flat.pdf'), bbox_inches='tight', pad_inches=0.05)\n",
    "        plt.close('all')\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        for c in data.keys():\n",
    "            ys = [y['err_mean'] for y in data[c]]\n",
    "            mean = np.percentile(ys, 50, axis=0)\n",
    "            err1 = mean - np.percentile(ys, 20, axis=0)\n",
    "            err2 = np.percentile(ys, 80, axis=0) - mean\n",
    "            label, marker = algs[c]\n",
    "            ax.set_yscale(\"log\")\n",
    "            ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "        ax.set_xlabel('Steps')\n",
    "        ax.set_ylabel('Mean error')\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "    #     fig.suptitle(suptit)\n",
    "        fig.savefig(os.path.join(graph_folder, f'err_med.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "        if ('darcy' in case_folder) or ('eik' in case_folder) or ('fd' in case_folder):\n",
    "\n",
    "            if 'darcy' in case_folder:\n",
    "                field_name = 'a'\n",
    "            elif 'eik' in case_folder:\n",
    "                field_name = 'v'\n",
    "            elif 'fd' in case_folder:\n",
    "                field_name = 'p'\n",
    "\n",
    "            fig, ax = plt.subplots(figsize=(5, 4))\n",
    "            for c in data.keys():\n",
    "                ys = [y['a_err_mean'] for y in data[c]]\n",
    "                mean = np.mean(ys, axis=0)\n",
    "                err = np.std(ys, axis=0)\n",
    "                label, marker = algs[c]\n",
    "                ax.set_yscale(\"log\")\n",
    "                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "            ax.set_xlabel('Steps')\n",
    "            ax.set_ylabel(f'Mean error of {field_name}(x)')\n",
    "            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "            # fig.suptitle(suptit)\n",
    "            fig.savefig(os.path.join(graph_folder, f'err-{field_name}_avg.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "            plt.close('all')\n",
    "\n",
    "            fig, ax = plt.subplots(figsize=(5, 4))\n",
    "            for c in data.keys():\n",
    "                ys = [y['a_err_mean'] for y in data[c]]\n",
    "                mean = np.percentile(ys, 50, axis=0)\n",
    "                err1 = mean - np.percentile(ys, 20, axis=0)\n",
    "                err2 = np.percentile(ys, 80, axis=0) - mean\n",
    "                label, marker = algs[c]\n",
    "                ax.set_yscale(\"log\")\n",
    "                ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "            ax.set_xlabel('Steps')\n",
    "            ax.set_ylabel(f'Mean error of {field_name}(x)')\n",
    "            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "            # fig.suptitle(suptit)\n",
    "            fig.savefig(os.path.join(graph_folder, f'err-{field_name}_med.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "            plt.close('all')\n",
    "\n",
    "            fig, ax = plt.subplots(figsize=(5, 4))\n",
    "            for c in data.keys():\n",
    "                ys = [y['a_err_q90'] for y in data[c]]\n",
    "                mean = np.mean(ys, axis=0)\n",
    "                err = np.std(ys, axis=0)\n",
    "                label, marker = algs[c]\n",
    "                ax.set_yscale(\"log\")\n",
    "                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "            ax.set_xlabel('Steps')\n",
    "            ax.set_ylabel(f'90th Quantile Error of {field_name}(x)')\n",
    "            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "            # fig.suptitle(suptit)\n",
    "            fig.savefig(os.path.join(graph_folder, f'err-{field_name}_q90.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "            plt.close('all')\n",
    "\n",
    "            if 'darcy' in case_folder:\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y['bool_err_mean'] for y in data[c]]\n",
    "                    mean = np.mean(ys, axis=0)\n",
    "                    err = np.std(ys, axis=0)\n",
    "                    label, marker = algs[c]\n",
    "                    ax.set_yscale(\"log\")\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel(f'Mean boolean error of {field_name}(x)')\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                # fig.suptitle(suptit)\n",
    "                fig.savefig(os.path.join(graph_folder, f'err-{field_name}-bool_avg.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "                plt.close('all')\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y['bool_err_mean'] for y in data[c]]\n",
    "                    mean = np.median(ys, axis=0)\n",
    "                    err1 = mean - np.percentile(ys, 20, axis=0)\n",
    "                    err2 = np.percentile(ys, 80, axis=0) - mean\n",
    "                    label, marker = algs[c]\n",
    "                    ax.set_yscale(\"log\")\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel(f'Median boolean error of {field_name}(x)')\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                # fig.suptitle(suptit)\n",
    "                fig.savefig(os.path.join(graph_folder, f'err-{field_name}-bool_med.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "            plt.close('all')\n",
    "\n",
    "\n",
    "        for q in [50, 90, 95, 100]:\n",
    "\n",
    "            fig, ax = plt.subplots(figsize=(5, 4))\n",
    "            for c in data.keys():\n",
    "                ys = [y[f'err_q{q}'] for y in data[c]]\n",
    "                mean = np.mean(ys, axis=0)\n",
    "                err = np.std(ys, axis=0)\n",
    "                label, marker = algs[c]\n",
    "                ax.set_yscale(\"log\")\n",
    "                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "            ax.set_xlabel('Steps')\n",
    "            ax.set_ylabel(f'{q}th Quantile error')\n",
    "            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "            # fig.suptitle(suptit)\n",
    "            fig.savefig(os.path.join(graph_folder, f'err_q{q}.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "            plt.close('all')\n",
    "\n",
    "\n",
    "    #     fig, ax = plt.subplots(figsize=(5, 4))\n",
    "    #     for c in cases.keys():\n",
    "    #         label, marker = algs[c]\n",
    "    #         ax.semilogy(\n",
    "    #             data[c][0]['steps'], \n",
    "    #             [jnp.sqrt(jnp.mean(e**2)) for e in data[c][np.argmin([x['err_mean'][-1] for x in data[c]])]['err']], \n",
    "    #             marker, label=label\n",
    "    #         )\n",
    "    #     ax.set_xlabel('Steps')\n",
    "    #     ax.set_ylabel('RMSE')\n",
    "    #     \n",
    "    #     fig.savefig(os.path.join(graph_folder, f'rmse.pdf'))\n",
    "    #     plt.close('all')\n",
    "\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        for c in data.keys():\n",
    "            ys = [y['res_mean'] for y in data[c]]\n",
    "            mean = np.mean(ys, axis=0)\n",
    "            err = np.std(ys, axis=0)\n",
    "            label, marker = algs[c]\n",
    "            ax.set_yscale(\"log\")\n",
    "            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "        ax.set_xlabel('Steps')\n",
    "        ax.set_ylabel('Mean residue')\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "        # fig.suptitle(suptit)\n",
    "        fig.savefig(os.path.join(graph_folder, f'res_mean.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "        if plot_inv_param:\n",
    "            for i in range(len(inv_params)):\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y['inv_param_pred'][i] for y in data[c]]\n",
    "                    mean = np.mean(ys, axis=0)\n",
    "                    err = np.std(ys, axis=0)\n",
    "                    label, marker = algs[c]\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, err, capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.axhline(y=inv_params[i], color='black')\n",
    "                ax.set_ylim(max(0., inv_params[i]-0.05), None)\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel(f'Inv. param {i+1}')\n",
    "                # fig.suptitle(suptit)\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y['inv_param_pred'][i] for y in data[c]]\n",
    "                    mean = np.percentile(ys, 50, axis=0)\n",
    "                    err1 = mean - np.percentile(ys, 20, axis=0)\n",
    "                    err2 = np.percentile(ys, 80, axis=0) - mean\n",
    "                    label, marker = algs[c]\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.axhline(y=inv_params[i], color='black')\n",
    "                ax.set_ylim(max(0., inv_params[i]-0.05), None)\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel(f'Inv. param {i+1}')\n",
    "                # fig.suptitle(suptit)\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}_med.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y['inv_param_pred'][i] for y in data[c]]\n",
    "                    inv_err = np.abs(inv_params[i] - np.array(ys))\n",
    "                    mean = np.mean(inv_err, axis=0)\n",
    "                    err = np.std(inv_err, axis=0)\n",
    "                    label, marker = algs[c]\n",
    "                    ax.set_yscale(\"log\")\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel(f'Inv. param {i+1} error')\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                # fig.suptitle(suptit)\n",
    "                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}_err.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y['inv_param_pred'][i] for y in data[c]]\n",
    "                    inv_err = np.abs(inv_params[i] - np.array(ys))\n",
    "                    mean = np.percentile(inv_err, 50, axis=0)\n",
    "                    err1 = mean - np.percentile(inv_err, 20, axis=0)\n",
    "                    err2 = np.percentile(inv_err, 80, axis=0) - mean\n",
    "                    label, marker = algs[c]\n",
    "                    ax.set_yscale(\"log\")\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel('Mean error')\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                # fig.suptitle(suptit)\n",
    "                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}_err_med.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "                plt.close('all')\n",
    "\n",
    "\n",
    "        if x_test.shape[1] == 2:\n",
    "\n",
    "            for s in ['low', 'mid', 'high']:\n",
    "\n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                for c in data.keys():\n",
    "                    ys = [y[f'fft_mean_{s}'] for y in data[c]]\n",
    "                    mean = np.mean(ys, axis=0)\n",
    "                    err = np.std(ys, axis=0)\n",
    "                    label, marker = algs[c]\n",
    "                    ax.set_yscale(\"log\")\n",
    "                    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "                ax.set_xlabel('Steps')\n",
    "                ax.set_ylabel(f'Mean FFT ({s}) diff.')\n",
    "                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "                # fig.suptitle(suptit)\n",
    "                fig.savefig(os.path.join(graph_folder, f'fft-{s}_mean.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "            plt.close('all')\n",
    "\n",
    "\n",
    "        k0 = list(data.keys())[0]\n",
    "        best_from = 'err_mean'\n",
    "        \n",
    "        if ('conv' in case_folder) or ('burger' in case_folder) or ('eik' in case_folder):\n",
    "            \n",
    "            if ('conv' in case_folder) or ('burger' in case_folder):\n",
    "                axislabel = ('x', 't')\n",
    "            else:\n",
    "                axislabel = ('x', 'y')\n",
    "\n",
    "            for stidx in [-1, -2]:\n",
    "\n",
    "                for dim in range(data[k0][0]['y_test'].shape[1]):\n",
    "\n",
    "                    if '_change/' in case_folder:\n",
    "                        start_y = [data[k0][0]['pred'][0][:,dim:dim+1], data[k0][0]['y_test'][:,dim:dim+1]]\n",
    "                        start_title = ['Initial model', 'True solution']\n",
    "                    else:\n",
    "                        start_y = [data[k0][0]['y_test'][:,dim:dim+1]]\n",
    "                        start_title = ['True solution']\n",
    "\n",
    "                    fig, axs = plot_contours(\n",
    "                        xs=data[k0][0]['x_test'], \n",
    "                        ys_list=start_y + [data[c][np.argmin([x[best_from][-1] for x in data[c]])]['pred'][stidx][:,dim:dim+1] for c in data.keys()], \n",
    "                        titles=start_title + [algs[c][0] for c in data.keys()], \n",
    "                        axislabels=axislabel,\n",
    "                    )\n",
    "            #         fig.suptitle(f'{suptit}, step {steps_plot[stidx]}')\n",
    "            #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                    fig.savefig(os.path.join(graph_folder, f'pred_s{data[c][0][\"steps\"][stidx]}-d{dim}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                    plt.close('all')\n",
    "\n",
    "                    fig, axs = plot_contours(\n",
    "                        xs=data[k0][0]['x_test'], \n",
    "                        ys_list=[data[c][np.argmin([x[best_from][-1] for x in data[c]])]['err'][stidx][:,dim:dim+1] for c in data.keys()],\n",
    "                        titles=[algs[c][0] for c in data.keys()], \n",
    "                        axislabels=axislabel,\n",
    "                    )\n",
    "            #         fig.suptitle(f'{suptit}, step {steps_plot[stidx]}')\n",
    "            #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                    fig.savefig(os.path.join(graph_folder, f'err_s{data[c][0][\"steps\"][stidx]}-d{dim}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                    plt.close('all')\n",
    "\n",
    "                fig, axs = plot_contours(\n",
    "                    xs=data[k0][0]['x_test'], \n",
    "                    ys_list=[data[c][np.argmin([x[best_from][-1] for x in data[c]])]['res'][stidx] for c in data.keys()],\n",
    "                    titles=[algs[c][0] for c in data.keys()], \n",
    "                    axislabels=axislabel,\n",
    "                )\n",
    "        #         fig.suptitle(f'{suptit}, step {steps_plot[stidx]}')\n",
    "        #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                fig.savefig(os.path.join(graph_folder, f'res_s{data[c][0][\"steps\"][stidx]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                plt.close('all')\n",
    "                \n",
    "                fig, ax = plt.subplots(figsize=(5, 4))\n",
    "                ms = 6.\n",
    "                samples = data[c][0]['chosen_pts'][0]\n",
    "                ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='PDE CL Pts')\n",
    "                for i, (bc_pts, name) in enumerate(zip(samples['bcs'], ['IC CL Pts', 'BC CL Pts'])):\n",
    "                    ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False, label=name)\n",
    "                if 'anc' in samples.keys():\n",
    "                    ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='Exp Pts')\n",
    "                figl = plt.figure(figsize=(6, 1.5))\n",
    "                axl = figl.add_subplot(111)\n",
    "                axl.legend(*ax.get_legend_handles_labels() , loc=\"center\", ncol=4)\n",
    "                axl.axis('off')\n",
    "                figl.savefig(os.path.join(graph_folder, f'labels_trainpts.pdf'), bbox_inches='tight', pad_inches=0.05)\n",
    "                plt.close('all')\n",
    "\n",
    "\n",
    "            if (graph_root == main_graph) or ('pinnacle' in graph_root):\n",
    "\n",
    "                for c in data.keys():\n",
    "\n",
    "                    min_idx = np.argmin([x[best_from][-1] for x in data[c]])\n",
    "                    steps = [data[c][min_idx]['steps'].index(s) for s in steps_plot]\n",
    "\n",
    "                    fig, axs = plot_contours(\n",
    "                        xs=data[k0][0]['x_test'], \n",
    "                        ys_list=[data[c][min_idx]['pred'][s] for s in steps],\n",
    "                        titles=[f'Step {s}' for s in steps_plot], \n",
    "                        axislabels=axislabel,\n",
    "                    )\n",
    "                    if multidim:\n",
    "                        for ax, s in zip(axs[0], steps):\n",
    "                            plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "                    else:\n",
    "                        for ax, s in zip(axs, steps):\n",
    "                            plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "            #         fig.suptitle(algs[c][0])\n",
    "            #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                    fig.savefig(os.path.join(graph_folder, f'data_pred_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                    plt.close('all')\n",
    "\n",
    "                    fig, axs = plot_contours(\n",
    "                        xs=data[k0][0]['x_test'], \n",
    "                        ys_list=[data[c][min_idx]['err'][s] for s in steps], \n",
    "                        titles=[f'Step {s}' for s in steps_plot], \n",
    "                        axislabels=axislabel,\n",
    "                    )\n",
    "                    if multidim:\n",
    "                        for ax, s in zip(axs[0], steps):\n",
    "                            plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "                    else:\n",
    "                        for ax, s in zip(axs, steps):\n",
    "                            plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "            #         fig.suptitle(algs[c][0])\n",
    "            #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                    fig.savefig(os.path.join(graph_folder, f'data_err_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                    plt.close('all')\n",
    "\n",
    "                    fig, axs = plot_contours(\n",
    "                        xs=data[c][min_idx]['x_test'], \n",
    "                        ys_list=[data[c][min_idx]['res'][s] for s in steps], \n",
    "                        titles=[f'Step {data[c][0][\"steps\"][s]}' for s in steps], \n",
    "                        axislabels=axislabel,\n",
    "                    )\n",
    "                    for ax, s in zip(axs, steps):\n",
    "                        plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "            #         fig.suptitle(algs[c][0])\n",
    "            #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                    fig.savefig(os.path.join(graph_folder, f'data_res_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                    plt.close('all')\n",
    "\n",
    "\n",
    "    ######################################## FINE TUNING EXPERIMENTS\n",
    "\n",
    "    case_list = [\n",
    "\n",
    "        (\n",
    "            'al_pinn_results_ic_change/conv-1d{1.0}_ftic-898-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-200-50-0',\n",
    "            [0, 10000, 50000, 200000],\n",
    "    #         [0, 20000, 40000, 60000, 80000, 100000, 150000],\n",
    "            '1D Advection (FT)'\n",
    "        ),\n",
    "\n",
    "        (\n",
    "            'al_pinn_results_ic_change/burgers-1d{0.02}_ftic-131-20_ic/nn-None-4-128_adam_bcsloss-1.0_budget-200-50-0',\n",
    "            [0, 10000, 50000, 200000],\n",
    "    #         [0, 20000, 40000, 60000, 80000, 100000, 150000],\n",
    "            '1D Burgers (FT)'\n",
    "        ),\n",
    "\n",
    "    ]\n",
    "\n",
    "    for case_folder, steps_plot, suptit in case_list:\n",
    "\n",
    "        if steps_plot[-1] > 150000:\n",
    "            tick_spacing = 100000\n",
    "        elif steps_plot[-1] > 50000:\n",
    "            tick_spacing = 50000\n",
    "        else:\n",
    "            tick_spacing = 10000\n",
    "\n",
    "        throwout = []\n",
    "\n",
    "        print('---------------------------------------------------\\n\\n')\n",
    "        print('PROCESSING:', graph_root, case_folder)\n",
    "\n",
    "        max_steps = steps_plot[-1]\n",
    "\n",
    "        root_folder = os.path.join(data_folder, case_folder)\n",
    "\n",
    "        _, arch, depth, width = root_folder.split('/')[-1].split('_')[0].split('-')\n",
    "\n",
    "        net, _ = construct_net(\n",
    "            input_dim=2, \n",
    "            output_dim=1, \n",
    "            hidden_layers=int(depth), \n",
    "            hidden_dim=int(width), \n",
    "            arch=(None if arch == 'None' else arch)\n",
    "        )\n",
    "\n",
    "        cases = {x: os.listdir(f'{root_folder}/{x}') for x in algs.keys() if os.path.exists(f'{root_folder}/{x}')}\n",
    "        print('Exist:', list(cases.keys()))\n",
    "\n",
    "        data = dict()\n",
    "        steps_min = dict()\n",
    "        plotted_cases = dict()\n",
    "\n",
    "        for c in cases.keys():\n",
    "\n",
    "            s_min = float('inf')\n",
    "\n",
    "            runs = []\n",
    "            runs_cases = []\n",
    "\n",
    "            included = set()\n",
    "\n",
    "            for r in sorted(cases[c]):\n",
    "\n",
    "                try:\n",
    "\n",
    "                    ver = r.split('_')[0]\n",
    "                    if ver in included:\n",
    "                        continue\n",
    "\n",
    "                    d = dict()\n",
    "\n",
    "                    for file in os.listdir(f'{root_folder}/{c}/{r}'):\n",
    "\n",
    "                        if file.startswith('snapshot_data'):\n",
    "\n",
    "                            fname = f'{root_folder}/{c}/{r}/{file}'\n",
    "\n",
    "                            with open(fname, 'rb') as f:\n",
    "                                d_update = pkl.load(f)\n",
    "\n",
    "                            d.update(d_update)\n",
    "\n",
    "                    steps_range = sorted([x for x in d.keys() if (x is not None) and (max_steps >= x)])\n",
    "                    if (len(steps_range) > 0) and (max_steps == steps_range[-1]) and (None in d.keys()):\n",
    "\n",
    "                        print(c, r, sorted([x for x in d.keys() if (x is not None)])[-1])\n",
    "                        included.add(ver)\n",
    "\n",
    "                        s_min = min(s_min, steps_range[-1])\n",
    "\n",
    "                        x_test = d[None]['x_test']\n",
    "\n",
    "                        d_modified = {\n",
    "                            'x_test': x_test,\n",
    "                            'y_test': d[None]['y_test'],\n",
    "                            'steps': steps_range,\n",
    "                            'res_mean': [d[k]['residue_test_mean'] for k in steps_range],\n",
    "                            'err_mean': [d[k]['error_test_mean'] for k in steps_range],\n",
    "                            'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],\n",
    "                            'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],\n",
    "                            'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],\n",
    "                            'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],\n",
    "                            'res': [d[k]['residue_test'] for k in steps_range],\n",
    "                            'err': [d[k]['error_test'] for k in steps_range],\n",
    "                            'pred': [d[k]['pred_test'] if 'pred_test' in d[k].keys() \n",
    "                                     else net.apply(d[k]['params'][0], x_test)\n",
    "                                     for k in steps_range],\n",
    "                            'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],\n",
    "                            'inv': [d[k]['params'][1] for k in steps_range],\n",
    "                        }\n",
    "\n",
    "                        arr_shape = [d_modified['y_test'].shape[1]] + [np.unique(x).shape[0] for x in d_modified['x_test'].T]\n",
    "                        d_modified['y_test_fft'] = np.fft.fftn(\n",
    "                            d_modified['y_test'].reshape(*arr_shape), \n",
    "                            axes=[1, 2]\n",
    "                        )\n",
    "                        d_modified['pred_fft'] = [np.fft.fftn(\n",
    "                            y.reshape(*arr_shape), axes=[1, 2]) \n",
    "                            for y in d_modified['pred']]\n",
    "\n",
    "                        d_modified['fft_err'] = [np.abs(yf - d_modified['y_test_fft'])\n",
    "                            for yf in d_modified['pred_fft']]\n",
    "\n",
    "            #             idxs = np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))[0].T\n",
    "            #             idxs = np.array([idxs, idxs])\n",
    "\n",
    "                        idxs = np.array(np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))).swapaxes(1, 2)\n",
    "\n",
    "                        klow = (idxs <= 4).all(axis=0).astype(float)\n",
    "                        kmid = (idxs <= 12).all(axis=0).astype(float) - klow\n",
    "                        khigh = (idxs <= np.inf).all(axis=0).astype(float) - kmid - klow\n",
    "\n",
    "                        for s, k in [('low', klow), ('mid', kmid), ('high', khigh)]:\n",
    "                            d_modified[f'fft_mean_{s}'] = [np.sum(yf * k[None, :]) / (np.sum(k) * yf.shape[0])\n",
    "                                for yf in d_modified['fft_err']]\n",
    "\n",
    "                        runs.append(d_modified)\n",
    "                        runs_cases.append(r)\n",
    "\n",
    "                    else:\n",
    "                        throwout.append(f'{root_folder}/{c}/{r}')\n",
    "\n",
    "                except:\n",
    "                    pass\n",
    "\n",
    "            if len(runs) > 0:\n",
    "                data[c] = runs\n",
    "                steps_min[c] = s_min\n",
    "                plotted_cases[c] = runs_cases\n",
    "\n",
    "        print('To plot algorithms =', {k: len(data[k]) for k in data.keys()})\n",
    "\n",
    "        if len([k for k in data.keys()]) == 0:\n",
    "            continue\n",
    "\n",
    "        graph_folder = os.path.join(data_folder, graph_root, case_folder)\n",
    "        os.makedirs(graph_folder, exist_ok=True)\n",
    "\n",
    "        with open(os.path.join(graph_folder, 'cases_plotted'), 'w+') as f:\n",
    "            f.write(str(plotted_cases))\n",
    "\n",
    "\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        for c in data.keys():\n",
    "            ys = [y['err_mean'] for y in data[c]]\n",
    "            mean = np.mean(ys, axis=0)\n",
    "            err = np.std(ys, axis=0)\n",
    "            label, marker = algs[c]\n",
    "            ax.set_yscale(\"log\")\n",
    "            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "        ax.set_xlabel('Steps')\n",
    "        ax.set_ylabel('Mean error')\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "    #     fig.suptitle(suptit)\n",
    "        fig.savefig(os.path.join(graph_folder, f'err_mean.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        for c in data.keys():\n",
    "            ys = [y['err_mean'] for y in data[c]]\n",
    "            mean = np.percentile(ys, 50, axis=0)\n",
    "            err1 = mean - np.percentile(ys, 20, axis=0)\n",
    "            err2 = np.percentile(ys, 80, axis=0) - mean\n",
    "            label, marker = algs[c]\n",
    "            ax.set_yscale(\"log\")\n",
    "            ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "        ax.set_xlabel('Steps')\n",
    "        ax.set_ylabel('Mean error')\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "        # fig.suptitle(suptit)\n",
    "        fig.savefig(os.path.join(graph_folder, f'err_med.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "\n",
    "        for q in [50, 90, 95, 100]:\n",
    "\n",
    "            fig, ax = plt.subplots(figsize=(5, 4))\n",
    "            for c in data.keys():\n",
    "                ys = [y[f'err_q{q}'] for y in data[c]]\n",
    "                mean = np.mean(ys, axis=0)\n",
    "                err = np.std(ys, axis=0)\n",
    "                label, marker = algs[c]\n",
    "                ax.set_yscale(\"log\")\n",
    "                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "            ax.set_xlabel('Steps')\n",
    "            ax.set_ylabel(f'{q}th Quantile error')\n",
    "            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "    #         fig.suptitle(suptit)\n",
    "            fig.savefig(os.path.join(graph_folder, f'err_q{q}.pdf'), bbox_inches='tight', pad_inches=0)\n",
    "            plt.close('all')\n",
    "\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        for c in data.keys():\n",
    "            ys = [y['res_mean'] for y in data[c]]\n",
    "            mean = np.mean(ys, axis=0)\n",
    "            err = np.std(ys, axis=0)\n",
    "            label, marker = algs[c]\n",
    "            ax.set_yscale(\"log\")\n",
    "            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "        ax.set_xlabel('Steps')\n",
    "        ax.set_ylabel('Mean residue')\n",
    "        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "    #     fig.suptitle(suptit)\n",
    "        fig.savefig(os.path.join(graph_folder, f'res_mean.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "\n",
    "\n",
    "        for s in ['low', 'mid', 'high']:\n",
    "\n",
    "            fig, ax = plt.subplots(figsize=(5, 4))\n",
    "            for c in data.keys():\n",
    "                ys = [y[f'fft_mean_{s}'] for y in data[c]]\n",
    "                mean = np.mean(ys, axis=0)\n",
    "                err = np.std(ys, axis=0)\n",
    "                label, marker = algs[c]\n",
    "                ax.set_yscale(\"log\")\n",
    "                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, \n",
    "                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])\n",
    "            ax.set_xlabel('Steps')\n",
    "            ax.set_ylabel(f'Mean FFT ({s}) diff.')\n",
    "            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "    #         fig.suptitle(suptit)\n",
    "            fig.savefig(os.path.join(graph_folder, f'fft-{s}_mean.pdf'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "\n",
    "        axislabel = ('x', 't')\n",
    "\n",
    "        k0 = list(data.keys())[0]\n",
    "        idx = 0\n",
    "\n",
    "        if '_change/' in case_folder:\n",
    "            start_y = [data[k0][idx]['pred'][0], data[k0][idx]['y_test']]\n",
    "            start_title = ['Initial model', 'True solution']\n",
    "        else:\n",
    "            start_y = [data[k0][idx]['y_test']]\n",
    "            start_title = ['True solution']\n",
    "\n",
    "        if graph_root == main_graph:\n",
    "            keys_used = [\n",
    "                'random_Hammersley_prop-0.8', 'residue_prop-0.8_alltype', \n",
    "                'sampling_alignment_scale-none_mem_autoal', 'kmeans_alignment_scale-none_mem_autoal'\n",
    "            ]\n",
    "        else:\n",
    "            keys_used = data.keys()\n",
    "\n",
    "        fig, axs = plot_contours(\n",
    "            xs=data[k0][idx]['x_test'], \n",
    "            ys_list=start_y + [data[c][idx]['pred'][-1] for c in keys_used], \n",
    "            titles=start_title + [algs[c][0] for c in keys_used], \n",
    "            axislabels=axislabel,\n",
    "        )\n",
    "    #     fig.suptitle(f'{suptit}, step {data[c][0][\"steps\"][-1]}')\n",
    "    #     fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "        fig.savefig(os.path.join(graph_folder, f'pred_s{data[c][0][\"steps\"][-1]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "        fig, axs = plot_contours(\n",
    "            xs=data[k0][idx]['x_test'], \n",
    "            ys_list=[data[c][idx]['err'][-1] for c in keys_used],\n",
    "            titles=[algs[c][0] for c in keys_used], \n",
    "            axislabels=axislabel,\n",
    "        )\n",
    "    #     fig.suptitle(f'{suptit}, step {data[c][0][\"steps\"][-1]}')\n",
    "    #     fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "        fig.savefig(os.path.join(graph_folder, f'err_s{data[c][0][\"steps\"][-1]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "\n",
    "        fig, axs = plot_contours(\n",
    "            xs=data[k0][idx]['x_test'], \n",
    "            ys_list=[data[c][idx]['res'][-1] for c in keys_used],\n",
    "            titles=[algs[c][idx] for c in keys_used], \n",
    "            axislabels=axislabel,\n",
    "        )\n",
    "    #     fig.suptitle(f'{suptit}, step {data[c][0][\"steps\"][-1]}')\n",
    "    #     fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "        fig.savefig(os.path.join(graph_folder, f'res_s{data[c][idx][\"steps\"][-1]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "        plt.close('all')\n",
    "        \n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        ms = 6.\n",
    "        samples = data[c][0]['chosen_pts'][0]\n",
    "        ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='PDE CL Pts')\n",
    "        for i, (bc_pts, name) in enumerate(zip(samples['bcs'], ['IC CL Pts', 'BC CL Pts'])):\n",
    "            ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False, label=name)\n",
    "        if 'anc' in samples.keys():\n",
    "            ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='Exp Pts')\n",
    "        figl = plt.figure(figsize=(6, 1.5))\n",
    "        axl = figl.add_subplot(111)\n",
    "        axl.legend(*ax.get_legend_handles_labels() , loc=\"center\", ncol=4)\n",
    "        axl.axis('off')\n",
    "        figl.savefig(os.path.join(graph_folder, f'labels_trainpts.pdf'), bbox_inches='tight', pad_inches=0.05)\n",
    "        plt.close('all')\n",
    "\n",
    "        if graph_root == main_graph:\n",
    "\n",
    "            for c in data.keys():\n",
    "\n",
    "                min_idx = 0\n",
    "        #         min_idx = np.argmin([x['err_mean'][-1] for x in data[c]])\n",
    "                steps = [data[c][min_idx]['steps'].index(s) for s in steps_plot]\n",
    "\n",
    "                fig, axs = plot_contours(\n",
    "                    xs=data[k0][0]['x_test'], \n",
    "                    ys_list=(\n",
    "#                         [data[c][min_idx]['y_test']] +\n",
    "                        [data[c][min_idx]['pred'][s] for s in steps]\n",
    "                    ), \n",
    "                    titles=(\n",
    "#                         ['True solution'] +\n",
    "                        [f'Step {s}' for s in steps_plot]\n",
    "                    ), \n",
    "                    axislabels=axislabel,\n",
    "                )\n",
    "                for ax, s in zip(axs, steps):\n",
    "                    plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "        #         fig.suptitle(algs[c][0])\n",
    "        #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                fig.savefig(os.path.join(graph_folder, f'data_pred_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                plt.close('all')\n",
    "\n",
    "                fig, axs = plot_contours(\n",
    "                    xs=data[k0][0]['x_test'], \n",
    "                    ys_list=[data[c][min_idx]['err'][s] for s in steps], \n",
    "                    titles=[f'Step {s}' for s in steps_plot], \n",
    "                    axislabels=axislabel,\n",
    "                )\n",
    "                for ax, s in zip(axs, steps):\n",
    "                    plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "        #         fig.suptitle(algs[c][0])\n",
    "        #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                fig.savefig(os.path.join(graph_folder, f'data_err_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                plt.close('all')\n",
    "\n",
    "                fig, axs = plot_contours(\n",
    "                    xs=data[c][min_idx]['x_test'], \n",
    "                    ys_list=[data[c][min_idx]['res'][s] for s in steps], \n",
    "                    titles=[f'Step {data[c][0][\"steps\"][s]}' for s in steps], \n",
    "                    axislabels=axislabel,\n",
    "                )\n",
    "                for ax, s in zip(axs, steps):\n",
    "                    plot_training_data(ax, data[c][0]['chosen_pts'][s])\n",
    "        #         fig.suptitle(algs[c][0])\n",
    "        #         fig.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "                fig.savefig(os.path.join(graph_folder, f'data_res_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)\n",
    "                plt.close('all')\n",
    "\n",
    "    print('=================================DONE=============================\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "101d4430",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b6b6c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a083a11a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
