{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "import itertools\n",
    "from ccmm.utils.utils import l2_norm_models\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "from functools import partial\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "from ccmm.utils.utils import fuse_batch_norm_into_conv\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    linear_interpolate,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "from ccmm.pl_modules.pl_module import MyLightningModule\n",
    "\n",
    "from ccmm.matching.utils import load_permutations\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as anp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymanopt\n",
    "import pymanopt.manifolds\n",
    "import pymanopt.optimizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "import numpy as np\n",
    "from scipy.linalg import eig\n",
    "from numpy.linalg import svd\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "import scipy\n",
    "import json\n",
    "\n",
    "\n",
    "def build_laplacian(knn_graph, normalized=True):\n",
    "\n",
    "    A = (knn_graph + knn_graph.T).astype(float)\n",
    "    A = A.toarray()\n",
    "\n",
    "    D = np.diag(np.sum(A, axis=1))\n",
    "    L = D - A\n",
    "\n",
    "    if normalized:\n",
    "        D_inv_sqrt = np.diag(1 / (np.sqrt(np.diag(D)) + 1e-6))\n",
    "        L = D_inv_sqrt @ L @ D_inv_sqrt\n",
    "\n",
    "    evals, evecs = eig(L)\n",
    "    evals = evals.real\n",
    "\n",
    "    idx = evals.argsort()\n",
    "    evals = evals[idx]\n",
    "    evecs = evecs[:, idx]\n",
    "\n",
    "    return A, L, evals, evecs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"func_maps\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = -1\n",
    "num_train_samples = -1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "num_train_samples = len(train_dataset) if num_train_samples < 0 else num_train_samples\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=512, num_workers=cfg.num_workers)\n",
    "\n",
    "num_test_samples = len(test_dataset) if num_test_samples < 0 else num_test_samples\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False, max_epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input=28 * 28, num_classes=10):\n",
    "        super().__init__()\n",
    "        self.input = input\n",
    "        self.layer0 = nn.Linear(input, 512)\n",
    "        self.layer1 = nn.Linear(512, 512)\n",
    "        self.layer2 = nn.Linear(512, 512)\n",
    "        self.layer3 = nn.Linear(512, 256)\n",
    "        self.layer4 = nn.Linear(256, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input)\n",
    "\n",
    "        h0 = nn.functional.relu(self.layer0(x))\n",
    "\n",
    "        h1 = nn.functional.relu(self.layer1(h0))\n",
    "\n",
    "        h2 = nn.functional.relu(self.layer2(h1))\n",
    "\n",
    "        h3 = nn.functional.relu(self.layer3(h2))\n",
    "\n",
    "        h4 = self.layer4(h3)\n",
    "\n",
    "        embeddings = [h0, h1, h2, h3, h4]\n",
    "\n",
    "        return nn.functional.log_softmax(h4, dim=-1), embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.permutation_spec import MLPPermutationSpecBuilder\n",
    "\n",
    "permutation_spec_builder = MLPPermutationSpecBuilder(4)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.seed_index = 0\n",
    "seed_index_everything(cfg)\n",
    "model_a = MyLightningModule(MLP(), num_classes=10)\n",
    "\n",
    "trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=50)\n",
    "trainer.fit(model_a, train_loader)\n",
    "\n",
    "trainer.test(model_a, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.seed_index = 1\n",
    "seed_index_everything(cfg)\n",
    "\n",
    "model_b = MyLightningModule(MLP(), num_classes=10)\n",
    "trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=20)\n",
    "trainer.fit(model_b, train_loader)\n",
    "\n",
    "trainer.test(model_b, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.weight_matching import weight_matching\n",
    "\n",
    "permutations = weight_matching(permutation_spec, model_a.model.state_dict(), model_b.model.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import apply_permutation_to_statedict\n",
    "\n",
    "updated_params = apply_permutation_to_statedict(permutation_spec, permutations, model_b.model.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "model_b_perm = copy.deepcopy(model_b)\n",
    "model_b_perm.model.load_state_dict(updated_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lambdas = np.linspace(0, 1, 10)\n",
    "\n",
    "all_results = {\"naive\": [], \"matched\": []}\n",
    "\n",
    "for lambd in lambdas:\n",
    "\n",
    "    model_interp = copy.deepcopy(model_b)\n",
    "    model_naive = copy.deepcopy(model_b)\n",
    "\n",
    "    naive_interp_params = linear_interpolate(model_a=model_a, model_b=model_b, lambd=lambd)\n",
    "\n",
    "    model_naive.load_state_dict(naive_interp_params)\n",
    "\n",
    "    model_interp_params = linear_interpolate(model_a=model_a, model_b=model_b_perm, lambd=lambd)\n",
    "\n",
    "    model_interp.load_state_dict(model_interp_params)\n",
    "\n",
    "    trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=20)\n",
    "    results = trainer.test(model_interp, test_loader)\n",
    "    results_naive = trainer.test(model_naive, test_loader)\n",
    "\n",
    "    all_results[\"naive\"].append(results_naive)\n",
    "    all_results[\"matched\"].append(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot lambdas as x and test accuracy as y\n",
    "\n",
    "plt.plot(lambdas, [x[0][\"loss/test\"] for x in all_results[\"naive\"]], label=\"naive\")\n",
    "plt.plot(lambdas, [x[0][\"loss/test\"] for x in all_results[\"matched\"]], label=\"matched\")\n",
    "\n",
    "plt.legend()\n",
    "plt.xlabel(\"$\\lambda$\")\n",
    "plt.ylabel(\"Test Accuracy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
