{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.express as px\n",
    "import torch\n",
    "from scipy.spatial.distance import cdist\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.arange(3 * 2).reshape(3, 2).float()\n",
    "A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "d = 6\n",
    "weight_matrix = torch.arange(n * d).reshape(n, d).float()\n",
    "# weight_matrix = torch.rand(n*d).reshape(n, d)\n",
    "weight_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_perm_indices = torch.randperm(n)\n",
    "gt_perm_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Wa = weight_matrix[gt_perm_indices]\n",
    "Wa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_perm_ba_indices = torch.randperm(n)\n",
    "gt_perm_ba_indices\n",
    "P_BA_gt = torch.eye(n)[gt_perm_ba_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(P_BA_gt)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Wb = P_BA_gt.T @ Wa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Wb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.all(P_BA_gt @ Wb == Wa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_matrix = Wa @ Wb.T\n",
    "fig = plt.imshow(sim_matrix)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_aa = torch.cdist(Wa, Wa)\n",
    "dist_bb = torch.cdist(Wb, Wb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig1 = plt.figure()\n",
    "fig1 = plt.imshow(dist_aa)\n",
    "fig2 = plt.figure()\n",
    "fig2 = plt.imshow(dist_bb)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "ri, ci = linear_sum_assignment(sim_matrix.detach().numpy(), maximize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.matching_utils import perm_indices_to_perm_matrix\n",
    "\n",
    "\n",
    "perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(perm_matrix)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.all(perm_matrix == P_BA_gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var_percentage = 5\n",
    "K = 10\n",
    "k = 1\n",
    "\n",
    "var_a = dist_aa.max() * var_percentage\n",
    "var_b = dist_bb.max() * var_percentage\n",
    "assert var_a > 0 and var_b > 0\n",
    "\n",
    "print(var_a, var_b)\n",
    "dist_bb_perm = dist_bb @ perm_matrix.T\n",
    "\n",
    "kernel_A = torch.exp(-(dist_aa).pow(2) / (2 * var_a))\n",
    "kernel_B = torch.exp(-(dist_bb_perm).pow(2) / (2 * var_b))\n",
    "\n",
    "new_sim_matrix = kernel_A @ kernel_B.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(kernel_A)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel_A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_perm_kernel_B = torch.exp(-(dist_bb).pow(2) / (2 * var_b))\n",
    "fig = plt.imshow(pre_perm_kernel_B)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(kernel_B)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel_B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.imshow(new_sim_matrix)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_sim_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
