{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "from ccmm.utils.utils import fuse_batch_norm_into_conv\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "\n",
    "from ccmm.matching.utils import load_permutations\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching_n_tasks\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 5000\n",
    "num_train_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "\n",
    "datamodule: pl.LightningDataModule = hydra.utils.instantiate(core_cfg.nn.data, _recursive_=False)\n",
    "\n",
    "test_dataloaders = []\n",
    "train_dataloaders = []\n",
    "\n",
    "for task_ind in range(datamodule.num_tasks + 1):\n",
    "    datamodule.task_ind = task_ind\n",
    "    datamodule.transform_func = hydra.utils.instantiate(core_cfg.dataset.transform_func, _recursive_=True)\n",
    "    datamodule.setup()\n",
    "    test_dataloaders.append(datamodule.test_dataloader()[0])\n",
    "    train_dataloaders.append(datamodule.train_dataloader())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda task: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_T{task}_{cfg.seed_index}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(task): load_model_from_artifact(run, artifact_path(task))\n",
    "    for task in range(cfg.num_tasks + 1)\n",
    "}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "symbols = symbols.difference({\"o\"})  # \"o\" is the model trained over all the dataset\n",
    "\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "\n",
    "# (a, b), (a, c), (b, c), ...\n",
    "all_combinations = get_all_symbols_combinations(symbols)\n",
    "# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc\n",
    "canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(f\"Matching the following model pairs: {canonical_combinations}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.sinkhorn_matching import get_perm_dict\n",
    "\n",
    "\n",
    "ref_model = copy.deepcopy(models[\"a\"])\n",
    "dummy_input = torch.randn(1, 3, 32, 32)\n",
    "\n",
    "perm_dict, map_param_index, map_prev_param_index = get_perm_dict(ref_model, dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "# map each layer to the matrix permuting its rows\n",
    "# pprint(map_prev_param_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map each layer to the matrix permuting its columns\n",
    "# pprint(map_param_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(perm_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.graph import graph_permutations_to_perm_spec\n",
    "\n",
    "perm_spec = graph_permutations_to_perm_spec(ref_model, perm_dict, map_param_index, map_prev_param_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ref_model = list(models.values())[0]\n",
    "# assert set(perm_spec.layer_and_axes_to_perm.keys()) == set(ref_model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sorted(set(perm_spec.layer_and_axes_to_perm.keys()).difference(set(ref_model.state_dict().keys())))\n",
    "\n",
    "# sorted(set(ref_model.state_dict().keys()).difference(set(perm_spec.layer_and_axes_to_perm.keys())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.permutation_spec import PermutationSpecBuilder\n",
    "from ccmm.matching.permutation_spec import PermutationSpec\n",
    "\n",
    "layer_and_axes_to_perm = {k[6:]: v for k, v in perm_spec.layer_and_axes_to_perm.items()}\n",
    "perm_spec_builder = PermutationSpecBuilder()\n",
    "perm_spec = perm_spec_builder.permutation_spec_from_axes_to_perm(layer_and_axes_to_perm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher = instantiate(cfg.matcher, permutation_spec=perm_spec)\n",
    "pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {symb: model.to(\"cpu\") for symb, model in models.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models to universe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def permute_batchnorm(model, perm, perm_dict, map_param_index):\n",
    "\n",
    "    for name, module in model.named_modules():\n",
    "\n",
    "        if \"BatchNorm\" in str(type(module)):\n",
    "\n",
    "            if name + \".weight\" in map_param_index:\n",
    "\n",
    "                if module.running_mean is None and module.running_var is None:\n",
    "                    continue\n",
    "\n",
    "                i = perm_dict[map_param_index[name + \".weight\"]]\n",
    "\n",
    "                index = torch.argmax(perm[i], dim=1) if i is not None else torch.arange(module.running_mean.shape[0])\n",
    "\n",
    "                module.running_mean.copy_(module.running_mean[index, ...])\n",
    "                module.running_var.copy_(module.running_var[index, ...])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import perm_matrix_to_perm_indices\n",
    "\n",
    "vector_to_state_dict\n",
    "models_permuted_to_universe = {symbol: copy.deepcopy(models[symbol]) for symbol in symbols}\n",
    "\n",
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    perms_to_universe = {}\n",
    "\n",
    "    for perm_name, perm in permutations[symbol].items():\n",
    "        perm = perm_indices_to_perm_matrix(perm)\n",
    "        perm_to_universe = perm.T\n",
    "        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)\n",
    "        perms_to_universe[perm_name] = perm_to_universe\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(perm_spec, perms_to_universe, model.model.state_dict())\n",
    "    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models pairwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import unfactor_permutations\n",
    "\n",
    "models_permuted_pairwise = {\n",
    "    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols\n",
    "}\n",
    "pairwise_permutations = unfactor_permutations(permutations)\n",
    "\n",
    "for fixed, permutee in all_combinations:\n",
    "    ref_model = copy.deepcopy(models[\"a\"])\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(\n",
    "        perm_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()\n",
    "    )\n",
    "    ref_model.model.load_state_dict(permuted_params)\n",
    "    models_permuted_pairwise[fixed][permutee] = ref_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check performance of models before and after permutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loader = train_dataloaders[0]\n",
    "# for symbol, model in models_permuted_to_universe.items():\n",
    "#     trainer.test(models_permuted_to_universe[symbol], loader)\n",
    "#     trainer.test(models[symbol], loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze models as vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flatten models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}\n",
    "flat_models_permuted_to_universe = {\n",
    "    symbol: torch.nn.utils.parameters_to_vector(model.parameters())\n",
    "    for symbol, model in models_permuted_to_universe.items()\n",
    "}\n",
    "\n",
    "flat_models_permuted_pairwise = {\n",
    "    symbol: {\n",
    "        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()\n",
    "    }\n",
    "    for symbol, models in models_permuted_pairwise.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpolation curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_interpolation(model_a, model_b, lamb):\n",
    "    return (1 - lamb) * model_a + lamb * model_b\n",
    "\n",
    "\n",
    "def get_interp_curve(lambdas, model_a, model_b, ref_model, test_loader):\n",
    "    interp_losses = []\n",
    "    interp_accs = []\n",
    "\n",
    "    for lamb in lambdas:\n",
    "        interp_params = linear_interpolation(model_a=model_a, model_b=model_b, lamb=lamb)\n",
    "\n",
    "        interp_params = vector_to_state_dict(interp_params, ref_model.model)\n",
    "\n",
    "        ref_model.model.load_state_dict(interp_params)\n",
    "        results = trainer.test(ref_model, test_loader, verbose=False)\n",
    "\n",
    "        interp_losses.append(results[0][f\"loss/test\"])\n",
    "        interp_accs.append(results[0][f\"acc/test\"])\n",
    "\n",
    "    return interp_losses, interp_accs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Average in the universe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_model = copy.deepcopy(models[\"a\"])\n",
    "\n",
    "vec = torch.nn.utils.parameters_to_vector(merged_model.parameters())\n",
    "\n",
    "vec = torch.stack([model for model in flat_models_permuted_to_universe.values()]).mean(dim=0)\n",
    "\n",
    "torch.nn.utils.vector_to_parameters(vec, merged_model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-repair evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "symbols = sorted(list(symbols))\n",
    "\n",
    "for symbol, loader in zip(symbols, test_dataloaders[1:]):\n",
    "    pylogger.info(f\"Symbol: {symbol}\")\n",
    "\n",
    "    task_spec_model = models[symbol]\n",
    "    pylogger.info(\"Task specific\")\n",
    "    trainer.test(task_spec_model, loader)\n",
    "\n",
    "    pylogger.info(\"Merged model\")\n",
    "    trainer.test(merged_model, loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloaders_repeated = [train_dataloaders[0]] * len(symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.repair import repair_model\n",
    "\n",
    "\n",
    "repaired_model = repair_model(merged_model, models_permuted_to_universe, train_dataloaders_repeated)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation: merged model vs task-specific models "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "symbols = sorted(list(symbols))\n",
    "\n",
    "for symbol, loader in zip(symbols, test_dataloaders[1:]):\n",
    "    pylogger.info(f\"Symbol: {symbol}\")\n",
    "\n",
    "    task_spec_model = models[symbol]\n",
    "    pylogger.info(\"Task specific\")\n",
    "    trainer.test(task_spec_model, loader)\n",
    "\n",
    "    pylogger.info(\"Merged model\")\n",
    "    trainer.test(repaired_model, loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation: merged model vs task-specific models on the whole dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_loader = test_dataloaders[0]\n",
    "\n",
    "for symbol, model in models.items():\n",
    "    pylogger.info(f\"Symbol: {symbol}\")\n",
    "    trainer.test(model, global_loader)\n",
    "\n",
    "pylogger.info(\"Merged model\")\n",
    "trainer.test(repaired_model, global_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation: merged model vs task-specific models on tasks different from the training tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for symbol, model in models.items():\n",
    "#     pylogger.info(f'Symbol: {symbol}')\n",
    "#     trainer.test(model, global_loader)\n",
    "\n",
    "# pylogger.info('Merged model')\n",
    "# trainer.test(repaired_model, global_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot LMC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_lmc(values, lambdas, labels, axis=None):\n",
    "\n",
    "    num_curves = len(values)\n",
    "    transparencies = np.linspace(0.5, 1, num_curves)\n",
    "    linewidths = np.linspace(2.0, 4.0, num_curves)\n",
    "\n",
    "    for i, (val, label) in enumerate(zip(values, labels)):\n",
    "        if axis is None:\n",
    "            axis = plt\n",
    "\n",
    "        axis.plot(lambdas, val, label=label, alpha=transparencies[i], linewidth=linewidths[i])\n",
    "\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_perm_to_a = models_permuted_pairwise[\"a\"][\"b\"]\n",
    "lambdas = np.linspace(0, 1, 3)\n",
    "\n",
    "from ccmm.utils.utils import get_interpolated_loss_acc_curves\n",
    "\n",
    "loss, acc = get_interpolated_loss_acc_curves(\n",
    "    model_a=copy.deepcopy(models[\"a\"]),\n",
    "    model_b=p_perm_to_a,\n",
    "    lambdas=lambdas,\n",
    "    ref_model=ref_model,\n",
    "    trainer=trainer,\n",
    "    loader=test_dataloaders[0],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "values = [loss]\n",
    "labels = [\"Loss\"]\n",
    "plot_lmc(values, lambdas, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "values = [acc]\n",
    "labels = [\"Acc\"]\n",
    "plot_lmc(values, lambdas, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aligning to a reference model with git-rebasin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.merger import GitRebasinPairwiseMerger\n",
    "\n",
    "ref_to_merged_model = {symbol: copy.deepcopy(models[symbol]) for symbol in symbols}\n",
    "\n",
    "task_models = {symbol: copy.deepcopy(models[symbol]) for symbol in sorted_symbols}\n",
    "\n",
    "for symbol in sorted_symbols:\n",
    "    merger = GitRebasinPairwiseMerger(name=\"git_rebasin\", permutation_spec=perm_spec, ref_model_symbol=symbol)\n",
    "\n",
    "    merged = merger(task_models, repair=False)\n",
    "\n",
    "    ref_to_merged_model[symbol] = copy.deepcopy(merged)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "symbols = sorted(list(symbols))\n",
    "\n",
    "for symbol, loader in zip(sorted_symbols, test_dataloaders[1:]):\n",
    "\n",
    "    pylogger.info(f\"Task: {symbol}\")\n",
    "\n",
    "    task_spec_model = task_models[symbol]\n",
    "    pylogger.info(\"Task specific\")\n",
    "    trainer.test(task_spec_model, loader)\n",
    "\n",
    "    for ref_symbol, model in ref_to_merged_model.items():\n",
    "        pylogger.info(f\"Ref symbol: {ref_symbol}\")\n",
    "        trainer.test(model, loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
