{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8761e1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib as mpl\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import wandb\n",
    "\n",
    "plt = mpl.pyplot\n",
    "idx = pd.IndexSlice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a59c4419",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf351e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# DATASET, NUM_CLASSES = 'cifar10', 10\n",
    "DATASET, NUM_CLASSES = 'cifar100', 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "685c42bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MongoDB-style filter does not work with parameters that contain dots!\n",
    "# (It must try to treat them as nested parameters.)\n",
    "# Unfortunately, we can't use nesting because the params not specified in the sweep cannot be updated!\n",
    "\n",
    "project = 'anonymous/project'\n",
    "\n",
    "all_runs = list(api.runs(project, {\n",
    "    # 'config.dataset': DATASET,\n",
    "    '$and': [\n",
    "        {'config.dataset': DATASET},\n",
    "        {'state': 'finished'},\n",
    "        # {'summary._step': {'$gt': 0}},\n",
    "    ],\n",
    "}))\n",
    "\n",
    "len(all_runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f57f9868",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If we didn't even log the first step, then exclude.\n",
    "all_runs = [run for run in all_runs if run.historyLineCount]\n",
    "# runs = [run for run in runs if '_step' in run.summary]\n",
    "len(all_runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fdaf80c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ARCH = 'vgg11'\n",
    "ARCH = 'resnet_v1_18'\n",
    "\n",
    "# WEIGHT_DECAY = 0.0\n",
    "# WEIGHT_DECAY = 1e-4\n",
    "WEIGHT_DECAY = 5e-4\n",
    "# MAX_LR = 0.9\n",
    "MAX_LR = np.inf\n",
    "JACOBIAN_METRIC = 'train/jacobian_eval_jtj'\n",
    "CONFIG_NAME = f'{DATASET}_{ARCH}_wd{WEIGHT_DECAY}'\n",
    "\n",
    "# WEIGHT_DECAY = None\n",
    "# MAX_LR = np.inf\n",
    "# JACOBIAN_METRIC = 'train/jacobian_jjt'\n",
    "# CONFIG_NAME = f'{DATASET}_{ARCH}'\n",
    "\n",
    "runs = all_runs\n",
    "\n",
    "runs = [run for run in runs if (\n",
    "    run.config.get('train.base_learning_rate', None) is None  # deprecated flag\n",
    "    and not run.config.get('train.mixup', False)  # deprecated flag (exclude mixup=True)\n",
    "    and run.config.get('model.arch', None) == ARCH\n",
    "    and run.config.get('train.batch_size', None) == 128\n",
    "    and run.config.get('train.num_epochs', None) == 90\n",
    "    and run.config.get('train.schedule', None) == 'piece'\n",
    "    and run.config.get('train.schedule_step_epochs', None) == '50,80'\n",
    "    and run.config.get('model.res_init_zero', True) == True\n",
    "    and run.config['train.learning_rate'] <= MAX_LR\n",
    "    and (JACOBIAN_METRIC in run.summary or\n",
    "         (JACOBIAN_METRIC == 'train/jacobian_jjt' and 'train/jacobian' in run.summary))\n",
    "    and (WEIGHT_DECAY is None or\n",
    "         abs(run.config.get('train.weight_decay', 0.0) - WEIGHT_DECAY) / (WEIGHT_DECAY + 1e-10) < 1e-5)\n",
    ")]\n",
    "len(runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "645750c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "steps_per_epoch = 50000 // 128\n",
    "num_epochs = 90\n",
    "num_steps = num_epochs * steps_per_epoch\n",
    "\n",
    "num_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9763e23d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# runs = [run for run in runs if 'train.mixup_beta' in run.config]\n",
    "# len(runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da31222a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# runs = [run for run in runs if run.summary['_step'] == num_steps]\n",
    "# len(runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d6a0312",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Check how many do not have mixup_beta.\n",
    "# len([run for run in runs if 'train.mixup_beta' in run.config])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ca0c294",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Old jobs with mixup did not log train/acc.\n",
    "runs = [run for run in runs if 'train/acc' in run.summary]\n",
    "len(runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38170eb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "set(run.state for run in runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70eb1edd",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata = pd.DataFrame.from_dict({\n",
    "    run.name: {k: getattr(run, k) for k in ['state', 'historyLineCount', 'createdAt', 'heartbeatAt']}\n",
    "    for run in runs\n",
    "}, orient='index')\n",
    "\n",
    "for k in metadata:\n",
    "    if 'At' in k:\n",
    "        metadata[k] = pd.to_datetime(metadata[k])\n",
    "\n",
    "metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b15eca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sanitize(x):\n",
    "    # wandb gives 'NaN' as a string not a float\n",
    "    if isinstance(x, str):\n",
    "        return float('nan')\n",
    "    return x\n",
    "\n",
    "def sanitize_values(d):\n",
    "    return {k: sanitize(v) for k, v in d.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd22404",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "summary = pd.DataFrame.from_dict({\n",
    "    run.name: sanitize_values(run.summary)\n",
    "    for run in runs\n",
    "}, orient='index')\n",
    "\n",
    "del summary['_wandb']\n",
    "\n",
    "summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e948ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = pd.DataFrame.from_dict({run.name: run.config for run in runs}, orient='index')\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a4471eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Backwards compatibility.\n",
    "\n",
    "if 'train/jacobian_jjt' not in summary.columns:\n",
    "    summary['train/jacobian_jjt'] = summary['train/jacobian']\n",
    "else:\n",
    "    if 'train/jacobian' in summary:\n",
    "        summary['train/jacobian_jjt'] = summary['train/jacobian_jjt'].fillna(summary['train/jacobian'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83843a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add some derived metrics.\n",
    "\n",
    "summary['epoch/val_gap/acc'] = summary['epoch/train/acc'] - summary['epoch/val/acc']\n",
    "summary['epoch/val_gap/loss'] = summary['epoch/val/loss'] - summary['epoch/train/loss']\n",
    "\n",
    "if 'train/hessian_max' in summary.columns:\n",
    "    summary['train/hessian_abs'] = np.maximum(\n",
    "        np.abs(summary['train/hessian_max']),\n",
    "        np.abs(summary['train/hessian_min']))\n",
    "else:\n",
    "    summary['train/hessian_abs'] = np.abs(summary['train/hessian'])\n",
    "\n",
    "for subset in ['train', 'val']:\n",
    "    acc_metric = f'epoch/{subset}/acc'\n",
    "    err_metric = f'epoch/{subset}/err'\n",
    "    err_val = 1.0 - summary[acc_metric]\n",
    "    summary[err_metric] = summary.get(err_metric, err_val).fillna(err_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8952a599",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4295d661",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Backwards compatibility:\n",
    "# mixup=False  mixup_beta=null => mixup_beta=null  # v0\n",
    "# mixup=True   mixup_beta=null => EXCLUDE          # v0 (no train/acc metric)\n",
    "# mixup_beta=x mixup_beta=null => mixup_beta=x     # v1\n",
    "\n",
    "if 'train.mixup' in config.columns:\n",
    "    # Assert mixup is not True.\n",
    "    assert not np.any(sum(config['train.mixup'] == True))  # Already filtered for 'train/acc'.\n",
    "    # Replace mixup=False with mixup_beta = 0.\n",
    "    config['train.mixup_beta'] = config['train.mixup_beta'].fillna(0.0)\n",
    "    del config['train.mixup']\n",
    "\n",
    "assert np.all(np.isfinite(config['train.mixup_beta']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fb636d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config['train.sam_rho'] = config.get('train.sam_rho', pd.Series(np.nan, config.index)).fillna(0)\n",
    "\n",
    "assert np.all(np.isfinite(config['train.sam_rho']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88a53307",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Backwards compatibility:\n",
    "# aug=null    aug_prob=null => aug='cifar' aug_prob=0  # v0\n",
    "# aug='cifar' aug_prob=null => aug='cifar' aug_prob=1  # v1\n",
    "# aug=''      aug_prob=null => aug='cifar' aug_prob=0  # v1\n",
    "# aug='cifar' aug_prob=x    => aug='cifar' aug_prob=x  # v2\n",
    "\n",
    "# If aug_prob is set, do not modify it.\n",
    "# If aug_prob is nan, set to 0 where aug is nan or empty, set to 1 where aug is 'cifar'.\n",
    "# Then replace all nans in aug with 'cifar'.\n",
    "\n",
    "# If aug_prob not present, set to nan everywhere.\n",
    "if 'train.aug_prob' not in config:\n",
    "    config['train.aug_prob'] = np.nan\n",
    "# Replace nans in aug_prob with 1 if aug == 'cifar', 0 if aug is nan or ''.\n",
    "config['train.aug_prob'] = config['train.aug_prob'].fillna(pd.Series(np.where(config['train.aug'] == 'cifar', 1.0, 0.0), index=config.index))\n",
    "# Replace nans and empty string in aug with 'cifar'.\n",
    "config['train.aug'] = config['train.aug'].fillna('cifar').replace('', 'cifar')\n",
    "\n",
    "assert np.all(np.isfinite(config['train.aug_prob']))\n",
    "assert set(config['train.aug']) == {'cifar'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f937c83",
   "metadata": {},
   "outputs": [],
   "source": [
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f345d71a",
   "metadata": {},
   "outputs": [],
   "source": [
    "const_cols = list(config.columns[config.nunique(dropna=False) == 1])\n",
    "config[const_cols].iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6655f79e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assume that flag which is not present gets default value.\n",
    "group_cols = list(config.columns[config.nunique(dropna=True) != 1])\n",
    "group_cols.remove('seed')\n",
    "\n",
    "allowed_group_cols = [\n",
    "    'model.bn',\n",
    "    'train.learning_rate',\n",
    "    'train.weight_decay',\n",
    "    'train.ce_smooth',\n",
    "    'train.mixup_beta',\n",
    "    'train.aug_prob',\n",
    "    'train.sam_rho',\n",
    "]\n",
    "\n",
    "# Only these parameters should vary across configs.\n",
    "assert set(group_cols) <= set(allowed_group_cols), group_cols\n",
    "\n",
    "# Reorder.\n",
    "group_cols = [x for x in allowed_group_cols if x in group_cols]\n",
    "group_cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c2aae9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Take most recent run amongst runs with same config.\n",
    "\n",
    "instances = config[['seed', *group_cols]].join(metadata, how='inner')\n",
    "instances.index.name = 'name'\n",
    "instances = instances.reset_index()\n",
    "\n",
    "# Check no nans before sorting (can screw with ordering).\n",
    "assert np.all(np.isfinite(instances[['seed', *group_cols, 'heartbeatAt']]))\n",
    "\n",
    "# Order by group and then recency (descending by heartbeatAt).\n",
    "instances = instances.sort_values(\n",
    "    ['seed', *group_cols, 'heartbeatAt'],\n",
    "    ascending=[True, *[True for _ in group_cols], False])\n",
    "\n",
    "instances = instances.set_index(['seed', *group_cols])\n",
    "\n",
    "is_duplicate = instances.index.duplicated(keep='first')\n",
    "instances = instances.loc[~is_duplicate]\n",
    "\n",
    "assert instances.index.is_unique\n",
    "assert instances.index.is_monotonic_increasing\n",
    "\n",
    "config = config.loc[instances['name']]\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b1a4680",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MAIN_COLS = ['epoch/train/loss', 'epoch/train/acc', 'epoch/val/acc', 'train/jacobian_eval_jtj', 'train/hessian']\n",
    "MAIN_COLS = ['epoch/train/loss', 'epoch/train/acc', 'epoch/val/acc', 'train/jacobian_jjt', 'train/hessian']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a7913d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_all = (\n",
    "    config[[*group_cols, 'seed']].join(summary, how='inner')\n",
    "    .set_index([*group_cols, 'seed'])\n",
    "    .sort_index())\n",
    "\n",
    "assert table_all.index.is_unique\n",
    "assert table_all.index.is_monotonic_increasing\n",
    "\n",
    "table_all[MAIN_COLS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f8e1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check whether failure of val/acc corresponds to failure of train/acc.\n",
    "table_all.plot.scatter('epoch/train/err', 'epoch/val/err');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa70edd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get fraction of failures due to divergence.\n",
    "# NOTE: Unreliable before commit which changed divergence to not raise ValueError.\n",
    "\n",
    "table_all['diverged'] = np.logical_not(\n",
    "    (table_all['_step'] == num_steps)  # Check that all steps completed (assume no early stopping).\n",
    "    & np.isfinite(table_all['train/objective'])  # Check that objective was not nan/inf.\n",
    ")\n",
    "\n",
    "# Compute fraction of seeds that failed for each group.\n",
    "diverged_rate = table_all['diverged'].groupby(group_cols).mean()\n",
    "\n",
    "# Inspect failure fractions that are non-zero.\n",
    "pd.DataFrame({'diverged_rate': diverged_rate[diverged_rate > 0]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b32073ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Divergence by config parameter (check if one is causing divergence).\n",
    "\n",
    "for col in group_cols:\n",
    "    print(table_all['diverged'].astype(float).groupby(col).describe(percentiles=[]))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53bfeeab",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_not_diverged = (\n",
    "    table_all.loc[~table_all['diverged']]\n",
    "    .drop(columns=['diverged'])\n",
    ").copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e430d07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TRAIN_ACC_THRESHOLD = {\n",
    "#     'cifar10': 0.2,\n",
    "#     'cifar100': 0.1,\n",
    "# }[DATASET]\n",
    "\n",
    "# TRAIN_ACC_THRESHOLD = 2 / NUM_CLASSES\n",
    "# TRAIN_ACC_THRESHOLD = 0.1\n",
    "\n",
    "TRAIN_LOSS_THRESHOLD = 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9e60a3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FILTER = f'th{TRAIN_ACC_THRESHOLD:.3g}'\n",
    "# failure = ~(table_not_diverged['epoch/train/acc'] >= TRAIN_ACC_THRESHOLD)\n",
    "\n",
    "# FILTER = f'loss{TRAIN_LOSS_THRESHOLD:.3g}'\n",
    "# failure = ~(table_not_diverged['epoch/train/loss'] <= TRAIN_LOSS_THRESHOLD * np.log(NUM_CLASSES))\n",
    "\n",
    "FILTER = 'nofilter'\n",
    "failure = False\n",
    "\n",
    "table_not_diverged = (\n",
    "    table_all.loc[~table_all['diverged']]\n",
    "    .drop(columns=['diverged'])\n",
    ").copy()\n",
    "\n",
    "table_not_diverged['failure'] = failure\n",
    "\n",
    "np.mean(table_not_diverged['failure'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6518e26d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train acc stats by config parameter (check if one parameter is causing poor accuracy).\n",
    "\n",
    "for col in group_cols:\n",
    "    print(table_not_diverged.groupby(col)['epoch/train/acc'].describe().to_string())\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba16542b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Val acc stats by config parameter.\n",
    "\n",
    "for col in group_cols:\n",
    "    print(table_not_diverged.groupby(col)['epoch/val/acc'].describe().to_string())\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbdfdb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute fraction of seeds that failed for each group (add this to mean).\n",
    "failure_stats = table_not_diverged['failure'].astype(float).groupby(group_cols).describe()[['count', 'mean']]\n",
    "failure_stats['sum'] = table_not_diverged['failure'].astype(float).groupby(group_cols).sum()\n",
    "failure_rate = failure_stats['mean']\n",
    "\n",
    "# Inspect failure fractions that are non-zero.\n",
    "# Should be mostly due to high learning rates.\n",
    "failure_stats[failure_stats['mean'] > 0].sort_values('mean', kind='stable')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6333c3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exclude failed runs.\n",
    "# Remove 'failure' column since it is always False.\n",
    "assert not np.all(table_not_diverged['failure'])\n",
    "\n",
    "table = (\n",
    "    table_not_diverged.loc[~table_not_diverged['failure']]\n",
    "    .drop(columns=['failure']))\n",
    "table = table.copy()\n",
    "\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c88c8660",
   "metadata": {},
   "outputs": [],
   "source": [
    "MIN_TRIALS = 3\n",
    "\n",
    "trial_count = table.reset_index('seed').index.value_counts()\n",
    "\n",
    "pd.DataFrame({'count': trial_count[trial_count < MIN_TRIALS]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0773d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exclude configs where the number of (non-diverged) runs is too low.\n",
    "\n",
    "# Look up count (same for all seeds).\n",
    "# Caution: Use hack to get around index problem (seed missing from trial_count.index).\n",
    "table['trial_count'] = pd.Series(\n",
    "    np.array(trial_count[table.reset_index('seed').index]),\n",
    "    index=table.index)\n",
    "\n",
    "table = table.loc[table['trial_count'] >= MIN_TRIALS]\n",
    "table = table.copy()\n",
    "\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "038d1415",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "stats = table.groupby(group_cols).describe()\n",
    "mean = pd.DataFrame(stats.xs('mean', level=1, axis=1))\n",
    "mean['failure'] = failure_rate\n",
    "\n",
    "mean[['trial_count', 'failure', *MAIN_COLS]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13b3fb0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Separate into different experiments, varying one parameter.\n",
    "\n",
    "default_values = {\n",
    "    'train.aug_prob': 0.0,\n",
    "    'train.ce_smooth': 0.0,\n",
    "    'train.mixup_beta': 0.0,\n",
    "    'train.weight_decay': 0.0,\n",
    "    'train.sam_rho': 0.0,\n",
    "}\n",
    "\n",
    "def slice_table(table, key):\n",
    "    levels = tuple([col for col in group_cols if col not in {key, 'model.bn', 'train.learning_rate'}])\n",
    "    values = tuple([default_values[level] for level in levels])\n",
    "    return table.xs(values, level=levels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e7c216c",
   "metadata": {},
   "outputs": [],
   "source": [
    "METRIC_NAMES = {\n",
    "    'epoch/val/loss': 'Val. loss (excl. reg.)',\n",
    "    'epoch/train/loss': 'Train loss (excl. reg.)',\n",
    "    'epoch/val_gap/loss': 'Gen. gap (loss)',\n",
    "    'epoch/val/acc': 'Val. acc.',\n",
    "    'epoch/val/err': 'Val. err.',\n",
    "    'epoch/val/acc_pct': 'Val acc. (%)',\n",
    "    'epoch/val/err_pct': 'Val err. (%)',\n",
    "    'epoch/train/acc': 'Train acc.',\n",
    "    'epoch/train/err': 'Train err.',\n",
    "    'epoch/train/acc_pct': 'Train acc. (%)',\n",
    "    'epoch/train/err_pct': 'Train err. (%)',\n",
    "    # 'epoch/val_gap/acc': 'Gen. gap (acc)',\n",
    "}\n",
    "\n",
    "EIGENVAL_NAMES = {\n",
    "    'train/hessian': 'Hessian (mag.)',\n",
    "    'train/hessian_max': 'Hessian',\n",
    "    'train/hessian_min': 'Hessian (min.)',\n",
    "    'train/hessian_abs': 'Hessian (abs.)',\n",
    "    'train/gauss_newton': 'Gauss-Newton',\n",
    "    'train/jacobian_eval_jtj': 'Jacobian',\n",
    "    'train/jacobian_jtj': 'Jacobian',  # 'Jacobian (train mode)',\n",
    "    'train/jacobian_jjt': 'Jacobian',  # 'Jacobian (JJ\\', train mode)',\n",
    "}\n",
    "\n",
    "SUMMARY_NAMES = {**METRIC_NAMES, **EIGENVAL_NAMES}\n",
    "\n",
    "USE_LOG_AXIS = ['train/hessian_max', 'train/gauss_newton']\n",
    "\n",
    "METHOD_NAMES = {\n",
    "    'train.weight_decay': 'Weight Decay',\n",
    "    'train.ce_smooth': 'Label Smoothing',\n",
    "    'train.mixup_beta': 'Mixup',\n",
    "    'train.aug_prob': 'Data Aug.',\n",
    "    'train.sam_rho': 'Sharpness Aware Min.',\n",
    "}\n",
    "\n",
    "METHOD_PARAM_NAMES = {\n",
    "    'train.weight_decay': 'λ',\n",
    "    'train.ce_smooth': 'α',\n",
    "    'train.mixup_beta': 'β',\n",
    "    'train.aug_prob': 'p',\n",
    "    'train.sam_rho': 'ρ',\n",
    "}\n",
    "\n",
    "DATASET_NAMES = {\n",
    "    'cifar10': 'CIFAR10',\n",
    "    'cifar100': 'CIFAR100',\n",
    "}\n",
    "\n",
    "ARCH_NAMES = {\n",
    "    'vgg11': 'VGG11',\n",
    "    'resnet_v1_18': 'ResNet18',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f99c3862",
   "metadata": {},
   "outputs": [],
   "source": [
    "lrs = sorted(table.index.get_level_values('train.learning_rate').unique())\n",
    "\n",
    "lrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb95e3a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def odds2prob(a):\n",
    "    return a / (1 + a)\n",
    "\n",
    "def prob2odds(p):\n",
    "    return p / (1 - p)\n",
    "\n",
    "lr_hue_kwargs = dict(\n",
    "    hue='train.learning_rate',\n",
    "    # hue_norm=mpl.colors.LogNorm(vmin=min(lrs), vmax=max(lrs)),\n",
    "    # palette='flare')\n",
    "    palette='tab10')\n",
    "\n",
    "# 1/sqrt((b/a)^2 / 2^2 - 1) * a = 1/sqrt((10/2)^2-1) = about 1/5.\n",
    "\n",
    "def set_xscale(ax, target, values):\n",
    "\n",
    "    if target == 'train.weight_decay' or target == 'train.sam_rho':\n",
    "        min_pos = min(x for x in values if x > 0)\n",
    "        next_pos = min(x for x in values if x > min_pos)\n",
    "        ratio = next_pos / min_pos\n",
    "\n",
    "        # min_pos = min(x for x in table.index.get_level_values(target).unique() if x > 0)\n",
    "        ax.set_xscale('asinh', base=10, linear_width=min_pos * 2 / ratio)\n",
    "        ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(values))\n",
    "        # ax.xaxis.set_major_locator(mpl.ticker.AsinhLocator(linear_width=next_pos))\n",
    "        ax.xaxis.set_minor_locator(mpl.ticker.SymmetricalLogLocator(\n",
    "            base=10, linthresh=min_pos, subs=tuple(range(10))))\n",
    "\n",
    "    if target == 'train.mixup_beta':\n",
    "        ax.set_xscale('function', functions=(odds2prob, prob2odds))\n",
    "        ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(values))\n",
    "        ax.xaxis.set_major_formatter(mpl.ticker.FixedFormatter(values))\n",
    "\n",
    "    if target == 'train.sam_rho':\n",
    "        ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(values))\n",
    "        ax.xaxis.set_major_formatter(mpl.ticker.FixedFormatter(values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eab8f60b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def num_to_string(x, vmin, vmax):\n",
    "    s = f'{x:.3g}'\n",
    "    chars = s.lstrip('-').replace('.', '')\n",
    "    # if '0.000' in s:\n",
    "    #     s = f'{x:.3e}'\n",
    "    if len(chars) > 3:\n",
    "        s = f'{x:.3e}'\n",
    "\n",
    "    if 'e' in s:\n",
    "        val, exp = s.split('e')\n",
    "        val = float(val)\n",
    "        exp = int(exp)\n",
    "        s = f'{val:.3g}e{exp:d}'\n",
    "    return s\n",
    "\n",
    "def make_sane(formatter: mpl.ticker.LogFormatter):\n",
    "    formatter._num_to_string = num_to_string\n",
    "    return formatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1cbd26a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_method_cols(methods, scalars, suffix='', log_axis_for=()):\n",
    "    width = 3\n",
    "    height = 2.5\n",
    "    marker_size = 50\n",
    "\n",
    "    num_rows = len(scalars)\n",
    "    num_cols = len(methods)\n",
    "\n",
    "    for _, bn_value in enumerate([False, True]):\n",
    "        fig, axs = plt.subplots(num_rows, num_cols, figsize=(width*num_cols, height*num_rows))\n",
    "        fig.suptitle(f'{DATASET_NAMES[DATASET]:s}, {ARCH_NAMES[ARCH]:s} (BN: {str(bn_value).lower()})')\n",
    "\n",
    "        scatter_kwargs = dict(style='train.learning_rate', hue='train.learning_rate', palette='tab10')\n",
    "        line_kwargs = dict(**scatter_kwargs)\n",
    "\n",
    "        for col, method in enumerate(methods):\n",
    "            subtable = slice_table(table, method).loc[bn_value]\n",
    "            submean = slice_table(mean, method).loc[bn_value]\n",
    "\n",
    "            for row, scalar in enumerate(scalars):\n",
    "                assert scalar in SUMMARY_NAMES, scalar\n",
    "                assert scalar in subtable.columns, scalar\n",
    "                ax = axs[row, col]\n",
    "                plt.sca(ax)\n",
    "                sns.lineplot(data=subtable, x=method, y=scalar, **line_kwargs, markers=True)\n",
    "                # Plot samples.\n",
    "                sns.scatterplot(subtable, x=method, y=scalar, **scatter_kwargs, legend=False, alpha=0.3)\n",
    "                # Plot mean with larger markers. Must be higher than markers in lines.\n",
    "                sns.scatterplot(submean, x=method, y=scalar, **scatter_kwargs, legend=False, s=marker_size, zorder=3)\n",
    "\n",
    "                set_xscale(ax, method, np.sort(table.index.get_level_values(method).unique()))\n",
    "                if 'acc' in scalar:\n",
    "                    ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, _: f'{100*x:g}%'))\n",
    "                if scalar in log_axis_for:\n",
    "                    plt.yscale('log')\n",
    "                    ax.yaxis.set_major_formatter(make_sane(mpl.ticker.LogFormatter()))\n",
    "                    ax.yaxis.set_minor_formatter(make_sane(mpl.ticker.LogFormatter()))\n",
    "                if col == 0:\n",
    "                    plt.ylabel(SUMMARY_NAMES.get(scalar, scalar))\n",
    "                else:\n",
    "                    plt.ylabel(None)\n",
    "                if row == 0:\n",
    "                    plt.title(METHOD_NAMES[method])\n",
    "                if row == num_rows - 1:\n",
    "                    plt.xlabel(METHOD_PARAM_NAMES[method])\n",
    "                else:\n",
    "                    plt.xlabel(None)\n",
    "                    ax.xaxis.set_ticklabels([])  # Disable tick labels.\n",
    "                plt.grid()\n",
    "                plt.legend(loc='upper right')\n",
    "                if row != 0:\n",
    "                    ax.legend().set_visible(False)\n",
    "\n",
    "        plt.tight_layout()\n",
    "\n",
    "        plot_dir = Path('plots')\n",
    "        plot_dir.mkdir(0o755, exist_ok=True)\n",
    "        filename = f'methods_{CONFIG_NAME}_{FILTER}_bn{int(bn_value)}'\n",
    "        plt.savefig(plot_dir / f'{filename}{suffix}.pdf')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bfdc0b6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "scalars = [\n",
    "    'epoch/val_gap/loss',\n",
    "    JACOBIAN_METRIC,\n",
    "    'train/hessian_max',\n",
    "]\n",
    "\n",
    "# plot_method_cols(\n",
    "#     group_cols[2:], scalars,\n",
    "#     log_axis_for=LOG_AXIS_FOR,\n",
    "#     suffix='_brief')\n",
    "\n",
    "plot_method_cols(group_cols[2:], scalars, suffix='_brief', log_axis_for=USE_LOG_AXIS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89d7deec",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "scalars = [\n",
    "    'epoch/val_gap/loss',\n",
    "    JACOBIAN_METRIC,\n",
    "    'train/hessian_max',\n",
    "    'train/gauss_newton',\n",
    "    'epoch/train/loss',\n",
    "    'epoch/val/acc',\n",
    "]\n",
    "\n",
    "plot_method_cols(group_cols[2:], scalars, log_axis_for=USE_LOG_AXIS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf7a3021",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_effect(method, scalars, bn_values=(False, True), log_axis_for=()):\n",
    "    if method not in table.index.names:\n",
    "        print('method not present in table')\n",
    "        return\n",
    "\n",
    "    width = 3\n",
    "    height = 2.5\n",
    "    marker_size = 70\n",
    "\n",
    "    subtable = slice_table(table, method)\n",
    "    submean = slice_table(mean, method)\n",
    "    # xlabel = f'{METHOD_NAMES[method]:s} ({METHOD_PARAM_NAMES[method]:s})'\n",
    "    xlabel = METHOD_PARAM_NAMES[method]\n",
    "\n",
    "    scatter_kwargs = dict(style='train.learning_rate', hue='train.learning_rate', palette='tab10')\n",
    "    line_kwargs = dict(**scatter_kwargs)\n",
    "\n",
    "    for bn_value in bn_values:\n",
    "        group = subtable.xs(bn_value, level='model.bn')\n",
    "        group_mean = submean.xs(bn_value, level='model.bn')\n",
    "\n",
    "        num_rows = 1\n",
    "        num_cols = len(scalars)\n",
    "        fig, axs = plt.subplots(num_rows, num_cols, figsize=(width*num_cols, 0.75 + height*num_rows), squeeze=False)\n",
    "        fig.suptitle(f'{METHOD_NAMES[method]:s}: {DATASET_NAMES[DATASET]:s}, {ARCH_NAMES[ARCH]:s} (BN: {str(bn_value).lower()})')\n",
    "\n",
    "        row = 0\n",
    "        for col, scalar in enumerate(scalars):\n",
    "            ax = axs[row, col]\n",
    "            plt.sca(ax)\n",
    "            sns.lineplot(group, x=method, y=scalar, **line_kwargs, markers=True)\n",
    "            # Plot samples.\n",
    "            sns.scatterplot(group, x=method, y=scalar, **scatter_kwargs, legend=False, alpha=0.3)\n",
    "            # Plot mean. Must be higher than markers in lines.\n",
    "            sns.scatterplot(group_mean, x=method, y=scalar, **scatter_kwargs, legend=False, s=marker_size, zorder=3)\n",
    "\n",
    "            set_xscale(ax, method, np.sort(table.index.get_level_values(method).unique()))\n",
    "            if scalar in log_axis_for:\n",
    "                plt.yscale('log')\n",
    "                ax.yaxis.set_major_formatter(make_sane(mpl.ticker.LogFormatter()))\n",
    "                ax.yaxis.set_minor_formatter(make_sane(mpl.ticker.LogFormatter()))\n",
    "            plt.grid()\n",
    "\n",
    "            plt.title(SUMMARY_NAMES.get(scalar, scalar))\n",
    "            plt.ylabel(None)\n",
    "            if row == num_rows - 1:\n",
    "                plt.xlabel(xlabel)\n",
    "            else:\n",
    "                plt.xlabel(None)\n",
    "                ax.xaxis.set_ticklabels([])\n",
    "            plt.legend(loc='best')\n",
    "            if col != 0:\n",
    "                ax.legend().set_visible(False)\n",
    "\n",
    "        plt.tight_layout()\n",
    "\n",
    "        plot_dir = Path('plots')\n",
    "        plot_dir.mkdir(0o755, exist_ok=True)\n",
    "        plt.savefig(plot_dir / f'effect_{CONFIG_NAME}_{FILTER}_bn{int(bn_value)}_{method}.pdf')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a13a69",
   "metadata": {},
   "outputs": [],
   "source": [
    "scalars = ['epoch/train/loss', 'epoch/val_gap/loss', JACOBIAN_METRIC, 'train/hessian_max']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5024cb40",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_effect('train.weight_decay', scalars, log_axis_for=[*USE_LOG_AXIS, 'epoch/train/loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afef566a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_effect('train.ce_smooth', scalars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f620569",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_effect('train.aug_prob', scalars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36182625",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_effect('train.mixup_beta', scalars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68fc76f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_effect('train.sam_rho', scalars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4fa956",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def plot_vs_eigenvalue(method):\n",
    "    marker_size = 70\n",
    "\n",
    "    subtable = slice_table(table, method)\n",
    "    submean = slice_table(mean, method)\n",
    "\n",
    "    for y in metrics:\n",
    "        num_rows = subtable.index.get_level_values('model.bn').nunique()\n",
    "        num_cols = 2 + len(eigenvals)\n",
    "        fig, axs = plt.subplots(num_rows, num_cols, figsize=(6*num_cols, 0.5 + 4.5*num_rows))\n",
    "        fig.suptitle(y)\n",
    "        yscale = 'linear' if 'acc' in y else 'log'\n",
    "\n",
    "        for row, (bn_value, group) in enumerate(subtable.groupby('model.bn')):\n",
    "            group = group.xs(bn_value, level='model.bn')\n",
    "            # group = group.reset_index().set_index('model.bn')\n",
    "\n",
    "            ax = axs[row, 0]\n",
    "            plt.sca(ax)\n",
    "            sns.lineplot(group, x=method, y=y, markers=True, **lr_hue_kwargs)\n",
    "            # markers=True above doesn't work?\n",
    "            sns.scatterplot(group, x=method, y=y, style=method, **lr_hue_kwargs, legend=False)\n",
    "            set_xscale(ax, method, np.sort(table.index.get_level_values(method).unique()))\n",
    "            plt.yscale(yscale)\n",
    "            plt.grid()\n",
    "            plt.ylabel(f'bn={bn_value}')\n",
    "\n",
    "            ax = axs[row, 1]\n",
    "            plt.sca(ax)\n",
    "            sns.lineplot(group, x='train.learning_rate', y=y, style=method, markers=True, color='#ccc')\n",
    "            sns.scatterplot(group, x='train.learning_rate', y=y, style=method, **lr_hue_kwargs, legend=False, zorder=3)\n",
    "            plt.xscale('log')\n",
    "            plt.yscale(yscale)\n",
    "            plt.grid()\n",
    "            plt.ylabel(None)\n",
    "\n",
    "            for i, x in enumerate(eigenvals):\n",
    "                ax = axs[row, 2 + i]\n",
    "                plt.sca(ax)\n",
    "                # sns.scatterplot(group, x=x, y=y, style=method, **lr_hue_kwargs)\n",
    "                sns.scatterplot(group, x=x, y=y, style=method, **lr_hue_kwargs, alpha=0.3)\n",
    "\n",
    "                group_mean = group[[x, y]].groupby(group.index.names[:-1]).transform('mean')\n",
    "                with_mean_x = pd.DataFrame(group)\n",
    "                with_mean_y = pd.DataFrame(group)\n",
    "                with_mean_x[x] = group_mean[x]\n",
    "                with_mean_y[y] = group_mean[y]\n",
    "                # sns.lineplot(with_mean_x, x=x, y=y, errorbar=('se', 2), err_style='bars', style=method, **lr_hue_kwargs, alpha=0.3)\n",
    "                # sns.lineplot(with_mean_y, x=x, y=y, errorbar=('se', 2), err_style='bars', style=method, **lr_hue_kwargs, alpha=0.3, orient='y')\n",
    "                group_mean = group[[x, y]].groupby(group.index.names[:-1]).mean()\n",
    "                sns.scatterplot(group_mean, x=x, y=y, style=method, **lr_hue_kwargs, s=marker_size)\n",
    "\n",
    "                plt.xscale('log')\n",
    "                plt.yscale(yscale)\n",
    "                # plt.grid()\n",
    "                plt.legend(loc='lower left')\n",
    "                ax.legend().set_visible(False)  # (i == 0)\n",
    "                plt.ylabel(None)\n",
    "\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "070e8b48",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# plot_vs_eigenvalue('train.weight_decay')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64093f0a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# plot_vs_eigenvalue('train.ce_smooth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2904f8b0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# plot_vs_eigenvalue('train.mixup_beta')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3a77043",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# plot_vs_eigenvalue('train.aug_prob')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb33f9d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_vs_eigenvalue('train.sam_rho')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54ae809e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_vs_eigenvalue():\n",
    "\n",
    "    methods = ['train.ce_smooth', 'train.aug_prob', 'train.mixup_beta']\n",
    "    eigenvals = ['train/hessian', 'train/jacobian_eval_jtj']\n",
    "\n",
    "    subtables = []\n",
    "    submeans = []\n",
    "    for method in methods:\n",
    "        # table[y]\n",
    "        subtable = slice_table(table, method).copy()\n",
    "        subtable['reg_method'] = method\n",
    "        subtable.set_index('reg_method', append=True, inplace=True)\n",
    "        subtable.index.rename('reg_degree', level=method, inplace=True)\n",
    "        submean = subtable.groupby(['model.bn', 'reg_method', 'reg_degree', 'train.learning_rate']).mean()\n",
    "        submeans.append(submean)\n",
    "        subtables.append(subtable)\n",
    "\n",
    "    method_trials = pd.concat(subtables).reorder_levels(['model.bn', 'reg_method', 'reg_degree', 'train.learning_rate', 'seed'])\n",
    "    method_means = pd.concat(submeans).reorder_levels(['model.bn', 'reg_method', 'reg_degree', 'train.learning_rate'])\n",
    "\n",
    "    y = 'epoch/val_gap/loss'\n",
    "\n",
    "    num_rows = len(eigenvals)\n",
    "    num_cols = method_means.index.get_level_values('model.bn').nunique()\n",
    "    fig, axs = plt.subplots(num_rows, num_cols, figsize=(6*num_cols, 0.5 + 4.5*num_rows))\n",
    "    fig.suptitle(y)\n",
    "    yscale = 'linear' if 'acc' in y else 'log'\n",
    "\n",
    "    # for col, (bn_value, group) in enumerate(subtable.groupby('model.bn', sort=True)):\n",
    "    for col, bn_value in enumerate([False, True]):\n",
    "        group_trials = method_trials.xs(bn_value, level='model.bn')\n",
    "        group_means = method_means.xs(bn_value, level='model.bn')\n",
    "        for row, x in enumerate(eigenvals):\n",
    "            ax = axs[row, col]\n",
    "            plt.sca(ax)\n",
    "            sns.scatterplot(group_trials, x=y, y=x, hue='reg_method', style='train.learning_rate', markers=True, palette='colorblind', alpha=0.3, legend=False)\n",
    "            sns.lineplot(group_means, x=y, y=x, hue='reg_method', style='train.learning_rate', markers=True, palette='colorblind')\n",
    "            plt.loglog()\n",
    "            if row == 0:\n",
    "                plt.title(f'bn: {bn_value}')\n",
    "            if row != num_rows - 1:\n",
    "                plt.xlabel(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c697ffa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_all_vs_eigenvalue()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eae5f3a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
