{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b653e49",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os, sys, json, random, re\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple\n",
    "import json, itertools, copy\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "REPO_ROOT = os.path.abspath('..')\n",
    "sys.path.append(REPO_ROOT)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.colors as mcolors\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib.legend_handler import HandlerTuple\n",
    "from matplotlib.container import ErrorbarContainer \n",
    "from matplotlib.lines import Line2D\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334292c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_df(df, filter_dict):\n",
    "  _df = df.loc[df[list(filter_dict)].eq(pd.Series(filter_dict)).all(axis=1)]\n",
    "  return _df\n",
    "\n",
    "def get_grid(df, cols):\n",
    "  grid = {}\n",
    "  for col in cols:\n",
    "    unique_vals = df[col].unique()\n",
    "    if len(unique_vals) > 1:\n",
    "      grid[col] = unique_vals\n",
    "  return grid\n",
    "\n",
    "def grid_to_param_dicts(grid):\n",
    "  all_params = []\n",
    "  for values in itertools.product(*grid.values()):\n",
    "      param_dict = {key: value for key, value in zip(grid.keys(), values)}\n",
    "      all_params.append(param_dict)\n",
    "  return all_params\n",
    "\n",
    "def format_xticks(ax, axis):\n",
    "  if axis == 'x':\n",
    "    ticks = ax.get_xticks()\n",
    "    ax.set_xticks(ticks)\n",
    "    ax.set_xticklabels([f'{int(x/1000)}k' for x in ticks])\n",
    "  elif axis == 'y':\n",
    "    ticks = ax.get_yticks()\n",
    "    ax.set_yticks(ticks)\n",
    "    ax.set_yticklabels([f'{int(x/1000)}k' for x in ticks])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f19d515",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dir = os.path.join(REPO_ROOT, 'data')\n",
    "agg_filepath = os.path.join(load_dir, 'fig2_AGG.csv')\n",
    "avg_filepath = os.path.join(load_dir, 'fig2_AVG.csv')\n",
    "\n",
    "\n",
    "df = pd.read_csv(agg_filepath)\n",
    "print(f\"df (agg) shape: {df.shape}\")\n",
    "\n",
    "avg_df = pd.read_csv(avg_filepath)\n",
    "print(f\"avg df shape: {avg_df.shape}\")\n",
    "\n",
    "seeds = sorted([int(s) for s in df['seed'].unique()])\n",
    "print(f\"seeds: {seeds}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd49c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_cols = sorted(df.columns)\n",
    "avg_df_cols = sorted(avg_df.columns)\n",
    "\n",
    "logM_cols = [col for col in df_cols if col.startswith('M=')]\n",
    "k_cols = [col for col in df_cols if col.startswith('k=')]\n",
    "metric_cols = logM_cols + k_cols + ['correct', 'kl']\n",
    "\n",
    "axis_cols = ['iteration', 'eval']\n",
    "param_cols = sorted(list(set(df_cols) - set(metric_cols + axis_cols)))\n",
    "avg_param_cols = sorted(list(set(avg_df_cols) - set(metric_cols + axis_cols)))\n",
    "print(f\"df param_cols: {param_cols}\")\n",
    "print(f\"avg param_cols: {avg_param_cols}\")\n",
    "\n",
    "\n",
    "print(f\"\\n# df cols: {len(df_cols)}\")\n",
    "print(f\"\\t- {len(metric_cols)} metrics\")\n",
    "print(f\"\\t- {len(param_cols)} params\")\n",
    "print(f\"\\t- {len(axis_cols)} axes\")\n",
    "print(f\"total {len(param_cols + metric_cols + axis_cols)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd9899b",
   "metadata": {},
   "outputs": [],
   "source": [
    "majority_edge_shuffle_rules = sorted(df['majority_edge_shuffle_rule'].unique())\n",
    "minority_edge_shuffle_rules = sorted(df['minority_edge_shuffle_rule'].unique())\n",
    "print(f\"majority_edge_shuffle_rules: {majority_edge_shuffle_rules}\")\n",
    "print(f\"minority_edge_shuffle_rules: {minority_edge_shuffle_rules}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99fd49c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_logM = 4.0\n",
    "plot_logM_col = f\"M={plot_logM}\"\n",
    "logM_col = plot_logM_col\n",
    "\n",
    "filter_dict = {\n",
    "  'rule': 'vhard',\n",
    "  'num_chunks': 8,\n",
    "  'ratio': 0.05,\n",
    "  'minority_edge_shuffle_rule': 'bucketby',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eb9b838",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_context(\"paper\", font_scale=2)\n",
    "\n",
    "colors = sns.color_palette('cool', 3)\n",
    "darker_colors = [tuple(c * 0.7 for c in color[:3]) + (color[3],) for color in [mcolors.to_rgba(color) for color in colors]]\n",
    "darker_red = tuple(c * 0.7 for c in (1, 0, 0, 1))\n",
    "\n",
    "\n",
    "fdf = df.sort_values(['iteration'])\n",
    "favg_df = avg_df.sort_values(['iteration'])\n",
    "max_iters = fdf['iteration'].max()\n",
    "layers_list = sorted(list(fdf['layers'].unique()))\n",
    "\n",
    "num_rows = 1\n",
    "num_cols = 3\n",
    "fig, axs = plt.subplots(num_rows, num_cols, figsize=(4.8*num_cols, 4.3*num_rows))\n",
    "kl_axs = axs[0]\n",
    "pcov_axs = axs[1]\n",
    "ratio_axs = axs[2]\n",
    "\n",
    "qlist = ['std', 'sem', 'upper', 'lower', 'median']\n",
    "\n",
    "for layers in layers_list:\n",
    "  layers_df = fdf[fdf['layers'] == layers]\n",
    "  seeds = sorted([int(s) for s in layers_df['seed'].unique()])\n",
    "  iterations = sorted(layers_df['iteration'].unique())\n",
    "\n",
    "  avg_layers_df = favg_df[favg_df['layers'] == layers]\n",
    "\n",
    "  kl_dict = {\n",
    "    'avg' : avg_layers_df['kl'].to_numpy(),\n",
    "    **{f'{q}' : avg_layers_df[f'kl_{q}'].to_numpy() for q in qlist}\n",
    "  }\n",
    "  pcov_dict = {\n",
    "    'avg': avg_layers_df[logM_col].to_numpy(),\n",
    "    **{f'{q}' : avg_layers_df[f'{logM_col}_{q}'].to_numpy() for q in qlist}\n",
    "  }\n",
    "\n",
    "  seed_ratios = []\n",
    "  seed_ratios = np.full((len(seeds), len(pcov_dict['avg'])), np.nan)\n",
    "\n",
    "  for seed_idx, seed in enumerate(seeds):\n",
    "    seed_df = layers_df[layers_df['seed'] == seed]\n",
    "    iterations = seed_df['iteration'].to_numpy()\n",
    "    kl = seed_df['kl'].to_numpy()\n",
    "    value = seed_df[logM_col].to_numpy()\n",
    "    if len(iterations) > 51:\n",
    "      continue\n",
    "    ratio = np.divide(kl, value)\n",
    "    seed_ratios[seed_idx] = ratio\n",
    "\n",
    "  ratio_dict = {\n",
    "    'avg': np.nanmean(seed_ratios, axis=0),\n",
    "    'median': np.nanmedian(seed_ratios, axis=0),\n",
    "    'std': np.nanstd(seed_ratios, axis=0),\n",
    "    'sem': np.nanstd(seed_ratios, axis=0) / np.sqrt(len(seed_ratios)),\n",
    "    'upper': np.nanquantile(seed_ratios, 14/16, axis=0),\n",
    "    'lower': np.nanquantile(seed_ratios, 2/16, axis=0),\n",
    "  }\n",
    "\n",
    "  alpha = 0.2\n",
    "  linewidth=4\n",
    "  upper_fn = lambda x_dict: x_dict['upper'] \n",
    "  lower_fn = lambda x_dict: x_dict['lower'] \n",
    "  avg_key = 'median'\n",
    "  qname = avg_key[:3].upper()\n",
    "\n",
    "  for ax, results_dict in zip([ratio_axs, pcov_axs, kl_axs], [ratio_dict, pcov_dict, kl_dict]):\n",
    "    ax.plot(iterations, results_dict[avg_key], label=f'H={layers}', color=colors[layers_list.index(layers)], linewidth=linewidth)\n",
    "    ax.fill_between(iterations, lower_fn(results_dict), upper_fn(results_dict), color=colors[layers_list.index(layers)], alpha=alpha)\n",
    "\n",
    "kl_axs.set_ylim(0, 20)\n",
    "kl_axs.set_ylabel('KL')\n",
    "\n",
    "pcov_axs.set_ylim(0, 1.05)\n",
    "pcov_axs.set_ylabel(r'$\\mathsf{Cov}_{N=16}$')\n",
    "\n",
    "ratio_axs.set_ylabel(r\"Ratio $\\frac{\\mathsf{KL}}{\\mathsf{Cov}_{N=16}}$\")\n",
    "ratio_axs.set_ylim(bottom=0)\n",
    "\n",
    "for ax in axs:\n",
    "  ax.set_xlim(left=0, right=max_iters)\n",
    "  ax.set_xlabel(\"Iterations\")\n",
    "  format_xticks(ax, 'x')\n",
    "  ax.legend(loc='upper right', handlelength=1, fontsize=15)\n",
    "  ax.locator_params(axis='y', nbins=5)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c07c35e0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "coverage",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
