{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5858c994-317a-4bcb-95b7-29e3a3eedfbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoMaskGCNConv(GCNConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard mask\n",
    "        return super().call(inputs, mask=None)\n",
    "       \n",
    "class GCN(tf.keras.Model):\n",
    "    def __init__(self, n_labels, seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "        self.conv1 = NoMaskGCNConv(16, activation='relu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGCNConv(n_labels, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings\n",
    "\n",
    "from spektral.layers import GATConv\n",
    "import tensorflow as tf\n",
    "\n",
    "# Define a custom wrapper for GATConv that avoids mask issues\n",
    "class NoMaskGATConv(GATConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard the mask argument\n",
    "        return super().call(inputs, mask=None)\n",
    "\n",
    "\n",
    "# Define the GAT model using the NoMaskGATConv\n",
    "class GAT(tf.keras.Model):\n",
    "    def __init__(self, n_labels, num_heads=8, seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        # Use the custom NoMaskGATConv instead of the original GATConv\n",
    "        self.conv1 = NoMaskGATConv(16, attn_heads=num_heads, concat_heads=True, activation='elu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGATConv(n_labels, attn_heads=1, concat_heads=False, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings\n",
    "\n",
    "# Define the GraphSAGE model\n",
    "class GraphSAGE(tf.keras.Model):\n",
    "    def __init__(self, n_labels, hidden_dim=16, aggregator='mean', seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        self.conv1 = GraphSageConv(hidden_dim, activation='relu', aggregator=aggregator, kernel_initializer=initializer)\n",
    "        self.conv2 = GraphSageConv(n_labels, activation='softmax', aggregator=aggregator, kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings\n",
    "\n",
    "classifiers=['gcn','gat','graphsage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "79f37499-55ed-45c4-b1cc-4caacf4d0cfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3d3005b0-168d-4dcf-95dd-0ee081beed0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Processing dataset: AmazonPhotos ===\n",
      "\n",
      "Processing split: 30_70\n",
      "\n",
      "Running classifier: GCN\n",
      "\n",
      "Seed: 42\n",
      "Epoch 20, Loss: 1.9806, Accuracy: 0.2537\n",
      "Epoch 40, Loss: 1.6833, Accuracy: 0.3896\n",
      "Epoch 60, Loss: 1.5538, Accuracy: 0.4112\n",
      "Epoch 80, Loss: 1.4160, Accuracy: 0.4591\n",
      "Epoch 100, Loss: 1.2314, Accuracy: 0.5706\n",
      "Epoch 120, Loss: 0.9836, Accuracy: 0.6676\n",
      "Epoch 140, Loss: 0.8322, Accuracy: 0.7047\n",
      "Epoch 160, Loss: 0.7299, Accuracy: 0.7499\n",
      "Epoch 180, Loss: 0.6656, Accuracy: 0.7897\n",
      "Epoch 200, Loss: 0.6232, Accuracy: 0.7986\n",
      "[[  5  19   1  71  11  25   2 107]\n",
      " [  6 639  11  75 109  27  68 244]\n",
      " [  3  65 130  27  57  14  94  93]\n",
      " [  6  42   5  86 286   6  34 189]\n",
      " [  1   6   1   2 373   8 176  34]\n",
      " [  4  31   4   0   0 265 116 135]\n",
      " [  5  60  37  23 698  17 432  69]\n",
      " [  2   2   0   1  90   4  19 104]]\n",
      "Seed 42 execution time: 4.41 seconds\n",
      "\n",
      "Seed: 46\n",
      "Epoch 20, Loss: 1.9316, Accuracy: 0.2838\n",
      "Epoch 40, Loss: 1.7720, Accuracy: 0.2905\n",
      "Epoch 60, Loss: 1.4769, Accuracy: 0.5371\n",
      "Epoch 80, Loss: 1.0289, Accuracy: 0.6348\n",
      "Epoch 100, Loss: 0.8437, Accuracy: 0.7126\n",
      "Epoch 120, Loss: 0.7356, Accuracy: 0.7573\n",
      "Epoch 140, Loss: 0.6569, Accuracy: 0.7953\n",
      "Epoch 160, Loss: 0.6036, Accuracy: 0.8130\n",
      "Epoch 180, Loss: 0.5668, Accuracy: 0.8254\n",
      "Epoch 200, Loss: 0.5368, Accuracy: 0.8351\n",
      "[[  68   20    4   21    0    1  144    0]\n",
      " [  19  904   48   18   12   20  127    5]\n",
      " [   0   47  349    2   57    0   16    0]\n",
      " [   3  398   50   85    7    0   62   14]\n",
      " [   0   38   71   11  238    0  256    0]\n",
      " [   7  110   13    3   16  246  190    0]\n",
      " [  38   85   59    4   62   11 1088    1]\n",
      " [   3  149   11    5    0    0   57    0]]\n",
      "Seed 46 execution time: 3.93 seconds\n",
      "\n",
      "Seed: 123\n",
      "Epoch 20, Loss: 1.8164, Accuracy: 0.3022\n",
      "Epoch 40, Loss: 1.5268, Accuracy: 0.4270\n",
      "Epoch 60, Loss: 1.1097, Accuracy: 0.6334\n",
      "Epoch 80, Loss: 0.8628, Accuracy: 0.7022\n",
      "Epoch 100, Loss: 0.7759, Accuracy: 0.7547\n",
      "Epoch 120, Loss: 0.7200, Accuracy: 0.7750\n",
      "Epoch 140, Loss: 0.6767, Accuracy: 0.7847\n",
      "Epoch 160, Loss: 0.6455, Accuracy: 0.7975\n",
      "Epoch 180, Loss: 0.6130, Accuracy: 0.8064\n",
      "Epoch 200, Loss: 0.5838, Accuracy: 0.8130\n",
      "[[ 39  38   3  42  21  20  82  16]\n",
      " [ 11 574 270 201  78  12  15   7]\n",
      " [  0 146 311   7   9   0   2   0]\n",
      " [  3 109   4 432  66   3   2   5]\n",
      " [  0  19  70  81 436   3   3   1]\n",
      " [  0 204   9   4   6 349  10   0]\n",
      " [  5  64  31 246 748   2 215  12]\n",
      " [  2  13   1 155  43   0   0   8]]\n",
      "Seed 123 execution time: 4.73 seconds\n",
      "\n",
      "Seed: 2025\n",
      "Epoch 20, Loss: 1.7978, Accuracy: 0.3496\n",
      "Epoch 40, Loss: 1.6940, Accuracy: 0.4252\n",
      "Epoch 60, Loss: 1.4899, Accuracy: 0.5146\n",
      "Epoch 80, Loss: 1.1720, Accuracy: 0.6217\n",
      "Epoch 100, Loss: 0.9205, Accuracy: 0.6991\n",
      "Epoch 120, Loss: 0.7797, Accuracy: 0.7482\n",
      "Epoch 140, Loss: 0.7016, Accuracy: 0.7677\n",
      "Epoch 160, Loss: 0.6531, Accuracy: 0.7854\n",
      "Epoch 180, Loss: 0.6209, Accuracy: 0.7956\n",
      "Epoch 200, Loss: 0.5957, Accuracy: 0.8080\n",
      "[[ 43   9   2  32  41   0 132   3]\n",
      " [  3 592  27 138 169  12 182  35]\n",
      " [  0  99 260  69  46   1  16   0]\n",
      " [  1  32   7 109 398   1  68  14]\n",
      " [  0  21   4   5 454   0 109  21]\n",
      " [  0  38   7   0 153 310  36   1]\n",
      " [ 20  54   3   9 314   0 824 124]\n",
      " [  0   5   0   3  74   0 141   4]]\n",
      "Seed 2025 execution time: 3.73 seconds\n",
      "\n",
      "Seed: 999\n",
      "Epoch 20, Loss: 2.0435, Accuracy: 0.2052\n",
      "Epoch 40, Loss: 1.9748, Accuracy: 0.2309\n",
      "Epoch 60, Loss: 1.7495, Accuracy: 0.3183\n",
      "Epoch 80, Loss: 1.4523, Accuracy: 0.4672\n",
      "Epoch 100, Loss: 1.1751, Accuracy: 0.5838\n",
      "Epoch 120, Loss: 1.2407, Accuracy: 0.5395\n",
      "Epoch 140, Loss: 0.9034, Accuracy: 0.7008\n",
      "Epoch 160, Loss: 0.7747, Accuracy: 0.7522\n",
      "Epoch 180, Loss: 0.6944, Accuracy: 0.7832\n",
      "Epoch 200, Loss: 0.6352, Accuracy: 0.8023\n",
      "[[  0 173  60   1   0   0   3  22]\n",
      " [  4 875 100  28  14   7  44 129]\n",
      " [  0  27 402   0   1  12   3  20]\n",
      " [  0 270  21  56   5   2  17 259]\n",
      " [  0 204  15  16  24   0  52 311]\n",
      " [  0  27 106   1   0 316 102   7]\n",
      " [  2 222  60  36  24   6  75 899]\n",
      " [  0  14   5  11   0   0   0 189]]\n",
      "Seed 999 execution time: 4.32 seconds\n",
      "\n",
      "Running classifier: GAT\n",
      "\n",
      "Seed: 42\n",
      "Epoch 20, Loss: 1.5791, Accuracy: 0.4272\n",
      "Epoch 40, Loss: 0.8761, Accuracy: 0.7742\n",
      "Epoch 60, Loss: 0.4418, Accuracy: 0.8792\n",
      "Epoch 80, Loss: 0.2327, Accuracy: 0.9318\n",
      "Epoch 100, Loss: 0.1376, Accuracy: 0.9664\n",
      "Epoch 120, Loss: 0.0852, Accuracy: 0.9792\n",
      "Epoch 140, Loss: 0.0565, Accuracy: 0.9889\n",
      "Epoch 160, Loss: 0.0394, Accuracy: 0.9934\n",
      "Epoch 180, Loss: 0.0285, Accuracy: 0.9951\n",
      "Epoch 200, Loss: 0.0212, Accuracy: 0.9956\n",
      "[[164  41   0   4  20   4   6   2]\n",
      " [ 77 495  80 122  89  51 163 102]\n",
      " [ 72  73  81  42  54   3  48 110]\n",
      " [132  63   6 162 160  38  39  54]\n",
      " [ 68  23  11   6 350   4 123  16]\n",
      " [ 73  33  18  19 154 137  19 102]\n",
      " [317  30   9 123 414  17 382  49]\n",
      " [ 34   3   0   5  44   3  41  92]]\n",
      "Seed 42 execution time: 15.85 seconds\n",
      "\n",
      "Seed: 46\n",
      "Epoch 20, Loss: 1.6395, Accuracy: 0.4169\n",
      "Epoch 40, Loss: 0.8784, Accuracy: 0.7604\n",
      "Epoch 60, Loss: 0.4480, Accuracy: 0.8842\n",
      "Epoch 80, Loss: 0.2384, Accuracy: 0.9315\n",
      "Epoch 100, Loss: 0.1367, Accuracy: 0.9602\n",
      "Epoch 120, Loss: 0.0810, Accuracy: 0.9801\n",
      "Epoch 140, Loss: 0.0497, Accuracy: 0.9889\n",
      "Epoch 160, Loss: 0.0334, Accuracy: 0.9929\n",
      "Epoch 180, Loss: 0.0246, Accuracy: 0.9938\n",
      "Epoch 200, Loss: 0.0188, Accuracy: 0.9956\n",
      "[[  75    7  122    3    0    1   31   19]\n",
      " [  35  242  135  126   46   92  410   67]\n",
      " [  23   20  370   12    9    5   19   13]\n",
      " [   7   54  111   77   30    9  273   58]\n",
      " [   6   22  104   24  185    6  257   10]\n",
      " [   2   22   69   24   25  219  217    7]\n",
      " [  17   28   96   68   51   50 1004   34]\n",
      " [   0    4   28    8    3    6  122   54]]\n",
      "Seed 46 execution time: 16.06 seconds\n",
      "\n",
      "Seed: 123\n",
      "Epoch 20, Loss: 1.6034, Accuracy: 0.4230\n",
      "Epoch 40, Loss: 0.8881, Accuracy: 0.7587\n",
      "Epoch 60, Loss: 0.4773, Accuracy: 0.8619\n",
      "Epoch 80, Loss: 0.2747, Accuracy: 0.9131\n",
      "Epoch 100, Loss: 0.1678, Accuracy: 0.9471\n",
      "Epoch 120, Loss: 0.1066, Accuracy: 0.9696\n",
      "Epoch 140, Loss: 0.0705, Accuracy: 0.9815\n",
      "Epoch 160, Loss: 0.0507, Accuracy: 0.9859\n",
      "Epoch 180, Loss: 0.0383, Accuracy: 0.9885\n",
      "Epoch 200, Loss: 0.0310, Accuracy: 0.9899\n",
      "[[123  19   7   1   9   4  41  57]\n",
      " [ 60 562 136  87  34  20 131 138]\n",
      " [ 13  75 335  23   5   2   7  15]\n",
      " [ 87 120  19 133  21  11  59 174]\n",
      " [ 31  86  80  25  49   6  66 270]\n",
      " [ 32 159 173  17  28 113  20  40]\n",
      " [ 12  50  25  30  40   9 949 208]\n",
      " [ 16  13   6   8   1   2  20 156]]\n",
      "Seed 123 execution time: 15.78 seconds\n",
      "\n",
      "Seed: 2025\n",
      "Epoch 20, Loss: 1.6278, Accuracy: 0.4469\n",
      "Epoch 40, Loss: 0.9652, Accuracy: 0.7562\n",
      "Epoch 60, Loss: 0.4987, Accuracy: 0.8642\n",
      "Epoch 80, Loss: 0.2754, Accuracy: 0.9177\n",
      "Epoch 100, Loss: 0.1570, Accuracy: 0.9540\n",
      "Epoch 120, Loss: 0.0984, Accuracy: 0.9704\n",
      "Epoch 140, Loss: 0.0666, Accuracy: 0.9814\n",
      "Epoch 160, Loss: 0.0467, Accuracy: 0.9872\n",
      "Epoch 180, Loss: 0.0355, Accuracy: 0.9885\n",
      "Epoch 200, Loss: 0.0290, Accuracy: 0.9912\n",
      "[[  8   9  54  25  14   2 142   8]\n",
      " [  6 188 228 152 164  73 282  65]\n",
      " [  0   9 139  44 122  15 113  49]\n",
      " [  4  33 124 189 158  26  70  26]\n",
      " [ 19  11  43  55 171   1 278  36]\n",
      " [  1  26  26  10  84 213 180   5]\n",
      " [  7  31  49 129  85   2 970  75]\n",
      " [  0   5  28  45  25   2  69  53]]\n",
      "Seed 2025 execution time: 15.65 seconds\n",
      "\n",
      "Seed: 999\n",
      "Epoch 20, Loss: 1.5921, Accuracy: 0.4149\n",
      "Epoch 40, Loss: 0.8424, Accuracy: 0.7886\n",
      "Epoch 60, Loss: 0.4114, Accuracy: 0.8954\n",
      "Epoch 80, Loss: 0.2050, Accuracy: 0.9455\n",
      "Epoch 100, Loss: 0.1140, Accuracy: 0.9681\n",
      "Epoch 120, Loss: 0.0692, Accuracy: 0.9809\n",
      "Epoch 140, Loss: 0.0478, Accuracy: 0.9863\n",
      "Epoch 160, Loss: 0.0343, Accuracy: 0.9907\n",
      "Epoch 180, Loss: 0.0252, Accuracy: 0.9929\n",
      "Epoch 200, Loss: 0.0190, Accuracy: 0.9942\n",
      "[[  0  40  40   3   3   0 106  67]\n",
      " [  4 490 168  88  74  41 164 172]\n",
      " [  0  11 352   3  15   7  55  22]\n",
      " [  1 116  85 132  54  11 107 124]\n",
      " [  1  91  84  54  89  13 139 151]\n",
      " [  3  33 178   5  24 181  57  78]\n",
      " [  3  77 133  60  33  13 794 211]\n",
      " [  2   9  17  10   1   1  29 150]]\n",
      "Seed 999 execution time: 16.68 seconds\n",
      "\n",
      "Running classifier: GRAPHSAGE\n",
      "\n",
      "Seed: 42\n",
      "Epoch 20, Loss: 1.7551, Accuracy: 0.4511\n",
      "Epoch 40, Loss: 1.5232, Accuracy: 0.7703\n",
      "Epoch 60, Loss: 1.4198, Accuracy: 0.8623\n",
      "Epoch 80, Loss: 1.3663, Accuracy: 0.9044\n",
      "Epoch 100, Loss: 1.3350, Accuracy: 0.9354\n",
      "Epoch 120, Loss: 1.3166, Accuracy: 0.9473\n",
      "Epoch 140, Loss: 1.3047, Accuracy: 0.9544\n",
      "Epoch 160, Loss: 1.2964, Accuracy: 0.9602\n",
      "Epoch 180, Loss: 1.2915, Accuracy: 0.9646\n",
      "Epoch 200, Loss: 1.2862, Accuracy: 0.9672\n",
      "[[ 45  40  35  55  18  24  10  14]\n",
      " [ 11 632  43 191  75  73 123  31]\n",
      " [  0 100 185  66  21  10  56  45]\n",
      " [  3 112  27 305  75  70  31  31]\n",
      " [  4  54  18  48 281  37 104  55]\n",
      " [  1  44  19  15  38 367  48  23]\n",
      " [  7 100  66 241 170 137 523  97]\n",
      " [  1  17   7  83  14  16  43  41]]\n",
      "Seed 42 execution time: 13.13 seconds\n",
      "\n",
      "Seed: 46\n",
      "Epoch 20, Loss: 1.7753, Accuracy: 0.4164\n",
      "Epoch 40, Loss: 1.5395, Accuracy: 0.7423\n",
      "Epoch 60, Loss: 1.4258, Accuracy: 0.8581\n",
      "Epoch 80, Loss: 1.3731, Accuracy: 0.9041\n",
      "Epoch 100, Loss: 1.3439, Accuracy: 0.9288\n",
      "Epoch 120, Loss: 1.3269, Accuracy: 0.9434\n",
      "Epoch 140, Loss: 1.3093, Accuracy: 0.9540\n",
      "Epoch 160, Loss: 1.2980, Accuracy: 0.9598\n",
      "Epoch 180, Loss: 1.2908, Accuracy: 0.9682\n",
      "Epoch 200, Loss: 1.2853, Accuracy: 0.9699\n",
      "[[ 13  65  26  74   5   6  66   3]\n",
      " [ 11 717  26 133  64  40 150  12]\n",
      " [  0  90 247  49  46   8  26   5]\n",
      " [  0 172  11 313  24  19  63  17]\n",
      " [ 10 137  19  62 185  50 141  10]\n",
      " [  2  66  27  58  17 262 142  11]\n",
      " [  9 219  56 229  41  61 706  27]\n",
      " [  2  82   7  61  10   7  34  22]]\n",
      "Seed 46 execution time: 13.92 seconds\n",
      "\n",
      "Seed: 123\n",
      "Epoch 20, Loss: 1.7479, Accuracy: 0.4649\n",
      "Epoch 40, Loss: 1.5186, Accuracy: 0.7790\n",
      "Epoch 60, Loss: 1.4078, Accuracy: 0.8681\n",
      "Epoch 80, Loss: 1.3568, Accuracy: 0.9096\n",
      "Epoch 100, Loss: 1.3283, Accuracy: 0.9316\n",
      "Epoch 120, Loss: 1.3116, Accuracy: 0.9471\n",
      "Epoch 140, Loss: 1.2993, Accuracy: 0.9532\n",
      "Epoch 160, Loss: 1.2914, Accuracy: 0.9554\n",
      "Epoch 180, Loss: 1.2871, Accuracy: 0.9599\n",
      "Epoch 200, Loss: 1.2819, Accuracy: 0.9607\n",
      "[[ 61  34  50  35  21   8  48   4]\n",
      " [  1 730  63 127  64  57 106  20]\n",
      " [  0 141 249  37   7   4  28   9]\n",
      " [  2 147  32 290  42  29  54  28]\n",
      " [  7 186  53  34 170  37  93  33]\n",
      " [  0 282  32  29  26 136  64  13]\n",
      " [ 10 249  88  74  78  16 760  48]\n",
      " [  1  41  29  25  10   5  79  32]]\n",
      "Seed 123 execution time: 13.65 seconds\n",
      "\n",
      "Seed: 2025\n",
      "Epoch 20, Loss: 1.7983, Accuracy: 0.4204\n",
      "Epoch 40, Loss: 1.5528, Accuracy: 0.7265\n",
      "Epoch 60, Loss: 1.4340, Accuracy: 0.8451\n",
      "Epoch 80, Loss: 1.3766, Accuracy: 0.8969\n",
      "Epoch 100, Loss: 1.3446, Accuracy: 0.9150\n",
      "Epoch 120, Loss: 1.3292, Accuracy: 0.9310\n",
      "Epoch 140, Loss: 1.3116, Accuracy: 0.9425\n",
      "Epoch 160, Loss: 1.3040, Accuracy: 0.9500\n",
      "Epoch 180, Loss: 1.2942, Accuracy: 0.9544\n",
      "Epoch 200, Loss: 1.2892, Accuracy: 0.9562\n",
      "[[  4 103   4  68  19   2  55   7]\n",
      " [  3 688  18 132  93  47 165  12]\n",
      " [  1 225  79  85  33  14  31  23]\n",
      " [  1 166   6 239  57  31 121   9]\n",
      " [  3 128  13 100  91  28 220  31]\n",
      " [  1  85   9  10  42 321  73   4]\n",
      " [  6 264  10 162  92  45 749  20]\n",
      " [  2  67   5  46  20   5  71  11]]\n",
      "Seed 2025 execution time: 13.53 seconds\n",
      "\n",
      "Seed: 999\n",
      "Epoch 20, Loss: 1.7398, Accuracy: 0.4849\n",
      "Epoch 40, Loss: 1.5042, Accuracy: 0.7899\n",
      "Epoch 60, Loss: 1.3982, Accuracy: 0.8830\n",
      "Epoch 80, Loss: 1.3466, Accuracy: 0.9251\n",
      "Epoch 100, Loss: 1.3184, Accuracy: 0.9473\n",
      "Epoch 120, Loss: 1.3012, Accuracy: 0.9597\n",
      "Epoch 140, Loss: 1.2901, Accuracy: 0.9654\n",
      "Epoch 160, Loss: 1.2846, Accuracy: 0.9712\n",
      "Epoch 180, Loss: 1.2775, Accuracy: 0.9743\n",
      "Epoch 200, Loss: 1.2725, Accuracy: 0.9743\n",
      "[[  3 158  11  29  14   5  24  15]\n",
      " [  2 839  38 111  46  69  72  24]\n",
      " [  0 119 237  58   8  13  19  11]\n",
      " [  0 257  31 238  33  18  40  13]\n",
      " [  0 245  55  86  78  32  68  58]\n",
      " [  0 159  24  14  21 274  44  23]\n",
      " [  9 491  52 350  36  21 301  64]\n",
      " [  0  48  10  89   6   7  18  41]]\n",
      "Seed 999 execution time: 13.85 seconds\n",
      "\n",
      "Saved results for split 30_70 at: .\\results\\AmazonPhotos\\30_70\\AmazonPhotos_analysis_results_30_70.csv\n",
      "\n",
      "Processing split: 70_30\n",
      "\n",
      "Running classifier: GCN\n",
      "\n",
      "Seed: 42\n",
      "Epoch 20, Loss: 2.0408, Accuracy: 0.2164\n",
      "Epoch 40, Loss: 2.0035, Accuracy: 0.2215\n",
      "Epoch 60, Loss: 1.9692, Accuracy: 0.2215\n",
      "Epoch 80, Loss: 1.9455, Accuracy: 0.2213\n",
      "Epoch 100, Loss: 1.9291, Accuracy: 0.2854\n",
      "Epoch 120, Loss: 1.9143, Accuracy: 0.2858\n",
      "Epoch 140, Loss: 1.8834, Accuracy: 0.3010\n",
      "Epoch 160, Loss: 1.8321, Accuracy: 0.3371\n",
      "Epoch 180, Loss: 1.7364, Accuracy: 0.3931\n",
      "Epoch 200, Loss: 1.6373, Accuracy: 0.4253\n",
      "[[  0  18   0   0   0   1  92   0]\n",
      " [  0 316   0   0   0   5 170   0]\n",
      " [  0 106   0   0   0   3  87   0]\n",
      " [  0  59   0   0   0  15 211   0]\n",
      " [  0  70   0   0   0  29 158   0]\n",
      " [  0  14   0   0   0 174  75   0]\n",
      " [  0  55   0   0   0  97 400   0]\n",
      " [  0  51   0   0   0   2  58   0]]\n",
      "Seed 42 execution time: 4.86 seconds\n",
      "\n",
      "Seed: 46\n",
      "Epoch 20, Loss: 2.0708, Accuracy: 0.1813\n",
      "Epoch 40, Loss: 2.0259, Accuracy: 0.2163\n",
      "Epoch 60, Loss: 1.9890, Accuracy: 0.2207\n",
      "Epoch 80, Loss: 1.9644, Accuracy: 0.2211\n",
      "Epoch 100, Loss: 1.9489, Accuracy: 0.2211\n",
      "Epoch 120, Loss: 1.9392, Accuracy: 0.2207\n",
      "Epoch 140, Loss: 1.9333, Accuracy: 0.2681\n",
      "Epoch 160, Loss: 1.9298, Accuracy: 0.2675\n",
      "Epoch 180, Loss: 1.9276, Accuracy: 0.2660\n",
      "Epoch 200, Loss: 1.9261, Accuracy: 0.2653\n",
      "[[  0   0   0   2   0   0 100   0]\n",
      " [  0  21   0   1   0   0 471   0]\n",
      " [  0   2   0   0   0   0 225   0]\n",
      " [  0  12   0   0   0   0 277   0]\n",
      " [  0   1   0   0   0   0 236   0]\n",
      " [  0   1   0   0   0   0 247   0]\n",
      " [  0   7   0   0   0   0 565   0]\n",
      " [  0   5   0   0   0   0  88   0]]\n",
      "Seed 46 execution time: 4.95 seconds\n",
      "\n",
      "Seed: 123\n",
      "Epoch 20, Loss: 1.9486, Accuracy: 0.2230\n",
      "Epoch 40, Loss: 1.7249, Accuracy: 0.3112\n",
      "Epoch 60, Loss: 1.5592, Accuracy: 0.4649\n",
      "Epoch 80, Loss: 1.4313, Accuracy: 0.5308\n",
      "Epoch 100, Loss: 1.3399, Accuracy: 0.5504\n",
      "Epoch 120, Loss: 1.2884, Accuracy: 0.5588\n",
      "Epoch 140, Loss: 1.2448, Accuracy: 0.5694\n",
      "Epoch 160, Loss: 1.2112, Accuracy: 0.5805\n",
      "Epoch 180, Loss: 1.1771, Accuracy: 0.6029\n",
      "Epoch 200, Loss: 1.1415, Accuracy: 0.6150\n",
      "[[  0  11   2  18   0   9  60   0]\n",
      " [  0 353   9  32   0  11  95   0]\n",
      " [  0   8 161  22   0   1  21   0]\n",
      " [  0 106  33  83   0   4  52   0]\n",
      " [  0 109   6  49   0   2  83   0]\n",
      " [  0   7   0   1   0 203  21   0]\n",
      " [  0 196  10  73   0  70 241   0]\n",
      " [  0  27  13  33   0   1  16   0]]\n",
      "Seed 123 execution time: 5.16 seconds\n",
      "\n",
      "Seed: 2025\n",
      "Epoch 20, Loss: 2.0315, Accuracy: 0.2182\n",
      "Epoch 40, Loss: 1.9935, Accuracy: 0.2294\n",
      "Epoch 60, Loss: 1.9548, Accuracy: 0.2324\n",
      "Epoch 80, Loss: 1.9151, Accuracy: 0.2625\n",
      "Epoch 100, Loss: 1.8575, Accuracy: 0.3020\n",
      "Epoch 120, Loss: 1.7534, Accuracy: 0.3499\n",
      "Epoch 140, Loss: 1.6018, Accuracy: 0.4558\n",
      "Epoch 160, Loss: 1.3671, Accuracy: 0.5276\n",
      "Epoch 180, Loss: 1.1821, Accuracy: 0.5997\n",
      "Epoch 200, Loss: 1.0108, Accuracy: 0.6676\n",
      "[[ 49  42   2  13   0   9   3   0]\n",
      " [  0 338  52  31   2  56  16   2]\n",
      " [  0  32  97   1   1   6  64   3]\n",
      " [  2  46  22 122  15  27  27   7]\n",
      " [  0   7   2  15 124  18  89   4]\n",
      " [  0  47   6  19  26  68  55   5]\n",
      " [  5  11   8   6  12  31 495   9]\n",
      " [  1  10   8  15  15   5  42  11]]\n",
      "Seed 2025 execution time: 4.95 seconds\n",
      "\n",
      "Seed: 999\n",
      "Epoch 20, Loss: 2.0900, Accuracy: 0.2248\n",
      "Epoch 40, Loss: 1.9371, Accuracy: 0.2559\n",
      "Epoch 60, Loss: 1.8155, Accuracy: 0.2724\n",
      "Epoch 80, Loss: 1.6940, Accuracy: 0.3049\n",
      "Epoch 100, Loss: 1.5844, Accuracy: 0.4238\n",
      "Epoch 120, Loss: 1.4580, Accuracy: 0.4997\n",
      "Epoch 140, Loss: 1.3456, Accuracy: 0.5464\n",
      "Epoch 160, Loss: 1.7167, Accuracy: 0.4113\n",
      "Epoch 180, Loss: 1.4889, Accuracy: 0.4787\n",
      "Epoch 200, Loss: 1.2439, Accuracy: 0.5955\n",
      "[[  0   9   0  12  26   1  64   0]\n",
      " [  0 352   0   1   7  32 126   0]\n",
      " [  0   1 163   0  22   2  25   0]\n",
      " [  0 116   2   8  19  18  98   0]\n",
      " [  0  24  18   1 108  18  95   0]\n",
      " [  0  22   4   0   2 184  31   0]\n",
      " [  0  95   5  10  39  43 363   0]\n",
      " [  0  36   0   4   5   2  51   0]]\n",
      "Seed 999 execution time: 4.95 seconds\n",
      "\n",
      "Running classifier: GAT\n",
      "\n",
      "Seed: 42\n",
      "Epoch 20, Loss: 1.7838, Accuracy: 0.3735\n",
      "Epoch 40, Loss: 1.1735, Accuracy: 0.7246\n",
      "Epoch 60, Loss: 0.6142, Accuracy: 0.8417\n",
      "Epoch 80, Loss: 0.3540, Accuracy: 0.9040\n",
      "Epoch 100, Loss: 0.2248, Accuracy: 0.9364\n",
      "Epoch 120, Loss: 0.1592, Accuracy: 0.9531\n",
      "Epoch 140, Loss: 0.1190, Accuracy: 0.9643\n",
      "Epoch 160, Loss: 0.0900, Accuracy: 0.9732\n",
      "Epoch 180, Loss: 0.0703, Accuracy: 0.9799\n",
      "Epoch 200, Loss: 0.0558, Accuracy: 0.9854\n",
      "[[ 62   4   1  13   0   1  30   0]\n",
      " [  8 366   6  55  12  11  23  10]\n",
      " [  1   0 186   1   1   1   5   1]\n",
      " [  5  48   8 195   6   4  13   6]\n",
      " [  3   3  11  11 213   2  14   0]\n",
      " [  5   2   2  25   6 214   7   2]\n",
      " [  0  12   8  26  31   1 446  28]\n",
      " [  1   6   4   5   3   0  25  67]]\n",
      "Seed 42 execution time: 45.02 seconds\n",
      "\n",
      "Seed: 46\n",
      "Epoch 20, Loss: 1.7759, Accuracy: 0.3849\n",
      "Epoch 40, Loss: 1.1364, Accuracy: 0.7199\n",
      "Epoch 60, Loss: 0.6042, Accuracy: 0.8521\n",
      "Epoch 80, Loss: 0.3740, Accuracy: 0.9022\n",
      "Epoch 100, Loss: 0.2460, Accuracy: 0.9291\n",
      "Epoch 120, Loss: 0.1764, Accuracy: 0.9479\n",
      "Epoch 140, Loss: 0.1333, Accuracy: 0.9606\n",
      "Epoch 160, Loss: 0.1023, Accuracy: 0.9680\n",
      "Epoch 180, Loss: 0.0787, Accuracy: 0.9761\n",
      "Epoch 200, Loss: 0.0601, Accuracy: 0.9820\n",
      "[[ 83   3   0   8   6   1   1   0]\n",
      " [  6 390  13  36  23   4   7  14]\n",
      " [  1  12 192   5   7   1   4   5]\n",
      " [ 11  59   3 177   9   5  16   9]\n",
      " [  0  21  10   1 182   3  19   1]\n",
      " [  1   6   1  12   1 204   4  19]\n",
      " [  4  13   4  14   8   3 516  10]\n",
      " [  1   6   0   4   2   0  21  59]]\n",
      "Seed 46 execution time: 46.51 seconds\n",
      "\n",
      "Seed: 123\n",
      "Epoch 20, Loss: 1.7750, Accuracy: 0.3653\n",
      "Epoch 40, Loss: 1.1085, Accuracy: 0.7267\n",
      "Epoch 60, Loss: 0.5765, Accuracy: 0.8461\n",
      "Epoch 80, Loss: 0.3490, Accuracy: 0.9061\n",
      "Epoch 100, Loss: 0.2231, Accuracy: 0.9351\n",
      "Epoch 120, Loss: 0.1576, Accuracy: 0.9527\n",
      "Epoch 140, Loss: 0.1172, Accuracy: 0.9642\n",
      "Epoch 160, Loss: 0.0882, Accuracy: 0.9744\n",
      "Epoch 180, Loss: 0.0662, Accuracy: 0.9818\n",
      "Epoch 200, Loss: 0.0490, Accuracy: 0.9879\n",
      "[[ 79   4   0  11   1   1   3   1]\n",
      " [  7 392   2  54  16   5  19   5]\n",
      " [  0  10 187   2   6   4   4   0]\n",
      " [  0  24   2 215  14   2  15   6]\n",
      " [  0  15   1  11 195   1  24   2]\n",
      " [  1  11   3   3   2 208   2   2]\n",
      " [  2  11   4  13  15   6 512  27]\n",
      " [  0   3   1  15   0  12  11  48]]\n",
      "Seed 123 execution time: 48.05 seconds\n",
      "\n",
      "Seed: 2025\n",
      "Epoch 20, Loss: 1.7847, Accuracy: 0.3819\n",
      "Epoch 40, Loss: 1.2076, Accuracy: 0.6937\n",
      "Epoch 60, Loss: 0.6122, Accuracy: 0.8388\n",
      "Epoch 80, Loss: 0.3408, Accuracy: 0.9070\n",
      "Epoch 100, Loss: 0.2210, Accuracy: 0.9400\n",
      "Epoch 120, Loss: 0.1593, Accuracy: 0.9528\n",
      "Epoch 140, Loss: 0.1218, Accuracy: 0.9661\n",
      "Epoch 160, Loss: 0.0952, Accuracy: 0.9746\n",
      "Epoch 180, Loss: 0.0757, Accuracy: 0.9788\n",
      "Epoch 200, Loss: 0.0596, Accuracy: 0.9833\n",
      "[[101   5   1   1   1   1   6   2]\n",
      " [  5 361  13  54  17  25  18   4]\n",
      " [  1  10 176   1   3   5   8   0]\n",
      " [  4  53   5 184   3   7   5   7]\n",
      " [  0   7   2   5 229   4  12   0]\n",
      " [  0   6   0   1   2 216   1   0]\n",
      " [  4  31   4  39  30  17 426  26]\n",
      " [  6   4   0   9   8   0  14  66]]\n",
      "Seed 2025 execution time: 46.49 seconds\n",
      "\n",
      "Seed: 999\n",
      "Epoch 20, Loss: 1.7883, Accuracy: 0.3589\n",
      "Epoch 40, Loss: 1.2066, Accuracy: 0.6879\n",
      "Epoch 60, Loss: 0.6226, Accuracy: 0.8382\n",
      "Epoch 80, Loss: 0.3436, Accuracy: 0.9080\n",
      "Epoch 100, Loss: 0.2180, Accuracy: 0.9389\n",
      "Epoch 120, Loss: 0.1518, Accuracy: 0.9573\n",
      "Epoch 140, Loss: 0.1108, Accuracy: 0.9687\n",
      "Epoch 160, Loss: 0.0837, Accuracy: 0.9763\n",
      "Epoch 180, Loss: 0.0652, Accuracy: 0.9827\n",
      "Epoch 200, Loss: 0.0510, Accuracy: 0.9875\n",
      "[[ 90   5   4   3   0   0   5   5]\n",
      " [ 10 361  29  56  17   2  19  24]\n",
      " [  4  16 178   0   1   5   5   4]\n",
      " [  3  41   9 156  30   1   4  17]\n",
      " [  0  20   3  11 210   1  17   2]\n",
      " [  4  11   7   2   2 207   6   4]\n",
      " [  1  13  28  13  23   2 453  22]\n",
      " [  1  13   1   5   3   0  18  57]]\n",
      "Seed 999 execution time: 45.74 seconds\n",
      "\n",
      "Running classifier: GRAPHSAGE\n",
      "\n",
      "Seed: 42\n",
      "Epoch 20, Loss: 1.9026, Accuracy: 0.3046\n",
      "Epoch 40, Loss: 1.7119, Accuracy: 0.5614\n",
      "Epoch 60, Loss: 1.5042, Accuracy: 0.7903\n",
      "Epoch 80, Loss: 1.4171, Accuracy: 0.8630\n",
      "Epoch 100, Loss: 1.3797, Accuracy: 0.8913\n",
      "Epoch 120, Loss: 1.3601, Accuracy: 0.9108\n",
      "Epoch 140, Loss: 1.3452, Accuracy: 0.9195\n",
      "Epoch 160, Loss: 1.3354, Accuracy: 0.9279\n",
      "Epoch 180, Loss: 1.3285, Accuracy: 0.9326\n",
      "Epoch 200, Loss: 1.3202, Accuracy: 0.9370\n",
      "[[ 75  20   0   5   3   4   3   1]\n",
      " [  3 394   4  44   7  12  24   3]\n",
      " [  0  10 174   4   3   0   4   1]\n",
      " [  2  31   2 215  12   9  11   3]\n",
      " [  1  13   3   7 216   4  13   0]\n",
      " [  1  14   1   6   1 230  10   0]\n",
      " [  4  45   1  22  25   6 433  16]\n",
      " [  0  14   0  18  12   0  20  47]]\n",
      "Seed 42 execution time: 19.80 seconds\n",
      "\n",
      "Seed: 46\n",
      "Epoch 20, Loss: 1.8908, Accuracy: 0.3237\n",
      "Epoch 40, Loss: 1.6536, Accuracy: 0.6153\n",
      "Epoch 60, Loss: 1.4749, Accuracy: 0.8161\n",
      "Epoch 80, Loss: 1.3969, Accuracy: 0.8773\n",
      "Epoch 100, Loss: 1.3624, Accuracy: 0.9048\n",
      "Epoch 120, Loss: 1.3439, Accuracy: 0.9190\n",
      "Epoch 140, Loss: 1.3304, Accuracy: 0.9293\n",
      "Epoch 160, Loss: 1.3212, Accuracy: 0.9372\n",
      "Epoch 180, Loss: 1.3104, Accuracy: 0.9441\n",
      "Epoch 200, Loss: 1.3037, Accuracy: 0.9507\n",
      "[[ 82   4   1   7   3   2   2   1]\n",
      " [  4 380   8  39  17  13  22  10]\n",
      " [  0  27 185   4   2   4   5   0]\n",
      " [  2  53  11 188   7   2  14  12]\n",
      " [  0  15   1   6 193   3  19   0]\n",
      " [  0  15   1   0   5 218   8   1]\n",
      " [  3  25  10  14  18   3 491   8]\n",
      " [  2   7   2   4   5   1  24  48]]\n",
      "Seed 46 execution time: 20.12 seconds\n",
      "\n",
      "Seed: 123\n",
      "Epoch 20, Loss: 1.8611, Accuracy: 0.3430\n",
      "Epoch 40, Loss: 1.6003, Accuracy: 0.6879\n",
      "Epoch 60, Loss: 1.4551, Accuracy: 0.8277\n",
      "Epoch 80, Loss: 1.3916, Accuracy: 0.8768\n",
      "Epoch 100, Loss: 1.3613, Accuracy: 0.9055\n",
      "Epoch 120, Loss: 1.3433, Accuracy: 0.9188\n",
      "Epoch 140, Loss: 1.3318, Accuracy: 0.9298\n",
      "Epoch 160, Loss: 1.3195, Accuracy: 0.9381\n",
      "Epoch 180, Loss: 1.3122, Accuracy: 0.9432\n",
      "Epoch 200, Loss: 1.3077, Accuracy: 0.9466\n",
      "[[ 73  11   0   9   2   1   4   0]\n",
      " [  2 394   4  36  22  10  27   5]\n",
      " [  0  38 153   5   9   2   6   0]\n",
      " [  1  42   7 197  14   6  10   1]\n",
      " [  0  21   3  10 195   2  15   3]\n",
      " [  0  22   2   1   8 196   3   0]\n",
      " [  4  27   4  20  18   5 506   6]\n",
      " [  1  15   0  13   3   0  21  37]]\n",
      "Seed 123 execution time: 20.87 seconds\n",
      "\n",
      "Seed: 2025\n",
      "Epoch 20, Loss: 1.8995, Accuracy: 0.3497\n",
      "Epoch 40, Loss: 1.7082, Accuracy: 0.5550\n",
      "Epoch 60, Loss: 1.5105, Accuracy: 0.7731\n",
      "Epoch 80, Loss: 1.4177, Accuracy: 0.8564\n",
      "Epoch 100, Loss: 1.3775, Accuracy: 0.8896\n",
      "Epoch 120, Loss: 1.3539, Accuracy: 0.9096\n",
      "Epoch 140, Loss: 1.3411, Accuracy: 0.9231\n",
      "Epoch 160, Loss: 1.3314, Accuracy: 0.9280\n",
      "Epoch 180, Loss: 1.3218, Accuracy: 0.9381\n",
      "Epoch 200, Loss: 1.3183, Accuracy: 0.9407\n",
      "[[ 63  34   0  10   1   2   8   0]\n",
      " [  0 409  13  27  12  10  22   4]\n",
      " [  0  26 162   4   3   0   9   0]\n",
      " [  0  57   6 178   4   4  14   5]\n",
      " [  0  22   0  11 204   3  13   6]\n",
      " [  0  18   3   2   7 188   8   0]\n",
      " [  1  45   2  14  17   6 486   6]\n",
      " [  1  29   0  18   5   0  27  27]]\n",
      "Seed 2025 execution time: 20.38 seconds\n",
      "\n",
      "Seed: 999\n",
      "Epoch 20, Loss: 1.8936, Accuracy: 0.3309\n",
      "Epoch 40, Loss: 1.6723, Accuracy: 0.6020\n",
      "Epoch 60, Loss: 1.4958, Accuracy: 0.7803\n",
      "Epoch 80, Loss: 1.4139, Accuracy: 0.8628\n",
      "Epoch 100, Loss: 1.3751, Accuracy: 0.8970\n",
      "Epoch 120, Loss: 1.3525, Accuracy: 0.9144\n",
      "Epoch 140, Loss: 1.3387, Accuracy: 0.9260\n",
      "Epoch 160, Loss: 1.3285, Accuracy: 0.9317\n",
      "Epoch 180, Loss: 1.3180, Accuracy: 0.9376\n",
      "Epoch 200, Loss: 1.3125, Accuracy: 0.9418\n",
      "[[ 73   8   4  18   4   0   4   1]\n",
      " [  2 423   2  35  20   8  26   2]\n",
      " [  0  18 176   3   1   2  13   0]\n",
      " [  2  58   4 155  12   2  23   5]\n",
      " [  0  25   2   8 208   5  14   2]\n",
      " [  1  22   2   4   0 205   9   0]\n",
      " [  1  47   1   4  12   4 480   6]\n",
      " [  1  24   1   9   1   1  24  37]]\n",
      "Seed 999 execution time: 20.35 seconds\n",
      "\n",
      "Saved results for split 70_30 at: .\\results\\AmazonPhotos\\70_30\\AmazonPhotos_analysis_results_70_30.csv\n",
      "\n",
      "=== Dataset AmazonPhotos completed ===\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import tensorflow as tf\n",
    "from sklearn.metrics import accuracy_score, f1_score, confusion_matrix\n",
    "from torch_geometric.datasets import Planetoid, WikiCS, Amazon\n",
    "import networkx as nx\n",
    "import time\n",
    "\n",
    "# Configuration\n",
    "BASE_DIR = \".\"\n",
    "#DATASETS = [\"Cora\", \"CiteSeer\", \"PubMed\", \"WikiCS\", \"AmazonPhotos\"]\n",
    "DATASETS = [\"AmazonPhotos\"]\n",
    "SEEDS = [42, 46, 123, 2025, 999]\n",
    "CLASSIFIERS = ['gcn', 'gat', 'graphsage']\n",
    "SPLITS = ['30_70', '70_30']\n",
    "OUTPUT_DIR = os.path.join(BASE_DIR, \"output\")\n",
    "MASKS_DIR = os.path.join(BASE_DIR, \"masks\")\n",
    "RESULTS_DIR = os.path.join(BASE_DIR, \"results\")\n",
    "\n",
    "os.makedirs(RESULTS_DIR, exist_ok=True)\n",
    "\n",
    "def set_global_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    tf.random.set_seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "def convert_to_networkx(A):\n",
    "    return nx.from_scipy_sparse_array(A)\n",
    "\n",
    "def evaluate_model(true_labels, predicted_labels):\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    f1 = f1_score(true_labels, predicted_labels, average='macro')\n",
    "    cm = confusion_matrix(true_labels, predicted_labels)\n",
    "    print(cm)\n",
    "    return {'accuracy': accuracy, 'f1_score': f1}\n",
    "\n",
    "for dataset_name in DATASETS:\n",
    "    print(f\"\\n=== Processing dataset: {dataset_name} ===\")\n",
    "\n",
    "    dataset_results_dir = os.path.join(RESULTS_DIR, dataset_name)\n",
    "    os.makedirs(dataset_results_dir, exist_ok=True)\n",
    "\n",
    "    for split in SPLITS:\n",
    "        print(f\"\\nProcessing split: {split}\")\n",
    "\n",
    "        split_results = []\n",
    "        data_dir = os.path.join(BASE_DIR, \"data\", dataset_name)\n",
    "\n",
    "        # Load dataset\n",
    "        if dataset_name in [\"Cora\", \"CiteSeer\", \"PubMed\"]:\n",
    "            dataset = Planetoid(root=data_dir, name=dataset_name)[0]\n",
    "        elif dataset_name == \"WikiCS\":\n",
    "            dataset = WikiCS(root=data_dir)[0]\n",
    "        elif dataset_name == \"AmazonPhotos\":\n",
    "            dataset = Amazon(root=data_dir, name='Photo')[0]  # fix: use 'Photo' with capital P\n",
    "        else:\n",
    "            raise NotImplementedError(f\"Dataset {dataset_name} loader not implemented\")\n",
    "\n",
    "        ground_truth_labels = dataset.y.numpy()\n",
    "        labels = ground_truth_labels  # Already integer class labels\n",
    "\n",
    "        # Build adjacency matrix\n",
    "        edge_index = dataset.edge_index.numpy()\n",
    "        num_nodes = dataset.num_nodes\n",
    "        A = sp.coo_matrix(\n",
    "            (np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])),\n",
    "            shape=(num_nodes, num_nodes)\n",
    "        )\n",
    "        X_full = dataset.x.numpy()\n",
    "\n",
    "        for classifier in CLASSIFIERS:\n",
    "            print(f\"\\nRunning classifier: {classifier.upper()}\")\n",
    "\n",
    "            all_accuracies = []\n",
    "            all_f1_scores = []\n",
    "            total_execution_time = 0.0\n",
    "\n",
    "            for seed in SEEDS:\n",
    "                print(f\"\\nSeed: {seed}\")\n",
    "                set_global_seed(seed)\n",
    "\n",
    "                mask_file = os.path.join(MASKS_DIR, dataset_name, split,\n",
    "                                         f\"{dataset_name}_{split}_masked_indices_seed{seed}.npy\")\n",
    "                labels_to_be_masked = np.load(mask_file)\n",
    "\n",
    "                masked_labels = np.array([\n",
    "                    -1 if i in labels_to_be_masked else labels[i]\n",
    "                    for i in range(len(labels))\n",
    "                ])\n",
    "                label_mask = masked_labels != -1\n",
    "\n",
    "                emb_file = os.path.join(OUTPUT_DIR, \"embeddings\", dataset_name, split,\n",
    "                                        f\"{dataset_name}_{split}_seed{seed}_mnmf.pkl\")\n",
    "                embedding_matrix = pd.read_pickle(emb_file)\n",
    "\n",
    "                start_time = time.time()\n",
    "                result, _ = train_and_evaluate(\n",
    "                    embedding_dict={classifier: embedding_matrix},\n",
    "                    embedding=classifier,\n",
    "                    classifier=classifier,\n",
    "                    ground_truth_labels=dataset.y.numpy(),\n",
    "                    masked_labels=masked_labels\n",
    "                )\n",
    "                end_time = time.time()\n",
    "                duration = end_time - start_time\n",
    "                total_execution_time += duration\n",
    "                print(f\"Seed {seed} execution time: {duration:.2f} seconds\")\n",
    "\n",
    "                all_accuracies.append(result['accuracy'])\n",
    "                all_f1_scores.append(result['f1_score'])\n",
    "\n",
    "            avg_accuracy = np.mean(all_accuracies)\n",
    "            std_accuracy = np.std(all_accuracies)\n",
    "            avg_f1 = np.mean(all_f1_scores)\n",
    "            std_f1 = np.std(all_f1_scores)\n",
    "            avg_time = total_execution_time / len(SEEDS)\n",
    "\n",
    "            summary = {\n",
    "                'classifier': classifier,\n",
    "                'average_accuracy': f\"{avg_accuracy:.4f} ± {std_accuracy:.4f}\",\n",
    "                'average_f1_score': f\"{avg_f1:.4f} ± {std_f1:.4f}\",\n",
    "                'average_execution_time_sec': round(avg_time, 2)\n",
    "            }\n",
    "            split_results.append(summary)\n",
    "\n",
    "        # Save results per split\n",
    "        split_dir = os.path.join(dataset_results_dir, split)\n",
    "        os.makedirs(split_dir, exist_ok=True)\n",
    "\n",
    "        df_split = pd.DataFrame(split_results)\n",
    "        filename = os.path.join(split_dir, f\"{dataset_name}_analysis_results_{split}.csv\")\n",
    "        df_split.to_csv(filename, index=False)\n",
    "        print(f\"\\nSaved results for split {split} at: {filename}\")\n",
    "\n",
    "    print(f\"\\n=== Dataset {dataset_name} completed ===\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "628b87fc-9c35-4bac-85bd-f04b8e55fceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.sparse as sp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "58684375-e2e7-42f9-bb24-4feda4f705df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Define training and evaluation function ----\n",
    "def train_and_evaluate(embedding_dict, embedding, classifier, ground_truth_labels, masked_labels):\n",
    "    X = embedding_dict[embedding]\n",
    "    num_emb_nodes = X.shape[0]  # number of nodes in embeddings\n",
    "\n",
    "    # Subset labels to match embedding size\n",
    "    masked_labels_sub = masked_labels[:num_emb_nodes]\n",
    "    ground_truth_labels_sub = ground_truth_labels[:num_emb_nodes]\n",
    "    train_mask = masked_labels_sub != -1\n",
    "\n",
    "    # Training data\n",
    "    X_train = X[train_mask]\n",
    "    Y_train = ground_truth_labels_sub[train_mask]\n",
    "    Y_train = tf.one_hot(Y_train, depth=len(np.unique(ground_truth_labels)))\n",
    "    Y_train = tf.cast(Y_train, dtype='float32')\n",
    "\n",
    "    # Build adjacency matrix for embedding nodes\n",
    "    # Convert global adjacency A to CSR for safe slicing\n",
    "    A_csr = A.tocsr()  # Ensure A is in CSR format\n",
    "    A_sub = A_csr[:num_emb_nodes, :num_emb_nodes]  # slice embedding nodes\n",
    "    A_train = A_sub[train_mask, :][:, train_mask]   # slice training nodes\n",
    "    A_train_coo = A_train.tocoo()\n",
    "    indices_train = np.column_stack((A_train_coo.row, A_train_coo.col))\n",
    "    A_train_tensor = tf.sparse.SparseTensor(\n",
    "        indices=indices_train,\n",
    "        values=A_train_coo.data,\n",
    "        dense_shape=A_train_coo.shape\n",
    "    )\n",
    "    A_train_tensor = tf.sparse.reorder(A_train_tensor)\n",
    "\n",
    "    # Initialize model\n",
    "    n_labels = Y_train.shape[1]\n",
    "    if classifier == 'gcn':\n",
    "        model = GCN(n_labels)\n",
    "    elif classifier == 'gat':\n",
    "        model = GAT(n_labels)\n",
    "    elif classifier == 'graphsage':\n",
    "        model = GraphSAGE(n_labels)\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown classifier {classifier}\")\n",
    "\n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
    "    loss_fn = tf.keras.losses.CategoricalCrossentropy()\n",
    "    acc_metric = tf.keras.metrics.CategoricalAccuracy()\n",
    "\n",
    "    # Training loop with prints every 20 epochs\n",
    "    epochs = 200\n",
    "    for epoch in range(epochs):\n",
    "        with tf.GradientTape() as tape:\n",
    "            predictions, _ = model([X_train, A_train_tensor])\n",
    "            loss = loss_fn(Y_train, predictions)\n",
    "        grads = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "\n",
    "        if (epoch + 1) % 20 == 0:\n",
    "            acc_metric.update_state(Y_train, predictions)\n",
    "            print(f\"Epoch {epoch+1}, Loss: {loss.numpy():.4f}, Accuracy: {acc_metric.result().numpy():.4f}\")\n",
    "            acc_metric.reset_state()\n",
    "\n",
    "    # Full graph prediction (restricted to embedding nodes)\n",
    "    A_full_coo = A_sub.tocoo()\n",
    "    indices_full = np.column_stack((A_full_coo.row, A_full_coo.col))\n",
    "    A_full_tensor = tf.sparse.SparseTensor(\n",
    "        indices=indices_full,\n",
    "        values=A_full_coo.data,\n",
    "        dense_shape=A_full_coo.shape\n",
    "    )\n",
    "    A_full_tensor = tf.sparse.reorder(A_full_tensor)\n",
    "\n",
    "    predictions, emb = model([X, A_full_tensor])\n",
    "    predicted_labels = tf.argmax(predictions, axis=1).numpy()\n",
    "\n",
    "    # Evaluate only masked nodes\n",
    "    masked_indices = np.where(masked_labels_sub == -1)[0]\n",
    "    predicted_masked = predicted_labels[masked_indices]\n",
    "    true_masked = ground_truth_labels_sub[masked_indices]\n",
    "\n",
    "    results = evaluate_model(true_masked, predicted_masked)\n",
    "    results['model'] = classifier\n",
    "    results['embedding'] = embedding\n",
    "\n",
    "    return results, emb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bab48072-6b13-4c61-8ece-a3c97a519651",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.metrics import CategoricalAccuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7310b48c-c9f6-41f2-8c94-59fbd56c1821",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
