{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import logging\n",
    "\n",
    "from tqdm import tqdm\n",
    "from wandb.sdk.wandb_run import Run\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import numpy as np\n",
    "import plotly.graph_objs as go\n",
    "import plotly.io as pio\n",
    "import matplotlib.pyplot as plt\n",
    "from nn_core.common import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.plot import Palette\n",
    "\n",
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "    }\n",
    ")\n",
    "\n",
    "pylogger = logging.getLogger(__name__)\n",
    "palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "entity, project = \"ANONYMIZED\", \"cycle-consistent-model-merging\"  # set to your entity and project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_runs(entity, project, positive_tags, negative_tags=None):\n",
    "    filters_pos_tags = {\"$and\": [{\"tags\": {\"$eq\": pos_tag}} for pos_tag in positive_tags]}\n",
    "    filters_neg_tags = {}\n",
    "\n",
    "    print(filters_pos_tags)\n",
    "    filters = {**filters_pos_tags, **filters_neg_tags}\n",
    "    runs = api.runs(entity + \"/\" + project, filters=filters)\n",
    "\n",
    "    print(f\"There are {len(runs)} runs respecting these conditions.\")\n",
    "    return runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tags = [\"width_exp\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs = get_runs(entity, project, positive_tags=tags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_key = \"matching/seed_index\"\n",
    "model_pair_key = \"matching/model_seeds\"\n",
    "\n",
    "merger_key = \"matching/merger/_target_\"\n",
    "\n",
    "gitrebasin_classname = \"ccmm.matching.merger.GitRebasinMerger\"\n",
    "frankwolfe_classname = \"ccmm.matching.merger.FrankWolfeSynchronizedMerger\"\n",
    "naive_classname = \"ccmm.matching.merger.DummyMerger\"\n",
    "\n",
    "model_key = \"model/name\"\n",
    "merger_mapping = {\n",
    "    gitrebasin_classname: \"git_rebasin\",\n",
    "    frankwolfe_classname: \"frank_wolfe\",\n",
    "    naive_classname: \"naive\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merger_mapping = {\n",
    "    gitrebasin_classname: \"git_rebasin\",\n",
    "    frankwolfe_classname: \"frank_wolfe\",\n",
    "    naive_classname: \"naive\",\n",
    "}\n",
    "\n",
    "mergers = [\"frank_wolfe\", \"git_rebasin\", \"naive\"]\n",
    "\n",
    "widths = [1, 2, 4, 8, 16]\n",
    "exps = {\n",
    "    merger: {\"repaired\": {width: None for width in widths}, \"untouched\": {width: None for width in widths}}\n",
    "    for merger in mergers\n",
    "}\n",
    "print(exps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for run in tqdm(runs):\n",
    "    run: Run\n",
    "    cfg = run.config\n",
    "\n",
    "    if len(cfg) == 0:\n",
    "        pylogger.warning(\"Runs are still running, skipping\")\n",
    "        continue\n",
    "\n",
    "    num_models = len(cfg[\"matching/model_seeds\"])\n",
    "\n",
    "    hist = run.scan_history()\n",
    "\n",
    "    merger_mapped = merger_mapping[cfg[merger_key]]\n",
    "\n",
    "    if \"merged\" in cfg[\"core/tags\"]:\n",
    "        repaired_key = \"untouched\"\n",
    "    elif \"repaired\" in cfg[\"core/tags\"]:\n",
    "        repaired_key = \"repaired\"\n",
    "    else:\n",
    "        pylogger.warning(\"Run is neither merged nor repaired, skipping\")\n",
    "        continue\n",
    "\n",
    "    train_acc = run.history(keys=[\"acc/train\"])[\"acc/train\"][0]\n",
    "    test_acc = run.history(keys=[\"acc/test\"])[\"acc/test\"][0]\n",
    "\n",
    "    train_loss = run.history(keys=[\"loss/train\"])[\"loss/train\"][0]\n",
    "    test_loss = run.history(keys=[\"loss/test\"])[\"loss/test\"][0]\n",
    "\n",
    "    width = cfg[\"model/widen_factor\"]\n",
    "    exps[merger_mapped][repaired_key][width] = {\n",
    "        \"train_acc\": train_acc,\n",
    "        \"test_acc\": test_acc,\n",
    "        \"train_loss\": train_loss,\n",
    "        \"test_loss\": test_loss,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "\n",
    "for merger_name, merger_repaired_data in exps.items():\n",
    "    for repaired_flag, width_data in merger_repaired_data.items():\n",
    "        for width, metrics in width_data.items():\n",
    "            if metrics:\n",
    "                record = {\n",
    "                    \"merger\": merger_name + \"_\" + repaired_flag,\n",
    "                    \"train_acc\": metrics[\"train_acc\"],\n",
    "                    \"test_acc\": metrics[\"test_acc\"],\n",
    "                    \"train_loss\": metrics[\"train_loss\"],\n",
    "                    \"test_loss\": metrics[\"test_loss\"],\n",
    "                    \"width\": width,\n",
    "                }\n",
    "\n",
    "                records.append(record)\n",
    "\n",
    "df = pd.DataFrame(records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "import plotly.io as pio\n",
    "from plotly.subplots import make_subplots\n",
    "\n",
    "pretty_metric = {\n",
    "    \"acc\": \"$Accuracy$\",\n",
    "    \"loss\": \"$Loss$\",\n",
    "}\n",
    "\n",
    "color_map = {\n",
    "    \"train\": \"blue\",\n",
    "    \"test\": \"red\",\n",
    "}\n",
    "\n",
    "legend_pos = {\"x\": 0.8, \"y\": 0.9}\n",
    "\n",
    "\n",
    "repaired_symbol = lambda repaired_flag: \"^\\dagger\" if repaired_flag == \"repaired\" else \"\"\n",
    "\n",
    "style_map = {\n",
    "    \"repaired\": {\"dash\": \"dash\", \"symbol\": \"circle\"},\n",
    "    \"untouched\": {\"dash\": \"solid\", \"symbol\": \"square\"},\n",
    "}\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=1, cols=2, subplot_titles=[r\"$\\text{Accuracy}$\", r\"$\\text{Loss}$\"]\n",
    ")  # , horizontal_spacing=spacing, vertical_spacing=spacing)  # Adjust spacing as needed\n",
    "\n",
    "for metric_ind, metric in enumerate([\"acc\", \"loss\"]):\n",
    "    for repaired_flag in [\"repaired\", \"untouched\"]:\n",
    "        df_repaired = df[df[\"merger\"].str.contains(repaired_flag)]\n",
    "\n",
    "        dash_style = style_map[repaired_flag][\"dash\"]\n",
    "\n",
    "        symbol = repaired_symbol(repaired_flag)\n",
    "        fig.add_trace(\n",
    "            go.Scatter(\n",
    "                x=df_repaired[\"width\"],\n",
    "                y=df_repaired[f\"train_{metric}\"],\n",
    "                mode=\"lines+markers\",\n",
    "                name=r\"$\\text{Train}\" + symbol + r\"$\",\n",
    "                line=dict(color=color_map[\"train\"], dash=dash_style, width=1),\n",
    "                showlegend=True if metric_ind == 0 else False,\n",
    "            ),\n",
    "            row=1,\n",
    "            col=metric_ind + 1,\n",
    "        )\n",
    "\n",
    "        fig.add_trace(\n",
    "            go.Scatter(\n",
    "                x=df_repaired[\"width\"],\n",
    "                y=df_repaired[f\"test_{metric}\"],\n",
    "                mode=\"lines+markers\",\n",
    "                name=r\"$\\text{Test}\" + symbol + r\"$\",\n",
    "                line=dict(color=color_map[\"test\"], dash=dash_style, width=1),\n",
    "                showlegend=True if metric_ind == 0 else False,\n",
    "            ),\n",
    "            row=1,\n",
    "            col=metric_ind + 1,\n",
    "        )\n",
    "\n",
    "        fig.update_xaxes(title_text=r\"$\\text{Width}$\", row=1, col=metric_ind + 1)\n",
    "\n",
    "fig.update_layout(\n",
    "    legend=dict(x=legend_pos[\"x\"], y=legend_pos[\"y\"], bgcolor=\"rgba(255,255,255,0.)\"),\n",
    "    width=600,\n",
    "    height=300,\n",
    "    font=dict(size=22, family=\"Times New Roman\"),\n",
    "    margin=dict(l=50, r=50, t=50, b=50),\n",
    ")\n",
    "fig.update_annotations(font_size=25)\n",
    "\n",
    "\n",
    "fig.show()\n",
    "pio.write_image(fig, f\"figures/width_exp.pdf\", format=\"pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "# Configuration\n",
    "pretty_metric = {\n",
    "    \"acc\": \"Accuracy\",\n",
    "    \"loss\": \"Loss\",\n",
    "}\n",
    "\n",
    "color_map = {\n",
    "    \"train\": palette[\"light red\"],\n",
    "    \"test\": palette[\"green\"],\n",
    "}\n",
    "\n",
    "\n",
    "repaired_symbol = lambda repaired_flag: \"†\" if repaired_flag == \"repaired\" else \"\"\n",
    "\n",
    "style_map = {\n",
    "    \"repaired\": {\"dash\": \"--\", \"symbol\": \"o\"},\n",
    "    \"untouched\": {\"dash\": \"-\", \"symbol\": \"s\"},\n",
    "}\n",
    "\n",
    "# Create subplots\n",
    "fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
    "\n",
    "for metric_ind, metric in enumerate([\"acc\", \"loss\"]):\n",
    "    ax = axes[metric_ind]\n",
    "    for repaired_flag in [\"repaired\", \"untouched\"]:\n",
    "        df_repaired = df[df[\"merger\"].str.contains(repaired_flag)]\n",
    "        dash_style = style_map[repaired_flag][\"dash\"]\n",
    "        marker_style = style_map[repaired_flag][\"symbol\"]\n",
    "        symbol = repaired_symbol(repaired_flag)\n",
    "\n",
    "        ax.plot(\n",
    "            df_repaired[\"width\"],\n",
    "            df_repaired[f\"train_{metric}\"],\n",
    "            marker=marker_style,\n",
    "            linestyle=dash_style,\n",
    "            color=color_map[\"train\"],\n",
    "            label=f\"Train {symbol}\",\n",
    "        )\n",
    "        ax.plot(\n",
    "            df_repaired[\"width\"],\n",
    "            df_repaired[f\"test_{metric}\"],\n",
    "            marker=marker_style,\n",
    "            linestyle=dash_style,\n",
    "            color=color_map[\"test\"],\n",
    "            label=f\"Test {symbol}\",\n",
    "        )\n",
    "\n",
    "    ax.set_title(pretty_metric[metric])\n",
    "    ax.set_xlabel(\"Width\")\n",
    "    ax.set_ylabel(metric.capitalize())\n",
    "\n",
    "    if metric_ind == 1:\n",
    "        ax.legend(loc=\"upper right\", ncol=1)\n",
    "\n",
    "# Adjust layout and show plot\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "\n",
    "plt.savefig(\"figures/width_exp.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
