{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "import itertools\n",
    "from ccmm.utils.utils import l2_norm_models\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "from functools import partial\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "from ccmm.utils.utils import fuse_batch_norm_into_conv\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    linear_interpolate,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "\n",
    "from ccmm.matching.utils import load_permutations\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "    }\n",
    ")\n",
    "sns.set_context(\"talk\")\n",
    "\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "from ccmm.utils.plot import Palette\n",
    "\n",
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "palette"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")\n",
    "cfg = compose(config_name=\"matching_n_models\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 5000\n",
    "num_train_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "\n",
    "# (a, b), (a, c), (b, c), ...\n",
    "all_combinations = get_all_symbols_combinations(symbols)\n",
    "# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc\n",
    "canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(f\"Matching the following model pairs: {canonical_combinations}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Git Rebasin "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.matcher import GitRebasinMatcher\n",
    "\n",
    "matcher = GitRebasinMatcher(name=\"git_rebasin\", permutation_spec=permutation_spec)\n",
    "pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import get_inverse_permutations\n",
    "\n",
    "universe_model = models[\"c\"]\n",
    "map_to_universe = {symb: None for symb in symbols}\n",
    "map_from_universe = {symb: None for symb in symbols}\n",
    "\n",
    "for symb, model in models.items():\n",
    "    permutations, perm_history = matcher(fixed=universe_model.model, permutee=model.model)\n",
    "    map_to_universe[symb] = permutations\n",
    "    map_from_universe[symb] = get_inverse_permutations(permutations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A -> B -> C\n",
    "\n",
    "perm = partial(apply_permutation_to_statedict, permutation_spec)\n",
    "\n",
    "A_to_univ = perm(map_to_universe[\"a\"], models[\"a\"].model.state_dict())\n",
    "A_to_B = perm(map_from_universe[\"b\"], A_to_univ)\n",
    "\n",
    "A_B_C = perm(map_to_universe[\"b\"], A_to_B)\n",
    "A_B_C = perm(map_from_universe[\"c\"], A_B_C)\n",
    "\n",
    "ref_model = copy.deepcopy(models[\"a\"])\n",
    "ref_model.model.load_state_dict(A_B_C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(ref_model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PLOT: variance when merging using pairwise maps on different reference basins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data\n",
    "categories = [\"a\", \"b\", \"c\", \"d\", \"e\"]\n",
    "test_values = [78.55, 65.32, 78.81, 77.72, 78.61]\n",
    "train_values = [82.05, 67.28, 82.63, 81.36, 82.02]\n",
    "mean_test = 75.8\n",
    "mean_train = 79.06\n",
    "C2M3_train = 82.41\n",
    "stddev_test = 5\n",
    "C2M3_test = 78.90\n",
    "# Bar width\n",
    "bar_width = 0.35\n",
    "\n",
    "# Positions of bars on x-axis\n",
    "r1 = np.arange(len(categories))\n",
    "r2 = [x + bar_width for x in r1]\n",
    "\n",
    "# Create the plot\n",
    "fig, ax = plt.subplots(figsize=(12, 8))\n",
    "\n",
    "# Bar plot for test values\n",
    "bars1 = ax.bar(r1, test_values, color=palette[\"light red\"], width=bar_width, edgecolor=\"grey\", label=\"Pairwise -- test\")\n",
    "bars2 = ax.bar(\n",
    "    r2, train_values, color=palette[\"light orange\"], width=bar_width, edgecolor=\"grey\", label=\"Pairwise -- train\"\n",
    ")\n",
    "\n",
    "# Adding accuracy scores on top of the bars\n",
    "for i, value in enumerate(test_values):\n",
    "    ax.text(i, value + 0.5, f\"{value}\", ha=\"center\", va=\"bottom\")\n",
    "\n",
    "for i, value in enumerate(train_values):\n",
    "    ax.text(i + bar_width, value + 0.5, f\"{value}\", ha=\"center\", va=\"bottom\")\n",
    "\n",
    "# Adding C^2M^3 bar\n",
    "ax.bar(\n",
    "    len(categories) + 1,\n",
    "    C2M3_test,\n",
    "    color=palette[\"dark blue\"],\n",
    "    width=bar_width,\n",
    "    edgecolor=\"grey\",\n",
    "    label=\"$C^2M^3$ -- test\",\n",
    ")\n",
    "ax.bar(\n",
    "    len(categories) + 1 + bar_width,\n",
    "    C2M3_train,\n",
    "    color=palette[\"green\"],\n",
    "    width=bar_width,\n",
    "    edgecolor=\"grey\",\n",
    "    label=\"$C^2M^3$ -- train\",\n",
    ")\n",
    "\n",
    "# Highlight mean and stddev for test\n",
    "ax.axhline(y=mean_test, color=palette[\"light red\"], linestyle=\"--\", linewidth=1, label=\"Pairwise mean -- test\")\n",
    "ax.axhline(y=mean_train, color=palette[\"light orange\"], linestyle=\"--\", linewidth=1, label=\"Pairwise mean -- train\")\n",
    "\n",
    "# Adding accuracy score for C^2M^3\n",
    "ax.text(len(categories) + 1, C2M3_test + 0.5, f\"{C2M3_test}\", ha=\"center\", va=\"bottom\")\n",
    "ax.text(len(categories) + 1 + bar_width, C2M3_train + 0.5, f\"{C2M3_train}\", ha=\"center\", va=\"bottom\")\n",
    "\n",
    "# Set y-axis limits\n",
    "ax.set_ylim(60, 85)\n",
    "\n",
    "# Add labels\n",
    "ax.set_ylabel(\"Accuracy\", fontweight=\"bold\")\n",
    "ax.set_title(\"Accuracy of the merged model for different reference models\", fontweight=\"bold\")\n",
    "ax.set_xticks(list(r1) + [len(categories) + 1])\n",
    "ax.set_xticklabels(categories + [\"universe\"])\n",
    "\n",
    "# Add legend\n",
    "ax.legend(ncol=3, loc=\"lower center\", bbox_to_anchor=(0.5, -0.25))\n",
    "\n",
    "plt.savefig(\"figures/pairwise_to_reference_barplot.pdf\", bbox_inches=\"tight\")\n",
    "# Show plot\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Synchronized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.matcher import FrankWolfeSynchronizedMatcher\n",
    "\n",
    "fw_matcher = FrankWolfeSynchronizedMatcher(\n",
    "    name=\"frank_wolfe\", permutation_spec=permutation_spec, initialization_method=\"identity\", max_iter=200\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations, perm_history = fw_matcher(models, sorted_symbols, canonical_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import unfactor_permutations\n",
    "\n",
    "models_permuted_pairwise = {\n",
    "    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols\n",
    "}\n",
    "pairwise_permutations = unfactor_permutations(permutations)\n",
    "\n",
    "for fixed, permutee in all_combinations:\n",
    "    ref_model = copy.deepcopy(models[\"c\"])\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(\n",
    "        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()\n",
    "    )\n",
    "\n",
    "    ref_model.model.load_state_dict(permuted_params)\n",
    "    models_permuted_pairwise[fixed][permutee] = ref_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_to_C = models_permuted_pairwise[\"c\"][\"a\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(A_to_C, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
