{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee667ea4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os, sys, json, random, re\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple\n",
    "import json, itertools, copy\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "REPO_ROOT = os.path.abspath('..')\n",
    "sys.path.append(REPO_ROOT)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.colors as mcolors\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib.legend_handler import HandlerTuple\n",
    "from matplotlib.container import ErrorbarContainer \n",
    "from matplotlib.lines import Line2D\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f845a257",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_df(df, filter_dict):\n",
    "  _df = df.loc[df[list(filter_dict)].eq(pd.Series(filter_dict)).all(axis=1)]\n",
    "  return _df\n",
    "\n",
    "def get_grid(df, cols):\n",
    "  grid = {}\n",
    "  for col in cols:\n",
    "    unique_vals = df[col].unique()\n",
    "    if len(unique_vals) > 1:\n",
    "      grid[col] = unique_vals\n",
    "  return grid\n",
    "\n",
    "def grid_to_param_dicts(grid):\n",
    "  all_params = []\n",
    "  for values in itertools.product(*grid.values()):\n",
    "      param_dict = {key: value for key, value in zip(grid.keys(), values)}\n",
    "      all_params.append(param_dict)\n",
    "  return all_params\n",
    "\n",
    "def format_xticks(ax, axis):\n",
    "  if axis == 'x':\n",
    "    ticks = ax.get_xticks()\n",
    "    ax.set_xticks(ticks)\n",
    "    ax.set_xticklabels([f'{int(x/1000)}k' for x in ticks])\n",
    "  elif axis == 'y':\n",
    "    ticks = ax.get_yticks()\n",
    "    ax.set_yticks(ticks)\n",
    "    ax.set_yticklabels([f'{int(x/1000)}k' for x in ticks])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebad0405",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dir = os.path.join(REPO_ROOT, 'data')\n",
    "load_path = os.path.join(load_dir, 'fig1_AGG.csv')\n",
    "\n",
    "df = pd.read_csv(load_path)\n",
    "print(f\"df (agg) shape: {df.shape}\")\n",
    "\n",
    "seeds = sorted(df['seed'].unique())\n",
    "print(f\"seeds: {seeds}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d5a6a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_cols = sorted(df.columns)\n",
    "\n",
    "logM_cols = [col for col in df_cols if col.startswith('M=')]\n",
    "k_cols = [col for col in df_cols if col.startswith('k=')]\n",
    "metric_cols = logM_cols + k_cols + ['correct', 'kl']\n",
    "\n",
    "axis_cols = ['iteration', 'eval']\n",
    "param_cols = sorted(list(set(df_cols) - set(metric_cols + axis_cols)))\n",
    "print(f\"df param_cols: {param_cols}\")\n",
    "\n",
    "\n",
    "print(f\"\\n# df cols: {len(df_cols)}\")\n",
    "print(f\"\\t- {len(metric_cols)} metrics\")\n",
    "print(f\"\\t- {len(param_cols)} params\")\n",
    "print(f\"\\t- {len(axis_cols)} axes\")\n",
    "print(f\"total {len(param_cols + metric_cols + axis_cols)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "030bc0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_cols = list(set(param_cols) - set(['seed', 'ratio1', 'ratio2']))\n",
    "grid = get_grid(df, grid_cols)\n",
    "all_params = grid_to_param_dicts(grid)\n",
    "\n",
    "for k, v in grid.items():\n",
    "  print(f'{k}: {v}')\n",
    "\n",
    "print(f'\\n# grid combinations: {len(all_params)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe9bcc50",
   "metadata": {},
   "outputs": [],
   "source": [
    "tdf_load_dir = os.path.join(REPO_ROOT, 'data')\n",
    "tdf_load_path = os.path.join(tdf_load_dir, 'fig1_tournament.csv')\n",
    "tdf = pd.read_csv(tdf_load_path)\n",
    "print(f\"Loaded tournament results {tdf.shape}\")\n",
    "\n",
    "\n",
    "tdf = tdf.sort_values('logM')\n",
    "print(f\"tournament result columns: {tdf.columns}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2726eeeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_plot_logMs = [1.0, 1.6, 2.0, 2.2, 2.5]\n",
    "plot_logMs = [1.0, 2.0, 2.5]\n",
    "plot_Ms = [2**logM for logM in plot_logMs]\n",
    "plot_ks = [4, 8, 12]\n",
    "\n",
    "plot_logM_cols = [f'M={logM}' for logM in plot_logMs]\n",
    "all_plot_logM_cols = [f'M={logM}' for logM in all_plot_logMs]\n",
    "plot_k_cols = [f'k={k}' for k in plot_ks]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee6cbbb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "colors = sns.color_palette('cool', len(all_plot_logMs))\n",
    "darker_colors = [tuple(min(c * 0.9, 1.0) for c in color[:3]) + (color[3],) for color in [mcolors.to_rgba(color) for color in colors]]\n",
    "lighter_colors = [tuple(min(c * 1.2, 1.0) for c in color[:3]) + (color[3],) for color in [mcolors.to_rgba(color) for color in colors]]\n",
    "darker_red = tuple(c * 0.7 for c in (1, 0, 0, 1))\n",
    "\n",
    "def get_cov_styles(logM_idx, seed=None):\n",
    "  plot_logM = all_plot_logMs[logM_idx]\n",
    "  n_value = np.round(2**plot_logM, 1)\n",
    "  n_value_str = str(n_value)\n",
    "  pcov_label = r\"$\\mathsf{Cov}_{\" + n_value_str + r\"}$\"\n",
    "  npcov_label = r\"$N=\" + n_value_str + r\"$ \"\n",
    "  k_value = plot_ks[logM_idx] if logM_idx < len(plot_logMs) else None\n",
    "  cov_style = {\n",
    "    'linewidth': 5,\n",
    "    'alpha': 1.0,\n",
    "    'zorder': logM_idx+1+len(plot_logMs),\n",
    "    'color': colors[logM_idx],\n",
    "    'label': pcov_label,\n",
    "  }\n",
    "  cov_fill_style = {\n",
    "    'alpha': 0.3,\n",
    "    'zorder': logM_idx+1,\n",
    "    'color': colors[logM_idx],\n",
    "  }\n",
    "  cov_x_style = {\n",
    "    's': 150,\n",
    "    'marker': 'd',\n",
    "    'zorder': 2*len(plot_logMs) + logM_idx + 1,\n",
    "    'alpha': 1.0,\n",
    "    'label': 'Selected model' if seed == 1 else None,\n",
    "    'color': colors[logM_idx],\n",
    "    'edgecolor': 'black',\n",
    "    'linewidth': 2,\n",
    "  }\n",
    "  cov_scatter_style = {\n",
    "    's': 30, \n",
    "    'label': 'Pass@{k_value}'.format(k_value=k_value) if seed == 1 else None,\n",
    "    'color': colors[logM_idx],\n",
    "    'alpha': 0.6,\n",
    "    'zorder': 3,\n",
    "    'edgecolor': 'none',\n",
    "  }\n",
    "  return {'line': cov_style, 'fill': cov_fill_style, 'x': cov_x_style, 'scatter': cov_scatter_style}\n",
    "\n",
    "def get_kl_styles(logM_idx, seed=None):\n",
    "  k_value = plot_ks[logM_idx]\n",
    "  kl_style = {\n",
    "    'linewidth': 4,\n",
    "    'linestyle': (0, (2, 1)),\n",
    "    'zorder': 2,\n",
    "    'alpha': 1.0,\n",
    "    'label': r'$\\mathsf{KL}$',\n",
    "    'color': 'black',\n",
    "  }\n",
    "  kl_fill_style = {\n",
    "    'color': 'black',\n",
    "    'alpha': 0.2,\n",
    "    'zorder': 1,\n",
    "  }\n",
    "  kl_x_style = {\n",
    "    'color': darker_red,\n",
    "    's': 90,\n",
    "    'marker': 'o',\n",
    "    'zorder': 2,\n",
    "    'alpha': 1.0,\n",
    "    'label': r'Model w/ min. $\\mathsf{KL}$' if seed == 1 else None,\n",
    "    'edgecolor': 'black',\n",
    "    'linewidth': 1.5,\n",
    "  }\n",
    "  kl_scatter_style = {\n",
    "    's': 30, \n",
    "    'label': 'Pass@{k_value}'.format(k_value=k_value) if seed == 1 else None,\n",
    "    'color': 'black', \n",
    "    'alpha': 0.4,\n",
    "    'zorder': 1,\n",
    "    'edgecolor': 'none',\n",
    "  }\n",
    "  return {'line': kl_style, 'fill': kl_fill_style, 'x': kl_x_style, 'scatter': kl_scatter_style}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed26fad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "sns.set_context(\"paper\", font_scale=2.0)\n",
    "\n",
    "\n",
    "\n",
    "max_iters = 50000\n",
    "min_scatter_iters = 10000\n",
    "max_scatter_iters = min(max_iters, 70000)\n",
    "\n",
    "\n",
    "##### Filter and set up the dfs\n",
    "\n",
    "fdf = df[df['iteration'] <= max_iters]\n",
    "fdf = fdf.sort_values(['iteration'])\n",
    "ftdf = tdf\n",
    "\n",
    "grouped_df = fdf.groupby('iteration')\n",
    "iterations = [g for g in grouped_df.groups]\n",
    "assert iterations == sorted(iterations)\n",
    "\n",
    "##### Set up the graph \n",
    "\n",
    "num_rows = 1\n",
    "num_cols = 1+len(plot_k_cols)\n",
    "widths = [5 for _ in range(num_cols)]\n",
    "spaces = [2.7] + [1.1 for _ in range(num_cols-1)]\n",
    "width_ratios = [[widths[i], spaces[i]] for i in range(num_cols)]\n",
    "width_ratios = [item for sublist in width_ratios for item in sublist]\n",
    "width_ratios = width_ratios[:-1]\n",
    "\n",
    "gs = gridspec.GridSpec(num_rows, len(width_ratios),\n",
    "                      width_ratios=width_ratios,\n",
    "                      wspace=0.0)\n",
    "fig = plt.figure(figsize=(4.9*num_cols, 5.5*num_rows))\n",
    "axs = [fig.add_subplot(gs[0, 2*i]) for i in range(num_cols)]\n",
    "twin_axs = [axs[0].twinx()] + [axs[i].twiny() for i in range(1, len(plot_k_cols)+1)]\n",
    "handle_dict = {}\n",
    "\n",
    "\n",
    "'''\n",
    "### Subplot 0\n",
    "'''\n",
    "twin_ax = twin_axs[0]\n",
    "ax = axs[0]\n",
    "\n",
    "### processing kl\n",
    "kl_grouped = grouped_df['kl']\n",
    "kl = kl_grouped.median().to_numpy()\n",
    "kl_upper = kl_grouped.quantile(0.9).to_numpy()\n",
    "kl_lower = kl_grouped.quantile(0.1).to_numpy()\n",
    "\n",
    "### plotting kl\n",
    "kl_styles = get_kl_styles(0)\n",
    "twin_ax.plot(iterations, kl, **kl_styles['line'])\n",
    "twin_ax.fill_between(iterations, kl_upper, kl_lower, **kl_styles['fill'])\n",
    "\n",
    "for logM_idx, (plot_logM, plot_logM_col) in enumerate(zip(all_plot_logMs, all_plot_logM_cols)):\n",
    "  cov_styles = get_cov_styles(logM_idx)\n",
    "\n",
    "  cov_grouped = grouped_df[plot_logM_col]\n",
    "  cov = cov_grouped.median().to_numpy()\n",
    "  cov_upper = cov_grouped.quantile(0.9).to_numpy()\n",
    "  cov_lower = cov_grouped.quantile(0.1).to_numpy()\n",
    "\n",
    "  ax.plot(iterations, cov, **cov_styles['line'])\n",
    "  ax.fill_between(iterations, cov_lower, cov_upper, **cov_styles['fill'])\n",
    "\n",
    "\n",
    "'''\n",
    "### Subplots 1+\n",
    "'''\n",
    "seeds_df = fdf.groupby('seed')\n",
    "seeds_df = fdf[fdf['iteration'] <= max_scatter_iters]\n",
    "seeds_df = seeds_df[seeds_df['iteration'] >= min_scatter_iters]\n",
    "seeds_df = seeds_df.sort_values('iteration').groupby('seed')\n",
    "\n",
    "for seed, seed_df in seeds_df:\n",
    "\n",
    "  print(f\"\\n================ Seed {seed} =================\")\n",
    "\n",
    "  #### Extract iterations\n",
    "  iterations = seed_df['iteration'].to_numpy()\n",
    "  seed_max_iters = seed_df['iteration'].max()\n",
    "  assert np.all(iterations == sorted(iterations)), f\"Seed {seed} iterations are not sorted\"\n",
    "\n",
    "  #### Extract kl\n",
    "  kl = seed_df['kl'].to_numpy()\n",
    "  best_kl_idx = int(np.argmin(kl))\n",
    "\n",
    "  ##### Extract pcovs\n",
    "  for idx, (logM, logM_col, k_col) in enumerate(zip(plot_logMs, plot_logM_cols, plot_k_cols)):\n",
    "    print(f\"- Processing logM={logM} with {k_col}\")\n",
    "\n",
    "    ### Set axes\n",
    "    pcov_styles = get_cov_styles(plot_logMs.index(logM), seed)\n",
    "    kl_styles = get_kl_styles(plot_logMs.index(logM), seed)\n",
    "    ax = axs[idx+1]\n",
    "    twin_ax = twin_axs[idx+1]\n",
    "\n",
    "    ### Extract k-values, pcovs\n",
    "    k_values = seed_df[k_col].to_numpy()\n",
    "    logM_values = seed_df[logM_col].to_numpy()\n",
    "    assert len(k_values) == len(logM_values) and len(k_values) == len(iterations)\n",
    "\n",
    "    ### Scatter k-values, pcovs\n",
    "    ax.scatter(logM_values, k_values, **pcov_styles['scatter'])\n",
    "    twin_ax.scatter(kl, k_values, **kl_styles['scatter'])\n",
    "    if seed == 1:\n",
    "      _handles = []\n",
    "      pcov_scatter_style = pcov_styles['scatter'].copy()\n",
    "      pcov_scatter_style['s'] = 80\n",
    "      pcov_scatter_style['alpha'] = 1.0\n",
    "\n",
    "      kl_scatter_style = kl_styles['scatter'].copy()\n",
    "      kl_scatter_style['s'] = 80\n",
    "      kl_scatter_style['alpha'] = 0.9\n",
    "\n",
    "      pcov_x_style = pcov_styles['x'].copy()\n",
    "      pcov_x_style['s'] = 100\n",
    "      kl_x_style = kl_styles['x'].copy()\n",
    "      kl_x_style['s'] = 90\n",
    "\n",
    "      handle = ax.scatter([], [], **pcov_scatter_style)\n",
    "      _handles.append(handle)\n",
    "      handle = ax.scatter([], [], **kl_scatter_style)\n",
    "      _handles.append(handle)\n",
    "      handle = ax.scatter([], [], **pcov_x_style)\n",
    "      _handles.append(handle)\n",
    "      handle = ax.scatter([], [], **kl_x_style)\n",
    "      _handles.append(handle)\n",
    "\n",
    "      handle_dict[k_col] = _handles\n",
    "\n",
    "    tdf_row = ftdf[(ftdf['logM'] == logM) & (ftdf['seed'] == seed)]\n",
    "    assert len(tdf_row) == 1, f\"Seed {seed} logM={logM} has {len(tdf_row)} rows\"\n",
    "\n",
    "    t_iter = tdf_row['best_iter'].values[0]\n",
    "    t_idx = list(iterations).index(t_iter)\n",
    "    t_idx_iter = iterations[t_idx]\n",
    "    assert t_idx_iter == t_iter, f\"Seed {seed} logM={logM} t_idx_iter={t_idx_iter} != t_iter={t_iter}\"\n",
    "    print(f\"\\t- Tournament iter: {t_iter}\")\n",
    "\n",
    "    real_pcov = logM_values[t_idx]\n",
    "    real_k = k_values[t_idx]\n",
    "    print(f\"\\t- Real pcov: {real_pcov:.3f}; min pcov: {logM_values.min():.3f}\")\n",
    "    print(f\"\\t- Real k: {real_k:.3f}; max k: {k_values.max():.3f}\")\n",
    "\n",
    "    ax.scatter([real_pcov], [real_k], **pcov_styles['x'])\n",
    "    twin_ax.scatter([kl[best_kl_idx]], [k_values[best_kl_idx]], **kl_styles['x'])\n",
    "\n",
    "\n",
    "''' \n",
    "Formatting \n",
    "'''\n",
    "\n",
    "legend_fontsize = 15\n",
    "\n",
    "##### Subplot 0\n",
    "ax = axs[0]\n",
    "twin_ax = twin_axs[0]\n",
    "kl_label = r\"$\\mathsf{KL}$\"\n",
    "kl_label = str(kl_label)\n",
    "pcov_label = r\"$\\mathsf{Cov}_{N}$\"\n",
    "pcov_label = str(pcov_label)\n",
    "\n",
    "ax.set_xlim(0, max_iters+500)\n",
    "ax.set_ylim(bottom=0.0, top=1.0)\n",
    "twin_ax.set_ylim(bottom=0.0, top=2.0)\n",
    "twin_ax.grid(False)\n",
    "\n",
    "ax.set_xlabel('Iteration')\n",
    "ax.set_ylabel(f\"{pcov_label}\")\n",
    "twin_ax.locator_params(axis='y', nbins=6)\n",
    "twin_ax.set_ylabel(f\"{kl_label}\")\n",
    "\n",
    "xticks = axs[0].get_xticks()\n",
    "xticks = [10000*i for i in range(6)]\n",
    "print(f\"xticks: {xticks}\")\n",
    "if len(xticks) > 0 and xticks[-1] > 50000:\n",
    "  xticks = xticks[:-1].tolist() + [50000]\n",
    "axs[0].set_xticks(xticks)\n",
    "axs[0].set_xticklabels([f'{int(x/1000)}k' for x in xticks])\n",
    "\n",
    "lines1, labels1 = ax.get_legend_handles_labels()\n",
    "lines2, labels2 = twin_ax.get_legend_handles_labels()\n",
    "legend = ax.legend(lines2 + lines1, labels2 + labels1, handlelength=0.8, fontsize=legend_fontsize, handletextpad=0.5)\n",
    "for i, handle in enumerate(legend.legend_handles):\n",
    "  linewidth = 3\n",
    "  if 'KL' in handle.get_label():\n",
    "    linewidth = 2.5\n",
    "  handle.set_linewidth(linewidth)\n",
    "\n",
    "\n",
    "##### Subplots 1+\n",
    "for k_idx, k_col in enumerate(plot_k_cols):\n",
    "  ax = axs[1+k_idx]\n",
    "  twin_ax = twin_axs[1+k_idx]\n",
    "  ax.set_zorder(twin_ax.get_zorder() - 1)\n",
    "  ax.patch.set_visible(False)\n",
    "\n",
    "  n_value = np.round(2**plot_logMs[k_idx], 1)\n",
    "  k_value = int(k_col.split('=')[-1])\n",
    "  n_value_str = str(n_value)\n",
    "  cov_label = r\"$\\mathsf{Cov}_{\" + n_value_str + r\"}$\"\n",
    "  cov_label = str(cov_label)\n",
    "  kl_label = r\"$\\mathsf{KL}$\"\n",
    "  kl_label = str(kl_label)\n",
    "\n",
    "  ax.set_xlabel(f\"{cov_label}\")\n",
    "  twin_ax.set_xlabel(f\"{kl_label}\")\n",
    "\n",
    "  if k_idx == 0:\n",
    "    ax.set_ylabel(r\"Pass@$\\,N$\")\n",
    "  ax.set_xlim(left=0)\n",
    "\n",
    "  twin_ax.set_xlim(0, 1.8)\n",
    "  twin_ax.locator_params(axis='x', nbins=5)\n",
    "  twin_ax.grid(False)\n",
    "\n",
    "  tick_increment =0.1 \n",
    "  if k_idx == 1:\n",
    "    tick_increment = 0.04\n",
    "  elif k_idx > 1:\n",
    "    tick_increment = 0.02\n",
    "  tick_total = 1 / tick_increment\n",
    "  y_min, y_max = ax.get_ylim()\n",
    "  if k_idx == len(plot_k_cols) - 1:\n",
    "    y_max = 1\n",
    "    y_min = 0.955\n",
    "  y_max =1\n",
    "  yticks = np.arange(np.ceil(y_min * tick_total) / tick_total, np.floor(y_max * tick_total) / tick_total + tick_increment, tick_increment)\n",
    "  yticks = [t for t in yticks if t <= 1]\n",
    "  if yticks[-1] <= 1:\n",
    "    yticks.append(1.0)\n",
    "  ax.set_yticks(yticks)\n",
    "  if k_idx == len(plot_k_cols) - 1:\n",
    "    ax.set_ylim(bottom=0.941)\n",
    "  elif k_idx == 0:\n",
    "    ax.set_ylim(bottom=0.75)\n",
    "  elif k_idx == 1:\n",
    "    ax.set_ylim(bottom=0.91)\n",
    "\n",
    "  x_nbins = 5 if k_idx == 0 else 3\n",
    "  ax.locator_params(axis='x', nbins=x_nbins)\n",
    "\n",
    "\n",
    "  lines1, labels1 = ax.get_legend_handles_labels()\n",
    "  lines2, labels2 = twin_ax.get_legend_handles_labels()\n",
    "  print(labels1)\n",
    "  _labels = [labels1[0], labels1[3], labels2[1]]\n",
    "  handles = handle_dict[plot_k_cols[k_idx]]\n",
    "  _lines = [(handles[0], handles[1]), handles[2], handles[3]]\n",
    "\n",
    "\n",
    "  ax.legend(\n",
    "    _lines, _labels,\n",
    "    handler_map={tuple: HandlerTuple(ndivide=None, pad=-0.3)},\n",
    "    loc=\"lower left\",\n",
    "    handletextpad=0.5,\n",
    "    fontsize=legend_fontsize,\n",
    "    )\n",
    "\n",
    "  for label in ax.get_legend().get_texts():\n",
    "    if 'Pass' in label.get_text():\n",
    "      label.set_weight('bold')\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd84ac1c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeae8de2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c39012a7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "coverage",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
