{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2ff281a-6cd9-4875-ac3a-1463ac6db9c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from scipy.stats import pearsonr\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d633dd0e-e788-4246-91e1-145ddbca09f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset('data', name='MUTAG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0f44246-67cb-4151-bc34-185b91e7dca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_eigenvalues_sum(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    # Take log of eigenvalues, square them, and sum\n",
    "    eigenvalues_sum = np.sum(np.log(eigenvalues)**2)\n",
    "    \n",
    "    return eigenvalues_sum\n",
    "\n",
    "def calculate_eigenvalues(a, b):\n",
    "    # Calculate m = a^-1 * b\n",
    "    m = np.linalg.inv(a) @ b\n",
    "    \n",
    "    # Calculate eigenvalues of m\n",
    "    eigenvalues, _ = np.linalg.eig(m)\n",
    "    \n",
    "    return eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20519828-713a-4aec-b37b-eeb4bf7319be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grampa(A, B, eta):\n",
    "    n = A.shape[0]\n",
    "\n",
    "    Lambda, U = eig(A)\n",
    "    Mu, V = eig(B) # eigen vectors and eigen values\n",
    "\n",
    "    temp = (np.tile(Lambda, (n, 1)) - np.tile(Mu, (n, 1)).T)**2 + eta**2 # (lambda - mu').^2 + eta^2;\n",
    "\n",
    "    coeff = 1 / temp\n",
    "    coeff = coeff * (U.T @ np.ones((n, n)) @ V) # coeff .* (U' * ones(n) * V);\n",
    "    X = U @ coeff @ V.T\n",
    "    Y = np.real(X.T) # linear_sum_assignment doesnt work on complex values\n",
    "\n",
    "    row_ind, col_ind = linear_sum_assignment(Y, maximize = True) # cost maximization\n",
    "    P = np.zeros((n, n))\n",
    "    P[row_ind, col_ind] = 1\n",
    "\n",
    "    return P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a08ce5c-6749-4269-a087-3b3f6f8d367f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def edgeindexToAdjmat(edge_tensor):\n",
    "    edge_lis = edge_tensor.numpy()\n",
    "    edge_list = edge_lis[:, (edge_lis < 100).all(axis=0)]\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    # Create an empty adjacency matrix\n",
    "    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    # Populate the adjacency matrix based on the edges\n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix[source_node, target_node] = 1\n",
    "\n",
    "    return adj_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9b984fc-b55a-4119-86ce-c1cf44ea2b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ggd(num_1, num_2):\n",
    "    g1 = dataset[num_1]\n",
    "    g2 = dataset[num_2]\n",
    "    line_index = num_2\n",
    "    \n",
    "    length1 = len(g1.x)\n",
    "    length2 = len(g2.x)\n",
    "    length = min(length1, length2)\n",
    "    \n",
    "    if length1 > length2:\n",
    "        g2 = dataset[num_1]\n",
    "        g1 = dataset[num_2]\n",
    "        line_index = num_1\n",
    "        \n",
    "    edge_tensor = g1.edge_index\n",
    "    edge_list = edge_tensor.numpy()\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix1 = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix1[source_node, target_node] = 1\n",
    "\n",
    "    file_path = 'coarsening/output/' + str(length)+ '_mutag_coarsened.txt'\n",
    "    with open(file_path, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "        line1 = lines[line_index*2].strip()\n",
    "        nline1 = [int(num) for num in line1.split(',')]\n",
    "        line2 = lines[line_index*2+1].strip()\n",
    "        nline2 = [int(num) for num in line2.split(',')]\n",
    "        \n",
    "    edge_list = [nline1, nline2]\n",
    "    \n",
    "    num_nodes = np.max(edge_list) + 1\n",
    "    \n",
    "    adj_matrix2 = np.zeros((num_nodes, num_nodes), dtype=int)\n",
    "    \n",
    "    for i in range(edge_list.shape[1]):\n",
    "        source_node = edge_list[0, i]\n",
    "        target_node = edge_list[1, i]\n",
    "        adj_matrix2[source_node, target_node] = 1\n",
    "\n",
    "    eta = 0.35\n",
    "    P = grampa(adj_matrix1, adj_matrix2, eta)\n",
    "\n",
    "    adj_matrix2 = P @ adj_matrix1 @ P.T\n",
    "    \n",
    "    # for i in range(edge_list.shape[1]):\n",
    "    #     source_node = edge_list[0, i]\n",
    "    #     target_node = edge_list[1, i]\n",
    "    #     adj_matrix[source_node, target_node] = 1\n",
    "    \n",
    "    # # nnl1 = normalized_laplacian(adj_matrix)\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix1, axis=1))\n",
    "    l1 = degree_matrix - adj_matrix1\n",
    "    \n",
    "    \n",
    "    # for i in range(edge_list.shape[1]):\n",
    "    #     source_node = edge_list[0, i]\n",
    "    #     target_node = edge_list[1, i]\n",
    "    #     adj_matrix[source_node, target_node] = 1\n",
    "    \n",
    "    # nnl2 = normalized_laplacian(adj_matrix)\n",
    "    \n",
    "    degree_matrix = np.diag(np.sum(adj_matrix2, axis=1))\n",
    "    l2 = degree_matrix - adj_matrix2\n",
    "    \n",
    "    shape = l1.shape\n",
    "    a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl1 = l1 + diagonal_matrix\n",
    "    \n",
    "    shape = l2.shape\n",
    "    # a = 0.0001  # You can change this to any value you want\n",
    "    diagonal_matrix = np.diag([a] * min(shape))\n",
    "    nl2 = l2 + diagonal_matrix\n",
    "    \n",
    "    # nnll1 = nnl1 + diagonal_matrix\n",
    "    # nnll2 = nnl2 + diagonal_matrix\n",
    "    \n",
    "    return(calculate_eigenvalues_sum(nl1, nl2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ca36b3-2a12-40bb-b724-c1b9a069dd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import BatchNorm1d, Linear, ReLU, Sequential, BCEWithLogitsLoss\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINConv, global_add_pool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "255f4e24-1e3d-44b0-8443-319446f47560",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TU Dataset\n",
    "dataset = TUDataset('data', name='MUTAG').shuffle()\n",
    "train_dataset = dataset[len(dataset) // 10:]\n",
    "test_dataset = dataset[:len(dataset) // 10]\n",
    "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b37222-ece7-496b-a12c-edaa6ec7c230",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    '''\n",
    "    3-layer GIN Network\n",
    "    '''\n",
    "    def __init__(self, in_channels, dim, out_channels, L):\n",
    "        super().__init__()\n",
    "        conv1 = GINConv(\n",
    "            Sequential(Linear(in_channels, dim), ReLU(),\n",
    "                       Linear(dim, dim), ReLU()))\n",
    "        conv2 = GINConv(\n",
    "            Sequential(Linear(dim, dim), ReLU(),\n",
    "                       Linear(dim, dim), ReLU()))\n",
    "        conv3 = GINConv(\n",
    "            Sequential(Linear(dim, dim), ReLU(),\n",
    "                       Linear(dim, dim), ReLU()))\n",
    "        self.convs = [conv1, conv2, conv3]\n",
    "        self.lin1 = Linear(dim, dim)\n",
    "        self.lin2 = Linear(dim, 1)\n",
    "        self.L = L\n",
    "\n",
    "    def forward(self, x, edge_index, batch):        \n",
    "        for l in range(int(self.L-1)):\n",
    "            x = self.convs[l](x, edge_index)\n",
    "        x = global_add_pool(x, batch)\n",
    "        x = self.lin1(x).relu()\n",
    "        x = self.lin2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a34a794f-bf3a-4ed6-8b0c-90e922c132eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = Net(dataset.num_features, 32, dataset.num_classes, args.L).to(device)\n",
    "for l in range(int(args.L-1)):\n",
    "    model.convs[l] = model.convs[l].to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "criterion = BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df2c71b2-701a-47fe-a953-e7e6b624b047",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        loss = criterion(output[:,0], data.y.float())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += float(loss) * data.num_graphs\n",
    "    return total_loss / len(train_loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9ebc5a-df10-47a8-a1f3-63e502f44f2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "\n",
    "    total_correct = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        out = model(data.x, data.edge_index, data.batch)\n",
    "        total_correct += int(((F.sigmoid(out[:, 0]) > 0.5).int() == data.y).sum())\n",
    "    return total_correct / len(loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3abd9a24-251c-490f-b872-b2340781ce10",
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 101):\n",
    "    loss = train()\n",
    "    train_acc = test(train_loader)\n",
    "    test_acc = test(test_loader)\n",
    "    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f} '\n",
    "          f'Test Acc: {test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47399be4-f9d2-47a8-a7b7-edcb35f60c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "oo = []\n",
    "tt = []\n",
    "for i in tqdm(range(200)):\n",
    "    a = random.randint(0, len(dataset)-1)\n",
    "    if i == 2: b = a\n",
    "    else: b = random.randint(0, len(dataset)-1)\n",
    "    g_a = dataset[a]#.cuda()\n",
    "    g_b = dataset[b]#.cuda()\n",
    "\n",
    "    # output from GIN\n",
    "    output_a = model(g_a.x, g_a.edge_index, torch.zeros(len(g_a.x), dtype=torch.int64).cpu())\n",
    "    output_b = model(g_b.x, g_b.edge_index, torch.zeros(len(g_b.x), dtype=torch.int64).cpu())\n",
    "\n",
    "    # print(output_a - output_b)\n",
    "\n",
    "    ggd_val = ggd(a, b)\n",
    "\n",
    "    oo.append(float(torch.norm(output_a - output_b).cpu().detach().numpy()))\n",
    "    tt.append(ggd_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42565f84-899c-4cf0-8dec-df2d3bd580bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.scatter(oo, tt)\n",
    "plt.savefig('gnn_plot.png', dpi=120)\n",
    "print('Pearson correlation: {}'.format(pearsonr(np.array(oo), np.array(tt))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e848e7-0d8c-4004-b00a-6712b5e22983",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54b28581-1600-4175-b8cf-79a9e3af3b9f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96ac6a5d-676e-4ffd-854f-5848ecb70fe1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
