{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3ff4f1-fa66-4a59-ac6d-8e8ac0e22330",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from scipy.stats import pearsonr\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbe1dbda-643e-4e1c-982f-1c3b3f4845ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset('data', name='MUTAG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ee8f1a9-2194-4e78-a11f-687bf24108dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "length_dict = {}\n",
    "\n",
    "for graph_num in range(len(dataset)):\n",
    "  length = len(dataset[graph_num].x)\n",
    "\n",
    "  if length in length_dict: length_dict[length].append(graph_num)\n",
    "  else: length_dict[length] = [graph_num]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c77494c6-71a0-4749-9150-49bed5df2a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# length_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2957a5b5-fd62-4797-8334-ba2544e424fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairs = []\n",
    "for value in length_dict.values():\n",
    "    # print(value)\n",
    "    # print(len(value))\n",
    "    if len(value) >= 2:\n",
    "        for i in range(len(value)):\n",
    "            for j in range(i+1, len(value)):\n",
    "                pair = [value[i], value[j]]\n",
    "                pairs.append(pair)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e867fd66-5983-4784-9eb7-b7f3a9e7483a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "945852f3-df59-482d-b0b9-096a6d1d78ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_eigenvalues_sum(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    # Take log of eigenvalues, square them, and sum\n",
    "    eigenvalues_sum = np.sum(np.log(eigenvalues)**2)\n",
    "    \n",
    "    return eigenvalues_sum\n",
    "\n",
    "def calculate_eigenvalues(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    return eigenvalues\n",
    "\n",
    "def calculate_partial_eigenvalues_sum(a, b, n=2):\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    sorted_eigenvalues = sorted(eigenvalues, reverse=True)\n",
    "    ind = int(n/2)\n",
    "    trimed_eigenvalue_list_1 = sorted_eigenvalues[:ind]\n",
    "    trimed_eigenvalue_list_2 = sorted_eigenvalues[-1*ind:]\n",
    "    return np.sum(np.log(trimed_eigenvalue_list_1)**2) + np.sum(np.log(trimed_eigenvalue_list_2)**2)\n",
    "\n",
    "\n",
    "def calculate_partial(a, b, n=2):\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    logvalues = np.log(eigenvalues)\n",
    "    log2values = logvalues**2\n",
    "    sorted_vales = sorted(log2values, reverse = True)\n",
    "    trimed = sorted_vales[:n]\n",
    "    return np.sum(trimed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3209ebd8-d669-4edc-9d56-cc7b90d9a6d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "riem_vals = []\n",
    "partial_vals_2 = []\n",
    "\n",
    "\n",
    "for pair in pairs:\n",
    "    g1 = dataset[pair[0]]\n",
    "    g2 = dataset[pair[1]]\n",
    "\n",
    "    edge_tensor = g1.edge_index\n",
    "    edge_list = edge_tensor.numpy()\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix, axis=1))\n",
    "    l1 = degree_matrix - adj_matrix\n",
    "    \n",
    "    edge_tensor = g2.edge_index\n",
    "    edge_list = edge_tensor.numpy()\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix, axis=1))\n",
    "    l2 = degree_matrix - adj_matrix\n",
    "    \n",
    "    shape = l1.shape\n",
    "    a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl1 = l1 + diagonal_matrix\n",
    "    \n",
    "    shape = l2.shape\n",
    "    # a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl2 = l2 + diagonal_matrix\n",
    "    \n",
    "    riem_vals.append(calculate_eigenvalues_sum(nl1, nl2))\n",
    "    partial_vals_2.append(calculate_partial_eigenvalues_sum(nl1, nl2, 6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da66cd5b-032c-4f24-b883-ad383ffbbb9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "\n",
    "# Extract real parts\n",
    "riem_vals_real = [np.real(val) for val in riem_vals]\n",
    "partial_vals_2_real = [np.real(val) for val in partial_vals_2]\n",
    "\n",
    "# Scatter plot\n",
    "plt.scatter(riem_vals_real, partial_vals_2_real, color='blue', label='Data')\n",
    "plt.xlabel('riem distance')\n",
    "plt.ylabel('riem distance with 2 eigenvalues')\n",
    "plt.title('mutag')\n",
    "plt.show()\n",
    "\n",
    "# Calculate Pearson correlation coefficient\n",
    "correlation_coefficient, _ = pearsonr(riem_vals_real, partial_vals_2_real)\n",
    "print(\"Pearson correlation coefficient:\", correlation_coefficient)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b60f8b1-9214-467a-9e3e-12c79d97c811",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "\n",
    "# Extract real parts\n",
    "riem_vals_real = [np.real(val) for val in riem_vals]\n",
    "partial_vals_2_real = [np.real(val) for val in partial_vals_2]\n",
    "\n",
    "# Scatter plot\n",
    "plt.scatter(riem_vals_real, partial_vals_2_real, color='blue', label='Data')\n",
    "plt.xlabel('riem distance')\n",
    "plt.ylabel('riem distance with 2 eigenvalues')\n",
    "plt.title('pc-3h')\n",
    "plt.show()\n",
    "\n",
    "# Calculate Pearson correlation coefficient\n",
    "correlation_coefficient, _ = pearsonr(riem_vals_real, partial_vals_2_real)\n",
    "print(\"Pearson correlation coefficient:\", correlation_coefficient)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa814ac2-94ea-4a5d-b61b-c8a516504d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "\n",
    "# Extract real parts\n",
    "riem_vals_real = [np.real(val) for val in riem_vals]\n",
    "partial_vals_2_real = [np.real(val) for val in partial_vals_2]\n",
    "\n",
    "# Scatter plot\n",
    "plt.scatter(riem_vals_real, partial_vals_2_real, color='blue', label='Data')\n",
    "plt.xlabel('riem distance')\n",
    "plt.ylabel('riem distance with 2 eigenvalues')\n",
    "plt.title('pc-3h')\n",
    "plt.show()\n",
    "\n",
    "# Calculate Pearson correlation coefficient\n",
    "correlation_coefficient, _ = pearsonr(riem_vals_real, partial_vals_2_real)\n",
    "print(\"Pearson correlation coefficient:\", correlation_coefficient)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415d5021-7652-427e-b6b1-a5866b55ff50",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
