{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import logging\n",
    "\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from wandb.sdk.wandb_run import Run\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import numpy as np\n",
    "import plotly.graph_objs as go\n",
    "\n",
    "api = wandb.Api()\n",
    "entity, project = \"ANONYMIZED\", \"cycle-consistent-model-merging\"  # set to your entity and project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_runs(entity, project, positive_tags, negative_tags=None):\n",
    "    filters_pos_tags = {\"$and\": [{\"tags\": {\"$eq\": pos_tag}} for pos_tag in positive_tags]}\n",
    "    filters_neg_tags = {}\n",
    "\n",
    "    print(filters_pos_tags)\n",
    "    filters = {**filters_pos_tags, **filters_neg_tags}\n",
    "    runs = api.runs(entity + \"/\" + project, filters=filters)\n",
    "\n",
    "    print(f\"There are {len(runs)} runs respecting these conditions.\")\n",
    "    return runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tags = [\"merge_n_models\", \"resnet\", \"emnist\"]  # 2x, 4x, 8x, cifar100, vgg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'$and': [{'tags': {'$eq': 'merge_n_models'}}, {'tags': {'$eq': 'resnet'}}, {'tags': {'$eq': 'emnist'}}]}\n",
      "There are 6 runs respecting these conditions.\n"
     ]
    }
   ],
   "source": [
    "runs = get_runs(entity, project, positive_tags=tags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "mergers = [\"frank_wolfe\", \"git_rebasin\", \"naive\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'frank_wolfe': {'repaired': {}, 'untouched': {}}, 'git_rebasin': {'repaired': {}, 'untouched': {}}, 'naive': {'repaired': {}, 'untouched': {}}}\n"
     ]
    }
   ],
   "source": [
    "exps = {merger: {\"repaired\": {}, \"untouched\": {}} for merger in mergers}\n",
    "print(exps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_key = \"matching/seed_index\"\n",
    "model_pair_key = \"matching/model_seeds\"\n",
    "\n",
    "merger_key = \"matching/merger/_target_\"\n",
    "\n",
    "gitrebasin_classname = \"ccmm.matching.merger.GitRebasinMerger\"\n",
    "frankwolfe_classname = \"ccmm.matching.merger.FrankWolfeSynchronizedMerger\"\n",
    "naive_classname = \"ccmm.matching.merger.DummyMerger\"\n",
    "\n",
    "model_key = \"model/name\"\n",
    "merger_mapping = {\n",
    "    gitrebasin_classname: \"git_rebasin\",\n",
    "    frankwolfe_classname: \"frank_wolfe\",\n",
    "    naive_classname: \"naive\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6/6 [00:08<00:00,  1.35s/it]\n"
     ]
    }
   ],
   "source": [
    "for run in tqdm(runs):\n",
    "    run: Run\n",
    "    cfg = run.config\n",
    "\n",
    "    if len(cfg) == 0:\n",
    "        pylogger.warning(\"Runs are still running, skipping\")\n",
    "        continue\n",
    "\n",
    "    if \"merged\" in cfg[\"core/tags\"]:\n",
    "        repaired_key = \"untouched\"\n",
    "    elif \"repaired\" in cfg[\"core/tags\"]:\n",
    "        repaired_key = \"repaired\"\n",
    "    else:\n",
    "        pylogger.warning(\"Run is neither merged nor repaired, skipping\")\n",
    "        continue\n",
    "\n",
    "    seed = cfg[seed_key]\n",
    "    model_pair = cfg[model_pair_key]\n",
    "\n",
    "    merger_mapped = merger_mapping[cfg[merger_key]]\n",
    "\n",
    "    hist = run.scan_history()\n",
    "\n",
    "    train_acc = run.history(keys=[\"acc/train\"])[\"acc/train\"][0]\n",
    "    test_acc = run.history(keys=[\"acc/test\"])[\"acc/test\"][0]\n",
    "\n",
    "    train_loss = run.history(keys=[\"loss/train\"])[\"loss/train\"][0]\n",
    "    test_loss = run.history(keys=[\"loss/test\"])[\"loss/test\"][0]\n",
    "\n",
    "    exps[merger_mapped][repaired_key] = {\n",
    "        \"train_acc\": train_acc,\n",
    "        \"test_acc\": test_acc,\n",
    "        \"train_loss\": train_loss,\n",
    "        \"test_loss\": test_loss,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'frank_wolfe': {'repaired': {'train_acc': 0.6081490516662598,\n",
       "   'test_acc': 0.6052404046058655,\n",
       "   'train_loss': 1.3268992900848389,\n",
       "   'test_loss': 1.3464902639389038},\n",
       "  'untouched': {'train_acc': 0.2715063989162445,\n",
       "   'test_acc': 0.2695673108100891,\n",
       "   'train_loss': 3.434561014175415,\n",
       "   'test_loss': 3.477875232696533}},\n",
       " 'git_rebasin': {'repaired': {'train_acc': 0.038942307233810425,\n",
       "   'test_acc': 0.03802884742617607,\n",
       "   'train_loss': 4.740589141845703,\n",
       "   'test_loss': 4.726057529449463},\n",
       "  'untouched': {'train_acc': 0.035817306488752365,\n",
       "   'test_acc': 0.03500000014901161,\n",
       "   'train_loss': 7.170246601104736,\n",
       "   'test_loss': 7.188019275665283}},\n",
       " 'naive': {'repaired': {'train_acc': 0.03745993599295616,\n",
       "   'test_acc': 0.03725961595773697,\n",
       "   'train_loss': 3.7505240440368652,\n",
       "   'test_loss': 3.7475671768188477},\n",
       "  'untouched': {'train_acc': 0.0395432710647583,\n",
       "   'test_acc': 0.04019230604171753,\n",
       "   'train_loss': 4.042140960693359,\n",
       "   'test_loss': 4.039624214172363}}}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "\n",
    "for merger_name, merger_repaired_data in exps.items():\n",
    "    for repaired_flag, metrics in merger_repaired_data.items():\n",
    "        if metrics:\n",
    "            record = {\n",
    "                \"merger\": merger_name + \"_\" + repaired_flag,\n",
    "                \"train_acc\": metrics[\"train_acc\"],\n",
    "                \"test_acc\": metrics[\"test_acc\"],\n",
    "                \"train_loss\": metrics[\"train_loss\"],\n",
    "                \"test_loss\": metrics[\"test_loss\"],\n",
    "            }\n",
    "\n",
    "            records.append(record)\n",
    "\n",
    "df = pd.DataFrame(records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>merger</th>\n",
       "      <th>train_acc</th>\n",
       "      <th>test_acc</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>test_loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>frank_wolfe_repaired</td>\n",
       "      <td>0.608149</td>\n",
       "      <td>0.605240</td>\n",
       "      <td>1.326899</td>\n",
       "      <td>1.346490</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>frank_wolfe_untouched</td>\n",
       "      <td>0.271506</td>\n",
       "      <td>0.269567</td>\n",
       "      <td>3.434561</td>\n",
       "      <td>3.477875</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>git_rebasin_repaired</td>\n",
       "      <td>0.038942</td>\n",
       "      <td>0.038029</td>\n",
       "      <td>4.740589</td>\n",
       "      <td>4.726058</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>git_rebasin_untouched</td>\n",
       "      <td>0.035817</td>\n",
       "      <td>0.035000</td>\n",
       "      <td>7.170247</td>\n",
       "      <td>7.188019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>naive_repaired</td>\n",
       "      <td>0.037460</td>\n",
       "      <td>0.037260</td>\n",
       "      <td>3.750524</td>\n",
       "      <td>3.747567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>naive_untouched</td>\n",
       "      <td>0.039543</td>\n",
       "      <td>0.040192</td>\n",
       "      <td>4.042141</td>\n",
       "      <td>4.039624</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  merger  train_acc  test_acc  train_loss  test_loss\n",
       "0   frank_wolfe_repaired   0.608149  0.605240    1.326899   1.346490\n",
       "1  frank_wolfe_untouched   0.271506  0.269567    3.434561   3.477875\n",
       "2   git_rebasin_repaired   0.038942  0.038029    4.740589   4.726058\n",
       "3  git_rebasin_untouched   0.035817  0.035000    7.170247   7.188019\n",
       "4         naive_repaired   0.037460  0.037260    3.750524   3.747567\n",
       "5        naive_untouched   0.039543  0.040192    4.042141   4.039624"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher_to_latex_map = {\n",
    "    \"frank_wolfe_repaired\": r\"\\texttt{Frank-Wolfe}$^\\dagger$\",\n",
    "    \"git_rebasin_repaired\": r\"\\texttt{Git-Rebasin}$^\\dagger$\",\n",
    "    \"naive_untouched\": r\"\\texttt{Naive}\",\n",
    "    \"naive_repaired\": r\"\\texttt{Naive}$^\\dagger$\",\n",
    "    \"frank_wolfe_untouched\": r\"\\texttt{Frank-Wolfe}\",\n",
    "    \"git_rebasin_untouched\": r\"\\texttt{Git-Rebasin}\",\n",
    "}\n",
    "\n",
    "ordering = [\n",
    "    \"naive_untouched\",\n",
    "    \"naive_repaired\",\n",
    "    \"git_rebasin_untouched\",\n",
    "    \"git_rebasin_repaired\",\n",
    "    \"frank_wolfe_untouched\",\n",
    "    \"frank_wolfe_repaired\",\n",
    "]\n",
    "\n",
    "df[\"merger\"] = pd.Categorical(df[\"merger\"], ordering)\n",
    "df.sort_values(by=\"merger\", ascending=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>merger</th>\n",
       "      <th>train_acc</th>\n",
       "      <th>test_acc</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>test_loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>naive_untouched</td>\n",
       "      <td>0.039543</td>\n",
       "      <td>0.040192</td>\n",
       "      <td>4.042141</td>\n",
       "      <td>4.039624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>naive_repaired</td>\n",
       "      <td>0.037460</td>\n",
       "      <td>0.037260</td>\n",
       "      <td>3.750524</td>\n",
       "      <td>3.747567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>git_rebasin_untouched</td>\n",
       "      <td>0.035817</td>\n",
       "      <td>0.035000</td>\n",
       "      <td>7.170247</td>\n",
       "      <td>7.188019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>git_rebasin_repaired</td>\n",
       "      <td>0.038942</td>\n",
       "      <td>0.038029</td>\n",
       "      <td>4.740589</td>\n",
       "      <td>4.726058</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>frank_wolfe_untouched</td>\n",
       "      <td>0.271506</td>\n",
       "      <td>0.269567</td>\n",
       "      <td>3.434561</td>\n",
       "      <td>3.477875</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>frank_wolfe_repaired</td>\n",
       "      <td>0.608149</td>\n",
       "      <td>0.605240</td>\n",
       "      <td>1.326899</td>\n",
       "      <td>1.346490</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  merger  train_acc  test_acc  train_loss  test_loss\n",
       "5        naive_untouched   0.039543  0.040192    4.042141   4.039624\n",
       "4         naive_repaired   0.037460  0.037260    3.750524   3.747567\n",
       "3  git_rebasin_untouched   0.035817  0.035000    7.170247   7.188019\n",
       "2   git_rebasin_repaired   0.038942  0.038029    4.740589   4.726058\n",
       "1  frank_wolfe_untouched   0.271506  0.269567    3.434561   3.477875\n",
       "0   frank_wolfe_repaired   0.608149  0.605240    1.326899   1.346490"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "# cmap = \"coolwarm\"\n",
    "cmap = sns.light_palette(\"seagreen\", as_cmap=True)\n",
    "cmap_reverse = sns.light_palette(\"seagreen\", as_cmap=True, reverse=True)\n",
    "# cmap = adjust_cmap_alpha(cmap, alpha=1)\n",
    "# cmap = sns.color_palette(\"vlag\", as_cmap=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ANONYMIZED/miniconda3/envs/ccmm/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/ANONYMIZED/miniconda3/envs/ccmm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/ANONYMIZED/miniconda3/envs/ccmm/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c104warnERKNS_7WarningE'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
      "  warn(\n"
     ]
    }
   ],
   "source": [
    "from ccmm.utils.plot import decimal_to_rgb_color\n",
    "\n",
    "max_loss_value = 6.0\n",
    "\n",
    "header = r\"\"\"\n",
    "\\begin{table}\n",
    "    \\begin{center}\n",
    "        \\begin{tabular}{lccc}\n",
    "        \\toprule\n",
    "        \\textbf{Matcher}        & \\multicolumn{2}{c}{\\textbf{Barrier}}                   \\\\\n",
    "                                & \\textbf{Train}                       & \\textbf{Test}   \\\\\n",
    "        \\midrule\n",
    "        \"\"\"\n",
    "\n",
    "\n",
    "body = \"\"\n",
    "\n",
    "for row in df.iterrows():\n",
    "    row = row[1]\n",
    "    merger = row[\"merger\"]\n",
    "\n",
    "    if merger == \"naive_repaired\":\n",
    "        continue\n",
    "\n",
    "    test_acc = row[\"test_acc\"]\n",
    "    train_acc = row[\"train_acc\"]\n",
    "    test_loss = row[\"test_loss\"]\n",
    "    train_loss = row[\"train_loss\"]\n",
    "\n",
    "    test_acc_col = decimal_to_rgb_color(test_acc, cmap)[:3]\n",
    "    train_acc_col = decimal_to_rgb_color(train_acc, cmap)[:3]\n",
    "    test_loss_col = None #decimal_to_rgb_color(test_loss / max_loss_value, cmap_reverse)[:3]\n",
    "    train_loss_col = None #decimal_to_rgb_color(train_loss / max_loss_value, cmap_reverse)[:3]\n",
    "\n",
    "    # col_and_val = lambda color, value: f\"\\\\cellcolor[rgb]{{{color}}}{value:.2f}\"\n",
    "    col_and_val = lambda color, value: f\"{value:.3f}\"\n",
    "\n",
    "    body += f\"\"\"\n",
    "                & {matcher_to_latex_map[merger]} &  {col_and_val(train_acc_col, train_acc)} & {col_and_val(test_acc_col, test_acc)} & {col_and_val(train_loss_col, train_loss)} & {col_and_val(test_loss_col, test_loss)} \\\\\\\\\"\"\".replace(\n",
    "        \"(\", \"\"\n",
    "    ).replace(\n",
    "        \")\", \"\"\n",
    "    )\n",
    "\n",
    "footer = r\"\"\"\n",
    "        \\bottomrule\n",
    "        \\end{tabular}\n",
    "    \\end{center}\n",
    "    \\caption{Mean and standard deviation of the test and train loss barrier for each matcher.}\n",
    "    \\label{tab:MLP_loss_barrier}\n",
    "\\end{table}\"\"\"\n",
    "\n",
    "table = header + body + footer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\\begin{table}\n",
      "    \\begin{center}\n",
      "        \\begin{tabular}{lccc}\n",
      "        \\toprule\n",
      "        \\textbf{Matcher}        & \\multicolumn{2}{c}{\\textbf{Barrier}}                   \\\\\n",
      "                                & \\textbf{Train}                       & \\textbf{Test}   \\\\\n",
      "        \\midrule\n",
      "        \n",
      "                & \\texttt{Naive} &  0.040 & 0.040 & 4.042 & 4.040 \\\\\n",
      "                & \\texttt{Git-Rebasin} &  0.036 & 0.035 & 7.170 & 7.188 \\\\\n",
      "                & \\texttt{Git-Rebasin}$^\\dagger$ &  0.039 & 0.038 & 4.741 & 4.726 \\\\\n",
      "                & \\texttt{Frank-Wolfe} &  0.272 & 0.270 & 3.435 & 3.478 \\\\\n",
      "                & \\texttt{Frank-Wolfe}$^\\dagger$ &  0.608 & 0.605 & 1.327 & 1.346 \\\\\n",
      "        \\bottomrule\n",
      "        \\end{tabular}\n",
      "    \\end{center}\n",
      "    \\caption{Mean and standard deviation of the test and train loss barrier for each matcher.}\n",
      "    \\label{tab:MLP_loss_barrier}\n",
      "\\end{table}\n"
     ]
    }
   ],
   "source": [
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
