{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "93c1ef1f-2372-44c0-9703-b8ac660cce10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4441684d-fa5d-4f8d-9967-84753fd5a4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import time \n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import umap.umap_ as umap\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\n",
    "from node2vec import Node2Vec\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.metrics import CategoricalAccuracy\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import spektral\n",
    "from spektral.layers import GCNConv, GATConv\n",
    "from spektral.layers import GraphSageConv\n",
    "from spektral.data import Graph, Dataset, BatchLoader\n",
    "from scipy.sparse import csr_matrix\n",
    "from spektral.datasets import Cora\n",
    "from torch_geometric.nn import DeepGraphInfomax, VGAE\n",
    "from torch_geometric.utils import from_networkx\n",
    "import scipy.sparse as sp\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
    "from scipy.sparse.csgraph import laplacian\n",
    "from scipy.sparse.linalg import eigsh\n",
    "from collections import Counter\n",
    "from sklearn.preprocessing import normalize\n",
    "from joblib import Parallel, delayed\n",
    "from torch_geometric.nn import GCNConv as PyG_GCNConv, VGAE as PyG_VGAE\n",
    "from torch_geometric.data import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "46e18457-6dbb-4928-adad-4b279400977c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2025\n",
    "\n",
    "# Set seed for Python's built-in random module\n",
    "random.seed(SEED)\n",
    "\n",
    "# Set seed for NumPy\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# Set seed for TensorFlow\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Set seed for PyTorch\n",
    "torch.manual_seed(SEED)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5b95802d-543f-4b3e-803a-cdbdb518595b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class CiteSeerDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Planetoid(root=\".\", name=\"CiteSeer\")  # Load CiteSeer dataset\n",
    "        data = dataset[0]  # Access the first graph\n",
    "        \n",
    "        # Convert Torch tensors to NumPy\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1  # Number of classes\n",
    "        y_one_hot = np.eye(num_classes)[y]  # One-hot encoding\n",
    "\n",
    "        # Convert edge_index to a sparse adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = csr_matrix((num_nodes, num_nodes))  # Initialize sparse matrix\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected graph\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "90c18215-d57c-47a9-957d-5f98fa330e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dimensionality=150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2d1fa1ef-526c-4ecb-a4e5-6c344f0412c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_networkx(A):\n",
    "    return nx.from_scipy_sparse_array(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f7fbba5b-3f7c-4821-9af5-d9d6e982b9ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\user\\anaconda3\\envs\\ContrastiveFUSE\\lib\\site-packages\\torch_geometric\\data\\dataset.py:238: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):\n",
      "C:\\Users\\user\\anaconda3\\envs\\ContrastiveFUSE\\lib\\site-packages\\torch_geometric\\data\\dataset.py:246: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):\n",
      "C:\\Users\\user\\anaconda3\\envs\\ContrastiveFUSE\\lib\\site-packages\\torch_geometric\\io\\fs.py:215: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  return torch.load(f, map_location)\n",
      "C:\\Users\\user\\anaconda3\\envs\\ContrastiveFUSE\\lib\\site-packages\\scipy\\sparse\\_index.py:108: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
      "  self._set_intXint(row, col, x.flat[0])\n"
     ]
    }
   ],
   "source": [
    "dataset = CiteSeerDataset()\n",
    "ground_truth_labels = dataset[0].y\n",
    "labels=np.argmax(ground_truth_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c92de7a7-bab8-4ee8-9601-e4f95f0386e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_file = \"C:/Users/user/Benchmarking_fuse/masks/Citeseer/30_70/Citeseer_30_70_masked_indices_seed2025.npy\"\n",
    "labels_to_be_masked = np.load(mask_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3d5ef760-dd97-4ac2-8b4a-7109d91d08b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2328"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(labels_to_be_masked)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a0cf50e3-56e5-4187-b4ef-9f05cd3bc76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_labels=[]\n",
    "for i in np.arange(len(labels)):\n",
    "    if i in labels_to_be_masked:\n",
    "        masked_labels.append(-1)\n",
    "    else:\n",
    "        masked_labels.append(labels[i])\n",
    "masked_labels=np.array(masked_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7f436be3-ba62-406f-9523-09713d59829b",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_mask = masked_labels != -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "517544ea-62e6-495c-bce9-6733a551c959",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = dataset[0].x\n",
    "A = dataset[0].a\n",
    "G = convert_to_networkx(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "17de4d41-86a4-40c3-bf4d-217511591a0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adjacency Matrix Shape: (3327, 3327)\n",
      "Graph Nodes: 3327\n",
      "Graph Edges: 4552\n"
     ]
    }
   ],
   "source": [
    "print(\"Adjacency Matrix Shape:\", A.shape)\n",
    "print(\"Graph Nodes:\", G.number_of_nodes())\n",
    "print(\"Graph Edges:\", G.number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0346f826-0826-4b89-8328-054b3295fdc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert your preprocessed data into a PyTorch Geometric Data object\n",
    "X_py = Data(\n",
    "    x=torch.tensor(X, dtype=torch.float),  # Node features\n",
    "    edge_index=torch.tensor(np.array(A.nonzero()), dtype=torch.long),  # Edge indices\n",
    "    y=torch.tensor(labels, dtype=torch.long)  # Labels\n",
    ")\n",
    "\n",
    "# Ensure edge_index is in the correct shape (2, num_edges)\n",
    "X_py.edge_index = X_py.edge_index.to(torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2acc40a-d4c0-48e1-8fc6-da99cccc02e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "81b4f8f4-20eb-4804-a2e7-b04c6e80d08d",
   "metadata": {},
   "source": [
    "## Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "49cada09-7de5-438e-b980-7d5e703615d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Loading DEEPWALK embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/deepwalk_embedding_30_70_2025.pkl...\n",
      " Saved DEEPWALK embedding to ./citeseer_analysis_results/embeddings/30_70/deepwalk_embedding_30_70_2025.pkl\n",
      " Loading NODE2VEC embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/node2vec_embedding_30_70_2025.pkl...\n",
      " Saved NODE2VEC embedding to ./citeseer_analysis_results/embeddings/30_70/node2vec_embedding_30_70_2025.pkl\n",
      " Loading FUSE embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/fuse_embedding_30_70_2025.pkl...\n",
      " Saved FUSE embedding to ./citeseer_analysis_results/embeddings/30_70/fuse_embedding_30_70_2025.pkl\n",
      " Loading VGAE embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/vgae_embedding_30_70_2025.pkl...\n",
      " Saved VGAE embedding to ./citeseer_analysis_results/embeddings/30_70/vgae_embedding_30_70_2025.pkl\n",
      " Loading DGI embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/dgi_embedding_30_70_2025.pkl...\n",
      " Saved DGI embedding to ./citeseer_analysis_results/embeddings/30_70/dgi_embedding_30_70_2025.pkl\n",
      " Loading RANDOM embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/random_embedding_30_70_2025.pkl...\n",
      " Saved RANDOM embedding to ./citeseer_analysis_results/embeddings/30_70/random_embedding_30_70_2025.pkl\n",
      " Loading GIVEN embedding from C:/Users/user/Benchmarking_fuse/benchmark_outputs/citeseer/30-70/given_embedding_30_70_2025.pkl...\n",
      " Saved GIVEN embedding to ./citeseer_analysis_results/embeddings/30_70/given_embedding_30_70_2025.pkl\n",
      "\n",
      " All embeddings loaded into memory and re-saved successfully.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# -------------------------------------\n",
    "# Configuration\n",
    "# -------------------------------------\n",
    "SEED = 2025\n",
    "dataset_name = \"citeseer\"\n",
    "split_folder = \"30-70\"\n",
    "\n",
    "# Input embeddings directory\n",
    "load_dir = f\"C:/Users/user/Benchmarking_fuse/benchmark_outputs/{dataset_name}/{split_folder}/\"\n",
    "\n",
    "# Output save directory\n",
    "save_dir = f\"./citeseer_analysis_results/embeddings/{split_folder.replace('-', '_')}/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "# Embedding filenames to load\n",
    "embedding_files = {\n",
    "    \"deepwalk\": f\"deepwalk_embedding_30_70_{SEED}.pkl\",\n",
    "    \"node2vec\": f\"node2vec_embedding_30_70_{SEED}.pkl\",\n",
    "    \"fuse\": f\"fuse_embedding_30_70_{SEED}.pkl\",\n",
    "    \"vgae\": f\"vgae_embedding_30_70_{SEED}.pkl\",\n",
    "    \"dgi\": f\"dgi_embedding_30_70_{SEED}.pkl\",\n",
    "    \"random\": f\"random_embedding_30_70_{SEED}.pkl\",\n",
    "    \"given\": f\"given_embedding_30_70_{SEED}.pkl\"\n",
    "}\n",
    "\n",
    "# Dictionary to store embeddings\n",
    "embedding_dict = {}\n",
    "\n",
    "# -------------------------------------\n",
    "# Utility: convert numpy to tf.Tensor\n",
    "# -------------------------------------\n",
    "def to_tf_tensor(x):\n",
    "    \"\"\"Convert numpy array to TensorFlow tensor.\"\"\"\n",
    "    if isinstance(x, tf.Tensor):\n",
    "        return x\n",
    "    return tf.convert_to_tensor(x, dtype=tf.float32)\n",
    "\n",
    "# -------------------------------------\n",
    "# Load and re-save embeddings\n",
    "# -------------------------------------\n",
    "for name, filename in embedding_files.items():\n",
    "    filepath = os.path.join(load_dir, filename)\n",
    "    if not os.path.exists(filepath):\n",
    "        print(f\" Warning: {name} embedding file not found at {filepath}. Skipping.\")\n",
    "        continue\n",
    "\n",
    "    print(f\" Loading {name.upper()} embedding from {filepath}...\")\n",
    "    with open(filepath, \"rb\") as f:\n",
    "        emb = pickle.load(f)\n",
    "\n",
    "    # Convert to TensorFlow tensor\n",
    "    embedding_dict[name] = to_tf_tensor(np.array(emb, dtype=float))\n",
    "\n",
    "    # Save again to new organized directory\n",
    "    save_path = os.path.join(save_dir, f\"{name}_embedding_30_70_{SEED}.pkl\")\n",
    "    with open(save_path, \"wb\") as f:\n",
    "        pickle.dump(embedding_dict[name].numpy(), f)\n",
    "    print(f\" Saved {name.upper()} embedding to {save_path}\")\n",
    "\n",
    "print(\"\\n All embeddings loaded into memory and re-saved successfully.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a98e7ea-fd83-440b-a479-bed7c46185a4",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e7297774-e5c4-4b87-9cc9-cab07deb7978",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_all_embeddings(all_embeddings, labels, label_mask):\n",
    "    \"\"\"\n",
    "    Visualize all embeddings in a grid with 4 columns per row using UMAP.\n",
    "\n",
    "    Parameters:\n",
    "    - all_embeddings: Dictionary where keys are embedding methods, and values are embeddings.\n",
    "    - labels: Labels (numpy array of shape [n_nodes]).\n",
    "    - label_mask: Boolean array indicating known labels (True for known, False for unknown).\n",
    "    \"\"\"\n",
    "    num_embeddings = len(all_embeddings)\n",
    "    num_rows = (num_embeddings + 3) // 4  # Ensure enough rows for all embeddings\n",
    "    fig, axes = plt.subplots(num_rows, 4, figsize=(8.27, 11.69))  # A4 size\n",
    "\n",
    "    for i, (embedding_type, embedding) in tqdm(enumerate(all_embeddings.items()), \n",
    "                                               total=num_embeddings, desc=\"Visualizing embeddings\"):\n",
    "        row, col = divmod(i, 4)\n",
    "        ax = axes[row, col] if num_rows > 1 else axes[col]  # Adjust for single-row case\n",
    "\n",
    "        # Ensure embedding is a NumPy array\n",
    "        if isinstance(embedding, tf.Tensor):\n",
    "            embedding = embedding.numpy()\n",
    "\n",
    "        # Reduce dimensionality using UMAP\n",
    "        reducer = umap.UMAP(n_components=2)\n",
    "        embedding_2d = reducer.fit_transform(embedding)\n",
    "\n",
    "        # Known labels\n",
    "        ax.scatter(embedding_2d[label_mask, 0], embedding_2d[label_mask, 1], \n",
    "                   c=labels[label_mask], cmap=\"Set1\", s=3, alpha=0.7, label=\"Known Labels\",\n",
    "                   edgecolors='none')\n",
    "\n",
    "        # Unknown labels\n",
    "        ax.scatter(embedding_2d[~label_mask, 0], embedding_2d[~label_mask, 1], \n",
    "                   c=labels[~label_mask], cmap=\"Set1\", s=5, alpha=0.7, \n",
    "                   label=\"Unknown Labels\", edgecolors='black', linewidths=0.2)\n",
    "\n",
    "        # Title with smaller font size\n",
    "        ax.set_title(embedding_type.upper(), fontsize=8, pad=2)\n",
    "\n",
    "        # Remove axis labels, ticks, and frames\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_frame_on(False)\n",
    "\n",
    "    # Remove empty subplots if num_embeddings is not a multiple of 4\n",
    "    for j in range(i + 1, num_rows * 4):\n",
    "        row, col = divmod(j, 4)\n",
    "        fig.delaxes(axes[row, col])\n",
    "\n",
    "    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=0.2, hspace=0.2)  # Adjust margins\n",
    "    save_path = \"./citeseer_analysis_results/embedding_grid_plot_citeseer_30_70.png\"\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Visualization saved to {save_path}\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "73add6f8-ad97-49f5-8f66-88df3b717390",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(true_labels, predicted_labels):\n",
    "    \"\"\"\n",
    "    Evaluate the model's performance using accuracy, F1-score, and confusion matrix.\n",
    "\n",
    "    Args:\n",
    "        true_labels (np.array): Ground truth labels (integers).\n",
    "        predicted_labels (np.array): Predicted labels (integers).\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing accuracy, F1-score, and confusion matrix.\n",
    "    \"\"\"\n",
    "    # Compute accuracy\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    \n",
    "    # Compute F1-score (macro-averaged)\n",
    "    f1 = f1_score(true_labels, predicted_labels, average='macro')\n",
    "    \n",
    "    # Compute confusion matrix\n",
    "    cm = confusion_matrix(true_labels, predicted_labels)\n",
    "\n",
    "    #\n",
    "    print(cm)\n",
    "    \n",
    "    # Return results as a dictionary\n",
    "    return {\n",
    "        'accuracy': accuracy,\n",
    "        'f1_score': f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc2b387-1e6d-49a3-b28f-23f59c1c789c",
   "metadata": {},
   "source": [
    "## Classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8f934d79-76c2-4d0e-81f4-486a1dc34d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoMaskGCNConv(GCNConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard mask\n",
    "        return super().call(inputs, mask=None)\n",
    "        \n",
    "class GCN(tf.keras.Model):\n",
    "    def __init__(self, n_labels, seed=40):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "        self.conv1 = NoMaskGCNConv(16, activation='relu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGCNConv(n_labels, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "75c02975-3867-49ee-bfd3-f74cd6d8dbe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spektral.layers import GATConv\n",
    "import tensorflow as tf\n",
    "\n",
    "# Define a custom wrapper for GATConv that avoids mask issues\n",
    "class NoMaskGATConv(GATConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard the mask argument\n",
    "        return super().call(inputs, mask=None)\n",
    "\n",
    "\n",
    "# Define the GAT model using the NoMaskGATConv\n",
    "class GAT(tf.keras.Model):\n",
    "    def __init__(self, n_labels, num_heads=8, seed=40):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        # Use the custom NoMaskGATConv instead of the original GATConv\n",
    "        self.conv1 = NoMaskGATConv(16, attn_heads=num_heads, concat_heads=True, activation='elu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGATConv(n_labels, attn_heads=1, concat_heads=False, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1805e34a-bed9-410b-a719-49a9df09c377",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the GraphSAGE model\n",
    "class GraphSAGE(tf.keras.Model):\n",
    "    def __init__(self, n_labels, hidden_dim=16, aggregator='mean', seed=40):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        self.conv1 = GraphSageConv(hidden_dim, activation='relu', aggregator=aggregator, kernel_initializer=initializer)\n",
    "        self.conv2 = GraphSageConv(n_labels, activation='softmax', aggregator=aggregator, kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7378d826-fa06-4a7d-8891-f46888320d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifiers=['gcn','gat','graphsage']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad91add-83d2-4b32-a97b-a7b48487c7fe",
   "metadata": {},
   "source": [
    "## Classification using different node embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3a2d104c-8a7a-42dc-8892-c479585380c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(\n",
    "    embedding_dict, embedding, classifier,\n",
    "    ground_truth_labels=ground_truth_labels,\n",
    "    masked_labels=masked_labels\n",
    "):\n",
    "    \"\"\"\n",
    "    Version that trains ONLY on the training subgraph:\n",
    "    - Uses X_train and A_train\n",
    "    - No test nodes seen during training\n",
    "    - Inference happens on full graph\n",
    "    \"\"\"\n",
    "\n",
    "    print(f\"\\nEmbedding: {embedding.upper()}\")\n",
    "    print(f\"Model: {classifier.upper()}\")\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 1. Construct training subgraph\n",
    "    # -------------------------------------------------\n",
    "    train_mask = masked_labels != -1\n",
    "    train_idx = np.where(train_mask)[0]\n",
    "\n",
    "    # Features for training nodes only\n",
    "    X_train = tf.convert_to_tensor(embedding_dict[embedding][train_mask], dtype=tf.float32)\n",
    "\n",
    "    # Labels for training (one-hot)\n",
    "    Y_train = tf.convert_to_tensor(ground_truth_labels[train_mask], dtype=tf.float32)\n",
    "\n",
    "    # Build training adjacency\n",
    "    A_train = A[train_mask][:, train_mask]   # induced subgraph\n",
    "\n",
    "    A_coo = A_train.tocoo()\n",
    "    A_train_tensor = tf.sparse.SparseTensor(\n",
    "        indices = np.column_stack([A_coo.row, A_coo.col]),\n",
    "        values  = A_coo.data.astype(np.float32),\n",
    "        dense_shape = A_coo.shape\n",
    "    )\n",
    "    A_train_tensor = tf.sparse.reorder(A_train_tensor)\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 2. Build classifier on TRAIN SUBGRAPH only\n",
    "    # -------------------------------------------------\n",
    "    n_classes = ground_truth_labels.shape[1]\n",
    "\n",
    "    if classifier == 'gcn':\n",
    "        model = GCN(n_classes)\n",
    "    elif classifier == 'gat':\n",
    "        model = GAT(n_classes)\n",
    "    elif classifier == 'graphsage':\n",
    "        model = GraphSAGE(n_classes)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown classifier: \" + classifier)\n",
    "\n",
    "    optimizer = Adam(learning_rate=0.01)\n",
    "    loss_fn = CategoricalCrossentropy()\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 3. Training (on training subgraph only)\n",
    "    # -------------------------------------------------\n",
    "    print(\"Training on TRAINING SUBGRAPH ONLY...\")\n",
    "\n",
    "    epochs = 200\n",
    "    for epoch in range(epochs):\n",
    "        with tf.GradientTape() as tape:\n",
    "\n",
    "            preds_train, _ = model([X_train, A_train_tensor], training=True)\n",
    "\n",
    "            loss = loss_fn(Y_train, preds_train)\n",
    "\n",
    "        grads = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            acc = CategoricalAccuracy()(Y_train, preds_train).numpy()\n",
    "            print(f\"Epoch {epoch} | Loss={loss.numpy():.4f} | Train Acc={acc:.4f}\")\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 4. Inference — NOW use the full graph\n",
    "    # -------------------------------------------------\n",
    "    print(\"Predicting on FULL GRAPH...\")\n",
    "\n",
    "    X_full = tf.convert_to_tensor(embedding_dict[embedding], dtype=tf.float32)\n",
    "\n",
    "    A_full = A\n",
    "    A_coo = A_full.tocoo()\n",
    "    A_full_tensor = tf.sparse.SparseTensor(\n",
    "        indices = np.column_stack([A_coo.row, A_coo.col]),\n",
    "        values  = A_coo.data.astype(np.float32),\n",
    "        dense_shape = A_coo.shape\n",
    "    )\n",
    "    A_full_tensor = tf.sparse.reorder(A_full_tensor)\n",
    "\n",
    "    preds_all, emb_full = model([X_full, A_full_tensor], training=False)\n",
    "\n",
    "    predicted_labels = tf.argmax(preds_all, axis=1).numpy()\n",
    "\n",
    "    # Evaluate only on masked nodes (test nodes)\n",
    "    predicted_test = predicted_labels[labels_to_be_masked]\n",
    "    true_test = labels[labels_to_be_masked]\n",
    "\n",
    "    results = evaluate_model(true_test, predicted_test)\n",
    "\n",
    "    print(f\"Accuracy: {results['accuracy']*100:.2f}%\")\n",
    "    print(f\"F1 Score: {results['f1_score']:.4f}\")\n",
    "\n",
    "    results[\"embedding\"] = embedding\n",
    "    results[\"model\"] = classifier\n",
    "\n",
    "    return results, emb_full\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a8b95ec7-e23c-4d6e-8157-a2264ef0c7a8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Embedding: DEEPWALK\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=2.0992 | Train Acc=0.1502\n",
      "Epoch 20 | Loss=1.3801 | Train Acc=0.4244\n",
      "Epoch 40 | Loss=1.2385 | Train Acc=0.4885\n",
      "Epoch 60 | Loss=1.1377 | Train Acc=0.5165\n",
      "Epoch 80 | Loss=1.0763 | Train Acc=0.5425\n",
      "Epoch 100 | Loss=1.0372 | Train Acc=0.5616\n",
      "Epoch 120 | Loss=1.0124 | Train Acc=0.5656\n",
      "Epoch 140 | Loss=0.9952 | Train Acc=0.5716\n",
      "Epoch 160 | Loss=0.9818 | Train Acc=0.5806\n",
      "Epoch 180 | Loss=0.9720 | Train Acc=0.5816\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 33  68  20  24  26  19]\n",
      " [ 37 222  63  30  35  30]\n",
      " [ 14  96 282  42  14  27]\n",
      " [ 43 118  72 209  27  26]\n",
      " [ 23  51  20  28 256  33]\n",
      " [  8  70  33  27  26 176]]\n",
      "Accuracy: 50.60%\n",
      "F1 Score: 0.4790\n",
      "\n",
      "Embedding: DEEPWALK\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7874 | Train Acc=0.1652\n",
      "Epoch 20 | Loss=1.3029 | Train Acc=0.4995\n",
      "Epoch 40 | Loss=1.1205 | Train Acc=0.5656\n",
      "Epoch 60 | Loss=1.0549 | Train Acc=0.5706\n",
      "Epoch 80 | Loss=0.9863 | Train Acc=0.5926\n",
      "Epoch 100 | Loss=0.9612 | Train Acc=0.6016\n",
      "Epoch 120 | Loss=0.9136 | Train Acc=0.6086\n",
      "Epoch 140 | Loss=0.9097 | Train Acc=0.6026\n",
      "Epoch 160 | Loss=0.9104 | Train Acc=0.5896\n",
      "Epoch 180 | Loss=0.9015 | Train Acc=0.5976\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 36  69  16  25  29  15]\n",
      " [ 19 276  49  21  27  25]\n",
      " [ 16  73 326  36   9  15]\n",
      " [ 26  85  49 287  27  21]\n",
      " [  5  40   8  14 312  32]\n",
      " [ 17  56  18  16  26 207]]\n",
      "Accuracy: 62.03%\n",
      "F1 Score: 0.5809\n",
      "\n",
      "Embedding: DEEPWALK\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8491 | Train Acc=0.1752\n",
      "Epoch 20 | Loss=1.3255 | Train Acc=0.6917\n",
      "Epoch 40 | Loss=1.2400 | Train Acc=0.7628\n",
      "Epoch 60 | Loss=1.2030 | Train Acc=0.7978\n",
      "Epoch 80 | Loss=1.1749 | Train Acc=0.8318\n",
      "Epoch 100 | Loss=1.1557 | Train Acc=0.8468\n",
      "Epoch 120 | Loss=1.1440 | Train Acc=0.8519\n",
      "Epoch 140 | Loss=1.1343 | Train Acc=0.8669\n",
      "Epoch 160 | Loss=1.1291 | Train Acc=0.8679\n",
      "Epoch 180 | Loss=1.1312 | Train Acc=0.8709\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 38  66  20  15  32  19]\n",
      " [ 22 245  49  25  33  43]\n",
      " [  9  64 318  46  15  23]\n",
      " [ 34  61  53 294  29  24]\n",
      " [ 23  28  13   6 310  31]\n",
      " [ 19  47  22  15  31 206]]\n",
      "Accuracy: 60.61%\n",
      "F1 Score: 0.5669\n",
      "\n",
      "Embedding: NODE2VEC\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=2.3863 | Train Acc=0.1391\n",
      "Epoch 20 | Loss=1.4255 | Train Acc=0.4364\n",
      "Epoch 40 | Loss=1.2765 | Train Acc=0.4765\n",
      "Epoch 60 | Loss=1.1750 | Train Acc=0.5185\n",
      "Epoch 80 | Loss=1.1065 | Train Acc=0.5345\n",
      "Epoch 100 | Loss=1.0588 | Train Acc=0.5425\n",
      "Epoch 120 | Loss=1.0272 | Train Acc=0.5556\n",
      "Epoch 140 | Loss=1.0072 | Train Acc=0.5596\n",
      "Epoch 160 | Loss=0.9934 | Train Acc=0.5686\n",
      "Epoch 180 | Loss=0.9829 | Train Acc=0.5726\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 41  57  26  29  15  22]\n",
      " [ 51 192  72  29  20  53]\n",
      " [ 10  73 280  68  14  30]\n",
      " [ 31 109  48 205  26  76]\n",
      " [ 42  53  33  30 215  38]\n",
      " [ 16  52  36  41  23 172]]\n",
      "Accuracy: 47.47%\n",
      "F1 Score: 0.4528\n",
      "\n",
      "Embedding: NODE2VEC\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7936 | Train Acc=0.1241\n",
      "Epoch 20 | Loss=1.2790 | Train Acc=0.5005\n",
      "Epoch 40 | Loss=1.1565 | Train Acc=0.5335\n",
      "Epoch 60 | Loss=1.0643 | Train Acc=0.5566\n",
      "Epoch 80 | Loss=0.9537 | Train Acc=0.5966\n",
      "Epoch 100 | Loss=0.9337 | Train Acc=0.5966\n",
      "Epoch 120 | Loss=0.9337 | Train Acc=0.5946\n",
      "Epoch 140 | Loss=0.8547 | Train Acc=0.6296\n",
      "Epoch 160 | Loss=0.9031 | Train Acc=0.6086\n",
      "Epoch 180 | Loss=0.8457 | Train Acc=0.6316\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 37  60  14  34  30  15]\n",
      " [ 20 256  47  30  23  41]\n",
      " [  2  64 305  61  20  23]\n",
      " [ 25  78  50 307  21  14]\n",
      " [ 18  50  15  29 268  31]\n",
      " [  5  59  22  29  22 203]]\n",
      "Accuracy: 59.11%\n",
      "F1 Score: 0.5555\n",
      "\n",
      "Embedding: NODE2VEC\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8022 | Train Acc=0.2062\n",
      "Epoch 20 | Loss=1.3280 | Train Acc=0.6737\n",
      "Epoch 40 | Loss=1.2435 | Train Acc=0.7658\n",
      "Epoch 60 | Loss=1.1947 | Train Acc=0.8088\n",
      "Epoch 80 | Loss=1.1672 | Train Acc=0.8348\n",
      "Epoch 100 | Loss=1.1555 | Train Acc=0.8418\n",
      "Epoch 120 | Loss=1.1571 | Train Acc=0.8448\n",
      "Epoch 140 | Loss=1.1352 | Train Acc=0.8649\n",
      "Epoch 160 | Loss=1.1274 | Train Acc=0.8719\n",
      "Epoch 180 | Loss=1.1245 | Train Acc=0.8789\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 29  44  26  31  37  23]\n",
      " [ 20 215  68  32  44  38]\n",
      " [ 12  44 319  57  13  30]\n",
      " [ 26  66  50 289  37  27]\n",
      " [ 17  26  24  31 283  30]\n",
      " [  7  51  24  34  27 197]]\n",
      "Accuracy: 57.22%\n",
      "F1 Score: 0.5290\n",
      "\n",
      "Embedding: FUSE\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7920 | Train Acc=0.1441\n",
      "Epoch 20 | Loss=1.5844 | Train Acc=0.4374\n",
      "Epoch 40 | Loss=1.3787 | Train Acc=0.5455\n",
      "Epoch 60 | Loss=1.2145 | Train Acc=0.5546\n",
      "Epoch 80 | Loss=1.1011 | Train Acc=0.5746\n",
      "Epoch 100 | Loss=1.0293 | Train Acc=0.5866\n",
      "Epoch 120 | Loss=0.9852 | Train Acc=0.5936\n",
      "Epoch 140 | Loss=0.9575 | Train Acc=0.5976\n",
      "Epoch 160 | Loss=0.9394 | Train Acc=0.6006\n",
      "Epoch 180 | Loss=0.9268 | Train Acc=0.6016\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 45  24  23  37  36  25]\n",
      " [ 29 216  65  26  47  34]\n",
      " [  6  30 344  46  27  22]\n",
      " [ 26  26  67 295  59  22]\n",
      " [ 20  16  18  21 316  20]\n",
      " [ 11  17  29  32  37 214]]\n",
      "Accuracy: 61.43%\n",
      "F1 Score: 0.5763\n",
      "\n",
      "Embedding: FUSE\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7917 | Train Acc=0.1431\n",
      "Epoch 20 | Loss=1.2956 | Train Acc=0.6627\n",
      "Epoch 40 | Loss=0.8712 | Train Acc=0.6507\n",
      "Epoch 60 | Loss=0.8495 | Train Acc=0.6286\n",
      "Epoch 80 | Loss=0.7747 | Train Acc=0.6577\n",
      "Epoch 100 | Loss=0.7417 | Train Acc=0.6687\n",
      "Epoch 120 | Loss=0.7576 | Train Acc=0.6637\n",
      "Epoch 140 | Loss=0.7809 | Train Acc=0.6547\n",
      "Epoch 160 | Loss=0.7673 | Train Acc=0.6577\n",
      "Epoch 180 | Loss=0.7823 | Train Acc=0.6507\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 47  28  26  35  32  22]\n",
      " [ 17 221  73  22  49  35]\n",
      " [  8  20 352  50  19  26]\n",
      " [ 21  22  66 313  47  26]\n",
      " [ 20  16  21  10 319  25]\n",
      " [ 14  14  30  37  34 211]]\n",
      "Accuracy: 62.84%\n",
      "F1 Score: 0.5896\n",
      "\n",
      "Embedding: FUSE\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8052 | Train Acc=0.2022\n",
      "Epoch 20 | Loss=1.0390 | Train Acc=0.9890\n",
      "Epoch 40 | Loss=1.0123 | Train Acc=0.9990\n",
      "Epoch 60 | Loss=1.0016 | Train Acc=1.0000\n",
      "Epoch 80 | Loss=0.9954 | Train Acc=1.0000\n",
      "Epoch 100 | Loss=0.9917 | Train Acc=1.0000\n",
      "Epoch 120 | Loss=0.9894 | Train Acc=1.0000\n",
      "Epoch 140 | Loss=0.9877 | Train Acc=1.0000\n",
      "Epoch 160 | Loss=0.9866 | Train Acc=1.0000\n",
      "Epoch 180 | Loss=0.9858 | Train Acc=1.0000\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 47  24  20  35  48  16]\n",
      " [ 39 209  70  27  51  21]\n",
      " [ 12  30 344  44  27  18]\n",
      " [ 27  28  54 308  61  17]\n",
      " [ 20  18  18  14 324  17]\n",
      " [ 11  16  24  37  38 214]]\n",
      "Accuracy: 62.11%\n",
      "F1 Score: 0.5836\n",
      "\n",
      "Embedding: VGAE\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7919 | Train Acc=0.1491\n",
      "Epoch 20 | Loss=1.3750 | Train Acc=0.4484\n",
      "Epoch 40 | Loss=1.2083 | Train Acc=0.4975\n",
      "Epoch 60 | Loss=1.1040 | Train Acc=0.5375\n",
      "Epoch 80 | Loss=1.0420 | Train Acc=0.5506\n",
      "Epoch 100 | Loss=1.0056 | Train Acc=0.5636\n",
      "Epoch 120 | Loss=0.9833 | Train Acc=0.5716\n",
      "Epoch 140 | Loss=0.9679 | Train Acc=0.5776\n",
      "Epoch 160 | Loss=0.9565 | Train Acc=0.5816\n",
      "Epoch 180 | Loss=0.9479 | Train Acc=0.5826\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 26  38  23  51  26  26]\n",
      " [ 43 155  68  71  35  45]\n",
      " [ 20  78 246  95  15  21]\n",
      " [ 33  71  62 236  44  49]\n",
      " [ 32  28  24  63 241  23]\n",
      " [ 16  46  37  52  41 148]]\n",
      "Accuracy: 45.19%\n",
      "F1 Score: 0.4223\n",
      "\n",
      "Embedding: VGAE\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7906 | Train Acc=0.1542\n",
      "Epoch 20 | Loss=1.3371 | Train Acc=0.4805\n",
      "Epoch 40 | Loss=1.2708 | Train Acc=0.5095\n",
      "Epoch 60 | Loss=1.1669 | Train Acc=0.5656\n",
      "Epoch 80 | Loss=1.1496 | Train Acc=0.5385\n",
      "Epoch 100 | Loss=1.0382 | Train Acc=0.5946\n",
      "Epoch 120 | Loss=1.0427 | Train Acc=0.5796\n",
      "Epoch 140 | Loss=0.9332 | Train Acc=0.6146\n",
      "Epoch 160 | Loss=0.8704 | Train Acc=0.6246\n",
      "Epoch 180 | Loss=0.8507 | Train Acc=0.6406\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 37  45  17  26  40  25]\n",
      " [ 26 214  50  46  43  38]\n",
      " [ 15  47 325  47  14  27]\n",
      " [ 23  37  56 290  64  25]\n",
      " [ 12  25   9  20 308  37]\n",
      " [ 15  21  28  34  35 207]]\n",
      "Accuracy: 59.32%\n",
      "F1 Score: 0.5520\n",
      "\n",
      "Embedding: VGAE\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8587 | Train Acc=0.1732\n",
      "Epoch 20 | Loss=1.2975 | Train Acc=0.7297\n",
      "Epoch 40 | Loss=1.2010 | Train Acc=0.8178\n",
      "Epoch 60 | Loss=1.1561 | Train Acc=0.8549\n",
      "Epoch 80 | Loss=1.1233 | Train Acc=0.8849\n",
      "Epoch 100 | Loss=1.1036 | Train Acc=0.9029\n",
      "Epoch 120 | Loss=1.0872 | Train Acc=0.9159\n",
      "Epoch 140 | Loss=1.0779 | Train Acc=0.9179\n",
      "Epoch 160 | Loss=1.0737 | Train Acc=0.9259\n",
      "Epoch 180 | Loss=1.0640 | Train Acc=0.9339\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 27  33  19  36  42  33]\n",
      " [ 14 211  45  42  42  63]\n",
      " [ 12  62 265  57  23  56]\n",
      " [ 10  51  44 285  59  46]\n",
      " [ 10  33  14  22 294  38]\n",
      " [  3  28  22  36  34 217]]\n",
      "Accuracy: 55.80%\n",
      "F1 Score: 0.5151\n",
      "\n",
      "Embedding: DGI\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=4.8407 | Train Acc=0.1441\n",
      "Epoch 20 | Loss=1.8037 | Train Acc=0.1942\n",
      "Epoch 40 | Loss=1.7683 | Train Acc=0.2062\n",
      "Epoch 60 | Loss=1.7590 | Train Acc=0.2062\n",
      "Epoch 80 | Loss=1.7548 | Train Acc=0.2062\n",
      "Epoch 100 | Loss=1.7529 | Train Acc=0.2062\n",
      "Epoch 120 | Loss=1.7522 | Train Acc=0.2062\n",
      "Epoch 140 | Loss=1.7519 | Train Acc=0.2062\n",
      "Epoch 160 | Loss=1.7518 | Train Acc=0.2062\n",
      "Epoch 180 | Loss=1.7518 | Train Acc=0.2062\n",
      "Predicting on FULL GRAPH...\n",
      "[[  0   0   0 190   0   0]\n",
      " [  0   0   0 417   0   0]\n",
      " [  0   0   0 475   0   0]\n",
      " [  0   0   0 495   0   0]\n",
      " [  0   0   0 411   0   0]\n",
      " [  0   0   0 340   0   0]]\n",
      "Accuracy: 21.26%\n",
      "F1 Score: 0.0584\n",
      "\n",
      "Embedding: DGI\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.9439 | Train Acc=0.1351\n",
      "Epoch 20 | Loss=1.7838 | Train Acc=0.2072\n",
      "Epoch 40 | Loss=1.7558 | Train Acc=0.2102\n",
      "Epoch 60 | Loss=1.7294 | Train Acc=0.2492\n",
      "Epoch 80 | Loss=1.7549 | Train Acc=0.2142\n",
      "Epoch 100 | Loss=1.7348 | Train Acc=0.2312\n",
      "Epoch 120 | Loss=1.7331 | Train Acc=0.2252\n",
      "Epoch 140 | Loss=1.7307 | Train Acc=0.2292\n",
      "Epoch 160 | Loss=1.7294 | Train Acc=0.2372\n",
      "Epoch 180 | Loss=1.7198 | Train Acc=0.2232\n",
      "Predicting on FULL GRAPH...\n",
      "[[  0   0  27  26 136   1]\n",
      " [  0   1  77  38 301   0]\n",
      " [  0   0  86  26 362   1]\n",
      " [  0   2  96  45 352   0]\n",
      " [  0   0  44  17 350   0]\n",
      " [  0   0  62  47 226   5]]\n",
      "Accuracy: 20.92%\n",
      "F1 Score: 0.1148\n",
      "\n",
      "Embedding: DGI\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7895 | Train Acc=0.2062\n",
      "Epoch 20 | Loss=1.7506 | Train Acc=0.2082\n",
      "Epoch 40 | Loss=1.7439 | Train Acc=0.2172\n",
      "Epoch 60 | Loss=1.7310 | Train Acc=0.2272\n",
      "Epoch 80 | Loss=1.7384 | Train Acc=0.2322\n",
      "Epoch 100 | Loss=1.7027 | Train Acc=0.2793\n",
      "Epoch 120 | Loss=1.6723 | Train Acc=0.3083\n",
      "Epoch 140 | Loss=1.6450 | Train Acc=0.3524\n",
      "Epoch 160 | Loss=1.6177 | Train Acc=0.3754\n",
      "Epoch 180 | Loss=1.6381 | Train Acc=0.3303\n",
      "Predicting on FULL GRAPH...\n",
      "[[  0  44  49  60  24  13]\n",
      " [  0 115  92 130  50  30]\n",
      " [  0  67 188  92  38  90]\n",
      " [  0  70 103 203  70  49]\n",
      " [  0  47  96 106  97  65]\n",
      " [  0  78  82  75  42  63]]\n",
      "Accuracy: 28.61%\n",
      "F1 Score: 0.2383\n",
      "\n",
      "Embedding: RANDOM\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=2.7719 | Train Acc=0.1291\n",
      "Epoch 20 | Loss=1.3067 | Train Acc=0.4515\n",
      "Epoch 40 | Loss=1.0989 | Train Acc=0.5506\n",
      "Epoch 60 | Loss=0.9813 | Train Acc=0.5866\n",
      "Epoch 80 | Loss=0.9271 | Train Acc=0.6006\n",
      "Epoch 100 | Loss=0.9074 | Train Acc=0.6006\n",
      "Epoch 120 | Loss=0.8990 | Train Acc=0.6026\n",
      "Epoch 140 | Loss=0.8943 | Train Acc=0.6026\n",
      "Epoch 160 | Loss=0.8916 | Train Acc=0.6036\n",
      "Epoch 180 | Loss=0.8897 | Train Acc=0.6036\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 17  37  38  44  27  27]\n",
      " [ 26  99  78  70  84  60]\n",
      " [ 19 143 119  97  38  59]\n",
      " [ 37  84 106 128  74  66]\n",
      " [ 22 111  64  98  70  46]\n",
      " [ 24  63  53  62  72  66]]\n",
      "Accuracy: 21.43%\n",
      "F1 Score: 0.2001\n",
      "\n",
      "Embedding: RANDOM\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8158 | Train Acc=0.1331\n",
      "Epoch 20 | Loss=1.3208 | Train Acc=0.4775\n",
      "Epoch 40 | Loss=1.0686 | Train Acc=0.5696\n",
      "Epoch 60 | Loss=0.9014 | Train Acc=0.6136\n",
      "Epoch 80 | Loss=0.8470 | Train Acc=0.6306\n",
      "Epoch 100 | Loss=0.8719 | Train Acc=0.6116\n",
      "Epoch 120 | Loss=0.8160 | Train Acc=0.6496\n",
      "Epoch 140 | Loss=0.7736 | Train Acc=0.6637\n",
      "Epoch 160 | Loss=0.7781 | Train Acc=0.6597\n",
      "Epoch 180 | Loss=0.7969 | Train Acc=0.6456\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 13  36  28  38  46  29]\n",
      " [ 17 158  63  71  70  38]\n",
      " [ 18  77 186  81  63  50]\n",
      " [ 20  60  79 172  98  66]\n",
      " [ 16  64  57  47 180  47]\n",
      " [ 21  44  71  47  66  91]]\n",
      "Accuracy: 34.36%\n",
      "F1 Score: 0.3113\n",
      "\n",
      "Embedding: RANDOM\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8531 | Train Acc=0.1642\n",
      "Epoch 20 | Loss=1.3101 | Train Acc=0.7508\n",
      "Epoch 40 | Loss=1.1710 | Train Acc=0.8949\n",
      "Epoch 60 | Loss=1.1077 | Train Acc=0.9399\n",
      "Epoch 80 | Loss=1.0730 | Train Acc=0.9610\n",
      "Epoch 100 | Loss=1.0531 | Train Acc=0.9730\n",
      "Epoch 120 | Loss=1.0392 | Train Acc=0.9790\n",
      "Epoch 140 | Loss=1.0329 | Train Acc=0.9830\n",
      "Epoch 160 | Loss=1.0247 | Train Acc=0.9850\n",
      "Epoch 180 | Loss=1.0197 | Train Acc=0.9890\n",
      "Predicting on FULL GRAPH...\n",
      "[[  1  46  30  35  43  35]\n",
      " [  2 119  78  60  72  86]\n",
      " [  1 127 100  68 111  68]\n",
      " [  5 135  78  80  96 101]\n",
      " [  0 112  84  49 106  60]\n",
      " [  0  79  68  33  66  94]]\n",
      "Accuracy: 21.48%\n",
      "F1 Score: 0.1880\n",
      "\n",
      "Embedding: GIVEN\n",
      "Model: GCN\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7928 | Train Acc=0.1381\n",
      "Epoch 20 | Loss=0.9126 | Train Acc=0.6006\n",
      "Epoch 40 | Loss=0.8896 | Train Acc=0.6026\n",
      "Epoch 60 | Loss=0.8870 | Train Acc=0.6026\n",
      "Epoch 80 | Loss=0.8863 | Train Acc=0.6026\n",
      "Epoch 100 | Loss=0.8859 | Train Acc=0.6026\n",
      "Epoch 120 | Loss=0.8858 | Train Acc=0.6026\n",
      "Epoch 140 | Loss=0.8854 | Train Acc=0.6026\n",
      "Epoch 160 | Loss=0.8852 | Train Acc=0.6036\n",
      "Epoch 180 | Loss=0.8851 | Train Acc=0.6036\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 48  58  14  25  24  21]\n",
      " [ 31 253  50  31  20  32]\n",
      " [ 13  35 344  51  11  21]\n",
      " [ 13  20  60 361  31  10]\n",
      " [ 24  17   9  16 329  16]\n",
      " [ 14   9  25  26  36 230]]\n",
      "Accuracy: 67.23%\n",
      "F1 Score: 0.6310\n",
      "\n",
      "Embedding: GIVEN\n",
      "Model: GAT\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.7923 | Train Acc=0.1301\n",
      "Epoch 20 | Loss=0.7533 | Train Acc=0.6647\n",
      "Epoch 40 | Loss=0.7586 | Train Acc=0.6557\n",
      "Epoch 60 | Loss=0.7409 | Train Acc=0.6677\n",
      "Epoch 80 | Loss=0.7295 | Train Acc=0.6657\n",
      "Epoch 100 | Loss=0.7318 | Train Acc=0.6627\n",
      "Epoch 120 | Loss=0.6904 | Train Acc=0.6797\n",
      "Epoch 140 | Loss=0.7537 | Train Acc=0.6486\n",
      "Epoch 160 | Loss=0.7601 | Train Acc=0.6517\n",
      "Epoch 180 | Loss=0.7572 | Train Acc=0.6557\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 36  64  19  26  30  15]\n",
      " [ 16 289  58  14  25  15]\n",
      " [  1  22 398  30  13  11]\n",
      " [  4  17  48 398  20   8]\n",
      " [  8  24   8   6 351  14]\n",
      " [  1  14  28  12  25 260]]\n",
      "Accuracy: 74.40%\n",
      "F1 Score: 0.6886\n",
      "\n",
      "Embedding: GIVEN\n",
      "Model: GRAPHSAGE\n",
      "Training on TRAINING SUBGRAPH ONLY...\n",
      "Epoch 0 | Loss=1.8091 | Train Acc=0.2062\n",
      "Epoch 20 | Loss=1.0352 | Train Acc=0.9690\n",
      "Epoch 40 | Loss=1.0043 | Train Acc=0.9980\n",
      "Epoch 60 | Loss=0.9931 | Train Acc=1.0000\n",
      "Epoch 80 | Loss=0.9886 | Train Acc=1.0000\n",
      "Epoch 100 | Loss=0.9865 | Train Acc=1.0000\n",
      "Epoch 120 | Loss=0.9854 | Train Acc=1.0000\n",
      "Epoch 140 | Loss=0.9847 | Train Acc=1.0000\n",
      "Epoch 160 | Loss=0.9843 | Train Acc=1.0000\n",
      "Epoch 180 | Loss=0.9840 | Train Acc=1.0000\n",
      "Predicting on FULL GRAPH...\n",
      "[[ 55  70  12  22  17  14]\n",
      " [ 21 310  40  14  14  18]\n",
      " [ 12  44 354  34  12  19]\n",
      " [ 30  29  42 362  16  16]\n",
      " [ 36  26   7   6 309  27]\n",
      " [  4  18  16  12  22 268]]\n",
      "Accuracy: 71.22%\n",
      "F1 Score: 0.6740\n"
     ]
    }
   ],
   "source": [
    "all_results=[]\n",
    "graph_embeddings_dict={}\n",
    "for emb in embedding_dict.keys():\n",
    "    for clf in classifiers:\n",
    "        results, embedding_matrix = train_and_evaluate(embedding_dict, emb, clf)\n",
    "        all_results.append(results)\n",
    "        key_string= emb + ' with ' + clf\n",
    "        graph_embeddings_dict[key_string]=embedding_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17527c3b-ef18-4b1a-abb6-48bcea7dc4a3",
   "metadata": {},
   "source": [
    "## Saving aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "fa504238-b2df-4efd-b68c-90532d749d70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results saved as ./citeseer_analysis_results/results/30_70/citeseer_30_70_2025_results.csv\n"
     ]
    }
   ],
   "source": [
    "# Convert to DataFrame\n",
    "df = pd.DataFrame(all_results)\n",
    "\n",
    "# Define dataset name and seed\n",
    "dataset_name = \"citeseer\"\n",
    "seed_value = SEED\n",
    "\n",
    "# Save as CSV file without sorting\n",
    "filename = f\"{dataset_name}_30_70_{SEED}_results.csv\"\n",
    "filename='./citeseer_analysis_results/results/30_70/'+filename\n",
    "df.to_csv(filename, index=False)\n",
    "\n",
    "print(f\"Results saved as {filename}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e9875500-5fdf-4d56-9a05-5d5d3cbceec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_embeddings= embedding_dict | graph_embeddings_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "602c48bd-1f98-47c9-993a-10ec0ec2966d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reorder_dict(original_dict, key_order):\n",
    "    \"\"\"\n",
    "    Reorders a dictionary based on a given list of keys.\n",
    "\n",
    "    Parameters:\n",
    "    - original_dict (dict): The dictionary to reorder.\n",
    "    - key_order (list): The list specifying the desired key order.\n",
    "\n",
    "    Returns:\n",
    "    - dict: A new dictionary with keys ordered as per key_order.\n",
    "    \"\"\"\n",
    "    return {key: original_dict[key] for key in key_order if key in original_dict}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3a0939b3-7a5a-4149-9130-86f5d6ca5141",
   "metadata": {},
   "outputs": [],
   "source": [
    "key_order = ['random', 'random with gcn', 'random with gat', 'random with graphsage', 'deepwalk', 'deepwalk with gcn', 'deepwalk with gat', 'deepwalk with graphsage', 'node2vec','node2vec with gcn', 'node2vec with gat', 'node2vec with graphsage', 'vgae', 'vgae with gcn', 'vgae with gat', 'vgae with graphsage', 'dgi', 'dgi with gcn', 'dgi with gat', 'dgi with graphsage', 'fuse', 'fuse with gcn', 'fuse with gat', 'fuse with graphsage', 'given', 'given with gcn', 'given with gat', 'given with graphsage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "267d6703-3449-4e03-a5ed-105a5eba4eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_embeddings = reorder_dict(all_embeddings, key_order)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3a9db961-2936-4f72-b644-8282bf37c445",
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize_all_embeddings(all_embeddings, labels, label_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "dd3c4b57-1af7-49e2-b092-620517ffaf88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "🔍 Evaluating Raw Embeddings...\n",
      "\n",
      " DEEPWALK | Silhouette=0.0578, DB=3.8248, NMI=0.2004, ARI=0.1017, V=0.2004\n",
      " NODE2VEC | Silhouette=0.0501, DB=4.2505, NMI=0.1999, ARI=0.1083, V=0.1999\n",
      " FUSE | Silhouette=0.3379, DB=7.9979, NMI=0.0623, ARI=0.0297, V=0.0623\n",
      " VGAE | Silhouette=0.0379, DB=4.6632, NMI=0.1304, ARI=0.0533, V=0.1304\n",
      " DGI | Silhouette=0.3655, DB=0.8477, NMI=0.0232, ARI=0.0122, V=0.0232\n",
      " RANDOM | Silhouette=0.0054, DB=9.1165, NMI=0.0017, ARI=-0.0004, V=0.0017\n",
      " GIVEN | Silhouette=0.0053, DB=8.4735, NMI=0.2750, ARI=0.2271, V=0.2750\n",
      "\n",
      "🔍 Evaluating GNN Embeddings...\n",
      "\n",
      " DEEPWALK WITH GCN | Silhouette=0.4845, DB=1.0343, NMI=0.0682, ARI=0.0032, V=0.0682\n",
      " DEEPWALK WITH GAT | Silhouette=0.1558, DB=2.0543, NMI=0.3096, ARI=0.2316, V=0.3096\n",
      " DEEPWALK WITH GRAPHSAGE | Silhouette=0.2403, DB=1.4843, NMI=0.2706, ARI=0.1415, V=0.2706\n",
      " NODE2VEC WITH GCN | Silhouette=0.3917, DB=1.0486, NMI=0.0705, ARI=0.0159, V=0.0705\n",
      " NODE2VEC WITH GAT | Silhouette=0.1196, DB=2.3845, NMI=0.2752, ARI=0.2066, V=0.2752\n",
      " NODE2VEC WITH GRAPHSAGE | Silhouette=0.1871, DB=1.5549, NMI=0.1868, ARI=0.0941, V=0.1868\n",
      " FUSE WITH GCN | Silhouette=0.2207, DB=1.2694, NMI=0.2510, ARI=0.1459, V=0.2510\n",
      " FUSE WITH GAT | Silhouette=0.3302, DB=1.1640, NMI=0.4049, ARI=0.4300, V=0.4049\n",
      " FUSE WITH GRAPHSAGE | Silhouette=0.5295, DB=0.7009, NMI=0.4377, ARI=0.4723, V=0.4377\n",
      " VGAE WITH GCN | Silhouette=0.4262, DB=1.5107, NMI=0.1073, ARI=0.0020, V=0.1073\n",
      " VGAE WITH GAT | Silhouette=0.0705, DB=2.9288, NMI=0.1892, ARI=0.1122, V=0.1892\n",
      " VGAE WITH GRAPHSAGE | Silhouette=0.2009, DB=1.6532, NMI=0.1159, ARI=0.0571, V=0.1159\n",
      " DGI WITH GCN | Silhouette=nan, DB=nan, NMI=0.0000, ARI=0.0000, V=0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\user\\anaconda3\\envs\\ContrastiveFUSE\\lib\\site-packages\\sklearn\\base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (6). Possibly due to duplicate points in X.\n",
      "  return fit_method(estimator, *args, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " DGI WITH GAT | Silhouette=0.2900, DB=1.0324, NMI=0.0361, ARI=0.0199, V=0.0361\n",
      " DGI WITH GRAPHSAGE | Silhouette=0.2471, DB=1.1839, NMI=0.0203, ARI=0.0076, V=0.0203\n",
      " RANDOM WITH GCN | Silhouette=0.1080, DB=2.3962, NMI=0.0117, ARI=0.0012, V=0.0117\n",
      " RANDOM WITH GAT | Silhouette=0.0125, DB=5.2370, NMI=0.0332, ARI=0.0216, V=0.0332\n",
      " RANDOM WITH GRAPHSAGE | Silhouette=0.1032, DB=2.1691, NMI=0.0036, ARI=-0.0001, V=0.0036\n",
      " GIVEN WITH GCN | Silhouette=0.4348, DB=1.0642, NMI=0.0318, ARI=-0.0005, V=0.0318\n",
      " GIVEN WITH GAT | Silhouette=0.2794, DB=1.3230, NMI=0.5240, ARI=0.5416, V=0.5240\n",
      " GIVEN WITH GRAPHSAGE | Silhouette=0.3272, DB=1.2578, NMI=0.5194, ARI=0.5467, V=0.5194\n",
      "\n",
      "📊 Combined clustering metrics saved to: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_2025.csv\n",
      "            Source                Embedding  Silhouette  Davies-Bouldin  \\\n",
      "0   Raw Embeddings                 deepwalk    0.057816        3.824772   \n",
      "1   Raw Embeddings                 node2vec    0.050067        4.250537   \n",
      "2   Raw Embeddings                     fuse    0.337931        7.997896   \n",
      "3   Raw Embeddings                     vgae    0.037882        4.663208   \n",
      "4   Raw Embeddings                      dgi    0.365487        0.847746   \n",
      "5   Raw Embeddings                   random    0.005420        9.116462   \n",
      "6   Raw Embeddings                    given    0.005319        8.473548   \n",
      "7   GNN Embeddings        deepwalk with gcn    0.484547        1.034266   \n",
      "8   GNN Embeddings        deepwalk with gat    0.155783        2.054328   \n",
      "9   GNN Embeddings  deepwalk with graphsage    0.240329        1.484262   \n",
      "10  GNN Embeddings        node2vec with gcn    0.391709        1.048613   \n",
      "11  GNN Embeddings        node2vec with gat    0.119633        2.384549   \n",
      "12  GNN Embeddings  node2vec with graphsage    0.187088        1.554855   \n",
      "13  GNN Embeddings            fuse with gcn    0.220702        1.269431   \n",
      "14  GNN Embeddings            fuse with gat    0.330219        1.164030   \n",
      "15  GNN Embeddings      fuse with graphsage    0.529539        0.700932   \n",
      "16  GNN Embeddings            vgae with gcn    0.426167        1.510731   \n",
      "17  GNN Embeddings            vgae with gat    0.070456        2.928776   \n",
      "18  GNN Embeddings      vgae with graphsage    0.200936        1.653225   \n",
      "19  GNN Embeddings             dgi with gcn         NaN             NaN   \n",
      "20  GNN Embeddings             dgi with gat    0.290029        1.032446   \n",
      "21  GNN Embeddings       dgi with graphsage    0.247136        1.183855   \n",
      "22  GNN Embeddings          random with gcn    0.108034        2.396166   \n",
      "23  GNN Embeddings          random with gat    0.012497        5.237035   \n",
      "24  GNN Embeddings    random with graphsage    0.103250        2.169083   \n",
      "25  GNN Embeddings           given with gcn    0.434821        1.064175   \n",
      "26  GNN Embeddings           given with gat    0.279411        1.322976   \n",
      "27  GNN Embeddings     given with graphsage    0.327156        1.257846   \n",
      "\n",
      "         NMI       ARI  V-Measure  \n",
      "0   0.200424  0.101740   0.200424  \n",
      "1   0.199929  0.108296   0.199929  \n",
      "2   0.062345  0.029653   0.062345  \n",
      "3   0.130396  0.053331   0.130396  \n",
      "4   0.023178  0.012220   0.023178  \n",
      "5   0.001680 -0.000429   0.001680  \n",
      "6   0.275023  0.227123   0.275023  \n",
      "7   0.068197  0.003200   0.068197  \n",
      "8   0.309559  0.231614   0.309559  \n",
      "9   0.270558  0.141473   0.270558  \n",
      "10  0.070515  0.015874   0.070515  \n",
      "11  0.275191  0.206573   0.275191  \n",
      "12  0.186794  0.094097   0.186794  \n",
      "13  0.251002  0.145948   0.251002  \n",
      "14  0.404941  0.430015   0.404941  \n",
      "15  0.437723  0.472259   0.437723  \n",
      "16  0.107291  0.002017   0.107291  \n",
      "17  0.189180  0.112237   0.189180  \n",
      "18  0.115919  0.057136   0.115919  \n",
      "19  0.000000  0.000000   0.000000  \n",
      "20  0.036079  0.019910   0.036079  \n",
      "21  0.020342  0.007608   0.020342  \n",
      "22  0.011672  0.001239   0.011672  \n",
      "23  0.033183  0.021620   0.033183  \n",
      "24  0.003625 -0.000089   0.003625  \n",
      "25  0.031836 -0.000528   0.031836  \n",
      "26  0.523997  0.541621   0.523997  \n",
      "27  0.519361  0.546686   0.519361  \n"
     ]
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import (\n",
    "    silhouette_score,\n",
    "    davies_bouldin_score,\n",
    "    normalized_mutual_info_score,\n",
    "    adjusted_rand_score,\n",
    "    v_measure_score\n",
    ")\n",
    "import pandas as pd\n",
    "# from benchmarking_utils import load_dataset  # or your label loader\n",
    "\n",
    "# Load true labels\n",
    "# ds = load_dataset(\"cora\")\n",
    "# labels = ds[\"labels\"]\n",
    "ds=dataset\n",
    "# Choose which embedding sources to evaluate\n",
    "sources = {\n",
    "    \"Raw Embeddings\": embedding_dict,\n",
    "    \"GNN Embeddings\": graph_embeddings_dict  # << the key fix\n",
    "}\n",
    "\n",
    "metrics_results = []\n",
    "\n",
    "for source_name, emb_source in sources.items():\n",
    "    print(f\"\\n Evaluating {source_name}...\\n\")\n",
    "\n",
    "    for name, emb_tensor in emb_source.items():\n",
    "        emb = emb_tensor.numpy() if isinstance(emb_tensor, tf.Tensor) else np.array(emb_tensor)\n",
    "\n",
    "        n_clusters = len(np.unique(labels))\n",
    "        kmeans = KMeans(n_clusters=n_clusters, random_state=SEED, n_init=10)\n",
    "        cluster_labels = kmeans.fit_predict(emb)\n",
    "\n",
    "        try:\n",
    "            silhouette = silhouette_score(emb, cluster_labels)\n",
    "        except Exception:\n",
    "            silhouette = np.nan\n",
    "        try:\n",
    "            db_index = davies_bouldin_score(emb, cluster_labels)\n",
    "        except Exception:\n",
    "            db_index = np.nan\n",
    "\n",
    "        nmi = normalized_mutual_info_score(labels, cluster_labels)\n",
    "        ari = adjusted_rand_score(labels, cluster_labels)\n",
    "        v_measure = v_measure_score(labels, cluster_labels)\n",
    "\n",
    "        metrics_results.append({\n",
    "            \"Source\": source_name,\n",
    "            \"Embedding\": name,\n",
    "            \"Silhouette\": silhouette,\n",
    "            \"Davies-Bouldin\": db_index,\n",
    "            \"NMI\": nmi,\n",
    "            \"ARI\": ari,\n",
    "            \"V-Measure\": v_measure\n",
    "        })\n",
    "\n",
    "        print(f\" {name.upper()} | Silhouette={silhouette:.4f}, DB={db_index:.4f}, NMI={nmi:.4f}, ARI={ari:.4f}, V={v_measure:.4f}\")\n",
    "\n",
    "# Save results\n",
    "metrics_df = pd.DataFrame(metrics_results)\n",
    "metrics_path = os.path.join(save_dir, f\"clustering_metrics_combined_{dataset_name}_{split_folder.replace('-', '_')}_{SEED}.csv\")\n",
    "metrics_df.to_csv(metrics_path, index=False)\n",
    "print(f\"\\n Combined clustering metrics saved to: {metrics_path}\")\n",
    "print(metrics_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "b064e5fa-8b3e-4423-8773-7d63112ffd9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Loaded: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_42.csv\n",
      "✅ Loaded: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_46.csv\n",
      "✅ Loaded: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_123.csv\n",
      "✅ Loaded: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_999.csv\n",
      "✅ Loaded: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_2025.csv\n",
      "\n",
      "📊 Averaged metrics saved to: ./citeseer_analysis_results/embeddings/30_70/clustering_metrics_combined_citeseer_30_70_averaged.csv\n",
      "\n",
      "            Source                Embedding  Silhouette_Mean  Silhouette_STD  \\\n",
      "0   GNN Embeddings        deepwalk with gat         0.138693        0.011631   \n",
      "1   GNN Embeddings        deepwalk with gcn         0.482555        0.008660   \n",
      "2   GNN Embeddings  deepwalk with graphsage         0.236790        0.021600   \n",
      "3   GNN Embeddings             dgi with gat         0.300266        0.168681   \n",
      "4   GNN Embeddings             dgi with gcn         0.766260        0.404850   \n",
      "5   GNN Embeddings       dgi with graphsage         0.304702        0.150384   \n",
      "6   GNN Embeddings            fuse with gat         0.348933        0.014372   \n",
      "7   GNN Embeddings            fuse with gcn         0.234606        0.015360   \n",
      "8   GNN Embeddings      fuse with graphsage         0.548453        0.014707   \n",
      "9   GNN Embeddings           given with gat         0.281221        0.012400   \n",
      "10  GNN Embeddings           given with gcn         0.407967        0.067659   \n",
      "11  GNN Embeddings     given with graphsage         0.331445        0.013857   \n",
      "12  GNN Embeddings        node2vec with gat         0.117680        0.008494   \n",
      "13  GNN Embeddings        node2vec with gcn         0.443299        0.043646   \n",
      "14  GNN Embeddings  node2vec with graphsage         0.186993        0.019345   \n",
      "15  GNN Embeddings          random with gat         0.015676        0.002716   \n",
      "16  GNN Embeddings          random with gcn         0.119114        0.017867   \n",
      "17  GNN Embeddings    random with graphsage         0.099110        0.008441   \n",
      "18  GNN Embeddings            vgae with gat         0.081840        0.010954   \n",
      "19  GNN Embeddings            vgae with gcn         0.422298        0.024560   \n",
      "20  GNN Embeddings      vgae with graphsage         0.195917        0.039850   \n",
      "21  Raw Embeddings                 deepwalk         0.057193        0.001319   \n",
      "22  Raw Embeddings                      dgi         0.271157        0.127350   \n",
      "23  Raw Embeddings                     fuse         0.233121        0.233456   \n",
      "24  Raw Embeddings                    given         0.004653        0.002026   \n",
      "25  Raw Embeddings                 node2vec         0.048109        0.001907   \n",
      "26  Raw Embeddings                   random         0.005212        0.000168   \n",
      "27  Raw Embeddings                     vgae         0.039155        0.001530   \n",
      "\n",
      "     DB_Mean    DB_STD  NMI_Mean   NMI_STD  ARI_Mean   ARI_STD  VMeasure_Mean  \\\n",
      "0   2.218561  0.116478  0.296034  0.014810  0.213411  0.023162       0.296034   \n",
      "1   1.037939  0.081823  0.068185  0.012742  0.004059  0.001233       0.068185   \n",
      "2   1.555659  0.064538  0.231934  0.038736  0.137707  0.026515       0.231934   \n",
      "3   1.278573  0.964225  0.044896  0.008248  0.028384  0.010747       0.044896   \n",
      "4   0.343756  0.595403  0.006322  0.009794  0.000446  0.000469       0.006322   \n",
      "5   1.163975  0.580470  0.026437  0.008097  0.010754  0.009427       0.026437   \n",
      "6   1.130565  0.059418  0.415821  0.010862  0.445377  0.012611       0.415821   \n",
      "7   1.163935  0.096813  0.242183  0.012066  0.127996  0.016153       0.242183   \n",
      "8   0.667308  0.028278  0.441775  0.011364  0.474610  0.010246       0.441775   \n",
      "9   1.382086  0.061525  0.522175  0.011812  0.536210  0.016502       0.522175   \n",
      "10  1.141208  0.078947  0.042958  0.016973  0.000920  0.001838       0.042958   \n",
      "11  1.366499  0.067774  0.511873  0.012335  0.531870  0.013314       0.511873   \n",
      "12  2.412752  0.044101  0.267712  0.011819  0.184941  0.017892       0.267712   \n",
      "13  1.093033  0.146395  0.073625  0.014221  0.006686  0.006284       0.073625   \n",
      "14  1.662332  0.114901  0.198242  0.037152  0.108802  0.020811       0.198242   \n",
      "15  4.962667  0.286306  0.040973  0.010110  0.025385  0.006817       0.040973   \n",
      "16  2.247266  0.108564  0.010598  0.002990 -0.000131  0.001495       0.010598   \n",
      "17  2.251765  0.065385  0.004934  0.000854  0.000950  0.000948       0.004934   \n",
      "18  2.634882  0.192949  0.173315  0.011056  0.090713  0.015531       0.173315   \n",
      "19  1.460641  0.059155  0.099525  0.017101  0.004304  0.003666       0.099525   \n",
      "20  1.760782  0.165800  0.132160  0.018321  0.063316  0.012695       0.132160   \n",
      "21  3.706866  0.458895  0.217242  0.019544  0.111275  0.014287       0.217242   \n",
      "22  1.291024  0.793071  0.025312  0.004564  0.016482  0.005021       0.025312   \n",
      "23  6.725252  0.984131  0.150317  0.126038  0.106428  0.117715       0.150317   \n",
      "24  8.539071  0.133389  0.219911  0.050314  0.176782  0.041195       0.219911   \n",
      "25  4.220846  0.124992  0.192686  0.008432  0.094454  0.012350       0.192686   \n",
      "26  9.139941  0.042841  0.002511  0.000901  0.000202  0.000677       0.002511   \n",
      "27  4.610268  0.090871  0.123349  0.013170  0.048768  0.008193       0.123349   \n",
      "\n",
      "    VMeasure_STD  \n",
      "0       0.014810  \n",
      "1       0.012742  \n",
      "2       0.038736  \n",
      "3       0.008248  \n",
      "4       0.009794  \n",
      "5       0.008097  \n",
      "6       0.010862  \n",
      "7       0.012066  \n",
      "8       0.011364  \n",
      "9       0.011812  \n",
      "10      0.016973  \n",
      "11      0.012335  \n",
      "12      0.011819  \n",
      "13      0.014221  \n",
      "14      0.037152  \n",
      "15      0.010110  \n",
      "16      0.002990  \n",
      "17      0.000854  \n",
      "18      0.011056  \n",
      "19      0.017101  \n",
      "20      0.018321  \n",
      "21      0.019544  \n",
      "22      0.004564  \n",
      "23      0.126038  \n",
      "24      0.050314  \n",
      "25      0.008432  \n",
      "26      0.000901  \n",
      "27      0.013170  \n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "# ----------------------------------------\n",
    "# Configuration\n",
    "# ----------------------------------------\n",
    "dataset_name = \"citeseer\"\n",
    "split_folder = \"30_70\"\n",
    "seeds = [42, 46, 123, 999, 2025]\n",
    "base_dir = f\"./citeseer_analysis_results/embeddings/{split_folder}/\"\n",
    "\n",
    "# Output file\n",
    "output_path = os.path.join(base_dir, f\"clustering_metrics_combined_{dataset_name}_{split_folder}_averaged.csv\")\n",
    "\n",
    "# ----------------------------------------\n",
    "# Load and combine all CSVs\n",
    "# ----------------------------------------\n",
    "dfs = []\n",
    "for seed in seeds:\n",
    "    filename = f\"clustering_metrics_combined_{dataset_name}_{split_folder}_{seed}.csv\"\n",
    "    filepath = os.path.join(base_dir, filename)\n",
    "    \n",
    "    if os.path.exists(filepath):\n",
    "        df = pd.read_csv(filepath)\n",
    "        df[\"Seed\"] = seed\n",
    "        dfs.append(df)\n",
    "        print(f\" Loaded: {filepath}\")\n",
    "    else:\n",
    "        print(f\" Missing file for seed {seed}: {filepath}\")\n",
    "\n",
    "if not dfs:\n",
    "    raise FileNotFoundError(\"No CSV files found — check your directory and filenames.\")\n",
    "\n",
    "# Merge all seed data\n",
    "combined_df = pd.concat(dfs, ignore_index=True)\n",
    "\n",
    "# ----------------------------------------\n",
    "# Average metrics across seeds\n",
    "# ----------------------------------------\n",
    "agg_df = (\n",
    "    combined_df\n",
    "    .groupby([\"Source\", \"Embedding\"])\n",
    "    .agg({\n",
    "        \"Silhouette\": [\"mean\", \"std\"],\n",
    "        \"Davies-Bouldin\": [\"mean\", \"std\"],\n",
    "        \"NMI\": [\"mean\", \"std\"],\n",
    "        \"ARI\": [\"mean\", \"std\"],\n",
    "        \"V-Measure\": [\"mean\", \"std\"]\n",
    "    })\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "# Flatten column names\n",
    "agg_df.columns = [\n",
    "    \"Source\", \"Embedding\",\n",
    "    \"Silhouette_Mean\", \"Silhouette_STD\",\n",
    "    \"DB_Mean\", \"DB_STD\",\n",
    "    \"NMI_Mean\", \"NMI_STD\",\n",
    "    \"ARI_Mean\", \"ARI_STD\",\n",
    "    \"VMeasure_Mean\", \"VMeasure_STD\"\n",
    "]\n",
    "\n",
    "# Save the averaged results\n",
    "agg_df.to_csv(output_path, index=False)\n",
    "print(f\"\\n Averaged metrics saved to: {output_path}\\n\")\n",
    "\n",
    "# Display summary\n",
    "print(agg_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b553ebf-f9b9-4616-91a0-ba65d188e749",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
