{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Make figures and tables\n",
    "\n",
    "## Imports and configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pickle\n",
    "import pyarrow.parquet as pq\n",
    "from matplotlib import font_manager\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "import matplotlib.colors as mcolors\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import equinox as eqx\n",
    "import jax.random as jr\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "import PIL\n",
    "import IPython\n",
    "import matplotlib.patheffects as pe\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "import rvsr\n",
    "from paths_config import paths_config\n",
    "from train_utils import get_test_loss_steps, get_train_loss_steps, preprocess_batch_for_superresolution_task\n",
    "from job import presets, hpars, get_preset_hpars\n",
    "from eval_mse import get_mask_sums\n",
    "from scipy.linalg import solve_toeplitz\n",
    "from data_utils import np_linear_to_srgb\n",
    "from padding import Padding2dLayer\n",
    "\n",
    "font_files = font_manager.findSystemFonts(fontpaths=None, fontext=\"ttf\") \n",
    "for font_file in font_files:\n",
    "    font_manager.fontManager.addfont(font_file)\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams[\"mathtext.fontset\"] = \"cm\"\n",
    "\n",
    "MONOSPACE_SIZE = 12\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 18\n",
    "plt.rc('font', size=SMALL_SIZE)          # controls default text sizes\n",
    "plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title\n",
    "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
    "plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
    "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
    "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "preset_ids = sorted(presets.keys())\n",
    "num_seeds = 12\n",
    "oc_list = [0, 1, 5]\n",
    "\n",
    "def train_attempted(preset, seed):\n",
    "    if preset in ['extr1', 'extr2', 'lp6x7'] and seed > 2:\n",
    "        return False\n",
    "    else:\n",
    "        return True\n",
    "    \n",
    "os.makedirs(paths_config[\"graphs_folder\"], exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extrapolation error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = [\n",
    "    \"black\",\n",
    "    \"#5b0188\",\n",
    "    \"#a6265e\",\n",
    "    \"#d67922\",\n",
    "]\n",
    "\n",
    "blur_padding_methods = {\n",
    "    \"zero, extr0\": [cmap[0], '-'],\n",
    "    \"repl, extr1\": [cmap[1], '-'],\n",
    "    \"extr2\": [cmap[2], '-'],\n",
    "    \"extr3\": [cmap[3], '-'],\n",
    "    \"lp1x1cs\": [cmap[1], '--'],\n",
    "    \"lp2x1, lp2x1cs\": [cmap[2], '--'],\n",
    "    \"lp3x1\": [cmap[3], '--'],\n",
    "}\n",
    "\n",
    "fixed_a = {\n",
    "    \"zero, extr0\": [-1],\n",
    "    \"repl, extr1\": [-1, 1],\n",
    "    \"extr2\": [-1, 2, -1],\n",
    "    \"extr3\": [-1, 3, -3, 1]\n",
    "}\n",
    "\n",
    "def kappa(d, sigma):\n",
    "    return np.exp(-d**2/(2*sigma**2))\n",
    "\n",
    "def get_a(method, sigma):\n",
    "    k0 = kappa(0, sigma)\n",
    "    k1 = kappa(1, sigma)\n",
    "    k2 = kappa(2, sigma)\n",
    "    k3 = kappa(3, sigma)\n",
    "    if method in fixed_a:\n",
    "        return np.array(fixed_a[method])\n",
    "    else:\n",
    "        if method == \"lp1x1cs\":\n",
    "            return np.array([-1, k1/k0])\n",
    "            # return np.concatenate(([-1], solve_toeplitz([k0], [k1]), [0, 0]))\n",
    "        elif method == \"lp2x1, lp2x1cs\":\n",
    "            return np.array([-1, (k1*k0 - k2*k1)/(k0**2 - k1**2), (k2*k0 - k1*k1)/(k0**2 - k1**2)])\n",
    "            # return np.concatenate(([-1], solve_toeplitz([k0, k1], [k1, k2]), [0]))\n",
    "        elif method == \"lp3x1\":\n",
    "            return np.concatenate(([-1], solve_toeplitz([k0, k1, k2], [k1, k2, k3])))\n",
    "        else:\n",
    "            raise Exception(f\"Unknown padding method {method}\")\n",
    "\n",
    "def blur_padding_error_var(a, sigma):\n",
    "    i, j = np.meshgrid(np.arange(len(a)), np.arange(len(a)))\n",
    "    return np.sum(a[i] * a[j] * kappa(j - i, sigma))\n",
    "\n",
    "d = np.arange(-10, 10, 0.01)\n",
    "sigma = np.arange(0.001, 6, 0.01)\n",
    "\n",
    "plt.figure(figsize=(5.5, 4.5))\n",
    "for method, (color, linestyle) in blur_padding_methods.items():\n",
    "    plt.semilogy(\n",
    "        sigma,\n",
    "        [blur_padding_error_var(get_a(method, sigma), sigma) for sigma in sigma],\n",
    "        linestyle,\n",
    "        label=method,\n",
    "        color=color,\n",
    "        linewidth=1 if linestyle == \"-\" else 2\n",
    "    )\n",
    "#plt.xscale(\"log\")\n",
    "#plt.xlim(0.4, 15)\n",
    "legend = plt.legend(labelspacing=0.08, framealpha=0)\n",
    "plt.setp(legend.get_texts(), font=\"cmtt10\", size=MONOSPACE_SIZE)\n",
    "#plt.ylabel(r\"$\\text{Var}(\\varepsilon)$\")\n",
    "plt.ylabel(\"NMSE\")\n",
    "plt.xlabel(r\"$\\sigma$\")\n",
    "plt.savefig(os.path.join(paths_config[\"graphs_folder\"], \"fig_data_blur.pdf\"), pad_inches=0, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load loss histories and trained models\n",
    "Sometimes loss histories have gaps. This is likely due to interruption during training, taking place between checkpointing and loss history recording. Model quality is not affected. Small gaps are fixed by linear interpolation. In case of large gaps, the loss history is discarded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Init related results\n",
    "if \"results\" not in locals():\n",
    "    results = {}\n",
    "results[\"train_loss_histories\"] = {}\n",
    "results[\"test_loss_histories\"] = {}\n",
    "results[\"trained_models\"] = {}\n",
    "results[\"train_loss_steps\"] = np.array(get_train_loss_steps(hpars))\n",
    "results[\"test_loss_steps\"] = np.array(get_test_loss_steps(hpars))\n",
    "\n",
    "# Read loss histories\n",
    "for preset in preset_ids:\n",
    "    results[\"train_loss_histories\"][preset] = np.full((num_seeds, len(results[\"train_loss_steps\"])), fill_value=np.nan, dtype=np.float32)\n",
    "    results[\"test_loss_histories\"][preset] = np.full((num_seeds, len(results[\"test_loss_steps\"])), fill_value=np.nan, dtype=np.float32)\n",
    "    results[\"trained_models\"][preset] = []\n",
    "    for seed in range(num_seeds):\n",
    "        #print(f\"Preset: {preset}, seed: {seed}\")\n",
    "        model_eqx_filename = os.path.join(paths_config[\"trained_models_folder\"], f\"{preset}_s{seed}.eqx\")\n",
    "        train_loss_filename = os.path.join(paths_config[\"trained_models_folder\"], f\"{preset}_s{seed}_train_loss.parquet\")\n",
    "        test_loss_filename = os.path.join(paths_config[\"trained_models_folder\"], f\"{preset}_s{seed}_test_loss.parquet\")\n",
    "        if os.path.isfile(train_loss_filename):\n",
    "            train_loss_table = pq.read_table(train_loss_filename)\n",
    "            if os.path.isfile(model_eqx_filename):\n",
    "                train_loss_history = train_loss_table[\"train_loss\"].to_numpy()\n",
    "                train_loss_steps = train_loss_table[\"step\"].to_numpy()\n",
    "                if np.array_equal(train_loss_steps, results[\"train_loss_steps\"]):\n",
    "                    results[\"train_loss_histories\"][preset][seed] = train_loss_history\n",
    "                else:\n",
    "                    max_gap = np.max(train_loss_steps[1:] - train_loss_steps[:-1])\n",
    "                    step_ratio = max_gap/np.max(results[\"train_loss_steps\"][1:] - results[\"train_loss_steps\"][:-1])\n",
    "                    if step_ratio < 3:\n",
    "                        #print(f\"Small gaps in train loss history for run {preset}_s{seed}, max_gap={max_gap}, fixing using linear interpolation\")\n",
    "                        results[\"train_loss_histories\"][preset][seed] = np.interp(results[\"train_loss_steps\"], train_loss_steps, train_loss_history)\n",
    "                    else:\n",
    "                        print(f\"Malformed train loss history for run {preset}_s{seed}, max_gap={max_gap}, discarding train loss history\")\n",
    "                        None\n",
    "            else:\n",
    "                #print(f\"Failed run {preset}_s{seed} ending at step {train_loss_table[\"step\"][-1]}, discarding train loss history\")\n",
    "                None\n",
    "        else:\n",
    "            #print(\"No train loss\")\n",
    "            None\n",
    "        if os.path.isfile(test_loss_filename):\n",
    "            test_loss_table = pq.read_table(test_loss_filename)\n",
    "            if os.path.isfile(model_eqx_filename):\n",
    "                test_loss_history = test_loss_table[\"test_loss\"].to_numpy()\n",
    "                test_loss_steps = test_loss_table[\"step\"].to_numpy()\n",
    "                if np.array_equal(test_loss_steps, results[\"test_loss_steps\"]):\n",
    "                    results[\"test_loss_histories\"][preset][seed] = test_loss_history\n",
    "                else:\n",
    "                    max_gap = np.max(test_loss_steps[1:] - test_loss_steps[:-1])\n",
    "                    step_ratio = max_gap/np.max(results[\"test_loss_steps\"][1:] - results[\"test_loss_steps\"][:-1])\n",
    "                    if step_ratio < 3:\n",
    "                        #print(f\"Small gaps in test loss history for run {preset}_s{seed}, max_gap={max_gap}, fixing using linear interpolation\")\n",
    "                        results[\"test_loss_histories\"][preset][seed] = np.interp(results[\"test_loss_steps\"], test_loss_steps, test_loss_history)\n",
    "                    else:\n",
    "                        print(f\"Malformed test loss history for run {preset}_s{seed}, max_gap={max_gap}, discarding test loss history\")\n",
    "                        None\n",
    "            else:\n",
    "                #print(f\"Failed run {preset}_s{seed} ending at step {test_loss_table[\"step\"][-1]}, discarding test loss history\")\n",
    "                None\n",
    "        else:\n",
    "            #print(\"No test loss\")\n",
    "            None\n",
    "        if os.path.isfile(model_eqx_filename):\n",
    "            print(f\"Loading model {model_eqx_filename}\")\n",
    "            preset_hpars = get_preset_hpars(preset)\n",
    "            model, model_state = eqx.nn.make_with_state(preset_hpars[\"model_type\"])(sr_rate=preset_hpars[\"sr_rate\"], **preset_hpars[\"model_hpars\"], key = jr.PRNGKey(42))\n",
    "            model_oc0, model_state_oc0 = eqx.nn.make_with_state(preset_hpars[\"model_type\"])(sr_rate=preset_hpars[\"sr_rate\"], **{**preset_hpars[\"model_hpars\"], \"output_crop\":0}, key = jr.PRNGKey(42))\n",
    "            model = rvsr.load_rvsr_weights(model, model_eqx_filename, model_oc0)\n",
    "            results[\"trained_models\"][preset].append({\n",
    "                \"model\": model,\n",
    "                \"model_state\": model_state\n",
    "            })\n",
    "        else:\n",
    "            results[\"trained_models\"][preset].append(None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load evaluations of MSE, inference time, train time, layer NMSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['output_crop_sumsqs', 'largest_train_batch_sizes', 'train_step_times', 'largest_inference_batch_sizes', 'inference_step_times', 'padding_seed_layer_nmses'])\n"
     ]
    }
   ],
   "source": [
    "eval_results_filename = paths_config[\"eval_results_filename\"]\n",
    "with open(os.path.join(paths_config[\"results_folder\"], eval_results_filename), \"rb\") as f:\n",
    "    eval_results = pickle.load(f)\n",
    "\n",
    "print(eval_results.keys())\n",
    "\n",
    "if \"results\" not in locals():\n",
    "    results = {}\n",
    "results = {\n",
    "    **results,\n",
    "    **eval_results\n",
    "}\n",
    "del eval_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_common_success_seeds(presets):\n",
    "    table = []\n",
    "    for preset in presets:\n",
    "        table.append([False if model is None else True for model in results[\"trained_models\"][preset]])\n",
    "    return np.arange(num_seeds)[np.all(table, axis=0)]\n",
    "\n",
    "def get_color_list():\n",
    "    return ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'black', 'teal', 'gray', 'cyan', 'magenta', 'brown']\n",
    "\n",
    "def get_shape_list():\n",
    "    return ['o', '^', 's', 'd', 'p', 'h', '*', 'x', '+', 'v', '>', '<']\n",
    "\n",
    "def get_1d_onion_bootstrap_abs(results, preset, valid_seeds, conf=0.95, resamples=10_000):\n",
    "    def get_onion_loss(preset, seed, edge_dist):\n",
    "        mask_sums = get_mask_sums(hpars)\n",
    "        if edge_dist < 10:\n",
    "            loss = (results[\"output_crop_sumsqs\"][preset][seed][edge_dist]\n",
    "                    -results[\"output_crop_sumsqs\"][preset][seed][edge_dist+1])\n",
    "            corr_fact = mask_sums[edge_dist]-mask_sums[edge_dist+1]\n",
    "        elif edge_dist == 10:\n",
    "            loss = results[\"output_crop_sumsqs\"][preset][seed][edge_dist]\n",
    "            corr_fact = mask_sums[edge_dist]\n",
    "        return loss/corr_fact\n",
    "    means_of_means = []\n",
    "    ci_low, ci_high = [], []\n",
    "    for edge_dist in range(11):\n",
    "        loss_means = []\n",
    "        for seed in valid_seeds:\n",
    "            loss_means.append(np.mean(get_onion_loss(preset, seed, edge_dist)))\n",
    "        result = stats.bootstrap(\n",
    "            (np.array(loss_means),), np.mean, confidence_level=conf, n_resamples=resamples\n",
    "        )\n",
    "        means_of_means.append(np.mean(loss_means))\n",
    "        low, high = result.confidence_interval\n",
    "        ci_low.append(np.mean(loss_means) - low)\n",
    "        ci_high.append(high - np.mean(loss_means))\n",
    "    return means_of_means, ci_low, ci_high, len(valid_seeds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save markers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed, (color, shape) in enumerate(zip(get_color_list(), get_shape_list())):\n",
    "    markersize=np.sqrt(160.0)/2\n",
    "    fig, ax = plt.subplots(figsize=(markersize/72*2, markersize/72*2))  # Small figure\n",
    "    ax.plot(0, 0, marker=shape, color=color, alpha=0.4, markersize=markersize, markeredgewidth=0.5)\n",
    "    ax.set_xlim(-1, 1)\n",
    "    ax.set_ylim(-1, 1)\n",
    "    ax.axis(\"off\")\n",
    "    \n",
    "    filename = os.path.join(paths_config[\"graphs_folder\"], f\"s{seed}.pdf\")\n",
    "    plt.savefig(filename, bbox_inches='tight', pad_inches=0)\n",
    "    #plt.show()\n",
    "    plt.close(fig)\n",
    "    print(f\"Saved: {filename}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load test data and calculate test data RGB biases and mean variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded test data numpy array with shape (1000, 3, 512, 512) and dtype float32\n",
      "[-0.28047428 -0.30902568 -0.22211602]\n",
      "0.043799866\n"
     ]
    }
   ],
   "source": [
    "import data_utils\n",
    "data_id = \"test\"  # \"train\" or \"test\" data\n",
    "data_path = os.path.join(paths_config[\"dataset_folder\"], paths_config[\"test_data_filename\"])\n",
    "data = data_utils.load_data_array(data_path, id=data_id)\n",
    "print(np.mean(data, axis=(-1, -2, -4), keepdims=False))\n",
    "print(np.mean((data - np.mean(data, axis=(-1, -2), keepdims=True))**2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample padding images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_diffuse_reflectance = 0.8\n",
    "input_size = (48, 48)\n",
    "num_padding = ((24, 24), (24, 24))\n",
    "images = []\n",
    "example_indexes = [5, 1, 2, 3, 4, 0, 6, 7, 8, 9]  # Non-cherry-picked, just swapped 0 and 5 for better text visibility\n",
    "batch = np.take(data, example_indexes, axis=0)\n",
    "target = jax.jit(lambda batch : preprocess_batch_for_superresolution_task(batch, (input_size[0]+num_padding[0][0]+num_padding[0][1])*4, (input_size[1]+num_padding[1][0]+num_padding[1][1])*4, 4, False)[0])(batch)\n",
    "images.append(np.hstack(np.moveaxis(np_linear_to_srgb((target*0.5+0.5)/paper_diffuse_reflectance), source=1, destination=3)))\n",
    "input = target[:, :, num_padding[0][0]:-num_padding[0][1], num_padding[1][0]:-num_padding[1][1]]\n",
    "padding_methods = [\"Target\"]\n",
    "for preset in presets:\n",
    "    if (not preset.endswith(\"_oc1\")) and (not preset.endswith(\"_oc5\")) and preset != \"zero-zero\" and preset != \"extr1\":\n",
    "        padding_method = preset if preset != \"zero-repl\" else \"zero\"\n",
    "        padding_methods.append(padding_method)\n",
    "        preset_hpars = get_preset_hpars(preset)[\"model_hpars\"]\n",
    "        image = np.hstack(np.moveaxis(np_linear_to_srgb((jax.vmap(Padding2dLayer(num_padding, preset_hpars[\"conv_padding_method\"], preset_hpars[\"padding_method_kwargs\"]))(input)*0.5+0.5)/paper_diffuse_reflectance), source=1, destination=3))\n",
    "        if padding_method == \"lp6x7\":\n",
    "            image_lp6x7 = image\n",
    "        images.append(image)\n",
    "image = np.vstack(images)\n",
    "image_width = 11\n",
    "image_height = 11*image.shape[0]/image.shape[1]\n",
    "row_height = image.shape[0]/len(padding_methods)\n",
    "dpi = image.shape[1]/image_width\n",
    "plt.figure(figsize=(image_width, image_height), dpi=dpi)\n",
    "plt.imshow(image, interpolation=\"nearest\")\n",
    "plt.axis('off')\n",
    "#for col in range(8):\n",
    "#    plt.gca().add_patch(Rectangle((24.5 + col*48*2, 24.5), 48, 48, edgecolor = 'gray', fill=False, lw=0.125))\n",
    "for row, padding_method in enumerate(padding_methods):\n",
    "    plt.text(8, 8 + row*row_height, padding_method, horizontalalignment='left', verticalalignment='top', **{} if row == 0 else {\"font\":\"cmtt10\", \"size\":MONOSPACE_SIZE}, path_effects=[] if not padding_method.startswith(\"extr\") else [pe.withStroke(linewidth=4, foreground=\"white\", alpha=0.5), pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.5), pe.withStroke(linewidth=2, foreground=\"white\", alpha=0.5)])\n",
    "plt.gca().set_position((0, 0, 1, 1))\n",
    "plt.savefig(os.path.join(paths_config[\"graphs_folder\"], f\"fig_padding_samples.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "#plt.savefig(os.path.join(f\"fig_padding_samples.png\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "\n",
    "image_height = 11*row_height/image.shape[1]\n",
    "plt.figure(figsize=(image_width, image_height), dpi=dpi)\n",
    "corner_size = 4\n",
    "for col in range(10):\n",
    "    #plt.gca().add_patch(Rectangle((24.5 + col*48*2, 24.5), 48, 48, edgecolor = 'red', fill=False, lw=0.125))\n",
    "    plt.gca().add_patch(Polygon(((24.5 + col*48*2, 24.5 + corner_size), (24.5 + col*48*2, 24.5), (24.5 + col*48*2 + corner_size, 24.5)), edgecolor = 'red', fill=False, lw=0.125, closed=False))\n",
    "    plt.gca().add_patch(Polygon(((24.5 + col*48*2 + 48, 24.5 + corner_size), (24.5 + col*48*2 + 48, 24.5), (24.5 + col*48*2 - corner_size + 48, 24.5)), edgecolor = 'red', fill=False, lw=0.125, closed=False))\n",
    "    plt.gca().add_patch(Polygon(((24.5 + col*48*2, 24.5 - corner_size + 48), (24.5 + col*48*2, 24.5 + 48), (24.5 + col*48*2 + corner_size, 24.5 + 48)), edgecolor = 'red', fill=False, lw=0.125, closed=False))\n",
    "    plt.gca().add_patch(Polygon(((24.5 + col*48*2 + 48, 24.5 - corner_size + 48), (24.5 + col*48*2 + 48, 24.5 + 48), (24.5 + col*48*2 - corner_size + 48, 24.5 + 48)), edgecolor = 'red', fill=False, lw=0.125, closed=False))\n",
    "plt.imshow(image_lp6x7, interpolation=\"nearest\")\n",
    "plt.axis('off')\n",
    "plt.gca().set_position((0, 0, 1, 1))\n",
    "plt.savefig(os.path.join(paths_config[\"graphs_folder\"], f\"fig_lp6x7.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use models in tiled processing and show corner between processed tiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 10  # This should be a seed which resulted in success for all presets\n",
    "num_images = 8\n",
    "zoomed_area_width = 192\n",
    "paper_diffuse_reflectance = 0.8\n",
    "\n",
    "def infer(model, state, inputs, key=None):\n",
    "    predictions, state = jax.vmap(\n",
    "        model, axis_name=\"batch\", in_axes=(0, None, None), out_axes=(0, None)\n",
    "    )(inputs, state, key)\n",
    "    return predictions\n",
    "\n",
    "example_indexes = [5, 1, 2, 3, 4, 0, 6, 7]  # Non-cherry-picked, just swapped 0 and 5 for better text visibility\n",
    "batch = np.take(data, example_indexes, axis=0)\n",
    "sr_rate = 4\n",
    "window_size = (zoomed_area_width + 2*sr_rate, zoomed_area_width + 2*sr_rate)\n",
    "shift_y = (batch.shape[-2]-window_size[0])//2\n",
    "shift_x = (batch.shape[-1]-window_size[1])//2\n",
    "inputs, targets = preprocess_batch_for_superresolution_task(\n",
    "    batch[:, :, shift_y:shift_y+window_size[0], shift_x:shift_x+window_size[1]], zoomed_area_width, zoomed_area_width, sr_rate, False, None\n",
    ")\n",
    "safe_inputs, _ = preprocess_batch_for_superresolution_task(\n",
    "    batch, batch.shape[-2]-sr_rate*2, batch.shape[-1]-sr_rate*2, sr_rate, False, None\n",
    ")\n",
    "inputs = jax.image.resize(inputs, targets.shape, \"nearest\")\n",
    "inputs = np.moveaxis(np_linear_to_srgb((inputs*0.5+0.5)/paper_diffuse_reflectance), source=1, destination=3)\n",
    "srgb_targets = np.moveaxis(np_linear_to_srgb((targets*0.5+0.5)/paper_diffuse_reflectance), source=1, destination=3)\n",
    "image = np.vstack([np.hstack(inputs), np.hstack(srgb_targets)])\n",
    "image_width = 11\n",
    "image_height = 11*image.shape[0]/image.shape[1]\n",
    "row_height = image.shape[0]/2\n",
    "dpi = image.shape[1]/image_width\n",
    "plt.figure(figsize=(image_width, image_height), dpi=dpi)\n",
    "plt.imshow(image, interpolation=\"nearest\")\n",
    "plt.axis('off')\n",
    "for row, text in enumerate([\"Low-res input\", \"High-res target\"]):\n",
    "    plt.text(8, 8 + row*row_height, text, horizontalalignment='left', verticalalignment='top')\n",
    "plt.gca().set_position((0, 0, 1, 1))\n",
    "plt.savefig(os.path.join(paths_config[\"graphs_folder\"], f\"fig_inputs_targets.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "\n",
    "for output_crop in [0, 1, 5]:\n",
    "    print(f\"Output crop {output_crop}\")\n",
    "    padding_methods = []\n",
    "    images = []\n",
    "    errors = []\n",
    "    for preset in list(presets.keys()):\n",
    "        if preset in results[\"trained_models\"] and results[\"trained_models\"][preset][seed] is not None:\n",
    "            hpars = get_preset_hpars(preset)\n",
    "            if hpars[\"model_hpars\"][\"output_crop\"] == output_crop:\n",
    "                padding_methods.append(preset.split(\"_\")[0])\n",
    "                model = results[\"trained_models\"][preset][seed][\"model\"]\n",
    "                model_state = results[\"trained_models\"][preset][seed][\"model_state\"]\n",
    "                oc = hpars[\"model_hpars\"][\"output_crop\"]*hpars[\"sr_rate\"]\n",
    "                window_size = (hpars[\"image_shape\"][1] + 2*hpars[\"sr_rate\"], hpars[\"image_shape\"][2] + 2*hpars[\"sr_rate\"])\n",
    "                v_tiled = []\n",
    "                for y, uncropped_shift_y in enumerate([(batch.shape[-2]-hpars[\"image_shape\"][1]*2)//2, (batch.shape[-2]-hpars[\"image_shape\"][1]*2)//2 + hpars[\"image_shape\"][1]]):\n",
    "                    shift_y = uncropped_shift_y - hpars[\"sr_rate\"] + (oc if y == 0 else -oc)\n",
    "                    h_tiled = []\n",
    "                    for x, uncropped_shift_x in enumerate([(batch.shape[-1]-hpars[\"image_shape\"][2]*2)//2, (batch.shape[-1]-hpars[\"image_shape\"][2]*2)//2 + hpars[\"image_shape\"][2]]):\n",
    "                        shift_x = uncropped_shift_x - hpars[\"sr_rate\"] + (oc if x == 0 else -oc)\n",
    "                        inputs, _ = preprocess_batch_for_superresolution_task(\n",
    "                            batch[:, :, shift_y:shift_y+window_size[0], shift_x:shift_x+window_size[1]], hpars[\"image_shape\"][1], hpars[\"image_shape\"][2], hpars[\"sr_rate\"], False, None\n",
    "                        )\n",
    "                        predictions = infer(model, model_state, inputs)\n",
    "                        h_tiled.append(predictions)\n",
    "                    h_tiled = np.concatenate(h_tiled, axis=-1)\n",
    "                    v_tiled.append(h_tiled)\n",
    "                v_tiled = np.concatenate(v_tiled, axis=-2)\n",
    "                cut = (v_tiled.shape[-1] - zoomed_area_width)//2\n",
    "                tiled = v_tiled[..., cut:-cut, cut:-cut]\n",
    "                safe_predictions = infer(model, model_state, safe_inputs)\n",
    "                safe_cut = (safe_predictions.shape[-1] - zoomed_area_width)//2\n",
    "                safe_predictions = safe_predictions[..., safe_cut:-safe_cut, safe_cut:-safe_cut]\n",
    "                images.append(np.hstack(np.moveaxis(np_linear_to_srgb((tiled*0.5+0.5)/paper_diffuse_reflectance), source=1, destination=3)))\n",
    "                deviation = np.abs(safe_predictions-tiled)\n",
    "                print(np.max(deviation))\n",
    "                errors.append(np.hstack(np.moveaxis(np_linear_to_srgb(4*deviation/paper_diffuse_reflectance), source=1, destination=3)))\n",
    "    image = np.vstack(images)\n",
    "    error = np.vstack(errors)\n",
    "    image_width = 11\n",
    "    image_height = 11*image.shape[0]/image.shape[1]\n",
    "    row_height = image.shape[0]/len(padding_methods)\n",
    "    dpi = image.shape[1]/image_width\n",
    "    plt.figure(figsize=(image_width, image_height), dpi=dpi)\n",
    "    plt.imshow(image, interpolation=\"nearest\")\n",
    "    plt.axis('off')\n",
    "    #plt.text(image.shape[1]/2, -5, f\"Output crop {output_crop}\", horizontalalignment='center', verticalalignment='bottom')\n",
    "    for row, padding_method in enumerate(padding_methods):\n",
    "        plt.text(8, 8 + row*row_height, padding_method, horizontalalignment='left', verticalalignment='top', font=\"cmtt10\", size=MONOSPACE_SIZE) # path_effects=[pe.withStroke(linewidth=4, foreground=\"white\", alpha=0.5), pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.5), pe.withStroke(linewidth=2, foreground=\"white\", alpha=0.5)], \n",
    "    plt.gca().set_position((0, 0, 1, 1))\n",
    "    plt.savefig(os.path.join(paths_config[\"graphs_folder\"], f\"fig_stitched_output_crop_{output_crop}.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure(figsize=(image_width, image_height), dpi=dpi)\n",
    "    plt.imshow(error, interpolation=\"nearest\")\n",
    "    plt.axis('off')\n",
    "    #plt.text(image.shape[1]/2, -5, f\"Output crop {output_crop}\", horizontalalignment='center', verticalalignment='bottom')\n",
    "    for row, padding_method in enumerate(padding_methods):\n",
    "        plt.text(8, 8 + row*row_height, padding_method, horizontalalignment='left', verticalalignment='top', font=\"cmtt10\", size=MONOSPACE_SIZE, color=\"white\") # path_effects=[pe.withStroke(linewidth=4, foreground=\"white\", alpha=0.5), pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.5), pe.withStroke(linewidth=2, foreground=\"white\", alpha=0.5)], \n",
    "    plt.gca().set_position((0, 0, 1, 1))\n",
    "    plt.savefig(os.path.join(paths_config[\"graphs_folder\"], f\"fig_shift_invariance_abs_deviation_output_crop_{output_crop}.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "    #image = PIL.Image.fromarray(image)\n",
    "    #IPython.display.display(image)\n",
    "    #error_image = PIL.Image.fromarray(np.vstack(errors))\n",
    "    #IPython.display.display(error_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer NMSE scatterplot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(5.5, 2*0.75))\n",
    "\n",
    "colors = {\n",
    "    \"lp2x3\": \"tab:blue\",\n",
    "    \"repl\": \"tab:orange\",\n",
    "    \"zero\": \"tab:green\",\n",
    "    \"extr3\": \"tab:red\"\n",
    "}\n",
    "\n",
    "legend_xs = [0.2, 0.5, 0.75]\n",
    "legend_ys = [1, 0.5, 0.41, 0.29, 0.17, 0.05]\n",
    "\n",
    "ax = fig.add_axes([0, 0, 0.65, 1])  # 75% of figure width\n",
    "legend_axs = (fig.add_axes([0.65, 0, 0.35, 1]), fig.add_axes([0.65, 0, 0.35, 1]))\n",
    "for train_preset_index, (train_preset, marker, alpha) in enumerate(zip([\"zero_oc5\", \"lp2x3_oc5\"], ['o', '^'], [0.3, 1])):\n",
    "    legend_axs[train_preset_index].set_xlim((0, 1))\n",
    "    legend_axs[train_preset_index].set_ylim((0, 1))\n",
    "    legend_axs[train_preset_index].axis(\"off\")\n",
    "    legend_axs[train_preset_index].text(legend_xs[train_preset_index], legend_ys[1], train_preset.split(\"_\")[0], horizontalalignment=\"center\", verticalalignment=\"baseline\", font=\"cmtt10\", fontsize=MONOSPACE_SIZE)\n",
    "    for preset_index, ((preset, seed_list), shape) in enumerate(zip(sorted(results[\"padding_seed_layer_nmses\"][train_preset].items()), ['o', '^', 's', 'd'])):\n",
    "        seed_mean = np.mean(np.array(seed_list), axis=0)\n",
    "        ax.scatter(np.arange(11), seed_mean, label=preset if train_preset==\"lp2x3_oc5\" else None, marker=shape, color=colors[preset], alpha=alpha)\n",
    "        legend_axs[train_preset_index].scatter([legend_xs[train_preset_index]], [legend_ys[preset_index + 2]], marker=shape, color=colors[preset], alpha=alpha)\n",
    "        if train_preset_index == 0:\n",
    "            legend_axs[train_preset_index].text(legend_xs[2], legend_ys[preset_index + 2], preset, horizontalalignment=\"left\", verticalalignment=\"center\", font=\"cmtt10\", fontsize=MONOSPACE_SIZE)\n",
    "    \n",
    "\n",
    "#fig.setp(legend.get_texts(), font=\"cmtt10\", size=MONOSPACE_SIZE)\n",
    "ax.set_yscale(\"log\")\n",
    "ax.set_xlabel(\"Feature map #\")\n",
    "ax.set_ylabel(\"Mean NMSE\")\n",
    "ax.grid(which=\"major\", axis=\"y\")\n",
    "ax.set_xticks(np.arange(11))\n",
    "\n",
    "# Creating the legend on the right side\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "\n",
    "#ax.legend(handles, labels, loc=\"upper left\", bbox_to_anchor=(0, 1), framealpha=1, prop={\"family\": \"cmtt10\"}, fontsize=MONOSPACE_SIZE)\n",
    "legend_axs[0].text(legend_xs[0]-0.1, legend_ys[0], \"Trained with\\noutput crop 5\\nand:\", horizontalalignment=\"left\", verticalalignment=\"top\")\n",
    "legend_axs[0].text(legend_xs[-1], legend_ys[1], \"Padding:\", horizontalalignment=\"left\", verticalalignment=\"baseline\")\n",
    "\n",
    "plt.savefig(os.path.join(paths_config[\"graphs_folder\"], \"fig_layer_nmse.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Presets and seeds table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mark(attempted, model):\n",
    "    if not attempted:\n",
    "        return r\"\\dmark\"\n",
    "    else:\n",
    "        if model is not None:\n",
    "            return r\"\\cmark\"\n",
    "        else:\n",
    "            return r\"\\xmark\"\n",
    "\n",
    "def create_checkmark_table(results):\n",
    "    preset_to_seed_tickmarks = {}\n",
    "    for preset in presets:\n",
    "        preset_to_seed_tickmarks[preset] = [get_mark(attempted, model) for (attempted, model) in zip([train_attempted(preset, seed) for seed in range(num_seeds)], results[\"trained_models\"][preset])]\n",
    "\n",
    "    header = [rf\"\\includegraphics{{markers/s{seed}.pdf}}\" for seed in range(12)]\n",
    "    df = pd.DataFrame.from_dict(preset_to_seed_tickmarks, orient='index')\n",
    "    df.columns = header\n",
    "\n",
    "    # create padding and loss columns out of preset string\n",
    "    paddings, loss_crops = [], []\n",
    "    for preset in presets:\n",
    "        padding, loss_crop = (r\"\\texttt{\" + preset.split(\"_\")[0] + \"}\", get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"])\n",
    "        paddings.append(padding)\n",
    "        loss_crops.append(loss_crop)\n",
    "\n",
    "    df.insert(0, \"padding method\", paddings)\n",
    "    df.insert(1, \"output crop\", loss_crops)\n",
    "\n",
    "    #df = df.sort_values(by=['pad preset', 'loss crop']) \n",
    "        \n",
    "    latex_str = df.to_latex(index=False)\n",
    "    return latex_str\n",
    "\n",
    "print(create_checkmark_table(results))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot test loss histories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_history_plot(\n",
    "    results,\n",
    "    loss_ylims=(0.0005, 0.003),\n",
    "):\n",
    "    legend_fontsize_backup = plt.rcParams['legend.fontsize']\n",
    "    plt.rc('legend', fontsize=MONOSPACE_SIZE)    # legend fontsize\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(11, 4), sharex=True, sharey=True)\n",
    "    for num_plot, (oc_value, oc_str) in enumerate(zip([0, 1, 5], [\"\", \"_oc1\", \"_oc5\"])):\n",
    "        paddings = [\"lp2x3\", \"repl\", \"zero\" if oc_value > 0 else \"zero-repl\"]\n",
    "        presets = [padding + oc_str for padding in paddings]\n",
    "        seeds = get_common_success_seeds(presets)\n",
    "        for preset, padding in zip(presets, paddings):\n",
    "            slice = results[\"test_loss_histories\"][preset][seeds]\n",
    "            seed_mean_hist = np.nanmean(slice, axis=0)\n",
    "            axs[num_plot].plot(get_test_loss_steps(hpars), seed_mean_hist, label=padding, linewidth=.8, alpha=0.75)\n",
    "        axs[num_plot].set_title(f\"Output crop {oc_value}\")\n",
    "        axs[num_plot].grid(axis=\"y\", which=\"both\")\n",
    "        axs[num_plot].grid(axis=\"x\", which=\"both\")\n",
    "        axs[num_plot].set_yscale(\"log\")\n",
    "        axs[num_plot].set_ylim(loss_ylims)\n",
    "        axs[num_plot].set_xscale(\"log\")\n",
    "        axs[num_plot].legend()\n",
    "        legend = axs[num_plot].legend()\n",
    "        plt.setp(legend.get_texts(), font=\"cmtt10\", size=MONOSPACE_SIZE)\n",
    "        axs[num_plot].set_xlim(100, 1_500_000)\n",
    "        \n",
    "    #fig.legend(framealpha = 1)\n",
    "    formatter = ScalarFormatter(useMathText=True, useOffset=True)\n",
    "    formatter.set_scientific(True)\n",
    "    formatter.set_powerlimits((-1, 1))\n",
    "    axs[0].yaxis.set_major_formatter(formatter)  # Apply to the first subplot\n",
    "\n",
    "    for i in range(3):\n",
    "        axs[i].autoscale(enable=False, axis='y')\n",
    "    \n",
    "    fig.canvas.draw()\n",
    "    y_offset = axs[0].yaxis.get_offset_text().get_text()\n",
    "    if y_offset:  # If there is an offset\n",
    "        axs[0].yaxis.offsetText.set_visible(False)  # Hide the original offset text\n",
    "        if y_offset.startswith(r\"$\\times\"):\n",
    "            y_offset = \"$\" + y_offset[len(r\"$\\times\"):]\n",
    "        axs[0].set_ylabel(f\"Mean test MSE ({y_offset})\")\n",
    "    else:\n",
    "        axs[0].set_ylabel(\"Mean test MSE\")\n",
    "    ticks = np.array([5e-4, 6e-4, 7e-4, 1e-3, 2e-3, 3e-3])\n",
    "    labels = np.astype(ticks*1e3, \"str\")\n",
    "    labels[1] = \"\"\n",
    "\n",
    "    axs[1].set_xlabel(\"Training step\")\n",
    "    axs[0].set_yticks(ticks, labels=labels)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(os.path.join(paths_config[\"graphs_folder\"], \"fig_test_loss_history.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "    plt.rc('legend', fontsize=legend_fontsize_backup)\n",
    "\n",
    "create_history_plot(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MSE scatterplot for presets and seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import stats\n",
    "from matplotlib.ticker import MultipleLocator\n",
    "\n",
    "def create_confidence_scatterplot(\n",
    "    results,\n",
    "    conf=0.95\n",
    "):\n",
    "    xtick_spacing = 1\n",
    "    width, height = 11*0.875, 6\n",
    "\n",
    "    mask_sums = get_mask_sums(hpars)\n",
    "    shape_map = get_shape_list()\n",
    "    cmap = get_color_list()\n",
    "    full, el5, el1 = 0, 0, 0\n",
    "    for preset in presets:\n",
    "        if get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"] == 5:\n",
    "            el5 += 1\n",
    "        elif get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"] == 1:\n",
    "            el1 += 1\n",
    "        elif get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"] == 0:\n",
    "            full += 1\n",
    "    print(f\"{full} full, {el5} el50, and {el1} el1 models found\")\n",
    "    fig, ax = plt.subplots(1, 3, sharey=True, figsize=(width, height), gridspec_kw={'width_ratios': [full, el1, el5]})\n",
    "\n",
    "    def plot_loss_group(num_plot, loss_type):\n",
    "        x_step = 0\n",
    "        names = []\n",
    "        loss_means = []\n",
    "        low_ci, high_ci = [], []\n",
    "        for preset in sorted(presets):\n",
    "            if loss_type == \"full\":\n",
    "                edge_dist = 0\n",
    "                el_cond = get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"] == 0\n",
    "            elif loss_type == \"el5\":\n",
    "                edge_dist = 5\n",
    "                el_cond = get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"] == 5\n",
    "            elif loss_type == \"el1\":\n",
    "                edge_dist = 1\n",
    "                el_cond = get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"] == 1\n",
    "            else:\n",
    "                raise Exception(f\"Unknown loss_type {loss_type}\")\n",
    "            if not el_cond:\n",
    "                continue\n",
    "\n",
    "            incl_seeds = get_common_success_seeds([preset])\n",
    "            loss_list = []\n",
    "            for seed in incl_seeds:\n",
    "                loss = results[\"output_crop_sumsqs\"][preset][seed][edge_dist] / mask_sums[edge_dist]\n",
    "                loss_list.append(loss)\n",
    "            loss_arr = np.vstack(loss_list)\n",
    "            preset_best = np.mean(loss_arr, axis=1) # (seeds,)\n",
    "            preset_name = preset.split(\"_\")[0]\n",
    "            if preset_name == \"zero\" and loss_type == \"full\":\n",
    "                preset_name = \"zero-repl\"\n",
    "            elif preset_name == \"zerozero\":\n",
    "                preset_name = \"zero-zero\"\n",
    "            names.append(preset_name)\n",
    "            if len(incl_seeds) < 4: # no means and confidence intervals for sample size 3 or lower\n",
    "                low_ci.append(np.nan)\n",
    "                high_ci.append(np.nan)\n",
    "                loss_means.append(np.nan)\n",
    "            else: # means and confidence intervals\n",
    "                loss_mean = np.mean(preset_best) # mean across seeds -> (1,)\n",
    "                result = stats.bootstrap((preset_best,), np.nanmean, vectorized=True, confidence_level=conf)\n",
    "                low_ci.append(loss_mean - result.confidence_interval.low)\n",
    "                high_ci.append(result.confidence_interval.high - loss_mean)\n",
    "                loss_means.append(loss_mean)\n",
    "            # scatter plot\n",
    "            shapes = [shape_map[seed] for seed in incl_seeds] # match each seed with unique shape and color\n",
    "            colors = [cmap[seed] for seed in incl_seeds]\n",
    "            for n in range(len(incl_seeds)): # shapes cannot be given a list and have to be looped across\n",
    "                ax[num_plot].scatter(x_step, preset_best[n], marker=shapes[n], c=colors[n], s=160.0, alpha=0.4)\n",
    "            x_step += 1\n",
    "        # plotting confidence intervals\n",
    "        err = [low_ci, high_ci]\n",
    "        x = np.arange(len(names))\n",
    "        ax[num_plot].errorbar(x, loss_means, yerr=err, linestyle='',\n",
    "                    capthick=2, #label=f\"{conf*100} % confidence\",\n",
    "                    color=\"black\", elinewidth=22, alpha=0.2)\n",
    "        ax[num_plot].scatter(x, loss_means, marker=\"_\", label=\"mean\", s=500, color=\"black\")\n",
    "        ax[num_plot].set_xticks(x, labels=names, rotation=90, font=\"cmtt10\", size=12)\n",
    "        ax[num_plot].grid(axis=\"y\")\n",
    "        ax[num_plot].xaxis.set_major_locator(MultipleLocator(xtick_spacing))\n",
    "        loss_type_dict = {\"full\": 0, \"el5\": 5, \"el1\": 1}\n",
    "        ax[num_plot].set_title(f'Output crop {loss_type_dict[loss_type]}')\n",
    "        if num_plot != 0:\n",
    "            ax[num_plot].tick_params(left=False, labelleft=False)\n",
    "        ax[num_plot].set_ylim(5.275e-4, 5.625e-4)\n",
    "\n",
    "    plot_loss_group(0, \"full\")\n",
    "    plot_loss_group(1, \"el1\")\n",
    "    plot_loss_group(2, \"el5\")\n",
    "\n",
    "    if True:        \n",
    "        formatter = ScalarFormatter(useMathText=True)\n",
    "        formatter.set_scientific(True)\n",
    "        formatter.set_powerlimits((-3, 3))\n",
    "        ax[0].yaxis.set_major_formatter(formatter)  # Apply to the first subplot\n",
    "\n",
    "    #ax[0].spines[\"right\"].set_visible(False)\n",
    "    #ax[1].spines[\"left\"].set_visible(False)\n",
    "    #yaxis = ax[0].yaxis\n",
    "    #plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))\n",
    "    #yaxis.set_ticklabels([float(label.get_text()) * (100//yaxis_mult) / 100 for label in yaxis.get_ticklabels()])\n",
    "\n",
    "    ax[0].set_xlim(-0.5, full-0.5)\n",
    "    ax[1].set_xlim(-0.5, el1-0.5)\n",
    "    ax[2].set_xlim(-0.5, el5-0.5)\n",
    "\n",
    "    #plt.draw() # Instantiates the offset text so it can be extracted\n",
    "    fig.canvas.draw()\n",
    "    y_offset = ax[0].yaxis.get_offset_text().get_text()\n",
    "    if y_offset:  # If there is an offset\n",
    "        ax[0].yaxis.offsetText.set_visible(False)  # Hide the original offset text\n",
    "        if y_offset.startswith(r\"$\\times\"):\n",
    "            y_offset = \"$\" + y_offset[len(r\"$\\times\"):]\n",
    "        ax[0].set_ylabel(f\"Test MSE ({y_offset})\")\n",
    "    else:\n",
    "        ax[0].set_ylabel(\"Test MSE\")\n",
    "\n",
    "    #fig.supxlabel(\"padding method\", y=-0.55/height)\n",
    "    plt.subplots_adjust(wspace=0)\n",
    "    plt.savefig(os.path.join(paths_config[\"graphs_folder\"], \"fig_scatterplot_presets_seeds.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "create_confidence_scatterplot(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Throughput chart"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def throughput_scatterplot():\n",
    "    label_offsets_outermost = {\n",
    "        \"train\": {\n",
    "            \"lp1x1cs\": (0, 0),\n",
    "            \"lp1x1cs_oc1\": (0, 0),\n",
    "            \"lp2x1\": (34, -22),\n",
    "            \"lp2x1cs\": (0, 0),\n",
    "            \"lp2x1cs_oc1\": (0, 0),\n",
    "            \"lp2x3\": (19, -22),\n",
    "            \"lp2x3_oc1\": (0, 0),\n",
    "            \"lp2x3_oc5\": (0, 0),\n",
    "            \"lp2x5\": (0, 0),\n",
    "            \"lp3x3\": (8, 2),\n",
    "            \"lp4x5\": (0, 0),\n",
    "            \"lp6x7\": (0, 0),\n",
    "            \"zero-repl\": (0, 0),\n",
    "            \"zero_oc1\": (15, 0),\n",
    "            \"zero_oc5\": (20, 0),\n",
    "            \"zero-zero\": (0, 0),\n",
    "            \"repl\": (32, -10),\n",
    "            \"repl_oc1\": (0, 0),\n",
    "            \"repl_oc5\": (0, 0),\n",
    "            \"extr1\": (0, 0),\n",
    "            \"extr2\": (0, 0),\n",
    "            \"extr3\": (8, -15),\n",
    "        },\n",
    "        \"inference\": {\n",
    "            \"lp1x1cs\": (0, 0),\n",
    "            \"lp1x1cs_oc1\": (0, 0),\n",
    "            \"lp2x1\": (4, 1),\n",
    "            \"lp2x1cs\": (0, 0),\n",
    "            \"lp2x1cs_oc1\": (0, 0),\n",
    "            \"lp2x3\": (0, 0),\n",
    "            \"lp2x3_oc1\": (0, 0),\n",
    "            \"lp2x3_oc5\": (0, 0),\n",
    "            \"lp2x5\": (0, 0),\n",
    "            \"lp3x3\": (20, 0),\n",
    "            \"lp4x5\": (-8, -5),\n",
    "            \"lp6x7\": (29, -23),\n",
    "            \"zero-repl\": (8, -13),\n",
    "            \"zero_oc1\": (0, 0),\n",
    "            \"zero_oc5\": (0, 0),\n",
    "            \"zero-zero\": (0, 0),\n",
    "            \"repl\": (11, 5),\n",
    "            \"repl_oc1\": (0, 0),\n",
    "            \"repl_oc5\": (0, 0),\n",
    "            \"extr1\": (-5, -3),\n",
    "            \"extr2\": (0, 0),\n",
    "            \"extr3\": (8, -13),\n",
    "        }\n",
    "    }\n",
    "    label_offsets_full = {\n",
    "        \"train\": {\n",
    "            \"lp1x1cs\": (24, -15),\n",
    "            \"lp1x1cs_oc1\": (0, 0),\n",
    "            \"lp2x1\": (17, -15),\n",
    "            \"lp2x1cs\": (0, 5),\n",
    "            \"lp2x1cs_oc1\": (0, 0),\n",
    "            \"lp2x3\": (8, -14),\n",
    "            \"lp2x3_oc1\": (0, 0),\n",
    "            \"lp2x3_oc5\": (15, -3),\n",
    "            \"lp2x5\": (9, -13),\n",
    "            \"lp3x3\": (15, -3),\n",
    "            \"lp4x5\": (7, 7),\n",
    "            \"lp6x7\": (3, 6),\n",
    "            \"zero-repl\": (0, 0),\n",
    "            \"zero_oc1\": (8, 3),\n",
    "            \"zero_oc5\": (5, 5),\n",
    "            \"zero-zero\": (0, 4),\n",
    "            \"repl\": (15, -10),\n",
    "            \"repl_oc1\": (0, 0),\n",
    "            \"repl_oc5\": (0, 0),\n",
    "            \"extr1\": (0, 0),\n",
    "            \"extr2\": (-1, 0),\n",
    "            \"extr3\": (0, -5),\n",
    "        },\n",
    "        \"inference\": {\n",
    "            \"lp1x1cs\": (0, 0),\n",
    "            \"lp1x1cs_oc1\": (0, 0),\n",
    "            \"lp2x1\": (25, -2),\n",
    "            \"lp2x1cs\": (0, 0),\n",
    "            \"lp2x1cs_oc1\": (0, 0),\n",
    "            \"lp2x3\": (0, 0),\n",
    "            \"lp2x3_oc1\": (16, 0),\n",
    "            \"lp2x3_oc5\": (0, 0),\n",
    "            \"lp2x5\": (19, -18),\n",
    "            \"lp3x3\": (11, 5),\n",
    "            \"lp4x5\": (0, 0),\n",
    "            \"lp6x7\": (12, -13),\n",
    "            \"zero-repl\": (11, -13),\n",
    "            \"zero_oc1\": (0, 0),\n",
    "            \"zero_oc5\": (19, 0),\n",
    "            \"zero-zero\": (0, 0),\n",
    "            \"repl\": (25, -5),\n",
    "            \"repl_oc1\": (0, 0),\n",
    "            \"repl_oc5\": (0, 0),\n",
    "            \"extr1\": (0, 0),\n",
    "            \"extr2\": (0, 0),\n",
    "            \"extr3\": (-2, -4),\n",
    "        }\n",
    "    }\n",
    "\n",
    "    fig, axs = plt.subplots(2, 2, figsize=(11, 9), sharey=\"row\", sharex=\"col\")\n",
    "    mask_sums = get_mask_sums(hpars)\n",
    "    colors = list(mcolors.TABLEAU_COLORS.values())\n",
    "    for row, shell_mode in zip([0, 1], [\"full\", \"outermost\"]):\n",
    "        for col, mode in zip([0, 1], [\"train\", \"inference\"]):\n",
    "            for preset in preset_ids:\n",
    "                lc_value = get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"]\n",
    "                time_s_per_batch = np.mean(results[f\"{mode}_step_times\"][preset][1:])\n",
    "                batch_size = results[f\"largest_{mode}_batch_sizes\"][preset]\n",
    "                time_s_per_image = time_s_per_batch/batch_size\n",
    "                #time_s_per_image += 0.00001492114  # PCIe 4.0 memory transfer\n",
    "                time_s_per_pixel = time_s_per_image/mask_sums[lc_value]\n",
    "                throughput_images_per_s = 1/time_s_per_image\n",
    "                throughput_pixels_per_s = 1/time_s_per_pixel\n",
    "                padding = preset.split(\"_\")[0]\n",
    "                seed_losses = []\n",
    "                if shell_mode == \"outermost\":\n",
    "                    label_offset_x, label_offset_y = label_offsets_outermost[mode][preset]\n",
    "                    for sse_table_for_seed in results[\"output_crop_sumsqs\"][preset].values():\n",
    "                        seed_losses.append(\n",
    "                            np.mean(sse_table_for_seed[lc_value] - sse_table_for_seed[lc_value + 1])/(mask_sums[lc_value] - mask_sums[lc_value + 1])\n",
    "                        )\n",
    "                elif shell_mode == \"full\":\n",
    "                    label_offset_x, label_offset_y = label_offsets_full[mode][preset]\n",
    "                    label_offset_y *= 0.2\n",
    "                    label_offset_x *= 2\n",
    "                    for sse_table_for_seed in results[\"output_crop_sumsqs\"][preset].values():\n",
    "                        seed_losses.append(\n",
    "                            np.mean(sse_table_for_seed[lc_value])/(mask_sums[lc_value])\n",
    "                        )\n",
    "                result = stats.bootstrap(\n",
    "                    (np.array(seed_losses),), np.mean, confidence_level=0.95, n_resamples=10_000\n",
    "                )\n",
    "                ci_low, ci_high = result.confidence_interval\n",
    "                loss = np.mean(seed_losses)\n",
    "                if mode == \"train\":\n",
    "                    x_val = throughput_images_per_s\n",
    "                else:\n",
    "                    x_val = throughput_pixels_per_s\n",
    "                label_offset_mult = 1e6 if mode == \"inference\" else 0.2\n",
    "                axs[row, col].text(x_val + label_offset_x*label_offset_mult, loss + label_offset_y*1e-6, padding, horizontalalignment='right', rotation=-45, font=\"cmtt10\", size=11)\n",
    "                lc_order = 2 if lc_value == 5 else lc_value\n",
    "                _, _, bars = axs[row, col].errorbar(\n",
    "                    x_val,\n",
    "                    loss,\n",
    "                    yerr=((np.abs(ci_low-loss),), (np.abs(loss-ci_high),)), \n",
    "                    capsize=0,\n",
    "                    elinewidth=4,\n",
    "                    markersize=0,\n",
    "                    color=colors[lc_order]\n",
    "                )\n",
    "                plt.setp(bars, capstyle=\"round\")\n",
    "            if row == 1:\n",
    "                formatter = ScalarFormatter(useMathText=True)\n",
    "                formatter.set_scientific(True)\n",
    "                formatter.set_powerlimits((-3, 3))\n",
    "                axs[row, col].xaxis.set_major_formatter(formatter)\n",
    "                axs[row, col].ticklabel_format(axis='x', style='sci', scilimits=(-3,3))    \n",
    "                fig.canvas.draw()\n",
    "                x_offset = axs[row, col].xaxis.get_offset_text().get_text()\n",
    "                if x_offset:\n",
    "                    axs[row, col].xaxis.offsetText.set_visible(False)\n",
    "                    if x_offset.startswith(r\"$\\times\"):\n",
    "                        x_offset = \"$\" + x_offset[len(r\"$\\times\"):]\n",
    "                    elif x_offset.startswith(\"1e\"):\n",
    "                        exponent = int(np.log10(float(x_offset)))\n",
    "                        print(x_offset, exponent)\n",
    "                        x_offset = r\"$\\times\\mathdefault{10^{\" + str(exponent) + r\"}}\\mathdefault{}$\" #$\\mathrm{\\mu}$\n",
    "                    label_str = (mode + \"ing\" if mode == \"train\" else mode) + r\" throughput (\" + x_offset + \" \" + (\"images\" if mode == \"train\" else \"pixels\") + \"/s)\"\n",
    "                else:\n",
    "                    label_str = (mode + \"ing\" if mode == \"train\" else mode) + r\" throughput (\" + (\"images\" if mode == \"train\" else \"pixels\") + \"/s)\"\n",
    "                label_str = label_str[0].upper() + label_str[1:] # capitalize\n",
    "                axs[row, col].set_xlabel(label_str)\n",
    "\n",
    "        y_axis_str = \"Outermost shell mean test MSE\" if shell_mode == \"outermost\" else \"Mean test MSE\"\n",
    "        formatter = ScalarFormatter(useMathText=True)\n",
    "        formatter.set_scientific(True)\n",
    "        formatter.set_powerlimits((-3, 3))\n",
    "        axs[row, 0].yaxis.set_major_formatter(formatter)\n",
    "        axs[row, 0].ticklabel_format(axis='y', style='sci', scilimits=(0,0))    \n",
    "        fig.canvas.draw()\n",
    "        y_offset = axs[row, 0].yaxis.get_offset_text().get_text()\n",
    "        if y_offset:\n",
    "            axs[row, 0].yaxis.offsetText.set_visible(False)\n",
    "            if y_offset.startswith(r\"$\\times\"):\n",
    "                y_offset = \"$\" + y_offset[len(r\"$\\times\"):]\n",
    "            axs[row, 0].set_ylabel(f\"{y_axis_str} ({y_offset})\")\n",
    "        else:\n",
    "            axs[row, 0].set_ylabel(f\"{y_axis_str}\")\n",
    "        legend_elements = [\n",
    "            plt.Line2D([0], [0], marker='o', color=\"w\", label='Output crop:', markerfacecolor=\"w\", markersize=7),\n",
    "            plt.Line2D([0], [0], marker='o', color=\"w\", label='0', markerfacecolor=colors[0], markersize=7),\n",
    "            plt.Line2D([0], [0], marker='o', color=\"w\", label='1', markerfacecolor=colors[1], markersize=7),\n",
    "            plt.Line2D([0], [0], marker='o', color=\"w\", label='5', markerfacecolor=colors[2], markersize=7)\n",
    "        ]\n",
    "        legend = fig.legend(handles=legend_elements, loc=\"lower center\",\n",
    "                   ncol=4, bbox_to_anchor=(0.51, 0.97), handlelength=0, frameon=False)\n",
    "        fig.tight_layout()\n",
    "        fig.savefig(os.path.join(paths_config[\"graphs_folder\"], \"fig_time.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "\n",
    "throughput_scatterplot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Output crop bar graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def composite_loss_crop_scatterplot(\n",
    "        results,\n",
    "        conf = 0.95,\n",
    "        resamples = 10_000,\n",
    "        xmin = -0.625,\n",
    "        xmax = 10.625,\n",
    "        fixed_mult = 10**-4,\n",
    "    ):\n",
    "\n",
    "    # utility functions\n",
    "    def format_confidence_interval(value, ci):\n",
    "        if value/fixed_mult > 100:\n",
    "            decimal_formatted_value = f\"{value/fixed_mult:.0f} ± {ci/fixed_mult:.0f}\"\n",
    "        elif value/fixed_mult > 10:\n",
    "            decimal_formatted_value = f\"{value/fixed_mult:.1f} ± {ci/fixed_mult:.1f}\"\n",
    "        else:\n",
    "            decimal_formatted_value = f\"{value/fixed_mult:.2f} ± {ci/fixed_mult:.2f}\"\n",
    "        return decimal_formatted_value\n",
    "\n",
    "    def plot_group(means_of_means, ci_low, ci_high, preset, n, col, train_edge, ymax):\n",
    "        x_offset = [-0.25, 0, 0.25][n]\n",
    "        ranges = [np.arange(train_edge, dtype=np.int32), np.arange(train_edge, 11, dtype=np.int32)]\n",
    "        for range_object, alpha, use_label, errcolor in zip(ranges, [0.5, 1], [False, True], [\"gray\", \"black\"]):\n",
    "            padding = preset.split(\"_\")[0]\n",
    "            cond = padding == \"zero\" and col == 0\n",
    "            label_str = \"zero-repl\" if cond else padding\n",
    "            axs[col].bar(range_object+x_offset, np.array(means_of_means)[range_object],\n",
    "                        label=label_str if use_label else None, color=list(mcolors.TABLEAU_COLORS.keys())[n],\n",
    "                        width=0.25, alpha=alpha\n",
    "            )\n",
    "            axs[col].errorbar(range_object+x_offset, np.array(means_of_means)[range_object],\n",
    "                        yerr=(np.array(ci_low)[range_object], np.array(ci_high)[range_object]), \n",
    "                        linestyle='', color=errcolor\n",
    "            )\n",
    "        for y, x, ci in zip(means_of_means, np.arange(11)+x_offset, np.maximum(ci_low, ci_high)):\n",
    "            if y > ymax:\n",
    "                height = ymax-3.5e-5 if n==1 else ymax-1e-6\n",
    "                axs[col].text(x-0.05, height, format_confidence_interval(y, ci),\n",
    "                                   size=8, rotation=270, verticalalignment=\"top\",\n",
    "                                   horizontalalignment=\"center\"\n",
    "                )\n",
    "\n",
    "    master_preset_list = [\n",
    "        [\"lp2x3\", \"repl\", \"zero-repl\"],\n",
    "        [\"lp2x3_oc1\", \"repl_oc1\", \"zero_oc1\"],\n",
    "        [\"lp2x3_oc5\", \"repl_oc5\", \"zero_oc5\"]\n",
    "    ]\n",
    "\n",
    "    # plot values\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(11, 4), sharex=True, sharey=False)\n",
    "    ymin, ymax = 5e-4, 6.75e-4 # tighter #ymin, ymax = 4e-4, 1.5e-2 # shows everything\n",
    "    for col, (preset_list, cutoff) in enumerate(zip(master_preset_list, [0,1,5])):\n",
    "        valid_seeds = get_common_success_seeds(preset_list)\n",
    "        print(f\"valid seeds: {valid_seeds}\")\n",
    "        for n, preset in enumerate(preset_list):\n",
    "            means_of_means, ci_low, ci_high, _ = get_1d_onion_bootstrap_abs(\n",
    "                results, preset, valid_seeds, conf, resamples\n",
    "            )\n",
    "            plot_group(means_of_means, ci_low, ci_high, preset, n, col, cutoff, ymax)\n",
    "        axs[col].set_xticks(np.arange(11))\n",
    "        if cutoff != 0:\n",
    "            axs[col].vlines([cutoff-0.5,], ymin, ymax, color=\"black\",\n",
    "                                    linestyles=\"dashed\", linewidth=1.0, zorder=100)\n",
    "        axs[col].set_ylim(ymin, ymax)\n",
    "        axs[col].set_xlim(xmin, xmax)\n",
    "        axs[col].grid(axis=\"y\", which=\"major\")\n",
    "        axs[col].legend(loc=\"upper right\")\n",
    "        legend = axs[col].legend()\n",
    "        plt.setp(legend.get_texts(), font=\"cmtt10\", size=MONOSPACE_SIZE)\n",
    "        if col == 0:\n",
    "            axs[col].set_ylabel(f\"mean test MSE difference with zero-pad\")\n",
    "        if col != 0:\n",
    "            axs[col].set_yticklabels([])\n",
    "        axs[col].set_title(f\"Training output crop {cutoff}\")\n",
    "\n",
    "        # fade out at bottom\n",
    "        x = np.linspace(0, 1, 100)\n",
    "        y = np.linspace(0, 1, 100)\n",
    "        X, Y = np.meshgrid(x, y)\n",
    "        colors = [(1, 1, 1, 0), (1, 1, 1, 1)]\n",
    "        cmap = mcolors.LinearSegmentedColormap.from_list('white_to_transparent', colors)\n",
    "        axs[col].imshow(\n",
    "            Y, cmap=cmap, extent=[xmin, xmax, ymin, ymin + 0.06*(ymax-ymin)], aspect=\"auto\", zorder=1\n",
    "        )\n",
    "    axs[1].set_xlabel(\"Shell #\")\n",
    "\n",
    "    # moving the multiplier from the axis to the axis label\n",
    "    formatter = ScalarFormatter(useMathText=True)\n",
    "    formatter.set_scientific(True)\n",
    "    formatter.set_powerlimits((-3, 3))\n",
    "    axs[0].yaxis.set_major_formatter(formatter) # Apply to the first subplot\n",
    "    fig.canvas.draw()\n",
    "    y_offset = axs[0].yaxis.get_offset_text().get_text()\n",
    "    if y_offset: # If there is an offset\n",
    "        axs[0].yaxis.offsetText.set_visible(False) # Hide the original offset text\n",
    "        if y_offset.startswith(r\"$\\times\"):\n",
    "            y_offset = \"$\" + y_offset[len(r\"$\\times\"):]\n",
    "        axs[0].set_ylabel(f\"Shell mean test MSE ({y_offset})\")\n",
    "    else:\n",
    "        axs[0].set_ylabel(\"Shell mean test MSE\")\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(os.path.join(paths_config[\"graphs_folder\"], \"fig_loss_shells.pdf\"), pad_inches=0, bbox_inches=\"tight\")\n",
    "\n",
    "composite_loss_crop_scatterplot(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Numerical data of the output crop bar graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def composite_loss_crop_table(\n",
    "        results,\n",
    "        conf = 0.95,\n",
    "        resamples = 10_000,\n",
    "        restricted_seeds = False, # restrict all plots to have the same set of seeds\n",
    "    ):\n",
    "\n",
    "    # the data from shell mean MSE bar graph but in table form, for more detailed inspection\n",
    "\n",
    "    master_preset_list = [\n",
    "        [\"lp2x3\", \"repl\", \"zero-repl\"],\n",
    "        [\"lp2x3_oc1\", \"repl_oc1\", \"zero_oc1\"],\n",
    "        [\"lp2x3_oc5\", \"repl_oc5\", \"zero_oc5\"]\n",
    "    ]\n",
    "    if restricted_seeds:\n",
    "        valid_seeds = get_common_success_seeds(master_preset_list[0]+master_preset_list[1]+master_preset_list[2])\n",
    "        print(f\"valid seeds: {valid_seeds}\")\n",
    "    index_col_1 = [0]*3 + [1]*3 + [5]*3\n",
    "    index_col_2 = master_preset_list[0]+master_preset_list[1]+master_preset_list[2]\n",
    "    df = pd.DataFrame(columns = pd.MultiIndex.from_arrays([index_col_1, index_col_2], names=[\"Output crop\", \"Preset\"]))                \n",
    "\n",
    "    for preset_list, oc_value in zip(master_preset_list, [0,1,5]):\n",
    "        if not restricted_seeds:\n",
    "            valid_seeds = get_common_success_seeds(preset_list)\n",
    "        for preset in preset_list:\n",
    "            means_of_means, ci_low_list, ci_high_list, _ = get_1d_onion_bootstrap_abs(\n",
    "                results, preset, valid_seeds, conf, resamples\n",
    "            )\n",
    "            for shell_num, (mean, ci_low, ci_high) in enumerate(zip(means_of_means, ci_low_list, ci_high_list)):\n",
    "                err = np.maximum(ci_high, ci_low)\n",
    "                err_str = str(err*1e6)[:3]\n",
    "                if err_str.endswith(\".\"):\n",
    "                    err_str = err_str[:2]\n",
    "                val_str = str(mean*1e6)[:5]\n",
    "                if val_str.endswith(\".\"):\n",
    "                    val_str = val_str[:4]\n",
    "                df.loc[shell_num, (oc_value, preset)] = val_str + \"±\" + err_str\n",
    "    return df\n",
    "\n",
    "composite_loss_crop_table(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# General data table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_table(\n",
    "    results,\n",
    "    baseline_list = [\"zero-repl\", \"zero_oc1\", \"zero_oc5\"]\n",
    "):\n",
    "\n",
    "    # helper functions\n",
    "    def get_error(arr):\n",
    "        assert len(arr.shape) == 1, f\"{arr.shape}\"\n",
    "        result = stats.bootstrap((arr,), np.nanmean, vectorized=True)\n",
    "        low, high = result.confidence_interval\n",
    "        error = np.maximum(high-arr, arr-low) # symmetric, worst-case estimate\n",
    "        return error\n",
    "\n",
    "    def get_train_throughput(preset):\n",
    "        batch_size = results[\"largest_train_batch_sizes\"][preset]\n",
    "        time_list = np.array(results[\"train_step_times\"][preset])\n",
    "        time_per_batch = np.mean(time_list[1:])\n",
    "        return batch_size / time_per_batch\n",
    "    \n",
    "    def get_inference_throughput(preset):\n",
    "        batch_size = results[\"largest_inference_batch_sizes\"][preset]\n",
    "        time_list = np.array(results[\"inference_step_times\"][preset])\n",
    "        time_per_batch = np.mean(time_list[1:])\n",
    "        lc_value = get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"]\n",
    "        pixels_per_img = mask_sums[lc_value]\n",
    "        return batch_size * pixels_per_img / time_per_batch\n",
    "\n",
    "    def outermost_shell_diff(preset, baseline, lc_value, conf=0.95, resamples=10_000):\n",
    "        def get_onion_loss(preset, baseline, seed, edge_dist):\n",
    "            assert edge_dist < 10\n",
    "            mask_sums = get_mask_sums(hpars)\n",
    "            corr_fact = mask_sums[edge_dist]-mask_sums[edge_dist+1]\n",
    "            loss = np.mean(results[\"output_crop_sumsqs\"][preset][seed][edge_dist]\n",
    "                    - results[\"output_crop_sumsqs\"][preset][seed][edge_dist+1]) / corr_fact\n",
    "            baseline_loss = np.mean(results[\"output_crop_sumsqs\"][baseline][seed][edge_dist]\n",
    "                    - results[\"output_crop_sumsqs\"][baseline][seed][edge_dist+1]) / corr_fact\n",
    "            return (loss - baseline_loss)/baseline_loss*100\n",
    "        lc_value = int(lc_value)\n",
    "        valid_seeds = get_common_success_seeds((preset, baseline))\n",
    "        loss_means = []\n",
    "        for seed in valid_seeds:\n",
    "            loss_means.append(np.mean(get_onion_loss(preset, baseline, seed, lc_value)))\n",
    "        if preset != baseline:\n",
    "            result = stats.bootstrap(\n",
    "                (np.array(loss_means),), np.mean, confidence_level=conf, n_resamples=resamples\n",
    "            )\n",
    "            low, high = result.confidence_interval\n",
    "            low_val = np.abs(np.mean(loss_means) - low)\n",
    "            high_val = np.abs(high - np.mean(loss_means))\n",
    "            ci = np.maximum(low_val, high_val)\n",
    "        else:\n",
    "            ci = \"\"\n",
    "        return str(np.mean(loss_means))  + r\" \\pm \" + str(ci)\n",
    "\n",
    "    mask_sums = get_mask_sums(hpars)\n",
    "    loss_mult = 1e-6\n",
    "    througput_mult = 1e6\n",
    "    valid_presets = {0: [], 1: [], 5: []} # lists contain (preset, row_name) tuples\n",
    "    for preset in presets.keys():\n",
    "        row_name, loss_crop = (r\"\\texttt{\" + preset.split(\"_\")[0] + \"}\", get_preset_hpars(preset)[\"model_hpars\"][\"output_crop\"])\n",
    "        valid_presets[loss_crop].append((preset, row_name))\n",
    "\n",
    "    # setting up dataframe\n",
    "    index_col_1 = []\n",
    "    for output_crop, group_presets_list in valid_presets.items():\n",
    "        count = len(group_presets_list)\n",
    "        index_col_1 += [output_crop]*count\n",
    "    index_col_2 = []\n",
    "    for l in valid_presets.values():\n",
    "        for _, pad in sorted(l):\n",
    "            index_col_2.append(pad)\n",
    "    data = pd.DataFrame(index = pd.MultiIndex.from_arrays([index_col_1, index_col_2], names=[\"Output crop\", \"Preset\"]))\n",
    "\n",
    "    # populate values\n",
    "    for (lc_value, preset_pad_list), baseline in zip(valid_presets.items(), baseline_list):\n",
    "        for preset, row_name in preset_pad_list:\n",
    "            incl_seeds = get_common_success_seeds([baseline, preset])\n",
    "            p_list, b_list = [], []\n",
    "            for seed in incl_seeds:\n",
    "                p_loss = results[\"output_crop_sumsqs\"][preset][seed][lc_value] / mask_sums[lc_value]\n",
    "                b_loss = results[\"output_crop_sumsqs\"][baseline][seed][lc_value] / mask_sums[lc_value]\n",
    "                p_list.append(np.mean(p_loss))\n",
    "                b_list.append(np.mean(b_loss))\n",
    "            preset_losses = np.hstack(p_list)\n",
    "            baseline_losses = np.hstack(b_list)\n",
    "            diff = preset_losses - baseline_losses\n",
    "            diff_percent = diff/baseline_losses*100 # shape is (seed,)\n",
    "            data.loc[(lc_value, row_name), \"Maximum training batch size (images)\"] = results[\"largest_train_batch_sizes\"][preset]\n",
    "            data.loc[(lc_value, row_name), \"Training throughput (images/s)\"] = get_train_throughput(preset)\n",
    "            data.loc[(lc_value, row_name), \"Maximum inference batch size (images)\"] = results[\"largest_inference_batch_sizes\"][preset]\n",
    "            inf_throughput_colname = r\"{Inference throughput ($10^{\" + str(int(np.log10(througput_mult))) + r\"}$ pixels/s)}\"\n",
    "            data.loc[(lc_value, row_name), inf_throughput_colname] = get_inference_throughput(preset) / througput_mult\n",
    "            if preset != baseline:\n",
    "                diff_err_str = np.mean(get_error(diff_percent)).astype(str)\n",
    "                loss_err_str = (np.mean(get_error(preset_losses))/loss_mult).astype(str)\n",
    "            else:\n",
    "                diff_err_str, loss_err_str = \"\", \"\"\n",
    "            loss_str = (np.mean(preset_losses)/loss_mult).astype(str)\n",
    "            data.loc[(lc_value, row_name), r\"{Mean test MSE ($10^{\" + str(int(np.log10(loss_mult))) + r\"}$)}\"] = loss_str + r\" \\pm \" + loss_err_str\n",
    "            diff_str = \"0.00\" if preset == baseline else np.mean(diff_percent).astype(str)\n",
    "            data.loc[(lc_value, row_name), \"Mean test MSE diff to \" + baseline_list[0] + r\" (\\%)\"] = diff_str + r\" \\pm \" + diff_err_str\n",
    "            outermost_colname = \"Outermost shell mean test MSE diff to \" + baseline_list[0] + r\" (\\%)\"\n",
    "            data.loc[(lc_value, row_name), outermost_colname] = outermost_shell_diff(preset, baseline, int(lc_value))\n",
    "\n",
    "    # replace invalid values\n",
    "    data[data == r\"0.00 \\pm \"] = \"0.00\"\n",
    "    data[data == r\"0.0 \\pm \"] = \"0.00\"\n",
    "\n",
    "    # rename and format columns\n",
    "    def process_column_name(s):\n",
    "        return r\"{\\rotatebox[origin=l]{90}{\\parbox{3.0cm}{\\raggedright \" + s + r\"}}}\"\n",
    "    new_names = {}\n",
    "    for col in data.columns:\n",
    "        new_names[col] = process_column_name(col)\n",
    "    data = data.rename(columns=new_names)\n",
    "    data = data.astype(str)\n",
    "    \n",
    "    # make best values bold\n",
    "    for oc_value in [0, 1, 5]:\n",
    "        temp = data.loc[oc_value]\n",
    "        for col in temp.columns:\n",
    "            if temp[col].str.contains(\" \").any(): # if there are spaces, indicating \"val \\pm err\" format\n",
    "                stripped_series = temp[col].str.split(\" \").str[0]\n",
    "                target_idx = stripped_series.astype(float).idxmin()\n",
    "                target_val = temp.loc[target_idx, col]\n",
    "                data.loc[(oc_value, target_idx), col] = r\"\\bfseries \" + str(target_val)\n",
    "            else: # no spaces, indicating int or float only, may be multiple best\n",
    "                target_val = temp[col].max()\n",
    "                target_idx_list = temp[col][temp[col] == target_val].index.tolist()\n",
    "                for idx in target_idx_list:\n",
    "                    data.loc[(oc_value, idx), col] = r\"\\bfseries \" + str(target_val)\n",
    "\n",
    "    # format latex table\n",
    "    def int_col(width):\n",
    "        return f\"S[round-mode=places,round-precision=0,table-column-width={width}cm]\" # throughput and time columns\n",
    "    latex_str = data.to_latex(\n",
    "        column_format=(\n",
    "            r\"p{0.17cm}p{1.4cm}\"\n",
    "            + int_col(0.6)*4\n",
    "            + \"S[table-format=3.1(2)]S[table-format=3.2(2)]\"            \n",
    "            + \"S[table-format=3.2(1)]\"\n",
    "        ),\n",
    "        multirow=True\n",
    "    )\n",
    "    return data, latex_str\n",
    "\n",
    "data, table_str = save_table(results)\n",
    "\n",
    "display(data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
