{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching_n_models\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 99\n",
    "num_sampled_points = 500  # 2048\n",
    "num_test_samples = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "\n",
    "# (a, b), (a, c), (b, c), ...\n",
    "all_combinations = get_all_symbols_combinations(symbols)\n",
    "# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc\n",
    "canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]\n",
    "\n",
    "pylogger.info(f\"Matching the following model pairs: {canonical_combinations}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)\n",
    "pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {symb: model.to(\"cpu\") for symb, model in models.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers)\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample points in the param space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_points = qmc.scale(\n",
    "    qmc.Sobol(d=2, scramble=True, seed=cfg.seed_index).random(num_sampled_points),\n",
    "    [-0.5, -0.5],\n",
    "    [1.5, 1.5],\n",
    ")\n",
    "\n",
    "pylogger.info(eval_points[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}\n",
    "\n",
    "\n",
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    permuted_params = apply_permutation_to_statedict(permutation_spec, permutations[symbol], model.model.state_dict())\n",
    "    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models_permuted_to_universe = {\n",
    "    symbol: torch.nn.utils.parameters_to_vector(model.parameters())\n",
    "    for symbol, model in models_permuted_to_universe.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import unfactor_permutations\n",
    "\n",
    "models_permuted_to_ref_model = {symbol: copy.deepcopy(model) for symbol, model in models.items()}\n",
    "\n",
    "ref_model_symb = \"a\"\n",
    "pairwise_permutations = unfactor_permutations(permutations)\n",
    "\n",
    "permuted_params = apply_permutation_to_statedict(\n",
    "    permutation_spec, pairwise_permutations[ref_model_symb][\"b\"], models_permuted_to_ref_model[\"b\"].model.state_dict()\n",
    ")\n",
    "models_permuted_to_ref_model[\"b\"].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models_permuted_to_ref_model = {\n",
    "    symbol: torch.nn.utils.parameters_to_vector(model.parameters())\n",
    "    for symbol, model in models_permuted_to_ref_model.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proj = lambda a, b: torch.dot(a, b) / torch.dot(b, b) * b\n",
    "norm = lambda a: torch.sqrt(torch.dot(a, a))\n",
    "normalize = lambda a: a / norm(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_a_flat = flat_models[\"a\"]\n",
    "model_b_flat = flat_models[\"b\"]\n",
    "\n",
    "model_b_flat_permuted = flat_models_permuted_to_universe[\"b\"]\n",
    "\n",
    "model_b_flat_permuted_pairwise = flat_models_permuted_to_ref_model[\"b\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating basis vectors\n",
    "\n",
    "# model_a_flat is the origin\n",
    "# basis1 is the vector from model_a_flat to model_b_flat\n",
    "# 2 basis vectors: one goes to theta_a, the other to pi(theta_b)\n",
    "\n",
    "\n",
    "def get_basis_vectors(origin_model, model_b_flat, model_b_flat_permuted):\n",
    "    basis1 = model_b_flat - origin_model\n",
    "    scale = norm(basis1)\n",
    "    basis1_normed = normalize(basis1)\n",
    "\n",
    "    # a_to_pi_b is the vector from pi(theta_b) to model_a_flat\n",
    "    a_to_pi_b = model_b_flat_permuted - origin_model\n",
    "    # make the basis orthogonal by discarding the component of a_to_pi_b in the direction of basis1\n",
    "    basis2 = a_to_pi_b - proj(a_to_pi_b, basis1)\n",
    "    basis2_normed = normalize(basis2)\n",
    "\n",
    "    return basis1_normed, basis2_normed, scale\n",
    "\n",
    "\n",
    "basis1_normed, basis2_normed, scale = get_basis_vectors(\n",
    "    origin_model=model_a_flat, model_b_flat=model_b_flat, model_b_flat_permuted=model_b_flat_permuted_pairwise\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project2d = (\n",
    "    lambda theta: (\n",
    "        torch.stack([torch.dot(theta - model_a_flat, basis1_normed), torch.dot(theta - model_a_flat, basis2_normed)])\n",
    "        / scale\n",
    "    )\n",
    "    .detach()\n",
    "    .cpu()\n",
    "    .numpy()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "\n",
    "def get_pentagon_vertices(center_x, center_y, radius):\n",
    "    pentagon_vertices = []\n",
    "    for i in range(5):\n",
    "        angle_deg = 72 * i  # 72 degrees between each point\n",
    "        angle_rad = math.radians(angle_deg)  # Convert to radians\n",
    "        x = radius * math.cos(angle_rad) + center_x\n",
    "        y = radius * math.sin(angle_rad) + center_y\n",
    "        pentagon_vertices.append((x, y))\n",
    "\n",
    "    return np.array(pentagon_vertices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def represent_barycentric_coordinates(x):\n",
    "    origins = get_pentagon_vertices(0.5, 0.5, 0.9)\n",
    "    A = origins.transpose(1, 0)\n",
    "\n",
    "    z, residuals, rank, s = np.linalg.lstsq(A, x, rcond=None)\n",
    "\n",
    "    assert np.allclose(np.dot(A, z), x)\n",
    "\n",
    "    return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = eval_points[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# also solve for the barycentric coordinates of the models in the high-dimensional space\n",
    "# scale the entire equation constraining the coefficients to sum up to 1 by a large scalar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_one(xy, model_flat, basis1_normed, basis2_normed, scale, model, trainer, test_loader):\n",
    "    new_flat_params = model_flat + scale * (basis1_normed * xy[0] + basis2_normed * xy[1])\n",
    "    new_params = vector_to_state_dict(new_flat_params, model.model)\n",
    "\n",
    "    model.model.load_state_dict(new_params)\n",
    "\n",
    "    eval_results = trainer.test(model, test_loader, verbose=False)\n",
    "\n",
    "    return eval_results\n",
    "\n",
    "\n",
    "eval_results = np.array(\n",
    "    [\n",
    "        eval_one(xy, model_a_flat, basis1_normed, basis2_normed, scale, model, trainer, test_loader)\n",
    "        for xy in tqdm(eval_points)\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_losses = np.array([res[0][\"loss/test\"] for res in eval_results])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create grid values first.\n",
    "xi = np.linspace(-0.5, 1.5)\n",
    "yi = np.linspace(-0.5, 1.5)\n",
    "\n",
    "# Linearly interpolate the data (x, y) on a grid defined by (xi, yi).\n",
    "triang = tri.Triangulation(eval_points[:, 0], eval_points[:, 1])\n",
    "# We need to cap the maximum loss value so that the contouring is not completely saturated by wildly large losses\n",
    "interpolator = tri.LinearTriInterpolator(triang, np.clip(test_losses, None, 0.55))\n",
    "\n",
    "# interpolator = tri.LinearTriInterpolator(triang, jnp.log(jnp.minimum(1.5, eval_results[:, 0])))\n",
    "zi = interpolator(*np.meshgrid(xi, yi))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "num_levels = 13\n",
    "plt.contour(xi, yi, zi, levels=num_levels, linewidths=0.25, colors=\"grey\", alpha=0.5)\n",
    "# cmap_name = \"RdGy\"\n",
    "# cmap_name = \"RdYlBu\"\n",
    "# cmap_name = \"Spectral\"\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "# cmap_name = \"YlOrBr_r\"\n",
    "# cmap_name = \"RdBu\"\n",
    "\n",
    "\n",
    "# See https://stackoverflow.com/a/18926541/3880977\n",
    "def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):\n",
    "    return colors.LinearSegmentedColormap.from_list(\n",
    "        \"trunc({n},{a:.2f},{b:.2f})\".format(n=cmap.name, a=minval, b=maxval),\n",
    "        cmap(np.linspace(minval, maxval, n)),\n",
    "    )\n",
    "\n",
    "\n",
    "cmap = truncate_colormap(plt.get_cmap(cmap_name), 0.0, 1.5)  # 0.9)\n",
    "plt.contourf(xi, yi, zi, levels=num_levels, cmap=cmap, extend=\"both\")\n",
    "\n",
    "x, y = project2d(model_a_flat)\n",
    "plt.scatter([x], [y], marker=\"x\", color=\"white\", zorder=10)\n",
    "\n",
    "x, y = project2d(model_b_flat)\n",
    "plt.scatter([x], [y], marker=\"x\", color=\"white\", zorder=10)\n",
    "\n",
    "x, y = project2d(model_b_flat_permuted_pairwise)\n",
    "plt.scatter([x], [y], marker=\"x\", color=\"white\", zorder=10)\n",
    "\n",
    "label_bboxes = dict(facecolor=\"tab:grey\", boxstyle=\"round\", edgecolor=\"none\", alpha=0.5)\n",
    "plt.text(\n",
    "    -0.075,\n",
    "    -0.1,\n",
    "    r\"${\\bf \\Theta_A}$\",\n",
    "    color=\"white\",\n",
    "    fontsize=24,\n",
    "    bbox=label_bboxes,\n",
    "    horizontalalignment=\"right\",\n",
    "    verticalalignment=\"top\",\n",
    ")\n",
    "plt.text(\n",
    "    1.075,\n",
    "    -0.1,\n",
    "    r\"${\\bf \\Theta_B}$\",\n",
    "    color=\"white\",\n",
    "    fontsize=24,\n",
    "    bbox=label_bboxes,\n",
    "    horizontalalignment=\"left\",\n",
    "    verticalalignment=\"top\",\n",
    ")\n",
    "x, y = project2d(model_b_flat_permuted_pairwise)\n",
    "plt.text(\n",
    "    x - 0.075,\n",
    "    y + 0.1,\n",
    "    r\"${\\bf \\pi(\\Theta_B)}$\",\n",
    "    color=\"white\",\n",
    "    fontsize=24,\n",
    "    bbox=label_bboxes,\n",
    "    horizontalalignment=\"right\",\n",
    "    verticalalignment=\"bottom\",\n",
    ")\n",
    "\n",
    "# https://github.com/matplotlib/matplotlib/issues/17284#issuecomment-772820638\n",
    "# Draw line only\n",
    "connectionstyle = \"arc3,rad=-0.3\"\n",
    "plt.annotate(\n",
    "    \"\",\n",
    "    xy=(1, 0),\n",
    "    xytext=(x, y),\n",
    "    arrowprops=dict(\n",
    "        arrowstyle=\"-\",\n",
    "        edgecolor=\"white\",\n",
    "        facecolor=\"none\",\n",
    "        linewidth=5,\n",
    "        linestyle=(0, (5, 3)),\n",
    "        shrinkA=20,\n",
    "        shrinkB=15,\n",
    "        connectionstyle=connectionstyle,\n",
    "    ),\n",
    ")\n",
    "# Draw arrow head only\n",
    "plt.annotate(\n",
    "    \"\",\n",
    "    xy=(1, 0),\n",
    "    xytext=(x, y),\n",
    "    arrowprops=dict(\n",
    "        arrowstyle=\"<|-\",\n",
    "        edgecolor=\"none\",\n",
    "        facecolor=\"white\",\n",
    "        mutation_scale=40,\n",
    "        linewidth=0,\n",
    "        shrinkA=12.5,\n",
    "        shrinkB=15,\n",
    "        connectionstyle=connectionstyle,\n",
    "    ),\n",
    ")\n",
    "\n",
    "plt.annotate(\n",
    "    \"\",\n",
    "    xy=(0, 0),\n",
    "    xytext=(x, y),\n",
    "    arrowprops=dict(\n",
    "        arrowstyle=\"-\",\n",
    "        edgecolor=\"white\",\n",
    "        alpha=0.5,\n",
    "        facecolor=\"none\",\n",
    "        linewidth=2,\n",
    "        linestyle=\"-\",\n",
    "        shrinkA=10,\n",
    "        shrinkB=10,\n",
    "    ),\n",
    ")\n",
    "plt.annotate(\n",
    "    \"\",\n",
    "    xy=(0, 0),\n",
    "    xytext=(1, 0),\n",
    "    arrowprops=dict(\n",
    "        arrowstyle=\"-\",\n",
    "        edgecolor=\"white\",\n",
    "        alpha=0.5,\n",
    "        facecolor=\"none\",\n",
    "        linewidth=2,\n",
    "        linestyle=\"-\",\n",
    "        shrinkA=10,\n",
    "        shrinkB=10,\n",
    "    ),\n",
    ")\n",
    "\n",
    "# plt.gca().add_artist(\n",
    "#     AnnotationBbox(\n",
    "#         OffsetImage(\n",
    "#             plt.imread(\n",
    "#                 \"https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/check-mark-button_2705.png\"\n",
    "#             ),\n",
    "#             zoom=0.1,\n",
    "#         ),\n",
    "#         (x / 2, y / 2),\n",
    "#         frameon=False,\n",
    "#     )\n",
    "# )\n",
    "# plt.gca().add_artist(\n",
    "#     AnnotationBbox(\n",
    "#         OffsetImage(\n",
    "#             plt.imread(\n",
    "#                 \"https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/cross-mark_274c.png\"\n",
    "#             ),\n",
    "#             zoom=0.1,\n",
    "#         ),\n",
    "#         (0.5, 0),\n",
    "#         frameon=False,\n",
    "#     )\n",
    "# )\n",
    "\n",
    "# \"Git Re-Basin\" box\n",
    "#   box_x = 0.5 * (arrow_start[0] + arrow_stop[0])\n",
    "#   box_y = 0.5 * (arrow_start[1] + arrow_stop[1])\n",
    "# box_x = 0.5 * (arrow_start[0] + arrow_stop[0]) + 0.325\n",
    "# box_y = 0.5 * (arrow_start[1] + arrow_stop[1]) + 0.2\n",
    "\n",
    "box_x = 0.5\n",
    "box_y = 1.3\n",
    "git_rebasin_text = r\"$C^2M^2$\"\n",
    "\n",
    "# Draw box only\n",
    "plt.text(\n",
    "    box_x,\n",
    "    box_y,\n",
    "    git_rebasin_text,\n",
    "    color=(0.0, 0.0, 0.0, 0.0),\n",
    "    fontsize=24,\n",
    "    horizontalalignment=\"center\",\n",
    "    verticalalignment=\"center\",\n",
    "    bbox=dict(boxstyle=\"round\", fc=(1, 1, 1, 1), ec=\"black\", pad=0.4),\n",
    ")\n",
    "# Draw text only\n",
    "plt.text(\n",
    "    box_x,\n",
    "    box_y - 0.0115,\n",
    "    git_rebasin_text,\n",
    "    color=(0.0, 0.0, 0.0, 1.0),\n",
    "    fontsize=24,\n",
    "    horizontalalignment=\"center\",\n",
    "    verticalalignment=\"center\",\n",
    ")\n",
    "\n",
    "# plt.colorbar()\n",
    "plt.xlim(-0.4, 1.4)\n",
    "plt.ylim(-0.45, 1.3)\n",
    "#   plt.xlim(-0.9, 1.9)\n",
    "#   plt.ylim(-0.9, 1.9)\n",
    "# plt.xticks([])\n",
    "# plt.yticks([])\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"resnet_cifar_loss_contour.png\", dpi=300)\n",
    "# plt.savefig(\"resnet_cifar_mlp_loss_contour.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
