{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Left multiplying by a permutation matrix permutes the rows, right multiplying permutes the columns."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.express as px\n",
    "import torch\n",
    "from scipy.spatial.distance import cdist\n",
    "import numpy as np\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import pygmtools as pygm\n",
    "import functools\n",
    "\n",
    "pygm.set_backend(\"numpy\")\n",
    "\n",
    "\n",
    "def time_decorator(func):\n",
    "    def wrapper(*args, **kwargs):\n",
    "        start_time = time.time()\n",
    "        result = func(*args, **kwargs)\n",
    "        end_time = time.time()\n",
    "        print(f\"Function {func.__name__} took {end_time - start_time} seconds to run.\")\n",
    "        return result\n",
    "\n",
    "    return wrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time_decorator\n",
    "def create_permuted_matrices(n, d, plot=False):\n",
    "    Wa = torch.rand(n * d).reshape(n, d)\n",
    "\n",
    "    gt_perm_ba_indices = torch.randperm(n)\n",
    "    P_BA_gt = torch.eye(n)[gt_perm_ba_indices]\n",
    "\n",
    "    if plot:\n",
    "        fig = px.imshow(P_BA_gt)\n",
    "        fig.show()\n",
    "\n",
    "    Wb = P_BA_gt @ Wa\n",
    "\n",
    "    P_AB_gt = P_BA_gt.T\n",
    "\n",
    "    assert torch.all(P_AB_gt @ Wb == Wa)\n",
    "\n",
    "    return Wa, Wb, P_AB_gt\n",
    "\n",
    "\n",
    "Wa, Wb, P_AB_gt = create_permuted_matrices(n=32, d=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Affinity matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time_decorator\n",
    "def build_affinity_matrix_inefficient(Wa, Wb):\n",
    "    Wa = torch.tensor(Wa)\n",
    "    Wb = torch.tensor(Wb)\n",
    "\n",
    "    num_neurons = len(Wa)\n",
    "    num_matchings = num_neurons**2\n",
    "\n",
    "    S = torch.zeros((num_matchings, num_matchings))\n",
    "\n",
    "    for xi in tqdm(range(num_neurons)):\n",
    "        p_xi = Wa[xi, :]\n",
    "        for yi in range(num_neurons):\n",
    "            p_yi = Wb[yi, :]\n",
    "            for xj in range(num_neurons):\n",
    "                p_xj = Wa[xj, :]\n",
    "                dx = torch.norm(p_xi - p_xj, p=2)\n",
    "                for yj in range(num_neurons):\n",
    "                    p_yj = Wb[yj, :]\n",
    "                    dy = torch.norm(p_yi - p_yj, p=2)\n",
    "                    S[xi * num_neurons + yi, xj * num_neurons + yj] = torch.exp(-torch.abs(dx - dy) / 1e-2)\n",
    "\n",
    "    return S.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "affinity = build_affinity_matrix_inefficient(Wa, Wb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.all(affinity == affinity.T)\n",
    "fig = plt.imshow(affinity)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from backports.strenum import StrEnum\n",
    "from enum import auto\n",
    "\n",
    "\n",
    "class DiagContent(StrEnum):\n",
    "    \"\"\"Enum for diagonal content of affinity matrix\"\"\"\n",
    "\n",
    "    ONES = auto()\n",
    "    SIMILARITIES = auto()\n",
    "\n",
    "\n",
    "@time_decorator\n",
    "def build_affinity_matrix_vectorized(Wa, Wb, diag: DiagContent):\n",
    "    Wa = torch.tensor(Wa)\n",
    "    Wb = torch.tensor(Wb)\n",
    "\n",
    "    num_neurons = Wa.size(0)\n",
    "    num_matchings = num_neurons**2\n",
    "\n",
    "    # Compute all pairwise Euclidean distances for Wa and Wb\n",
    "    Wa_distances = torch.cdist(Wa, Wa, p=2)\n",
    "    Wb_distances = torch.cdist(Wb, Wb, p=2)\n",
    "\n",
    "    # Prepare the distance matrices for broadcasting\n",
    "    Wa_distances = Wa_distances.view(num_neurons, 1, num_neurons, 1).expand(-1, num_neurons, -1, num_neurons)\n",
    "    Wb_distances = Wb_distances.view(1, num_neurons, 1, num_neurons).expand(num_neurons, -1, num_neurons, -1)\n",
    "\n",
    "    S = torch.exp(-torch.abs(Wa_distances - Wb_distances) / 1e-2)\n",
    "\n",
    "    S = S.reshape(num_matchings, num_matchings)\n",
    "\n",
    "    if diag == DiagContent.ONES:\n",
    "        diag_matrix = torch.eye(num_matchings)\n",
    "\n",
    "    elif diag == DiagContent.SIMILARITIES:\n",
    "        Wa_Wb_distances = torch.cdist(Wa, Wb, p=2)\n",
    "        Wa_Wb_sim = torch.exp(-torch.abs(Wa_Wb_distances) / 1e-2)\n",
    "\n",
    "        Wa_Wb_sim = Wa_Wb_sim.reshape(num_matchings, 1)\n",
    "\n",
    "        diag_matrix = torch.diag(Wa_Wb_sim.squeeze())\n",
    "\n",
    "    mask = torch.eye(num_matchings, dtype=torch.bool)\n",
    "    S[mask] = 0\n",
    "\n",
    "    S = S + diag_matrix\n",
    "\n",
    "    return S.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diag_content = DiagContent.ONES\n",
    "affinity_vectorized = build_affinity_matrix_vectorized(Wa, Wb, diag=diag_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(affinity_vectorized)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "assert np.all(np.abs(affinity_vectorized.T - affinity_vectorized) < 5e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if diag_content == DiagContent.ONES:\n",
    "    assert np.all(np.abs(np.diag(affinity_vectorized) - 1) < 5e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@time_decorator\n",
    "def get_principal_eigenvector(M):\n",
    "    num_neurons = torch.sqrt(torch.tensor(M.size(0))).int()\n",
    "\n",
    "    values, vectors = torch.linalg.eigh(M)\n",
    "    principal_eigenvector = vectors[:, torch.argmax(values)]\n",
    "\n",
    "    principal_eigenvector = principal_eigenvector.reshape(num_neurons, num_neurons)\n",
    "    principal_eigenvector = torch.abs(principal_eigenvector)\n",
    "\n",
    "    return principal_eigenvector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "principal_eigenvector = get_principal_eigenvector(torch.tensor(affinity_vectorized))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_aa = torch.cdist(Wa, Wa, p=2)\n",
    "dist_bb = torch.cdist(Wb, Wb, p=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_neurons = len(Wa)\n",
    "\n",
    "dist_aa_batched = np.expand_dims(dist_aa, axis=0)\n",
    "dist_bb_batched = np.expand_dims(dist_bb, axis=0)\n",
    "num_neurons_batched = np.expand_dims(num_neurons, axis=0)\n",
    "\n",
    "conn1, edge1, ne1 = pygm.utils.dense_to_sparse(dist_aa_batched)\n",
    "conn2, edge2, ne2 = pygm.utils.dense_to_sparse(dist_bb_batched)\n",
    "\n",
    "gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1)\n",
    "inner_prod_aff_fn = pygm.utils.inner_prod_aff_fn\n",
    "\n",
    "K = pygm.utils.build_aff_mat(\n",
    "    node_feat1=None,\n",
    "    edge_feat1=edge1,\n",
    "    connectivity1=conn1,\n",
    "    node_feat2=None,\n",
    "    edge_feat2=edge2,\n",
    "    connectivity2=conn2,\n",
    "    n1=num_neurons_batched,\n",
    "    ne1=None,\n",
    "    n2=num_neurons_batched,\n",
    "    ne2=None,\n",
    "    edge_aff_fn=inner_prod_aff_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = pygm.sm(K, num_neurons_batched, num_neurons_batched).squeeze()\n",
    "\n",
    "X = pygm.hungarian(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(X)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matching_accuracy = (P_AB_gt * X).sum() / num_neurons\n",
    "matching_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(principal_eigenvector)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(P_AB_gt)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import Tensor\n",
    "\n",
    "\n",
    "@time_decorator\n",
    "def extract_matching_leordeanu(principal_eigenvector: Tensor):\n",
    "    \"\"\"\n",
    "    principal_eigenvector: shape (num_neurons, num_neurons)\n",
    "    \"\"\"\n",
    "\n",
    "    num_neurons = principal_eigenvector.shape[0]\n",
    "\n",
    "    # Initialize the solution vector\n",
    "    x = torch.zeros((num_neurons, num_neurons)).type_as(principal_eigenvector).long()\n",
    "\n",
    "    # Initialize masks for rows and columns\n",
    "    row_mask = torch.ones(num_neurons, dtype=torch.bool)\n",
    "    col_mask = torch.ones(num_neurons, dtype=torch.bool)\n",
    "\n",
    "    while True:\n",
    "        # Apply masks to principal eigenvector\n",
    "        masked_principal_eigenvector = principal_eigenvector.clone()\n",
    "        masked_principal_eigenvector[~row_mask, :] = 0\n",
    "        masked_principal_eigenvector[:, ~col_mask] = 0\n",
    "\n",
    "        # Find the maximum value and its index\n",
    "        flat_index = masked_principal_eigenvector.argmax()\n",
    "\n",
    "        i, j = np.unravel_index(flat_index.item(), (num_neurons, num_neurons))\n",
    "\n",
    "        assignment_value = masked_principal_eigenvector[(i, j)]\n",
    "        if assignment_value == 0:\n",
    "            break\n",
    "\n",
    "        # Update the solution vector\n",
    "        x[i, j] = 1\n",
    "\n",
    "        # Update the masks to exclude row i and column j\n",
    "        row_mask[i] = False\n",
    "        col_mask[j] = False\n",
    "\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P_AB_leordeanu = extract_matching_leordeanu(principal_eigenvector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import linear_sum_assignment\n",
    "import scipy\n",
    "\n",
    "\n",
    "def extract_matching_lap(principal_eigenvector):\n",
    "    num_neurons = principal_eigenvector.shape[0]\n",
    "\n",
    "    principal_eigenvector = principal_eigenvector.cpu().numpy()\n",
    "\n",
    "    row_ind, col_ind = linear_sum_assignment(principal_eigenvector.max() - principal_eigenvector)\n",
    "    P_AB = scipy.sparse.coo_matrix(\n",
    "        (np.ones(num_neurons), (row_ind, col_ind)), shape=(num_neurons, num_neurons)\n",
    "    ).toarray()\n",
    "\n",
    "    return torch.tensor(P_AB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P_AB_lap = extract_matching_lap(principal_eigenvector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.all(P_AB_leordeanu == P_AB_lap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.all(P_AB_leordeanu == P_AB_gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EigenvectorPostprocess(StrEnum):\n",
    "    LEORDEANU = auto()\n",
    "    LAP = auto()\n",
    "\n",
    "\n",
    "def compare_matching_algorithms(num_neurons, dim, eigenvec_postprocess: EigenvectorPostprocess):\n",
    "    Wa, Wb, P_AB_gt = create_permuted_matrices(num_neurons, dim)\n",
    "\n",
    "    affinity = build_affinity_matrix_vectorized(Wa, Wb, diag=DiagContent.ONES)\n",
    "    affinity = torch.tensor(affinity).cuda()\n",
    "\n",
    "    principal_eigenvector = get_principal_eigenvector(affinity)\n",
    "\n",
    "    if eigenvec_postprocess == EigenvectorPostprocess.LEORDEANU:\n",
    "        P_AB = extract_matching_leordeanu(principal_eigenvector)\n",
    "    elif eigenvec_postprocess == EigenvectorPostprocess.LAP:\n",
    "        P_AB = extract_matching_lap(principal_eigenvector)\n",
    "\n",
    "    matching_accuracy = (P_AB_gt.cuda() * P_AB.cuda()).sum() / num_neurons\n",
    "\n",
    "    assert torch.all(P_AB.cpu().float() @ Wb.cpu() == Wa.cpu())\n",
    "\n",
    "    return matching_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_accuracies = []\n",
    "for i in tqdm(range(20)):\n",
    "    acc = compare_matching_algorithms(128, 256, EigenvectorPostprocess.LEORDEANU)\n",
    "    print(acc)\n",
    "\n",
    "    all_accuracies.append(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot all the accuracies\n",
    "fig = px.histogram([acc.cpu().numpy() for acc in all_accuracies])\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
