{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e7d4774-f42b-46da-9862-106fa8a04611",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "import numpy as np\n",
    "from scipy.linalg import eig\n",
    "from scipy.optimize import linear_sum_assignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d67e121-b411-4acf-a81e-a8246da3ee48",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset('data', name='MUTAG')\n",
    "# dataset = TUDataset('data', name='FIRSTMM_DB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d28d6363-0951-4a09-aec6-7836809fbdd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# node_num = 10\n",
    "# for i in range(len(dataset)):\n",
    "#     if len(dataset[i].x) == node_num: print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76de7f66-4d37-44b4-a568-c02d1742d3db",
   "metadata": {},
   "outputs": [],
   "source": [
    "def edgeindexToAdjmat(edge_tensor):\n",
    "    edge_lis = edge_tensor.numpy()\n",
    "    edge_list = edge_lis[:, (edge_lis < 100).all(axis=0)]\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    # Create an empty adjacency matrix\n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    # Populate the adjacency matrix based on the edges\n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "\n",
    "    return adj_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcd3c00f-ae69-4a1d-8b41-b25ff7ff86ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grampa(A, B, eta):\n",
    "    n = A.shape[0]\n",
    "\n",
    "    Lambda, U = eig(A)\n",
    "    Mu, V = eig(B) # eigen vectors and eigen values\n",
    "\n",
    "    temp = (np.tile(Lambda, (n, 1)) - np.tile(Mu, (n, 1)).T)**2 + eta**2 # (lambda - mu').^2 + eta^2;\n",
    "\n",
    "    coeff = 1 / temp\n",
    "    coeff = coeff * (U.T @ np.ones((n, n)) @ V) # coeff .* (U' * ones(n) * V);\n",
    "    X = U @ coeff @ V.T\n",
    "    Y = np.real(X.T) # linear_sum_assignment doesnt work on complex values\n",
    "\n",
    "    row_ind, col_ind = linear_sum_assignment(Y, maximize = True) # cost maximization\n",
    "    P = np.zeros((n, n))\n",
    "    P[row_ind, col_ind] = 1\n",
    "\n",
    "    return P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce4b5a27-dbe0-45f7-a0ff-a94d85951e00",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "adj_mat_A = edgeindexToAdjmat(dataset[0].edge_index)\n",
    "adj_mat_B = edgeindexToAdjmat(dataset[7].edge_index)\n",
    "\n",
    "eta = 0.2\n",
    "P = grampa(adj_mat_A, adj_mat_B, eta)\n",
    "# print(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "474dbf05-743f-47d0-9e52-813c99c43caa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "adj_mat_A = edgeindexToAdjmat(dataset[0].edge_index)\n",
    "adj_mat_B = edgeindexToAdjmat(dataset[7].edge_index)\n",
    "\n",
    "eta = 0.5\n",
    "P = grampa(adj_mat_A, adj_mat_B, eta)\n",
    "print(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b748923e-2579-4e28-8261-d3b18b22c618",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "adj_mat_A = edgeindexToAdjmat(dataset[0].edge_index)\n",
    "adj_mat_B = edgeindexToAdjmat(dataset[7].edge_index)\n",
    "\n",
    "eta = 0.5\n",
    "P = grampa(adj_mat_A, adj_mat_B, eta)\n",
    "# print(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1052ea-ef1f-402d-bf69-f0ffeb935f85",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "adj_mat_A = edgeindexToAdjmat(dataset[0].edge_index)\n",
    "adj_mat_B = edgeindexToAdjmat(dataset[7].edge_index)\n",
    "\n",
    "eta = 0.5\n",
    "P = grampa(adj_mat_A, adj_mat_B, eta)\n",
    "# print(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "113867fc-546f-49d6-861f-bd13097d54b5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9 (torch)",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
