{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import itertools\n",
    "import logging\n",
    "import math\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import json\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import pytorch_lightning\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "from tqdm import tqdm\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    load_permutations,\n",
    "    perm_indices_to_perm_matrix,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    fuse_batch_norm_into_conv,\n",
    "    get_interpolated_loss_acc_curves,\n",
    "    l2_norm_models,\n",
    "    linear_interpolate,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    normalize_unit_norm,\n",
    "    project_onto,\n",
    "    save_factored_permutations,\n",
    "    vector_to_state_dict,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "    }\n",
    ")\n",
    "sns.set_context(\"talk\")\n",
    "\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "from ccmm.utils.plot import Palette\n",
    "\n",
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 5000\n",
    "num_train_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Git Re-Basin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)\n",
    "# pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "from ccmm.matching.utils import get_inverse_permutations\n",
    "\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "fixed, permutee = \"a\", \"b\"\n",
    "fixed_model, permutee_model = models[fixed], models[permutee]\n",
    "\n",
    "# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a\n",
    "permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}\n",
    "\n",
    "# matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)\n",
    "# permutations[fixed_symbol][permutee_symbol], perm_history = matcher(\n",
    "#     fixed=fixed_model.model, permutee=permutee_model.model\n",
    "# )\n",
    "\n",
    "# permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(permutations[fixed_symbol][permutee_symbol])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Frank-Wolfe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.frank_wolfe_matching import collect_perm_sizes, frank_wolfe_weight_matching_trial\n",
    "from ccmm.matching.matcher import FrankWolfeMatcher\n",
    "\n",
    "params_a = fixed_model.model.state_dict()\n",
    "params_b = permutee_model.model.state_dict()\n",
    "perm_sizes = collect_perm_sizes(permutation_spec, params_a)\n",
    "\n",
    "initialization_method = \"identity\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_matrices, perm_matrices_history, new_obj, all_step_sizes = frank_wolfe_weight_matching_trial(\n",
    "    params_a,\n",
    "    params_b,\n",
    "    perm_sizes,\n",
    "    initialization_method,\n",
    "    permutation_spec,\n",
    "    200,\n",
    "    device=\"cuda\",\n",
    "    return_step_sizes=True,\n",
    "    global_step_size=True,\n",
    ")\n",
    "\n",
    "restore_original_weights(models, model_orig_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot([step_size for ind, step_size in enumerate(all_step_sizes)], color=palette[\"light red\"])\n",
    "\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Step size\")\n",
    "plt.title(\"Step size\")\n",
    "\n",
    "plt.savefig(\"figures/convergence_step_sizes.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.weight_matching import solve_linear_assignment_problem\n",
    "\n",
    "permutations[fixed][permutee] = {p: solve_linear_assignment_problem(perm) for p, perm in perm_matrices.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_params = {fixed: {permutee: None}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.evaluate_matched_models import evaluate_pair_of_models\n",
    "\n",
    "restore_original_weights(models, model_orig_weights)\n",
    "\n",
    "pylogger.info(f\"Permuting model {permutee} into {fixed}.\")\n",
    "\n",
    "# perms[a, b] maps b -> a\n",
    "updated_params[fixed][permutee] = apply_permutation_to_statedict(\n",
    "    permutation_spec, permutations[fixed][permutee], models[permutee].model.state_dict()\n",
    ")\n",
    "restore_original_weights(models, model_orig_weights)\n",
    "\n",
    "lambdas = np.linspace(0, 1, num=4)\n",
    "\n",
    "results = evaluate_pair_of_models(\n",
    "    models,\n",
    "    fixed,\n",
    "    permutee,\n",
    "    updated_params,\n",
    "    train_loader,\n",
    "    test_loader,\n",
    "    lambdas,\n",
    "    core_cfg,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results[\"test_loss_barrier\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
