{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "import itertools\n",
    "from ccmm.utils.utils import l2_norm_models\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "from functools import partial\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "from ccmm.utils.utils import fuse_batch_norm_into_conv\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "import autograd.numpy as anp\n",
    "\n",
    "import torch\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "import numpy as np\n",
    "from scipy.linalg import eig\n",
    "from numpy.linalg import svd\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "import scipy\n",
    "import json\n",
    "\n",
    "# import pymanopt\n",
    "# import pymanopt.manifolds\n",
    "# import pymanopt.optimizers\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    linear_interpolate,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "from ccmm.pl_modules.pl_module import MyLightningModule\n",
    "\n",
    "from ccmm.matching.utils import load_permutations\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves, cumulative_sum\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"func_maps\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change these values to any positive number $x \\in [1, \\dots, N]$ to select a subsample of the corresponding dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = -1\n",
    "num_train_samples = -1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "num_train_samples = len(train_dataset) if num_train_samples < 0 else num_train_samples\n",
    "\n",
    "train_subset = Subset(train_dataset, list(range(num_train_samples)))\n",
    "train_loader = DataLoader(train_subset, batch_size=1000, num_workers=cfg.num_workers)\n",
    "\n",
    "num_test_samples = len(test_dataset) if num_test_samples < 0 else num_test_samples\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False, max_epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we define a standard MLP with a input layer, 3 hidden layers and an output layer. We use ReLU as the activation function and log_softmax as the output function. We return the activactions for each layer as we will use them in matching the networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input=28 * 28, num_classes=10):\n",
    "        super().__init__()\n",
    "        self.input = input\n",
    "        self.layer0 = nn.Linear(input, 512)\n",
    "        self.layer1 = nn.Linear(512, 512)\n",
    "        self.layer2 = nn.Linear(512, 512)\n",
    "        self.layer3 = nn.Linear(512, 256)\n",
    "        self.layer4 = nn.Linear(256, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input)\n",
    "\n",
    "        h0 = nn.functional.relu(self.layer0(x))\n",
    "\n",
    "        h1 = nn.functional.relu(self.layer1(h0))\n",
    "\n",
    "        h2 = nn.functional.relu(self.layer2(h1))\n",
    "\n",
    "        h3 = nn.functional.relu(self.layer3(h2))\n",
    "\n",
    "        h4 = self.layer4(h3)\n",
    "\n",
    "        embeddings = [h0, h1, h2, h3, h4]\n",
    "\n",
    "        return nn.functional.log_softmax(h4, dim=-1), embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Permutation specifics tell us what permutations to apply to what layers. A PermutationSpec has two objects: \n",
    "- `perm_to_layers_and_axes`: a dictionary that maps each permutation matrix to the layers it permutes, specifying on what axis. e.g. `{'P0': [('conv1.weight', 0), ('conv2.weight', 1)], 'P1': ...}`\n",
    "- `layer_and_axes_to_perm`: a dictionary that maps each layer to a tuple long as the number of dimensions of the layer, each dimension specifying the permutation that acts on it, e.g. `{ 'conv2.weight': ('P1', 'P0', None, None), 'conv3.weight': ...}`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.permutation_spec import MLPPermutationSpecBuilder\n",
    "\n",
    "permutation_spec_builder = MLPPermutationSpecBuilder(4)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train and test first model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.seed_index = 0\n",
    "seed_index_everything(cfg)\n",
    "model_a = MyLightningModule(MLP(), num_classes=10)\n",
    "\n",
    "trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=5)\n",
    "trainer.fit(model_a, train_loader)\n",
    "\n",
    "trainer.test(model_a, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train and test second model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(test_loader))\n",
    "x, y = batch\n",
    "x = x[0]\n",
    "\n",
    "logits, embeddings = model_a(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the gradient of logits with respect to inputs\n",
    "logits[0][0].backward(retain_graph=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grads = model_a.model.layer0.weight.grad\n",
    "print(grads.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.seed_index = 1\n",
    "seed_index_everything(cfg)\n",
    "\n",
    "model_b = MyLightningModule(MLP(), num_classes=10)\n",
    "trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=5)\n",
    "trainer.fit(model_b, train_loader)\n",
    "\n",
    "trainer.test(model_b, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the permutations obtained from `git_rebasin` as ground truth to visualize the functional maps. We obtain these using function `weight_matching`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.weight_matching import weight_matching\n",
    "\n",
    "permutations = weight_matching(permutation_spec, model_a.model.state_dict(), model_b.model.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Focus on a single layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_idx = 1\n",
    "\n",
    "perm_gt = permutations[f\"P_{layer_idx}\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Descriptor 1: weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_a_weights = model_a.model.state_dict()[f\"layer{layer_idx}.weight\"]\n",
    "layer_b_weights = model_b.model.state_dict()[f\"layer{layer_idx}.weight\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_a = layer_a_weights.detach().numpy()\n",
    "W_b = layer_b_weights.detach().numpy()\n",
    "\n",
    "W_a.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Descriptor 2: Activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We run a forward pass over a single batch of size `num_activactions` to obtain the activactions from both models. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_activations = 10000\n",
    "train_loader = DataLoader(train_subset, batch_size=num_activations, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch in train_loader:\n",
    "\n",
    "    x, y = batch\n",
    "    # model returns logits and a list of embeddings, so we take the embeddings\n",
    "    features_a = model_a.model(x)[-1]\n",
    "    features_b = model_b.model(x)[-1]\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (descriptor_dim, num_neurons), where descriptor_dim is the number of samples for which we are considering the neuron activation\n",
    "layer_a = features_a[layer_idx]\n",
    "layer_b = features_b[layer_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_neurons = layer_a.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# normalize to have unit norm\n",
    "\n",
    "layer_a = layer_a / (torch.norm(layer_a, dim=0) + 1e-6)\n",
    "layer_b = layer_b / (torch.norm(layer_b, dim=0) + 1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(layer_a.shape, layer_b.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import to_np\n",
    "\n",
    "layer_a = to_np(layer_a)\n",
    "layer_b = to_np(layer_b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Descriptor 3: denoised activactions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (num_samples, num_neurons)\n",
    "layer_a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def svd_threshold(matrix, variance_threshold=0.99):\n",
    "    # Compute SVD\n",
    "    U, S, Vt = np.linalg.svd(matrix, full_matrices=False)\n",
    "\n",
    "    # Calculate the cumulative variance explained by the singular values\n",
    "    total_variance = np.sum(S**2)\n",
    "    explained_variance = np.cumsum(S**2) / total_variance\n",
    "\n",
    "    # Determine the number of singular values needed to explain the desired threshold of variance\n",
    "    num_components = np.argmax(explained_variance >= variance_threshold) + 1\n",
    "\n",
    "    # Select the subset of singular values and vectors explaining the desired variance\n",
    "    U_reduced = U[:, :num_components]\n",
    "    S_reduced = S[:num_components]\n",
    "    Vt_reduced = Vt[:num_components, :]\n",
    "\n",
    "    return U_reduced, S_reduced, Vt_reduced, explained_variance\n",
    "\n",
    "\n",
    "def svd_num_components(matrix, num_components=10):\n",
    "    # matrix is ~ (num_samples, num_neurons)\n",
    "    num_samples, num_neurons = matrix.shape\n",
    "    K = num_components\n",
    "\n",
    "    U, S, Vt = np.linalg.svd(matrix, full_matrices=False)\n",
    "\n",
    "    assert U.shape == (num_samples, num_neurons)\n",
    "    assert S.shape == (num_neurons,)\n",
    "    assert Vt.shape == (num_neurons, num_neurons)\n",
    "\n",
    "    U_reduced = U[:, :K]\n",
    "    S_reduced = S[:K]\n",
    "    Vt_reduced = Vt[:K, :]\n",
    "\n",
    "    return U_reduced, S_reduced, Vt_reduced"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (num_samples, num_comps), (num_comps), (num_comps, num_neurons)\n",
    "num_components = 256\n",
    "U_a, S_a, Vt_a = svd_num_components(layer_a, num_components=num_components)\n",
    "U_b, S_b, Vt_b = svd_num_components(layer_b, num_components=num_components)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(U_a.shape, S_a.shape, Vt_a.shape)\n",
    "print(U_b.shape, S_b.shape, Vt_b.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Reconstruction error "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# express each layer as a linear combination of the singular vectors\n",
    "layer_a_reconstructed = U_a @ np.diag(S_a) @ Vt_a\n",
    "layer_b_reconstructed = U_b @ np.diag(S_b) @ Vt_b\n",
    "\n",
    "layer_a_reconstructed.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if the reconstruction is close to the original layer by computing the norm\n",
    "np.linalg.norm(layer_a_reconstructed - layer_a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if the reconstruction is close to the original layer by computing the norm\n",
    "np.linalg.norm(layer_b_reconstructed - layer_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the norm of the two models for comparison\n",
    "np.linalg.norm(layer_a - layer_b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Descriptor 4: eigenneurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigenneurons_a = 1 / ((np.diag(S_a) ** 0.5) + 1 - 6) @ Vt_a\n",
    "eigenneurons_b = 1 / ((np.diag(S_b) ** 0.5) + 1 - 6) @ Vt_b\n",
    "\n",
    "eigenneurons_a = eigenneurons_a.T\n",
    "eigenneurons_b = eigenneurons_b.T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functional maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_descriptors(descriptor_type):\n",
    "\n",
    "    if descriptor_type == \"weights\":\n",
    "        X, Y = W_a, W_b\n",
    "    elif descriptor_type == \"features\":\n",
    "        X, Y = layer_a.T, layer_b.T\n",
    "    elif descriptor_type == \"features_denoised\":\n",
    "        X, Y = layer_a_reconstructed.T, layer_b_reconstructed.T\n",
    "    elif descriptor_type == \"eigenneurons\":\n",
    "        X, Y = eigenneurons_a, eigenneurons_b\n",
    "    else:\n",
    "        raise ValueError(\"Invalid value for use_weights_or_features\")\n",
    "\n",
    "    return X, Y\n",
    "\n",
    "\n",
    "descriptor_type = \"features_denoised\"  # weights, features, features_denoised, eigenneurons\n",
    "X, Y = get_descriptors(descriptor_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the KNN graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_laplacian(A, normalized=True):\n",
    "\n",
    "    D = np.diag(np.sum(A, axis=1)) + 1e-6\n",
    "\n",
    "    assert not np.any(D < 0)\n",
    "\n",
    "    L = D - A\n",
    "\n",
    "    if normalized:\n",
    "        D_inv_sqrt = np.diag(1 / np.sqrt(np.diag(D)))\n",
    "        L = D_inv_sqrt @ L @ D_inv_sqrt\n",
    "        L = (L + L.T) / 2\n",
    "\n",
    "    assert not np.any(np.isnan(L))\n",
    "\n",
    "    evals, evecs = np.linalg.eigh(L)\n",
    "\n",
    "    idx = evals.argsort()\n",
    "    evals = evals[idx]\n",
    "    evecs = evecs[:, idx]\n",
    "\n",
    "    return A, L, evals, evecs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_knn_graph(X, radius=None, num_neighbors=None, mode=\"distance\"):\n",
    "    assert radius is not None or num_neighbors is not None\n",
    "\n",
    "    if radius is not None:\n",
    "        Xneigh = NearestNeighbors(radius=radius)\n",
    "\n",
    "    elif num_neighbors is not None:\n",
    "        Xneigh = NearestNeighbors(n_neighbors=num_neighbors)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Either radius or num_neighbors must be provided\")\n",
    "\n",
    "    Xneigh.fit(X)\n",
    "\n",
    "    # (num_neurons, num_neurons)\n",
    "    X_knn_graph = Xneigh.kneighbors_graph(X, mode=mode)\n",
    "\n",
    "    X_adj = X_knn_graph.toarray()\n",
    "\n",
    "    np.fill_diagonal(X_adj, 0)\n",
    "\n",
    "    X_adj_sym = (X_adj + X_adj.T) / 2\n",
    "\n",
    "    assert np.allclose(X_adj_sym, X_adj_sym.T), \"Adjacences are not symmetric\"\n",
    "\n",
    "    return X_adj_sym"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functional maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_func_map(\n",
    "    X, Y, P, radius=None, num_neighbors=None, mode=\"distance\", normalize_lap=True, num_eigenvectors=50\n",
    "):\n",
    "\n",
    "    X_adj_sym = build_knn_graph(X, radius, num_neighbors, mode)\n",
    "    Y_adj_sym = build_knn_graph(Y, radius, num_neighbors, mode)\n",
    "\n",
    "    if X_adj_sym.sum() == 0 or Y_adj_sym.sum() == 0:\n",
    "        return np.zeros((X_adj_sym.shape[0], Y_adj_sym.shape[0]))\n",
    "\n",
    "    XA, XL, Xevals, Xevecs = build_laplacian(X_adj_sym, normalize_lap)\n",
    "    YA, YL, Yevals, Yevecs = build_laplacian(Y_adj_sym, normalize_lap)\n",
    "\n",
    "    Xevecs = Xevecs\n",
    "    Yevecs = Yevecs\n",
    "\n",
    "    num_eigenvectors = num_eigenvectors\n",
    "    C = Xevecs[:, :num_eigenvectors].T @ P @ Yevecs[:, :num_eigenvectors]\n",
    "\n",
    "    return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_func_maps(func_maps, fig_name, vmin, vmax):\n",
    "    fig, axs = plt.subplots(7, 7, figsize=(20, 20))\n",
    "\n",
    "    k = range(1, 100, 2)\n",
    "\n",
    "    for i in range(7):\n",
    "        for j in range(7):\n",
    "\n",
    "            ax = axs[i, j]\n",
    "            ax.imshow(func_maps[i * 7 + j], cmap=cmap_name, vmin=vmin, vmax=vmax)\n",
    "            ax.axis(\"off\")\n",
    "            ax.set_title(f\"k={k[i * 7 + j]}\")\n",
    "\n",
    "    plt.savefig(f\"figures/{fig_name}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = perm_indices_to_perm_matrix(perm_gt).numpy()\n",
    "normalize_lap = True\n",
    "mode = \"connectivity\"  # connectivity or distance\n",
    "num_eigenvectors = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func_maps_neighbors = [\n",
    "    compute_func_map(\n",
    "        X, Y, P, num_neighbors=k, mode=mode, normalize_lap=normalize_lap, num_eigenvectors=num_eigenvectors\n",
    "    )\n",
    "    for k in range(1, 100, 2)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_name = f\"func_maps_{descriptor_type}_{mode}_normalizeLap_{normalize_lap}_numEigenvectors_{num_eigenvectors}\"\n",
    "plot_func_maps(func_maps_neighbors, plot_name, vmin=-0.6, vmax=0.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func_maps_radius = [compute_func_map(X, Y, P, radius=r) for r in np.linspace(0.01, 1, 50)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_func_maps(func_maps_radius, f\"func_maps_{descriptor_type}_radius\", vmin=-0.5, vmax=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transfer indicator function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"distance\"\n",
    "normalize_lap = True\n",
    "num_neighbors = 80\n",
    "num_eigenvectors = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "descriptor_type = \"features_denoised\"  # weights, features, features_denoised, eigenneurons\n",
    "X, Y = get_descriptors(descriptor_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C = compute_func_map(\n",
    "    X, Y, P, num_neighbors=num_neighbors, mode=mode, normalize_lap=normalize_lap, num_eigenvectors=num_eigenvectors\n",
    ")\n",
    "\n",
    "plt.imshow(C, cmap=cmap_name, vmin=-0.6, vmax=0.6)\n",
    "plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "consider an indicator function $f$ \n",
    " \n",
    "$f(x_i) = 1$ for some $i$, 0 otherwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indicator_func = np.zeros((num_neurons,))\n",
    "selected_neuron_idx = 32\n",
    "indicator_func[selected_neuron_idx] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Project them onto the eigenvectors; basically, the identity matrix can be considered a stacking of all the indicator functions so we don't really need to do this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_adj_sym = build_knn_graph(X, num_neighbors=num_neighbors, mode=mode)\n",
    "XA, XL, Xevals, Xevecs = build_laplacian(X_adj_sym, normalize_lap)\n",
    "\n",
    "Y_adj_sym = build_knn_graph(Y, num_neighbors=num_neighbors, mode=mode)\n",
    "YA, YL, Yevals, Yevecs = build_laplacian(Y_adj_sym, normalize_lap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get degree from weighted adj matrix\n",
    "X_adj_sym.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Phi = Xevecs[:, :num_eigenvectors].real\n",
    "Psi = Yevecs[:, :num_eigenvectors].real\n",
    "P_tilde = Psi @ C @ Phi.T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "take the argmax of P_tilde (not guaranteed to be a permutation)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapped_points_argmax = P_tilde.argmax(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "solve an LAP to get a permutation \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.weight_matching import solve_linear_assignment_problem\n",
    "\n",
    "P_tilde_lap = solve_linear_assignment_problem(P_tilde.T, return_matrix=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing permutation matrices\n",
    "For each point, we map it to the other graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import perm_matrix_to_perm_indices\n",
    "\n",
    "# mapped_points[i] = j means that the i-th point in the first set is mapped to the j-th point in the second set\n",
    "mapped_points = perm_matrix_to_perm_indices(P_tilde_lap)\n",
    "mapped_points[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_exact_matchings = (mapped_points == perm_gt).sum().item()\n",
    "num_exact_matchings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we compute the minimum path from the mapped point to the ground truth point\n",
    "* x axis has a radius (0, diameter of the graph) \n",
    "* y axis has the number of matchings that are within the radius from the ground truth point\n",
    "* for radius=0, you are counting the number of exact matchings; for radius=diameter, every matching is considered a match\n",
    "* the curve goes from 0 to 100%, the faster curve gets to 100%, the better the matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "\n",
    "\n",
    "def bfs_shortest_distance(adj, start):\n",
    "    # Initialize distances with infinity\n",
    "    n = adj.shape[0]\n",
    "    distance = [np.inf] * n\n",
    "    distance[start] = 0\n",
    "    queue = deque([start])\n",
    "\n",
    "    while queue:\n",
    "        current = queue.popleft()\n",
    "        for i in range(n):\n",
    "            if adj[current, i] > 0 and distance[i] == np.inf:\n",
    "                queue.append(i)\n",
    "                distance[i] = distance[current] + 1\n",
    "    return distance\n",
    "\n",
    "\n",
    "def bfs_shortest_path(adj, u, v):\n",
    "    # Number of nodes\n",
    "    n = adj.shape[0]\n",
    "    # To keep track of visited nodes to prevent revisiting\n",
    "    visited = [False] * n\n",
    "    # To keep track of the path\n",
    "    parent = [-1] * n\n",
    "\n",
    "    # Queue for BFS\n",
    "    queue = deque([u])\n",
    "    visited[u] = True\n",
    "\n",
    "    # Perform BFS\n",
    "    while queue:\n",
    "        current = queue.popleft()\n",
    "\n",
    "        # If we've reached the target node, break\n",
    "        if current == v:\n",
    "            break\n",
    "\n",
    "        # Check all adjacent nodes\n",
    "        for i in range(n):\n",
    "            if adj[current, i] > 0 and not visited[i]:\n",
    "                queue.append(i)\n",
    "                visited[i] = True\n",
    "                parent[i] = current\n",
    "\n",
    "    # Reconstruct the path from u to v\n",
    "    path = []\n",
    "    if visited[v]:\n",
    "        while v != -1:\n",
    "            path.append(v)\n",
    "            v = parent[v]\n",
    "        path.reverse()\n",
    "\n",
    "    return path if path else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(num_neurons)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_lengths = []\n",
    "\n",
    "for i in range(num_neurons):\n",
    "    pred_mapping = mapped_points[i]\n",
    "    gt_mapping = perm_gt[i]\n",
    "\n",
    "    shortest_path = bfs_shortest_path(Y_adj_sym, pred_mapping.item(), gt_mapping.item())\n",
    "    shortest_path_length = len(shortest_path) - 1 if shortest_path is not None else np.inf\n",
    "\n",
    "    path_lengths.append(shortest_path_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_lengths[:10]\n",
    "\n",
    "# look for argmin_{P_tilde} P_tilde - Psi C Phi^T\n",
    "# multiply Psi^T to the left\n",
    "# Psi^T P_tilde - C Phi^T      --- Phi^T ~ (N, K)\n",
    "# look for the binary P_tilde that minimizes this measure\n",
    "# P_tilde_i = nearest_neighbor(Psi^T _i , C Phi^T _i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "path_length_count = Counter(path_lengths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_length_frequencies = {k: v / num_neurons for k, v in path_length_count.items()}\n",
    "path_length_frequencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_graph_diameter(adj):\n",
    "    n = adj.shape[0]\n",
    "    diameter = 0\n",
    "\n",
    "    for u in range(n):\n",
    "        distance = bfs_shortest_distance(adj, u)\n",
    "        # Update the diameter with the maximum distance found from this node\n",
    "        max_distance = max(distance)\n",
    "\n",
    "        if max_distance > diameter and max_distance != np.inf:\n",
    "            diameter = max_distance\n",
    "\n",
    "    return diameter\n",
    "\n",
    "\n",
    "diameter = compute_graph_diameter(X_adj_sym)\n",
    "print(diameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "radiuses = range(0, diameter + 1)\n",
    "print(radiuses)\n",
    "\n",
    "for r in radiuses:\n",
    "    if r not in path_length_frequencies:\n",
    "        path_length_frequencies[r] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys = cumulative_sum(path_length_frequencies)\n",
    "\n",
    "xs = radiuses\n",
    "\n",
    "plt.plot(xs, [ys[x] for x in xs], marker=\"o\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 10\n",
    "\n",
    "Xneigh = NearestNeighbors(n_neighbors=k)\n",
    "Xneigh.fit(X)\n",
    "\n",
    "# (num_neurons, num_neurons)\n",
    "X_knn_graph = Xneigh.kneighbors_graph(X, mode=\"connectivity\")\n",
    "\n",
    "Yneigh = NearestNeighbors(n_neighbors=k)\n",
    "Yneigh.fit(Y)\n",
    "Y_knn_graph = Yneigh.kneighbors_graph(Y, mode=\"connectivity\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = PCA(n_components=3)\n",
    "pca.fit(X.T)\n",
    "\n",
    "Xx = pca.components_[0, :]\n",
    "Xy = pca.components_[1, :]\n",
    "Xz = pca.components_[2, :]\n",
    "\n",
    "pca = PCA(n_components=3)\n",
    "pca.fit(Y.T)\n",
    "\n",
    "Yx = pca.components_[0, :]\n",
    "Yy = pca.components_[1, :]\n",
    "Yz = pca.components_[2, :]\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
    "\n",
    "ax[0] = fig.add_subplot(121, projection=\"3d\")\n",
    "ax[0].scatter(Xx, Xy, Xz, c=\"tab:blue\")\n",
    "\n",
    "ax[1] = fig.add_subplot(122, projection=\"3d\")\n",
    "ax[1].scatter(Yx, Yy, Yz, c=\"tab:red\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
    "\n",
    "ax[0] = fig.add_subplot(121, projection=\"3d\")\n",
    "\n",
    "num_neurons = W_a.shape[0]\n",
    "for i in range(num_neurons):\n",
    "    for j in range(num_neurons):\n",
    "        if X_knn_graph[i, j] > 0:\n",
    "            ax[0].plot([Xx[i], Xx[j]], [Xy[i], Xy[j]], [Xz[i], Xz[j]], \"b-\", alpha=0.5)\n",
    "\n",
    "ax[0].scatter(Xx, Xy, Xz, c=\"tab:blue\")\n",
    "\n",
    "ax[1] = fig.add_subplot(122, projection=\"3d\")\n",
    "for i in range(num_neurons):\n",
    "    for j in range(num_neurons):\n",
    "        if Y_knn_graph[i, j] > 0:\n",
    "            ax[1].plot([Yx[i], Yx[j]], [Yy[i], Yy[j]], [Yz[i], Yz[j]], \"b-\", alpha=0.5)\n",
    "\n",
    "ax[1].scatter(Yx, Yy, Yz, c=\"tab:red\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hic sunt leones: you can ignore this part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "XA, XL, Xevals, Xevecs = build_laplacian(X_knn_graph, True)\n",
    "YA, YL, Yevals, Yevecs = build_laplacian(Y_knn_graph, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solve a LAP in the reduced space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "# _, ci = linear_sum_assignment(U_a.T @ U_b + Vt_a.T @ Vt_b.T, maximize=True)\n",
    "_, ci = linear_sum_assignment(layer_a_reconstructed.T @ layer_b_reconstructed, maximize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_b_reconstructed_perm = perm_matrix @ layer_b_reconstructed.T\n",
    "\n",
    "layer_b_reconstructed_perm = layer_b_reconstructed_perm.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_b_recon_perm_norm = layer_b_reconstructed_perm / (np.linalg.norm(layer_b_reconstructed_perm, axis=0) + 1e-6)\n",
    "layer_a_norm = layer_a / (np.linalg.norm(layer_a, axis=0) + 1e-6)\n",
    "layer_b_norm = layer_b / (np.linalg.norm(layer_b, axis=0) + 1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.trace(layer_b_recon_perm_norm.T @ layer_a_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.trace(layer_b_norm.T @ layer_a_norm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LAP in the original space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_matrix_orig_space = layer_a @ layer_b.T\n",
    "\n",
    "_, ci = linear_sum_assignment(-sim_matrix_orig_space, maximize=True)\n",
    "perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_b_perm = perm_matrix @ layer_b\n",
    "\n",
    "layer_b_perm_norm = layer_b_perm / (np.linalg.norm(layer_b_perm, axis=0) + 1e-6)\n",
    "np.trace(layer_a_norm @ layer_b_perm_norm.T)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
