{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "72f3cf56-a77a-4663-82c2-8590eb125705",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from scipy.stats import pearsonr\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "917efc0d-1292-4b30-86d5-2c8a045edf82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = TUDataset('data', name='SW-620H')\n",
    "dataset = TUDataset('data', name='MUTAG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e069d9e7-c40d-407e-8337-ad91f8c1ee9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_eigenvalues_sum(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    # Take log of eigenvalues, square them, and sum\n",
    "    eigenvalues_sum = np.sum(np.log(eigenvalues)**2)\n",
    "    \n",
    "    return eigenvalues_sum\n",
    "\n",
    "def calculate_eigenvalues(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    return eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7826ce1-5cb0-4ffc-b9d2-691383ad6c3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def edgeindexToAdjmat(edge_tensor):\n",
    "    edge_lis = edge_tensor.numpy()\n",
    "    edge_list = edge_lis[:, (edge_lis < 100).all(axis=0)]\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    # Create an empty adjacency matrix\n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    # Populate the adjacency matrix based on the edges\n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "\n",
    "    return adj_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea4901a-dd8f-4e0e-a5a0-2f3199813e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grampa(A, B, eta):\n",
    "    n = A.shape[0]\n",
    "\n",
    "    Lambda, U = eig(A)\n",
    "    Mu, V = eig(B) # eigen vectors and eigen values\n",
    "\n",
    "    temp = (np.tile(Lambda, (n, 1)) - np.tile(Mu, (n, 1)).T)**2 + eta**2 # (lambda - mu').^2 + eta^2;\n",
    "\n",
    "    coeff = 1 / temp\n",
    "    coeff = coeff * (U.T @ np.ones((n, n)) @ V) # coeff .* (U' * ones(n) * V);\n",
    "    X = U @ coeff @ V.T\n",
    "    Y = np.real(X.T) # linear_sum_assignment doesnt work on complex values\n",
    "\n",
    "    row_ind, col_ind = linear_sum_assignment(Y, maximize = True) # cost maximization\n",
    "    P = np.zeros((n, n))\n",
    "    P[row_ind, col_ind] = 1\n",
    "\n",
    "    return P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4cc7231-4c74-42eb-aacc-e31fb2e07895",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ggd(num_1, num_2):\n",
    "    g1 = dataset[num_1]\n",
    "    g2 = dataset[num_2]\n",
    "    line_index = num_2\n",
    "    \n",
    "    length1 = len(g1.x)\n",
    "    length2 = len(g2.x)\n",
    "    length = min(length1, length2)\n",
    "    \n",
    "    if length1 > length2:\n",
    "        g2 = dataset[num_1]\n",
    "        g1 = dataset[num_2]\n",
    "        line_index = num_1\n",
    "        \n",
    "    edge_tensor = g1.edge_index\n",
    "    edge_list = edge_tensor.numpy()\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix1 = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix1[source_node, target_node] = 1\n",
    "\n",
    "    file_path = 'coarsening/output/' + str(length)+ '_mutag_coarsened.txt'\n",
    "    with open(file_path, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "        line1 = lines[line_index*2].strip()\n",
    "        nline1 = [int(num) for num in line1.split(',')]\n",
    "        line2 = lines[line_index*2+1].strip()\n",
    "        nline2 = [int(num) for num in line2.split(',')]\n",
    "        \n",
    "    edge_list = [nline1, nline2]\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix2 = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix2[source_node, target_node] = 1\n",
    "\n",
    "    eta = 0.35\n",
    "    P = grampa(adj_matrix1, adj_matrix2, eta)\n",
    "\n",
    "    adj_matrix2 = P @ adj_matrix1 @ P.T\n",
    "    \n",
    "    # for i in range(edge_list.shape[1]):\n",
    "    #     source_node = edge_list[0, i]\n",
    "    #     target_node = edge_list[1, i]\n",
    "    #     adj_matrix[source_node, target_node] = 1\n",
    "    \n",
    "    # # nnl1 = normalized_laplacian(adj_matrix)\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix1, axis=1))\n",
    "    l1 = degree_matrix - adj_matrix1\n",
    "    \n",
    "    \n",
    "    # for i in range(edge_list.shape[1]):\n",
    "    #     source_node = edge_list[0, i]\n",
    "    #     target_node = edge_list[1, i]\n",
    "    #     adj_matrix[source_node, target_node] = 1\n",
    "    \n",
    "    # nnl2 = normalized_laplacian(adj_matrix)\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix2, axis=1))\n",
    "    l2 = degree_matrix - adj_matrix2\n",
    "    \n",
    "    shape = l1.shape\n",
    "    a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl1 = l1 + diagonal_matrix\n",
    "    \n",
    "    shape = l2.shape\n",
    "    # a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl2 = l2 + diagonal_matrix\n",
    "    \n",
    "    # nnll1 = nnl1 + diagonal_matrix\n",
    "    # nnll2 = nnl2 + diagonal_matrix\n",
    "    \n",
    "    return(calculate_eigenvalues_sum(nl1, nl2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba429b69-afc5-4bc7-8a48-ffece4364c96",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "41d66804-1933-4dee-99ff-27c7de0c6cbc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cdd6fcd-d198-4d06-a289-94c4149089fb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f8b7b21-e7c7-4685-92f1-55c7b3726a24",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9be7b6f-296e-489d-8faa-7dbdf7f7893e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
