{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "from ccmm.utils.utils import fuse_batch_norm_into_conv\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    linear_interpolate_state_dicts,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "\n",
    "from ccmm.matching.utils import load_permutations\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching_n_models\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 5000\n",
    "num_train_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "\n",
    "# (a, b), (a, c), (b, c), ...\n",
    "all_combinations = get_all_symbols_combinations(symbols)\n",
    "# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc\n",
    "canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(f\"Matching the following model pairs: {canonical_combinations}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)\n",
    "pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {symb: model.to(\"cpu\") for symb, model in models.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models to universe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import perm_matrix_to_perm_indices\n",
    "\n",
    "models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}\n",
    "\n",
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    perms_to_universe = {}\n",
    "\n",
    "    for perm_name, perm in permutations[symbol].items():\n",
    "        perm = perm_indices_to_perm_matrix(perm)\n",
    "        perm_to_universe = perm.T\n",
    "        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)\n",
    "        perms_to_universe[perm_name] = perm_to_universe\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())\n",
    "    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models pairwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import unfactor_permutations\n",
    "\n",
    "models_permuted_pairwise = {\n",
    "    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols\n",
    "}\n",
    "\n",
    "pairwise_permutations = unfactor_permutations(permutations)\n",
    "\n",
    "for fixed, permutee in all_combinations:\n",
    "    ref_model = copy.deepcopy(models[\"a\"])\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(\n",
    "        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()\n",
    "    )\n",
    "    ref_model.model.load_state_dict(permuted_params)\n",
    "\n",
    "    models_permuted_pairwise[fixed][permutee] = ref_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check performance of models before and after permutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    trainer.test(models_permuted_to_universe[symbol], test_loader)\n",
    "    trainer.test(models[symbol], test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze models as vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flatten models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "other_symbs = {symbol: set(symbols).difference(symbol) for symbol in symbols}\n",
    "print(other_symbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}\n",
    "flat_models_permuted_to_universe = {\n",
    "    symbol: torch.nn.utils.parameters_to_vector(model.parameters())\n",
    "    for symbol, model in models_permuted_to_universe.items()\n",
    "}\n",
    "\n",
    "flat_models_permuted_pairwise = {\n",
    "    symbol: {\n",
    "        other_symb: torch.nn.utils.parameters_to_vector(models_permuted_pairwise[symbol][other_symb].parameters())\n",
    "        for other_symb in other_symbs[symbol]\n",
    "    }\n",
    "    for symbol in symbols\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for symbol in symbols:\n",
    "    flat_models_permuted_pairwise[symbol][symbol] = flat_models[symbol]\n",
    "    models_permuted_pairwise[symbol][symbol] = models[symbol]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyzing variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = next(iter(train_loader))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_layerwise_vars(model_a, model_b, ref_model):\n",
    "    vvs = []\n",
    "\n",
    "    for lamb in [0, 0.5, 1]:\n",
    "        interp_params = linear_interpolate_state_dicts(t1=model_a.state_dict(), t2=model_b.state_dict(), lam=lamb)\n",
    "        ref_model.load_state_dict(interp_params)\n",
    "\n",
    "        vv = []\n",
    "\n",
    "        for i in [1, 4, 7, 9, 12, 14, 17, 19]:\n",
    "            subnet = ref_model.model.embedder[:i]\n",
    "\n",
    "            with torch.no_grad():\n",
    "                out = subnet(inputs)\n",
    "\n",
    "            out = out.permute(1, 0, 2, 3).reshape(out.size(1), -1)\n",
    "            avg_var = out.var(1).mean()\n",
    "            vv.append(avg_var.item())\n",
    "\n",
    "        vvs.append(np.array(vv))\n",
    "\n",
    "    # lists of layerwise variances for endpoint A, midpoint 0.5, endpoint B\n",
    "    vv0, vva, vv1 = vvs\n",
    "    return vv0, vva, vv1\n",
    "\n",
    "\n",
    "def get_layerwise_ratios(model_a, model_b):\n",
    "    \"\"\"\n",
    "    Returns a list of ratios between the variance of the weight-interpolation midpoint and the averaged variances of the two endpoints.\n",
    "    \"\"\"\n",
    "\n",
    "    ref_model = copy.deepcopy(model_a)\n",
    "\n",
    "    vv0, vva, vv1 = get_layerwise_vars(model_a, model_b, ref_model)\n",
    "\n",
    "    vv00 = (vv0 + vv1) / 2\n",
    "\n",
    "    ratio = vva / vv00\n",
    "\n",
    "    return ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layerwise_ratios_naive = get_layerwise_ratios(models[\"a\"], models[\"c\"])\n",
    "# layerwise_ratios_permuted = get_layerwise_ratios(models_permuted_pairwise[\"a\"][\"c\"], models[\"a\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(layerwise_ratios_naive, label=\"Without neuron matching\")\n",
    "# plt.plot(layerwise_ratios_permuted, label=\"With neuron matching\")\n",
    "\n",
    "# plt.ylim([0, 1])\n",
    "\n",
    "# plt.xlabel(\"Layer index\")\n",
    "# plt.ylabel(r\"$ \\frac{ \\sigma_{0.5} } { ( \\sigma_0 + \\sigma_1 ) / 2}$\")\n",
    "# plt.title(\"VGG11 layerwise variance ratios\")\n",
    "\n",
    "# plt.legend()\n",
    "\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.repair import (\n",
    "    replace_conv_layers,\n",
    "    make_tracked_net,\n",
    "    reset_bn_stats,\n",
    "    ResetConv,\n",
    "    compute_goal_statistics_two_models,\n",
    "    compute_goal_statistics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wrap networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: use 'b' instead when it's available\n",
    "\n",
    "model_a = copy.deepcopy(models[\"a\"])\n",
    "model_b = copy.deepcopy(models[\"c\"])\n",
    "model_b_perm = copy.deepcopy(models_permuted_pairwise[\"a\"][\"c\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## calculate the statistics of every hidden unit in the endpoint networks\n",
    "\n",
    "model_a_wrapped = make_tracked_net(model_a).cuda()\n",
    "model_b_perm_wrapped = make_tracked_net(model_b_perm).cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that the results are still the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "check = False\n",
    "\n",
    "if check:\n",
    "    trainer.test(model_a, test_loader)\n",
    "    trainer.test(model_a_wrapped, test_loader)\n",
    "\n",
    "    trainer.test(model_b_perm, test_loader)\n",
    "    trainer.test(model_b_perm_wrapped, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reset batch norm stats "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_bn_stats(model_a_wrapped.cuda())\n",
    "reset_bn_stats(model_b_perm_wrapped.cuda())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that the results are still the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "check = False\n",
    "\n",
    "if check:\n",
    "    trainer.test(model_a, test_loader)\n",
    "    trainer.test(model_a_wrapped, test_loader)\n",
    "\n",
    "    trainer.test(model_b_perm, test_loader)\n",
    "    trainer.test(model_b_perm_wrapped, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create interpolated model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interp_model = copy.deepcopy(model_a)\n",
    "interp_model.load_state_dict(\n",
    "    linear_interpolate_state_dicts(t1=model_a.state_dict(), t2=model_b_perm.state_dict(), lam=0.5)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_interp_wrapped = make_tracked_net(interp_model).cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_goal_statistics_two_models(model_a_wrapped, model_interp_wrapped, model_b_perm_wrapped)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fuse batch norm layers into convolutions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset the tracked mean/var and fuse rescalings back into conv layers\n",
    "reset_bn_stats(model_interp_wrapped.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_model = copy.deepcopy(model_a)\n",
    "\n",
    "# fuse the rescaling+shift coefficients back into conv layers\n",
    "fused_interp = fuse_tracked_net(model_interp_wrapped)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate model_a\n",
    "model_a_wrapped.eval()\n",
    "trainer.test(model_a_wrapped, test_loader)\n",
    "\n",
    "# evaluate model_b_perm\n",
    "model_b_perm_wrapped.eval()\n",
    "trainer.test(model_b_perm_wrapped, test_loader)\n",
    "\n",
    "# evaluate fused_interp\n",
    "fused_interp.eval()\n",
    "repaired_results = trainer.test(model_interp_wrapped, test_loader)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_interpolated_model(lambd, model_a, model_b, ref_model):\n",
    "    interp_params = linear_interpolate_state_dicts(\n",
    "        t1=model_a.model.state_dict(), t2=model_b.model.state_dict(), lam=lambd\n",
    "    )\n",
    "\n",
    "    ref_model.model.load_state_dict(interp_params)\n",
    "\n",
    "    test_results = trainer.test(ref_model, test_loader, verbose=False)[0]\n",
    "\n",
    "    return test_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_model = copy.deepcopy(models[\"a\"])\n",
    "lambd = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "naive_interp_results = evaluate_interpolated_model(lambd, model_a=models[\"a\"], model_b=models[\"c\"], ref_model=ref_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_interp_results = evaluate_interpolated_model(\n",
    "    lambd, model_a=models[\"a\"], model_b=models_permuted_pairwise[\"a\"][\"c\"], ref_model=ref_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(\n",
    "    f\"naive: {naive_interp_results['loss/test']}, matched: {perm_interp_results['loss/test']}, repaired: {repaired_results['loss/test']}\"\n",
    ")\n",
    "pylogger.info(\n",
    "    f\"naive: {naive_interp_results['acc/test']}, matched: {perm_interp_results['acc/test']}, repaired: {repaired_results['acc/test']}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repair over N models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_model = copy.deepcopy(models[\"a\"])\n",
    "\n",
    "mean_model_params = torch.stack([model for model in flat_models_permuted_to_universe.values()]).mean(dim=0)\n",
    "\n",
    "merged_model.load_state_dict(vector_to_state_dict(mean_model_params, merged_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_model_wrapped = make_tracked_net(merged_model).cuda()\n",
    "\n",
    "merged_model_wrapped.eval()\n",
    "trainer.test(merged_model_wrapped, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wrapped_models = [make_tracked_net(models_permuted_to_universe[symbol]).cuda() for symbol in symbols]\n",
    "\n",
    "for model in wrapped_models:\n",
    "    reset_bn_stats(model.cuda())\n",
    "\n",
    "compute_goal_statistics(merged_model_wrapped, wrapped_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reset_bn_stats(merged_model_wrapped.cuda())\n",
    "merged_model_wrapped.eval()\n",
    "trainer.test(merged_model_wrapped, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import average_models\n",
    "\n",
    "\n",
    "for ref_symbol in symbols:\n",
    "    merged_model = copy.deepcopy(models[ref_symbol])\n",
    "\n",
    "    all_models_permuted_to_ref = {symb: models_permuted_pairwise[ref_symbol][symb] for symb in symbols}\n",
    "\n",
    "    model_params = {symbol: model.state_dict() for symbol, model in all_models_permuted_to_ref.items()}\n",
    "\n",
    "    mean_params = average_models(model_params)\n",
    "\n",
    "    merged_model.load_state_dict(mean_params)\n",
    "\n",
    "    results = trainer.test(merged_model, test_loader, verbose=True)\n",
    "\n",
    "    merged_model_wrapped = make_tracked_net(merged_model).cuda()\n",
    "\n",
    "    wrapped_models = [make_tracked_net(models[symbol]).cuda() for symbol in symbols]\n",
    "\n",
    "    for model in wrapped_models:\n",
    "        reset_bn_stats(model.cuda())\n",
    "\n",
    "    compute_goal_statistics(merged_model_wrapped, wrapped_models)\n",
    "    reset_bn_stats(merged_model_wrapped.cuda())\n",
    "\n",
    "    results = trainer.test(merged_model_wrapped, test_loader, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Git-rebasin merge many"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.merger import GitRebasinMerger\n",
    "\n",
    "git_rebasin_merger = GitRebasinMerger(name=\"git_rebasin_merger\", permutation_spec=permutation_spec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "merged_model = git_rebasin_merger(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_model_wrapped = make_tracked_net(merged_model).cuda()\n",
    "wrapped_models = [make_tracked_net(models[symbol]).cuda() for symbol in symbols]\n",
    "\n",
    "for model in wrapped_models:\n",
    "    reset_bn_stats(model.cuda())\n",
    "\n",
    "compute_goal_statistics(merged_model_wrapped, wrapped_models)\n",
    "reset_bn_stats(merged_model_wrapped.cuda())\n",
    "\n",
    "results = trainer.test(merged_model_wrapped, test_loader, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Git-rebasin merge many wrt reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.weight_matching import PermutationSpec, weight_matching\n",
    "\n",
    "\n",
    "def merge_wrt_model(ref_model_id, models):\n",
    "    model_params = [copy.deepcopy(model.model.state_dict()) for model in models.values()]\n",
    "\n",
    "    num_models = len(model_params)\n",
    "    ref_model_params = model_params[ref_model_id]\n",
    "\n",
    "    other_model_ids = [i for i in range(num_models) if i != ref_model_id]\n",
    "    permutations = []\n",
    "\n",
    "    for other_model_id in other_model_ids:\n",
    "        other_model_params = copy.deepcopy(model_params[other_model_id])\n",
    "\n",
    "        permutation = weight_matching(\n",
    "            permutation_spec,\n",
    "            fixed=ref_model_params,\n",
    "            permutee=other_model_params,\n",
    "        )\n",
    "\n",
    "        permutations.append(permutation)\n",
    "\n",
    "        other_model_params = apply_permutation_to_statedict(permutation_spec, permutation, other_model_params)\n",
    "\n",
    "        model_params[other_model_id] = other_model_params\n",
    "\n",
    "    mean_params = average_models(model_params)\n",
    "    merged_model = copy.deepcopy(models[list(models.keys())[0]])\n",
    "    merged_model.model.load_state_dict(mean_params)\n",
    "\n",
    "    return merged_model, permutations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "all_permutations = []\n",
    "all_merged_models = []\n",
    "\n",
    "for symbol_ind, symbol in enumerate(symbols):\n",
    "    merged_model, permutations = merge_wrt_model(ref_model_id=symbol_ind, models=copy.deepcopy(models))\n",
    "\n",
    "    all_permutations.append(permutations)\n",
    "    all_merged_models.append(merged_model)\n",
    "\n",
    "    results = trainer.test(merged_model, test_loader, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = []\n",
    "for symbol_ind, symbol in enumerate(symbols):\n",
    "    merged_model = all_merged_models[symbol_ind]\n",
    "    permutations = all_permutations[symbol_ind]\n",
    "\n",
    "    results = trainer.test(merged_model, test_loader, verbose=True)[0]\n",
    "    all_results.append(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_acc = np.mean([result[\"acc/test\"] for result in all_results])\n",
    "std_acc = np.std([result[\"acc/test\"] for result in all_results])\n",
    "pylogger.info(f\"${round(mean_acc, 4)} \\pm {round(std_acc, 4)}$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_loss = np.mean([result[\"loss/test\"] for result in all_results])\n",
    "std_loss = np.std([result[\"loss/test\"] for result in all_results])\n",
    "pylogger.info(f\"${round(mean_loss, 4)} \\pm {round(std_loss, 4)}$\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
