{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import logging\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "\n",
    "from tqdm import tqdm\n",
    "from wandb.sdk.wandb_run import Run\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import numpy as np\n",
    "import plotly.graph_objs as go\n",
    "import plotly.io as pio\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.plot import Palette\n",
    "\n",
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "    }\n",
    ")\n",
    "\n",
    "pylogger = logging.getLogger(__name__)\n",
    "palette"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "entity, project = \"ANONYMIZED\", \"cycle-consistent-model-merging\"  # set to your entity and project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_runs(entity, project, positive_tags, negative_tags=None):\n",
    "    filters_pos_tags = {\"$and\": [{\"tags\": {\"$eq\": pos_tag}} for pos_tag in positive_tags]}\n",
    "    filters_neg_tags = {}\n",
    "\n",
    "    print(filters_pos_tags)\n",
    "    filters = {**filters_pos_tags, **filters_neg_tags}\n",
    "    runs = api.runs(entity + \"/\" + project, filters=filters)\n",
    "\n",
    "    print(f\"There are {len(runs)} runs respecting these conditions.\")\n",
    "    return runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "considered_model = \"ResNet\"  # ResNet or MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "considered_model_tag = \"4x\" if considered_model == \"ResNet\" else \"mlp\"\n",
    "tags = [\"scaling\", considered_model_tag]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs = get_runs(entity, project, positive_tags=tags)  # negative_tags=[\"git_rebasin\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_key = \"matching/seed_index\"\n",
    "model_pair_key = \"matching/model_seeds\"\n",
    "\n",
    "merger_key = \"matching/merger/_target_\"\n",
    "\n",
    "gitrebasin_classname = \"ccmm.matching.merger.GitRebasinMerger\"\n",
    "frankwolfe_classname = \"ccmm.matching.merger.FrankWolfeSynchronizedMerger\"\n",
    "naive_classname = \"ccmm.matching.merger.DummyMerger\"\n",
    "\n",
    "model_key = \"model/name\"\n",
    "merger_mapping = {\n",
    "    gitrebasin_classname: \"git_rebasin\",\n",
    "    frankwolfe_classname: \"frank_wolfe\",\n",
    "    naive_classname: \"naive\",\n",
    "}\n",
    "mergers = [\"git_rebasin\", \"frank_wolfe\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_num_models = 11\n",
    "exps = {merger: [{} for i in range(max_num_models + 1)] for merger in mergers}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for run in tqdm(runs):\n",
    "    run: Run\n",
    "    cfg = run.config\n",
    "\n",
    "    if len(cfg) == 0:\n",
    "        pylogger.warning(\"Runs are still running, skipping\")\n",
    "        continue\n",
    "\n",
    "    num_models = len(cfg[\"matching/model_seeds\"])\n",
    "    print(num_models)\n",
    "\n",
    "    model_name = cfg[model_key]\n",
    "    merger = cfg[merger_key]\n",
    "\n",
    "    hist = run.scan_history()\n",
    "    merger_mapped = merger_mapping[cfg[merger_key]]\n",
    "\n",
    "    train_acc = run.history(keys=[\"acc/train\"])[\"acc/train\"][0]\n",
    "    test_acc = run.history(keys=[\"acc/test\"])[\"acc/test\"][0]\n",
    "\n",
    "    train_loss = run.history(keys=[\"loss/train\"])[\"loss/train\"][0]\n",
    "    test_loss = run.history(keys=[\"loss/test\"])[\"loss/test\"][0]\n",
    "\n",
    "    exps[merger_mapped][num_models] = {\n",
    "        \"train_acc\": train_acc,\n",
    "        \"test_acc\": test_acc,\n",
    "        \"train_loss\": train_loss,\n",
    "        \"test_loss\": test_loss,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot train and test accuracies\n",
    "records = []\n",
    "\n",
    "# exps has structure {merger_name: [ {'acc:' acc(1, 2), 'loss': loss(1,2)}, {'acc': acc(1, 2,3), 'loss': loss(1, 2,3)}, ...], ...}]}\n",
    "# where acc(1, 2, 3) is the accuracy of the model merged from seeds 1, 2, 3\n",
    "for merger_name, merger_data in exps.items():\n",
    "    for results in merger_data:\n",
    "        if len(results) == 0:\n",
    "            continue\n",
    "\n",
    "        record = {\n",
    "            \"merger\": merger_name,\n",
    "            \"train_acc\": results[\"train_acc\"],\n",
    "            \"test_acc\": results[\"test_acc\"],\n",
    "            \"train_loss\": results[\"train_loss\"],\n",
    "            \"test_loss\": results[\"test_loss\"],\n",
    "        }\n",
    "\n",
    "        records.append(record)\n",
    "\n",
    "\n",
    "df = pd.DataFrame(records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merger_dfs = {merger: df[df[\"merger\"] == merger] for merger in mergers}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for merger_df in merger_dfs.values():\n",
    "    merger_df.index = range(2, len(merger_df) + 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretty_metric = {\n",
    "    \"acc\": \"Accuracy\",\n",
    "    \"loss\": \"Loss\",\n",
    "}\n",
    "\n",
    "color_map = {\n",
    "    \"git_rebasin\": {\n",
    "        \"train\": palette[\"light red\"],\n",
    "        \"test\": palette[\"light red\"],\n",
    "    },\n",
    "    \"frank_wolfe\": {\n",
    "        \"train\": palette[\"green\"],\n",
    "        \"test\": palette[\"green\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "dash_map = {\n",
    "    \"git_rebasin\": {\n",
    "        \"train\": \"-\",\n",
    "        \"test\": \":\",\n",
    "    },\n",
    "    \"frank_wolfe\": {\n",
    "        \"train\": \"-\",\n",
    "        \"test\": \":\",\n",
    "    },\n",
    "}\n",
    "\n",
    "merger_map = {\n",
    "    \"git_rebasin\": r\"MergeMany\",\n",
    "    \"frank_wolfe\": r\"$C^2M^3$\",\n",
    "}\n",
    "\n",
    "# Create subplots\n",
    "fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
    "\n",
    "for merger_name, merger_df in merger_dfs.items():\n",
    "    for metric_ind, metric in enumerate([\"acc\", \"loss\"]):\n",
    "        ax = axes[metric_ind]\n",
    "        show_legend = True if metric_ind == 0 else False\n",
    "\n",
    "        train_label = merger_map[merger_name] + \" (train)\"\n",
    "        test_label = merger_map[merger_name] + \" (test)\"\n",
    "\n",
    "        ax.plot(\n",
    "            merger_df.index,\n",
    "            merger_df[f\"train_{metric}\"],\n",
    "            linestyle=dash_map[merger_name][\"train\"],\n",
    "            color=color_map[merger_name][\"train\"],\n",
    "            label=train_label,\n",
    "        )\n",
    "        ax.plot(\n",
    "            merger_df.index,\n",
    "            merger_df[f\"test_{metric}\"],\n",
    "            linestyle=dash_map[merger_name][\"test\"],\n",
    "            color=color_map[merger_name][\"test\"],\n",
    "            label=test_label,\n",
    "        )\n",
    "\n",
    "        ax.set_title(pretty_metric[metric])\n",
    "        ax.set_xlabel(\"Number of models\")\n",
    "        ax.set_ylabel(metric.capitalize())\n",
    "\n",
    "# Adjust legend and layout\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc=\"lower center\", bbox_to_anchor=(0.55, -0.2), ncol=2)\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.savefig(f\"figures/scaling_exp_{model_name}.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
