{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bfdb60a-8800-4d4e-b346-d0222e9a88d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from scipy.stats import pearsonr\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c9c8fdd-83e6-4521-975c-7f7631f5919b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = TUDataset('data', name='SW-620H')\n",
    "dataset = TUDataset('data', name='MUTAG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd9e9023-2593-4276-bd7f-6c85ddc3dfb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "length_dict = {}\n",
    "\n",
    "for graph_num in range(len(dataset)):\n",
    "  length = len(dataset[graph_num].x)\n",
    "\n",
    "  if length in length_dict: length_dict[length].append(graph_num)\n",
    "  else: length_dict[length] = [graph_num]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bc655c8-c29f-43ad-bc09-d37549b40b62",
   "metadata": {},
   "outputs": [],
   "source": [
    "length_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "182ee690-0520-404e-9745-269481161813",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairs = []\n",
    "for value in length_dict.values():\n",
    "    # print(value)\n",
    "    # print(len(value))\n",
    "    if len(value) >= 2:\n",
    "        for i in range(len(value)):\n",
    "            for j in range(i+1, len(value)):\n",
    "                pair = [value[i], value[j]]\n",
    "                pairs.append(pair)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b66ff58-66f0-4638-9d27-37c4dd526d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_eigenvalues_sum(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    # Take log of eigenvalues, square them, and sum\n",
    "    eigenvalues_sum = np.sum(np.log(eigenvalues)**2)\n",
    "    \n",
    "    return eigenvalues_sum\n",
    "\n",
    "def calculate_eigenvalues(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    return eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f49ccbc-a8c0-4153-b2d8-eb3657b5b37a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.linalg import logm, norm\n",
    "\n",
    "def log_euclidean_distance(A, B):\n",
    "    # Compute the matrix logarithms of A and B\n",
    "    log_A = logm(A)\n",
    "    log_B = logm(B)\n",
    "    # print(log_A)\n",
    "    # print(log_B)\n",
    "    \n",
    "    # Compute the Frobenius norm of the difference between log_A and log_B\n",
    "    distance = norm(log_A - log_B, 'fro')\n",
    "    \n",
    "    return distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f181c43-0be9-4431-862a-cefc8c90c514",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "le_vals = []\n",
    "ai_vals = []\n",
    "\n",
    "import random\n",
    "\n",
    "# random.shuffle(pairs)\n",
    "for pair in pairs:\n",
    "    g1 = dataset[pair[0]]\n",
    "    g2 = dataset[pair[1]]\n",
    "\n",
    "    edge_tensor = g1.edge_index\n",
    "    edge_list = edge_tensor.numpy()\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "\n",
    "    # nnl1 = normalized_laplacian(adj_matrix)\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix, axis=1))\n",
    "    l1 = degree_matrix - adj_matrix\n",
    "    \n",
    "    edge_tensor = g2.edge_index\n",
    "    edge_list = edge_tensor.numpy()\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "\n",
    "    # nnl2 = normalized_laplacian(adj_matrix)\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix, axis=1))\n",
    "    l2 = degree_matrix - adj_matrix\n",
    "    \n",
    "    shape = l1.shape\n",
    "    a = 0.1  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl1 = l1 + diagonal_matrix\n",
    "    \n",
    "    shape = l2.shape\n",
    "    # a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl2 = l2 + diagonal_matrix\n",
    "\n",
    "    # nnll1 = nnl1 + diagonal_matrix\n",
    "    # nnll2 = nnl2 + diagonal_matrix\n",
    "    \n",
    "    # ai_vals.append(calculate_eigenvalues_sum(nl1, nl2))\n",
    "    le_vals.append(log_euclidean_distance(nl1, nl2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0dedf81-b165-4ab3-94c4-d1efd6b19bd7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.scatter(ai_vals, le_vals, color='blue', label='Data Points', s=10)\n",
    "# plt.title('Scatter Plot of tmd_vals vs riem_vals')\n",
    "plt.xlabel('GGD using AI metric')\n",
    "plt.ylabel('GGD using LE metric')\n",
    "# plt.legend()\n",
    "# plt.grid(True)\n",
    "# plt.show()\n",
    "plt.savefig('aivsle.png', dpi= 600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771daffb-9ccc-4003-87d3-900ae717d1d0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58fc1275-394c-4e46-b72c-844b79794f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "# Assuming ai_vals and le_vals are your arrays with complex numbers\n",
    "# Extract the real parts of both arrays\n",
    "ai_vals_real = np.real(ai_vals)\n",
    "le_vals_real = np.real(le_vals)\n",
    "\n",
    "# Calculate the Pearson correlation using the real parts\n",
    "correlation_coefficient, p_value = pearsonr(ai_vals_real, le_vals_real)\n",
    "\n",
    "print(\"Pearson correlation coefficient:\", correlation_coefficient)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6948028-643b-4f7f-bde5-062ca6c1dc4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# a = 0.00001  # You can change this to any value you want\n",
    "# diagonal_matrix = np.diag([a] * 6)\n",
    "# print(diagonal_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d045fa17-4f66-40b5-bcd6-674d757a9041",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
