{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "\n",
    "from ccmm.matching.utils import load_permutations\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching_n_models\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 5000\n",
    "num_train_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=2000, num_workers=cfg.num_workers)\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "\n",
    "# (a, b), (a, c), (b, c), ...\n",
    "all_combinations = get_all_symbols_combinations(symbols)\n",
    "# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc\n",
    "canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(f\"Matching the following model pairs: {canonical_combinations}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)\n",
    "pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {symb: model.to(\"cpu\") for symb, model in models.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models to universe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import perm_matrix_to_perm_indices\n",
    "\n",
    "models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}\n",
    "\n",
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    perms_to_universe = {}\n",
    "\n",
    "    for perm_name, perm in permutations[symbol].items():\n",
    "        perm = perm_indices_to_perm_matrix(perm)\n",
    "        perm_to_universe = perm.T\n",
    "        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)\n",
    "        perms_to_universe[perm_name] = perm_to_universe\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())\n",
    "    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models pairwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import unfactor_permutations\n",
    "\n",
    "models_permuted_pairwise = {\n",
    "    symbol: {\n",
    "        other_symb: copy.deepcopy(model)\n",
    "        for symbol, model in models.items()\n",
    "        for other_symb in set(symbols).difference(symbol)\n",
    "    }\n",
    "    for symbol in symbols\n",
    "}\n",
    "pairwise_permutations = unfactor_permutations(permutations)\n",
    "\n",
    "for fixed, permutee in canonical_combinations:\n",
    "    permuted_params = apply_permutation_to_statedict(\n",
    "        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()\n",
    "    )\n",
    "    models_permuted_pairwise[fixed][permutee].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check performance of models before and after permutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    trainer.test(models_permuted_to_universe[symbol], test_loader)\n",
    "    trainer.test(models[symbol], test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze models as vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flatten models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}\n",
    "flat_models_permuted_to_universe = {\n",
    "    symbol: torch.nn.utils.parameters_to_vector(model.parameters())\n",
    "    for symbol, model in models_permuted_to_universe.items()\n",
    "}\n",
    "\n",
    "flat_models_permuted_pairwise = {\n",
    "    symbol: {\n",
    "        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()\n",
    "    }\n",
    "    for symbol, models in models_permuted_pairwise.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import to_np\n",
    "\n",
    "norms = {\"model\": [], \"permuted\": [], \"diff\": []}\n",
    "for symbol, model in models.items():\n",
    "    norm = torch.norm(flat_models[symbol])\n",
    "    norm_permuted = torch.norm(flat_models_permuted_to_universe[symbol])\n",
    "    norm_diff = torch.norm(flat_models[symbol] - flat_models_permuted_to_universe[symbol])\n",
    "\n",
    "    norms[\"model\"].append(to_np(norm))\n",
    "    norms[\"permuted\"].append(to_np(norm_permuted))\n",
    "    norms[\"diff\"].append(to_np(norm_diff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame(norms, index=models.keys())\n",
    "\n",
    "df = df.apply(pd.to_numeric)\n",
    "\n",
    "plt.figure(figsize=(5, 5))\n",
    "sns.heatmap(df, annot=True, cmap=\"viridis\")\n",
    "plt.title(\"Model Norms Comparison\")\n",
    "plt.ylabel(\"Model Symbol\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute gradient of universe model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_a_universe = copy.deepcopy(models_permuted_to_universe[\"a\"]).cuda()\n",
    "model_a = copy.deepcopy(models[\"a\"]).cuda()\n",
    "merged_model = copy.deepcopy(models[\"a\"]).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.stack([model_perm_to_uni for model_perm_to_uni in flat_models_permuted_to_universe.values()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean(torch.stack([model_perm_to_uni for model_perm_to_uni in flat_models_permuted_to_universe.values()]), dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_model_params = torch.mean(\n",
    "    torch.stack([model_perm_to_uni for model_perm_to_uni in flat_models_permuted_to_universe.values()]), dim=0\n",
    ")\n",
    "merged_model_universe = vector_to_state_dict(merged_model_params, merged_model)\n",
    "merged_model.load_state_dict(merged_model_universe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = torch.nn.utils.parameters_to_vector(merged_model.parameters())\n",
    "print(parameters)\n",
    "\n",
    "model_a_parameters = torch.nn.utils.parameters_to_vector(model_a.parameters())\n",
    "print(model_a_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "gradient_norm_merged = 0.0\n",
    "gradient_norm = 0.0\n",
    "\n",
    "merged_model.zero_grad()\n",
    "model_a.zero_grad()\n",
    "\n",
    "for batch in tqdm(test_loader):\n",
    "    x, y = batch\n",
    "\n",
    "    x = x.to(\"cuda\")\n",
    "    y = y.to(\"cuda\")\n",
    "\n",
    "    output_merged = merged_model(x)\n",
    "    loss_universe = loss_fn(output_merged, y)\n",
    "\n",
    "    output = model_a(x)\n",
    "\n",
    "    print(torch.allclose(output_merged, output, atol=1e-4))\n",
    "    loss = loss_fn(output, y)\n",
    "\n",
    "    loss_universe.backward()\n",
    "    loss.backward()\n",
    "\n",
    "for p_univ, p in zip(merged_model.parameters(), model_a.parameters()):\n",
    "    gradient_norm_merged += p_univ.grad.norm().item() ** 2\n",
    "    gradient_norm += p.grad.norm().item() ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(f\"Gradient norm merged: {gradient_norm_merged}\")\n",
    "pylogger.info(f\"Gradient norm: {gradient_norm}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
