{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ccmm.utils.plot import Palette\n",
    "from nn_core.common import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "matplotlib.rcParams[\"text.usetex\"] = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merge using arbitrary model as reference point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_accs = [78.55, 65.32, 78.81, 77.72, 78.61]\n",
    "train_accs = [82.05, 67.28, 82.63, 81.36, 82.02]\n",
    "\n",
    "train_accs = np.array(train_accs)\n",
    "test_accs = np.array(test_accs)\n",
    "\n",
    "print(train_accs.mean(), train_accs.std())\n",
    "print(test_accs.mean(), test_accs.std())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching accuracies for different seeds on git re-basin"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_accs = {\n",
    "    (1, 2): np.array(\n",
    "        [\n",
    "            0.7619400024414062,\n",
    "            0.7817599773406982,\n",
    "            0.7826399803161621,\n",
    "            0.7992200255393982,\n",
    "            0.7707399725914001,\n",
    "            0.7551400065422058,\n",
    "            0.7847200036048889,\n",
    "            0.7534199953079224,\n",
    "            0.8099200129508972,\n",
    "        ]\n",
    "    ),\n",
    "    (1, 3): np.array(\n",
    "        [\n",
    "            0.6714800000190735,\n",
    "            0.6920400261878967,\n",
    "            0.6909800171852112,\n",
    "            0.6874200105667114,\n",
    "            0.624239981174469,\n",
    "            0.6920199990272522,\n",
    "            0.6646000146865845,\n",
    "            0.7092000246047974,\n",
    "            0.6822999715805054,\n",
    "        ]\n",
    "    ),\n",
    "    (2, 3): np.array(\n",
    "        [\n",
    "            0.7508599758148193,\n",
    "            0.7400599718093872,\n",
    "            0.7481200098991394,\n",
    "            0.7238600254058838,\n",
    "            0.7573999762535095,\n",
    "            0.741919994354248,\n",
    "            0.7041199803352356,\n",
    "            0.7325999736785889,\n",
    "            0.7776399850845337,\n",
    "        ]\n",
    "    ),\n",
    "}\n",
    "\n",
    "test_accs = {\n",
    "    (1, 2): np.array(\n",
    "        [\n",
    "            0.727400004863739,\n",
    "            0.7450000047683716,\n",
    "            0.7450000047683716,\n",
    "            0.765500009059906,\n",
    "            0.7368999719619751,\n",
    "            0.7258999943733215,\n",
    "            0.7419000267982483,\n",
    "            0.7160000205039978,\n",
    "            0.7760000228881836,\n",
    "        ]\n",
    "    ),\n",
    "    (1, 3): np.array(\n",
    "        [\n",
    "            0.6434999704360962,\n",
    "            0.6574000120162964,\n",
    "            0.6657999753952026,\n",
    "            0.6507999897003174,\n",
    "            0.6031000018119812,\n",
    "            0.6621999740600586,\n",
    "            0.6305000185966492,\n",
    "            0.6744999885559082,\n",
    "            0.6492999792098999,\n",
    "        ]\n",
    "    ),\n",
    "    (2, 3): np.array(\n",
    "        [\n",
    "            0.7014999985694885,\n",
    "            0.7060999870300293,\n",
    "            0.711899995803833,\n",
    "            0.6819000244140625,\n",
    "            0.7226999998092651,\n",
    "            0.7037000060081482,\n",
    "            0.6658999919891357,\n",
    "            0.6990000009536743,\n",
    "            0.7368999719619751,\n",
    "        ]\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_tuple = (2, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latex_row_str = \"\"\n",
    "for train_acc in train_accs[current_tuple]:\n",
    "    latex_row_str += f\"{train_acc:.2f} & \"\n",
    "\n",
    "latex_row_str += (\n",
    "    f\"{train_accs[current_tuple].mean():.2f}\"\n",
    "    + \" & \"\n",
    "    + f\"{train_accs[current_tuple].std():.3f}\"\n",
    "    + f\" & {train_accs[current_tuple].max() - train_accs[current_tuple].min():.3f}\"\n",
    "    \"\\\\\\\\ \\n\"\n",
    ")\n",
    "\n",
    "for test_acc in test_accs[current_tuple]:\n",
    "    latex_row_str += f\"{test_acc:.2f} & \"\n",
    "\n",
    "latex_row_str += (\n",
    "    f\"{test_accs[current_tuple].mean():.2f}\"\n",
    "    + \" & \"\n",
    "    + f\"{test_accs[current_tuple].std():.3f}\"\n",
    "    + f\" & {test_accs[current_tuple].max() - test_accs[current_tuple].min():.3f}\"\n",
    "    \"\\\\\\\\\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(latex_row_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Seed-wise data from the provided table\n",
    "seeds = list(range(1, 10))\n",
    "train_accuracy = [0.76, 0.78, 0.78, 0.80, 0.77, 0.76, 0.78, 0.75, 0.81]\n",
    "test_accuracy = [0.73, 0.75, 0.75, 0.77, 0.74, 0.73, 0.74, 0.72, 0.78]\n",
    "\n",
    "# Frank-Wolfe results are consistent across seeds\n",
    "fw_train_accuracy = [0.78] * 9  # Constant for all seeds\n",
    "fw_test_accuracy = [0.75] * 9  # Constant for all seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Create a new figure\n",
    "plt.figure()\n",
    "\n",
    "# Plotting Git Re-Basin train and test accuracy lines\n",
    "plt.plot(seeds, train_accuracy, marker=\"o\", linestyle=\"-\", color=palette[\"light red\"], label=\"Git Re-Basin - train\")\n",
    "plt.plot(seeds, test_accuracy, marker=\"o\", linestyle=\"-\", color=palette[\"green\"], label=\"Git Re-Basin - test\")\n",
    "\n",
    "# Plotting Frank-Wolfe train and test accuracy lines\n",
    "plt.plot(seeds, fw_train_accuracy, linestyle=\"--\", color=palette[\"light red\"], label=\"Frank-Wolfe - train\")\n",
    "plt.plot(seeds, fw_test_accuracy, linestyle=\"--\", color=palette[\"green\"], label=\"Frank-Wolfe - test\")\n",
    "\n",
    "# Adding labels and title\n",
    "plt.xlabel(\"Seed\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.15), ncol=2)\n",
    "\n",
    "plt.savefig(\"figures/git-re-basin-variance.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
