{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import itertools\n",
    "import logging\n",
    "import math\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import json\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import pytorch_lightning\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "from tqdm import tqdm\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    load_permutations,\n",
    "    perm_indices_to_perm_matrix,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    fuse_batch_norm_into_conv,\n",
    "    get_interpolated_loss_acc_curves,\n",
    "    l2_norm_models,\n",
    "    linear_interpolate,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    normalize_unit_norm,\n",
    "    project_onto,\n",
    "    save_factored_permutations,\n",
    "    vector_to_state_dict,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "    }\n",
    ")\n",
    "sns.set_context(\"talk\")\n",
    "\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "from ccmm.utils.plot import Palette\n",
    "\n",
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching\", overrides=[\"model=vit\", \"dataset=cifar10\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 5000\n",
    "num_train_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:latest\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.permutation_spec import PermutationSpec, PermutationSpecBuilder\n",
    "\n",
    "\n",
    "class ViTPermutationSpecBuilder(PermutationSpecBuilder):\n",
    "    def __init__(self, depth) -> None:\n",
    "        self.depth = depth\n",
    "\n",
    "    def create_permutation(self) -> PermutationSpec:\n",
    "\n",
    "        axes_to_perm = {\n",
    "            # layer norm\n",
    "            \"to_patch_embedding.to_patch_tokens.1.weight\": (None,),  # (3*c*16)\n",
    "            \"to_patch_embedding.to_patch_tokens.1.bias\": (None,),  # (3*c*16)\n",
    "            # linear\n",
    "            \"to_patch_embedding.to_patch_tokens.2.weight\": (\"P_in\", None),  # (dim, 3*c*16)\n",
    "            \"to_patch_embedding.to_patch_tokens.2.bias\": (\"P_in\",),  # dim\n",
    "            \"pos_embedding\": (None, None, \"P_in\"),  # (1, p+1, dim) probably P_transf_in or its own P\n",
    "            \"cls_token\": (None, None, \"P_in\"),  # (1, 1, dim) probably P_transf_in or its own P\n",
    "            **transformer_block_axes(self.depth, p_in=\"P_in\", p_out=\"P_last\"),\n",
    "            # layer norm\n",
    "            \"mlp_head.0.weight\": (\"P_last\",),  # (dim, )\n",
    "            \"mlp_head.0.bias\": (\"P_last\",),  # (dim,)\n",
    "            # linear\n",
    "            \"mlp_head.1.bias\": (None,),  # (num_classes)\n",
    "            \"mlp_head.1.weight\": (None, \"P_last\"),  # (num_classes, dim)\n",
    "        }\n",
    "\n",
    "        return self.permutation_spec_from_axes_to_perm(axes_to_perm)\n",
    "\n",
    "\n",
    "def transformer_block_axes(depth, p_in, p_out):\n",
    "\n",
    "    all_axes = {}\n",
    "\n",
    "    for block_ind in range(depth):\n",
    "\n",
    "        block_out = p_out if block_ind == depth - 1 else f\"P{block_ind}_out\"\n",
    "        block_in = p_in if block_ind == 0 else f\"P{block_ind-1}_out\"\n",
    "\n",
    "        block_axes = {\n",
    "            # Attention\n",
    "            ## layer norm\n",
    "            f\"transformer.layers.{block_ind}.0.norm.weight\": (block_in,),  # (dim,) OK\n",
    "            f\"transformer.layers.{block_ind}.0.norm.bias\": (block_in,),  # (dim,) OK\n",
    "            f\"transformer.layers.{block_ind}.0.temperature\": (None,),  # (,)\n",
    "            # HEADS\n",
    "            f\"transformer.layers.{block_ind}.0.to_q.weight\": (\n",
    "                f\"P{block_ind}_attn_QK\",\n",
    "                block_in,\n",
    "            ),  # (head_dim, dim) row, col OK\n",
    "            f\"transformer.layers.{block_ind}.0.to_k.weight\": (\n",
    "                f\"P{block_ind}_attn_QK\",\n",
    "                block_in,\n",
    "            ),  # (head_dim, dim) row, col OK\n",
    "            f\"transformer.layers.{block_ind}.0.to_v.weight\": (None, block_in),  # (head_dim, dim) row, col OK\n",
    "            # Attention out\n",
    "            f\"transformer.layers.{block_ind}.0.to_out.0.weight\": (\n",
    "                # f\"{block_ind}_attn_out\",\n",
    "                None,\n",
    "                # f\"P{block_ind}_attn_V\",\n",
    "                None,\n",
    "            ),  # (dim, dim)\n",
    "            f\"transformer.layers.{block_ind}.0.to_out.0.bias\": (None,),  # (f\"{block_ind}_attn_out\",),  # (dim,)\n",
    "            # shortcut\n",
    "            # TODO: the rows perm of this one should be the product of block_in and {block_ind}_attn_out?\n",
    "            f\"transformer.layers.{block_ind}.1.identity\": (block_in, None),  # (dim, dim) # WORKS\n",
    "            # MLP\n",
    "            ## layer norm\n",
    "            f\"transformer.layers.{block_ind}.2.net.0.weight\": (None,),  # (f\"{block_ind}_attn_out\",),  # (dim, )\n",
    "            f\"transformer.layers.{block_ind}.2.net.0.bias\": (None,),  # (f\"{block_ind}_attn_out\",),  # (dim,)\n",
    "            #\n",
    "            ## linear 1\n",
    "            f\"transformer.layers.{block_ind}.2.net.1.weight\": (\n",
    "                f\"P{block_ind}_mlp_out\",\n",
    "                None,\n",
    "                # f\"{block_ind}_attn_out\",\n",
    "            ),  # (mlp_dim, dim)\n",
    "            f\"transformer.layers.{block_ind}.2.net.1.bias\": (f\"P{block_ind}_mlp_out\",),  # (mlp_dim,)\n",
    "            #\n",
    "            ## linear 2\n",
    "            f\"transformer.layers.{block_ind}.2.net.4.weight\": (\n",
    "                block_out,\n",
    "                f\"P{block_ind}_mlp_out\",\n",
    "            ),  # (dim, mlp_dim) # WORKS\n",
    "            f\"transformer.layers.{block_ind}.2.net.4.bias\": (block_out,),  # (dim,) # WORKS\n",
    "            #\n",
    "            # shortcut 2\n",
    "            f\"transformer.layers.{block_ind}.3.identity\": (None, block_out),  # (dim, dim) # WORKS\n",
    "        }\n",
    "\n",
    "        all_axes.update(block_axes)\n",
    "\n",
    "    return all_axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = ViTPermutationSpecBuilder(core_cfg.model.depth)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Git Re-Basin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "from ccmm.matching.matcher import GitRebasinMatcher\n",
    "from ccmm.matching.utils import get_inverse_permutations\n",
    "\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "fixed_symbol, permutee_symbol = \"a\", \"b\"\n",
    "fixed_model, permutee_model = models[fixed_symbol].cpu(), models[permutee_symbol].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a\n",
    "gitrebasin_permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}\n",
    "\n",
    "matcher = GitRebasinMatcher(name=\"git_rebasin\", permutation_spec=permutation_spec)\n",
    "gitrebasin_permutations[fixed_symbol][permutee_symbol], perm_history = matcher(\n",
    "    fixed=fixed_model.model, permutee=permutee_model.model\n",
    ")\n",
    "\n",
    "gitrebasin_permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(\n",
    "    gitrebasin_permutations[fixed_symbol][permutee_symbol]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.evaluate_matched_models import evaluate_pair_of_models\n",
    "\n",
    "restore_original_weights(models, model_orig_weights)\n",
    "\n",
    "updated_params = {fixed_symbol: {permutee_symbol: None}}\n",
    "\n",
    "pylogger.info(f\"Permuting model {permutee_symbol} into {fixed_symbol}.\")\n",
    "\n",
    "# perms[a, b] maps b -> a\n",
    "updated_params[fixed_symbol][permutee_symbol] = apply_permutation_to_statedict(\n",
    "    permutation_spec, gitrebasin_permutations[fixed_symbol][permutee_symbol], models[permutee_symbol].model.state_dict()\n",
    ")\n",
    "restore_original_weights(models, model_orig_weights)\n",
    "\n",
    "lambdas = [0.5, 1.0]\n",
    "\n",
    "gitrebasin_results = evaluate_pair_of_models(\n",
    "    models,\n",
    "    fixed_symbol,\n",
    "    permutee_symbol,\n",
    "    updated_params,\n",
    "    train_loader,\n",
    "    test_loader,\n",
    "    lambdas,\n",
    "    core_cfg,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gitrebasin_results[\"test_loss_barrier\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Frank-Wolfe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.frank_wolfe_matching import collect_perm_sizes, frank_wolfe_weight_matching_trial\n",
    "from ccmm.matching.matcher import FrankWolfeMatcher\n",
    "\n",
    "params_a = fixed_model.model.state_dict()\n",
    "params_b = permutee_model.model.state_dict()\n",
    "perm_sizes = collect_perm_sizes(permutation_spec, params_a)\n",
    "\n",
    "initialization_method = \"identity\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_matrices, perm_matrices_history, new_obj, all_step_sizes = frank_wolfe_weight_matching_trial(\n",
    "    params_a,\n",
    "    params_b,\n",
    "    perm_sizes,\n",
    "    initialization_method,\n",
    "    permutation_spec,\n",
    "    200,\n",
    "    device=\"cuda\",\n",
    "    return_step_sizes=True,\n",
    "    global_step_size=True,\n",
    ")\n",
    "\n",
    "restore_original_weights(models, model_orig_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot([step_size for ind, step_size in enumerate(all_step_sizes)], color=palette[\"light red\"])\n",
    "\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Step size\")\n",
    "plt.title(\"Step size\")\n",
    "\n",
    "plt.savefig(\"figures/convergence_step_sizes.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.weight_matching import solve_linear_assignment_problem\n",
    "\n",
    "fw_permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}\n",
    "fw_permutations[fixed_symbol][permutee_symbol] = {\n",
    "    p: solve_linear_assignment_problem(perm) for p, perm in perm_matrices.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_params = {fixed_symbol: {permutee_symbol: None}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.evaluate_matched_models import evaluate_pair_of_models\n",
    "\n",
    "restore_original_weights(models, model_orig_weights)\n",
    "\n",
    "pylogger.info(f\"Permuting model {permutee_symbol} into {fixed_symbol}.\")\n",
    "\n",
    "# perms[a, b] maps b -> a\n",
    "updated_params[fixed_symbol][permutee_symbol] = apply_permutation_to_statedict(\n",
    "    permutation_spec, fw_permutations[fixed_symbol][permutee_symbol], models[permutee_symbol].model.state_dict()\n",
    ")\n",
    "restore_original_weights(models, model_orig_weights)\n",
    "\n",
    "lambdas = [0.0, 0.5, 1]  # np.linspace(0, 1, num=4)\n",
    "\n",
    "fw_results = evaluate_pair_of_models(\n",
    "    models,\n",
    "    fixed_symbol,\n",
    "    permutee_symbol,\n",
    "    updated_params,\n",
    "    train_loader,\n",
    "    test_loader,\n",
    "    lambdas,\n",
    "    core_cfg,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fw_results[\"test_loss_barrier\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
