{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c1ef1f-2372-44c0-9703-b8ac660cce10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4441684d-fa5d-4f8d-9967-84753fd5a4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import time \n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import umap.umap_ as umap\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\n",
    "from node2vec import Node2Vec\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.metrics import CategoricalAccuracy\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import spektral\n",
    "from spektral.layers import GCNConv, GATConv\n",
    "from spektral.layers import GraphSageConv\n",
    "from spektral.data import Graph, Dataset, BatchLoader\n",
    "from scipy.sparse import csr_matrix\n",
    "from spektral.datasets import Cora\n",
    "from torch_geometric.nn import DeepGraphInfomax, VGAE\n",
    "from torch_geometric.utils import from_networkx\n",
    "import scipy.sparse as sp\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
    "from scipy.sparse.csgraph import laplacian\n",
    "from scipy.sparse.linalg import eigsh\n",
    "from collections import Counter\n",
    "from sklearn.preprocessing import normalize\n",
    "from joblib import Parallel, delayed\n",
    "from torch_geometric.nn import GCNConv as PyG_GCNConv, VGAE as PyG_VGAE\n",
    "from torch_geometric.data import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e18457-6dbb-4928-adad-4b279400977c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 42\n",
    "\n",
    "# Set seed for Python's built-in random module\n",
    "random.seed(SEED)\n",
    "\n",
    "# Set seed for NumPy\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# Set seed for TensorFlow\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Set seed for PyTorch\n",
    "torch.manual_seed(SEED)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b95802d-543f-4b3e-803a-cdbdb518595b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "from scipy.sparse import lil_matrix\n",
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class PubMedDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Planetoid(root=\".\", name=\"PubMed\")  # Load PubMed dataset\n",
    "        data = dataset[0]  # Access the first graph\n",
    "        \n",
    "        # Convert Torch tensors to NumPy\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1  # Number of classes\n",
    "        y_one_hot = np.eye(num_classes)[y]  # One-hot encoding\n",
    "\n",
    "        # Convert edge_index to a sparse adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = lil_matrix((num_nodes, num_nodes), dtype=np.float32)\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected graph\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90c18215-d57c-47a9-957d-5f98fa330e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dimensionality=150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d1fa1ef-526c-4ecb-a4e5-6c344f0412c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_networkx(A):\n",
    "    return nx.from_scipy_sparse_array(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fbba5b-3f7c-4821-9af5-d9d6e982b9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = PubMedDataset()\n",
    "ground_truth_labels = dataset[0].y\n",
    "labels=np.argmax(ground_truth_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c92de7a7-bab8-4ee8-9601-e4f95f0386e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_file = \"C:/Users/user/Benchmarking_fuse/masks/Pubmed/70_30/Pubmed_70_30_masked_indices_seed42.npy\"\n",
    "labels_to_be_masked = np.load(mask_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5ef760-dd97-4ac2-8b4a-7109d91d08b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(labels_to_be_masked)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0cf50e3-56e5-4187-b4ef-9f05cd3bc76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_labels=[]\n",
    "for i in np.arange(len(labels)):\n",
    "    if i in labels_to_be_masked:\n",
    "        masked_labels.append(-1)\n",
    "    else:\n",
    "        masked_labels.append(labels[i])\n",
    "masked_labels=np.array(masked_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f436be3-ba62-406f-9523-09713d59829b",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_mask = masked_labels != -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "517544ea-62e6-495c-bce9-6733a551c959",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = dataset[0].x\n",
    "A = dataset[0].a\n",
    "G = convert_to_networkx(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17de4d41-86a4-40c3-bf4d-217511591a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Adjacency Matrix Shape:\", A.shape)\n",
    "print(\"Graph Nodes:\", G.number_of_nodes())\n",
    "print(\"Graph Edges:\", G.number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0346f826-0826-4b89-8328-054b3295fdc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert your preprocessed data into a PyTorch Geometric Data object\n",
    "X_py = Data(\n",
    "    x=torch.tensor(X, dtype=torch.float),  # Node features\n",
    "    edge_index=torch.tensor(np.array(A.nonzero()), dtype=torch.long),  # Edge indices\n",
    "    y=torch.tensor(labels, dtype=torch.long)  # Labels\n",
    ")\n",
    "\n",
    "# Ensure edge_index is in the correct shape (2, num_edges)\n",
    "X_py.edge_index = X_py.edge_index.to(torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2acc40a-d4c0-48e1-8fc6-da99cccc02e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "81b4f8f4-20eb-4804-a2e7-b04c6e80d08d",
   "metadata": {},
   "source": [
    "## Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cada09-7de5-438e-b980-7d5e703615d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# -------------------------------------\n",
    "# Configuration\n",
    "# -------------------------------------\n",
    "SEED = 42\n",
    "dataset_name = \"pubmed\"\n",
    "split_folder = \"70-30\"\n",
    "\n",
    "# Input embeddings directory\n",
    "load_dir = f\"C:/Users/user/Benchmarking_fuse/benchmark_outputs/{dataset_name}/{split_folder}/\"\n",
    "\n",
    "# Output save directory\n",
    "save_dir = f\"./pubmed_analysis_results/embeddings/{split_folder.replace('-', '_')}/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "# Embedding filenames to load\n",
    "embedding_files = {\n",
    "    \"deepwalk\": f\"deepwalk_embedding_70_30_{SEED}.pkl\",\n",
    "    \"node2vec\": f\"node2vec_embedding_70_30_{SEED}.pkl\",\n",
    "    \"fuse\": f\"fuse_embedding_70_30_{SEED}.pkl\",\n",
    "    \"vgae\": f\"vgae_embedding_70_30_{SEED}.pkl\",\n",
    "    \"dgi\": f\"dgi_embedding_70_30_{SEED}.pkl\",\n",
    "    \"random\": f\"random_embedding_70_30_{SEED}.pkl\",\n",
    "    \"given\": f\"given_embedding_70_30_{SEED}.pkl\"\n",
    "}\n",
    "\n",
    "# Dictionary to store embeddings\n",
    "embedding_dict = {}\n",
    "\n",
    "# -------------------------------------\n",
    "# Utility: convert numpy to tf.Tensor\n",
    "# -------------------------------------\n",
    "def to_tf_tensor(x):\n",
    "    \"\"\"Convert numpy array to TensorFlow tensor.\"\"\"\n",
    "    if isinstance(x, tf.Tensor):\n",
    "        return x\n",
    "    return tf.convert_to_tensor(x, dtype=tf.float32)\n",
    "\n",
    "# -------------------------------------\n",
    "# Load and re-save embeddings\n",
    "# -------------------------------------\n",
    "for name, filename in embedding_files.items():\n",
    "    filepath = os.path.join(load_dir, filename)\n",
    "    if not os.path.exists(filepath):\n",
    "        print(f\" Warning: {name} embedding file not found at {filepath}. Skipping.\")\n",
    "        continue\n",
    "\n",
    "    print(f\" Loading {name.upper()} embedding from {filepath}...\")\n",
    "    with open(filepath, \"rb\") as f:\n",
    "        emb = pickle.load(f)\n",
    "\n",
    "    # Convert to TensorFlow tensor\n",
    "    embedding_dict[name] = to_tf_tensor(np.array(emb, dtype=float))\n",
    "\n",
    "    # Save again to new organized directory\n",
    "    save_path = os.path.join(save_dir, f\"{name}_embedding_70_30_{SEED}.pkl\")\n",
    "    with open(save_path, \"wb\") as f:\n",
    "        pickle.dump(embedding_dict[name].numpy(), f)\n",
    "    print(f\" Saved {name.upper()} embedding to {save_path}\")\n",
    "\n",
    "print(\"\\n All embeddings loaded into memory and re-saved successfully.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a98e7ea-fd83-440b-a479-bed7c46185a4",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7297774-e5c4-4b87-9cc9-cab07deb7978",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_all_embeddings(all_embeddings, labels, label_mask):\n",
    "    \"\"\"\n",
    "    Visualize all embeddings in a grid with 4 columns per row using UMAP.\n",
    "\n",
    "    Parameters:\n",
    "    - all_embeddings: Dictionary where keys are embedding methods, and values are embeddings.\n",
    "    - labels: Labels (numpy array of shape [n_nodes]).\n",
    "    - label_mask: Boolean array indicating known labels (True for known, False for unknown).\n",
    "    \"\"\"\n",
    "    num_embeddings = len(all_embeddings)\n",
    "    num_rows = (num_embeddings + 3) // 4  # Ensure enough rows for all embeddings\n",
    "    fig, axes = plt.subplots(num_rows, 4, figsize=(8.27, 11.69))  # A4 size\n",
    "\n",
    "    for i, (embedding_type, embedding) in tqdm(enumerate(all_embeddings.items()), \n",
    "                                               total=num_embeddings, desc=\"Visualizing embeddings\"):\n",
    "        row, col = divmod(i, 4)\n",
    "        ax = axes[row, col] if num_rows > 1 else axes[col]  # Adjust for single-row case\n",
    "\n",
    "        # Ensure embedding is a NumPy array\n",
    "        if isinstance(embedding, tf.Tensor):\n",
    "            embedding = embedding.numpy()\n",
    "\n",
    "        # Reduce dimensionality using UMAP\n",
    "        reducer = umap.UMAP(n_components=2)\n",
    "        embedding_2d = reducer.fit_transform(embedding)\n",
    "\n",
    "        # Known labels\n",
    "        ax.scatter(embedding_2d[label_mask, 0], embedding_2d[label_mask, 1], \n",
    "                   c=labels[label_mask], cmap=\"Set1\", s=3, alpha=0.7, label=\"Known Labels\",\n",
    "                   edgecolors='none')\n",
    "\n",
    "        # Unknown labels\n",
    "        ax.scatter(embedding_2d[~label_mask, 0], embedding_2d[~label_mask, 1], \n",
    "                   c=labels[~label_mask], cmap=\"Set1\", s=5, alpha=0.7, \n",
    "                   label=\"Unknown Labels\", edgecolors='black', linewidths=0.2)\n",
    "\n",
    "        # Title with smaller font size\n",
    "        pretty = pretty_map.get(embedding_type, embedding_type)\n",
    "        ax.set_title(pretty, fontsize=8, pad=2)\n",
    "\n",
    "        # Remove axis labels, ticks, and frames\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_frame_on(False)\n",
    "\n",
    "    # Remove empty subplots if num_embeddings is not a multiple of 4\n",
    "    for j in range(i + 1, num_rows * 4):\n",
    "        row, col = divmod(j, 4)\n",
    "        fig.delaxes(axes[row, col])\n",
    "\n",
    "    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=0.2, hspace=0.2)  # Adjust margins\n",
    "    save_path = \"./pubmed_analysis_results/embedding_grid_plot_pubmed_70_30.png\"\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Visualization saved to {save_path}\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73add6f8-ad97-49f5-8f66-88df3b717390",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(true_labels, predicted_labels):\n",
    "    \"\"\"\n",
    "    Evaluate the model's performance using accuracy, F1-score, and confusion matrix.\n",
    "\n",
    "    Args:\n",
    "        true_labels (np.array): Ground truth labels (integers).\n",
    "        predicted_labels (np.array): Predicted labels (integers).\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing accuracy, F1-score, and confusion matrix.\n",
    "    \"\"\"\n",
    "    # Compute accuracy\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    \n",
    "    # Compute F1-score (macro-averaged)\n",
    "    f1 = f1_score(true_labels, predicted_labels, average='macro')\n",
    "    \n",
    "    # Compute confusion matrix\n",
    "    cm = confusion_matrix(true_labels, predicted_labels)\n",
    "\n",
    "    #\n",
    "    print(cm)\n",
    "    \n",
    "    # Return results as a dictionary\n",
    "    return {\n",
    "        'accuracy': accuracy,\n",
    "        'f1_score': f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc2b387-1e6d-49a3-b28f-23f59c1c789c",
   "metadata": {},
   "source": [
    "## Classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f934d79-76c2-4d0e-81f4-486a1dc34d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoMaskGCNConv(GCNConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard mask\n",
    "        return super().call(inputs, mask=None)\n",
    "        \n",
    "class GCN(tf.keras.Model):\n",
    "    def __init__(self, n_labels, seed=40):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "        self.conv1 = NoMaskGCNConv(16, activation='relu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGCNConv(n_labels, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c02975-3867-49ee-bfd3-f74cd6d8dbe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spektral.layers import GATConv\n",
    "import tensorflow as tf\n",
    "\n",
    "# Define a custom wrapper for GATConv that avoids mask issues\n",
    "class NoMaskGATConv(GATConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard the mask argument\n",
    "        return super().call(inputs, mask=None)\n",
    "\n",
    "\n",
    "# Define the GAT model using the NoMaskGATConv\n",
    "class GAT(tf.keras.Model):\n",
    "    def __init__(self, n_labels, num_heads=8, seed=40):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        # Use the custom NoMaskGATConv instead of the original GATConv\n",
    "        self.conv1 = NoMaskGATConv(16, attn_heads=num_heads, concat_heads=True, activation='elu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGATConv(n_labels, attn_heads=1, concat_heads=False, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1805e34a-bed9-410b-a719-49a9df09c377",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the GraphSAGE model\n",
    "class GraphSAGE(tf.keras.Model):\n",
    "    def __init__(self, n_labels, hidden_dim=16, aggregator='mean', seed=40):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        self.conv1 = GraphSageConv(hidden_dim, activation='relu', aggregator=aggregator, kernel_initializer=initializer)\n",
    "        self.conv2 = GraphSageConv(n_labels, activation='softmax', aggregator=aggregator, kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7378d826-fa06-4a7d-8891-f46888320d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifiers=['gcn','gat','graphsage']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad91add-83d2-4b32-a97b-a7b48487c7fe",
   "metadata": {},
   "source": [
    "## Classification using different node embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a2d104c-8a7a-42dc-8892-c479585380c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(\n",
    "    embedding_dict, embedding, classifier,\n",
    "    ground_truth_labels=ground_truth_labels,\n",
    "    masked_labels=masked_labels\n",
    "):\n",
    "    \"\"\"\n",
    "    Version that trains ONLY on the training subgraph:\n",
    "    - Uses X_train and A_train\n",
    "    - No test nodes seen during training\n",
    "    - Inference happens on full graph\n",
    "    \"\"\"\n",
    "\n",
    "    print(f\"\\nEmbedding: {embedding.upper()}\")\n",
    "    print(f\"Model: {classifier.upper()}\")\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 1. Construct training subgraph\n",
    "    # -------------------------------------------------\n",
    "    train_mask = masked_labels != -1\n",
    "    train_idx = np.where(train_mask)[0]\n",
    "\n",
    "    # Features for training nodes only\n",
    "    X_train = tf.convert_to_tensor(embedding_dict[embedding][train_mask], dtype=tf.float32)\n",
    "\n",
    "    # Labels for training (one-hot)\n",
    "    Y_train = tf.convert_to_tensor(ground_truth_labels[train_mask], dtype=tf.float32)\n",
    "\n",
    "    # Build training adjacency\n",
    "    A_train = A[train_mask][:, train_mask]   # induced subgraph\n",
    "\n",
    "    A_coo = A_train.tocoo()\n",
    "    A_train_tensor = tf.sparse.SparseTensor(\n",
    "        indices = np.column_stack([A_coo.row, A_coo.col]),\n",
    "        values  = A_coo.data.astype(np.float32),\n",
    "        dense_shape = A_coo.shape\n",
    "    )\n",
    "    A_train_tensor = tf.sparse.reorder(A_train_tensor)\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 2. Build classifier on TRAIN SUBGRAPH only\n",
    "    # -------------------------------------------------\n",
    "    n_classes = ground_truth_labels.shape[1]\n",
    "\n",
    "    if classifier == 'gcn':\n",
    "        model = GCN(n_classes)\n",
    "    elif classifier == 'gat':\n",
    "        model = GAT(n_classes)\n",
    "    elif classifier == 'graphsage':\n",
    "        model = GraphSAGE(n_classes)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown classifier: \" + classifier)\n",
    "\n",
    "    optimizer = Adam(learning_rate=0.01)\n",
    "    loss_fn = CategoricalCrossentropy()\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 3. Training (on training subgraph only)\n",
    "    # -------------------------------------------------\n",
    "    print(\"Training on TRAINING SUBGRAPH ONLY...\")\n",
    "\n",
    "    epochs = 200\n",
    "    for epoch in range(epochs):\n",
    "        with tf.GradientTape() as tape:\n",
    "\n",
    "            preds_train, _ = model([X_train, A_train_tensor], training=True)\n",
    "\n",
    "            loss = loss_fn(Y_train, preds_train)\n",
    "\n",
    "        grads = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            acc = CategoricalAccuracy()(Y_train, preds_train).numpy()\n",
    "            print(f\"Epoch {epoch} | Loss={loss.numpy():.4f} | Train Acc={acc:.4f}\")\n",
    "\n",
    "    # -------------------------------------------------\n",
    "    # 4. Inference — NOW use the full graph\n",
    "    # -------------------------------------------------\n",
    "    print(\"Predicting on FULL GRAPH...\")\n",
    "\n",
    "    X_full = tf.convert_to_tensor(embedding_dict[embedding], dtype=tf.float32)\n",
    "\n",
    "    A_full = A\n",
    "    A_coo = A_full.tocoo()\n",
    "    A_full_tensor = tf.sparse.SparseTensor(\n",
    "        indices = np.column_stack([A_coo.row, A_coo.col]),\n",
    "        values  = A_coo.data.astype(np.float32),\n",
    "        dense_shape = A_coo.shape\n",
    "    )\n",
    "    A_full_tensor = tf.sparse.reorder(A_full_tensor)\n",
    "\n",
    "    preds_all, emb_full = model([X_full, A_full_tensor], training=False)\n",
    "\n",
    "    predicted_labels = tf.argmax(preds_all, axis=1).numpy()\n",
    "\n",
    "    # Evaluate only on masked nodes (test nodes)\n",
    "    predicted_test = predicted_labels[labels_to_be_masked]\n",
    "    true_test = labels[labels_to_be_masked]\n",
    "\n",
    "    results = evaluate_model(true_test, predicted_test)\n",
    "\n",
    "    print(f\"Accuracy: {results['accuracy']*100:.2f}%\")\n",
    "    print(f\"F1 Score: {results['f1_score']:.4f}\")\n",
    "\n",
    "    results[\"embedding\"] = embedding\n",
    "    results[\"model\"] = classifier\n",
    "\n",
    "    return results, emb_full\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b95ec7-e23c-4d6e-8157-a2264ef0c7a8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_results=[]\n",
    "graph_embeddings_dict={}\n",
    "for emb in embedding_dict.keys():\n",
    "    for clf in classifiers:\n",
    "        results, embedding_matrix = train_and_evaluate(embedding_dict, emb, clf)\n",
    "        all_results.append(results)\n",
    "        key_string= emb + ' with ' + clf\n",
    "        graph_embeddings_dict[key_string]=embedding_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17527c3b-ef18-4b1a-abb6-48bcea7dc4a3",
   "metadata": {},
   "source": [
    "## Saving aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa504238-b2df-4efd-b68c-90532d749d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to DataFrame\n",
    "df = pd.DataFrame(all_results)\n",
    "\n",
    "# Define dataset name and seed\n",
    "dataset_name = \"pubmed\"\n",
    "seed_value = SEED\n",
    "\n",
    "# Save as CSV file without sorting\n",
    "filename = f\"{dataset_name}_70_30_{SEED}_results.csv\"\n",
    "filename='./pubmed_analysis_results/results/70_30/'+filename\n",
    "df.to_csv(filename, index=False)\n",
    "\n",
    "print(f\"Results saved as {filename}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9875500-5fdf-4d56-9a05-5d5d3cbceec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_embeddings= embedding_dict | graph_embeddings_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "602c48bd-1f98-47c9-993a-10ec0ec2966d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reorder_dict(original_dict, key_order):\n",
    "    \"\"\"\n",
    "    Reorders a dictionary based on a given list of keys.\n",
    "\n",
    "    Parameters:\n",
    "    - original_dict (dict): The dictionary to reorder.\n",
    "    - key_order (list): The list specifying the desired key order.\n",
    "\n",
    "    Returns:\n",
    "    - dict: A new dictionary with keys ordered as per key_order.\n",
    "    \"\"\"\n",
    "    return {key: original_dict[key] for key in key_order if key in original_dict}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0939b3-7a5a-4149-9130-86f5d6ca5141",
   "metadata": {},
   "outputs": [],
   "source": [
    "key_order = ['random', 'random with gcn', 'random with gat', 'random with graphsage', 'deepwalk', 'deepwalk with gcn', 'deepwalk with gat', 'deepwalk with graphsage', 'node2vec','node2vec with gcn', 'node2vec with gat', 'node2vec with graphsage', 'vgae', 'vgae with gcn', 'vgae with gat', 'vgae with graphsage', 'dgi', 'dgi with gcn', 'dgi with gat', 'dgi with graphsage', 'fuse', 'fuse with gcn', 'fuse with gat', 'fuse with graphsage', 'given', 'given with gcn', 'given with gat', 'given with graphsage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9547f62-6711-4b3b-afd6-133849783fd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "pretty_names = ['Random', 'Random with GCN', 'Random with GAT', 'Random with SAGE', 'Deepwalk', 'Deepwalk with GCN', 'Deepwalk with GAT', 'Deepwalk with SAGE', 'Node2Vec','Node2Vec with GCN', 'Node2Vec with GAT', 'Node2Vec with SAGE', 'VGAE', 'VGAE with GCN', 'VGAE with GAT', 'VGAE with SAGE', 'DGI', 'DGI with GCN', 'DGI with GAT', 'DGI with SAGE', 'FUSE', 'FUSE with GCN', 'FUSE with GAT', 'FUSE with SAGE', 'Given', 'Given with GCN', 'Given with GAT', 'Given with SAGE']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3d8bb5d-4a89-42c0-9dbf-0425d3e66166",
   "metadata": {},
   "outputs": [],
   "source": [
    "pretty_map = dict(zip(key_order, pretty_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "267d6703-3449-4e03-a5ed-105a5eba4eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_embeddings = reorder_dict(all_embeddings, key_order)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9db961-2936-4f72-b644-8282bf37c445",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_all_embeddings(all_embeddings, labels, label_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd3c4b57-1af7-49e2-b092-620517ffaf88",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import (\n",
    "    silhouette_score,\n",
    "    davies_bouldin_score,\n",
    "    normalized_mutual_info_score,\n",
    "    adjusted_rand_score,\n",
    "    v_measure_score\n",
    ")\n",
    "import pandas as pd\n",
    "# from benchmarking_utils import load_dataset  # or your label loader\n",
    "\n",
    "# Load true labels\n",
    "# ds = load_dataset(\"cora\")\n",
    "# labels = ds[\"labels\"]\n",
    "ds=dataset\n",
    "# Choose which embedding sources to evaluate\n",
    "sources = {\n",
    "    \"Raw Embeddings\": embedding_dict,\n",
    "    \"GNN Embeddings\": graph_embeddings_dict  # << the key fix\n",
    "}\n",
    "\n",
    "metrics_results = []\n",
    "\n",
    "for source_name, emb_source in sources.items():\n",
    "    print(f\"\\n Evaluating {source_name}...\\n\")\n",
    "\n",
    "    for name, emb_tensor in emb_source.items():\n",
    "        emb = emb_tensor.numpy() if isinstance(emb_tensor, tf.Tensor) else np.array(emb_tensor)\n",
    "\n",
    "        n_clusters = len(np.unique(labels))\n",
    "        kmeans = KMeans(n_clusters=n_clusters, random_state=SEED, n_init=10)\n",
    "        cluster_labels = kmeans.fit_predict(emb)\n",
    "\n",
    "        try:\n",
    "            silhouette = silhouette_score(emb, cluster_labels)\n",
    "        except Exception:\n",
    "            silhouette = np.nan\n",
    "        try:\n",
    "            db_index = davies_bouldin_score(emb, cluster_labels)\n",
    "        except Exception:\n",
    "            db_index = np.nan\n",
    "\n",
    "        nmi = normalized_mutual_info_score(labels, cluster_labels)\n",
    "        ari = adjusted_rand_score(labels, cluster_labels)\n",
    "        v_measure = v_measure_score(labels, cluster_labels)\n",
    "\n",
    "        metrics_results.append({\n",
    "            \"Source\": source_name,\n",
    "            \"Embedding\": name,\n",
    "            \"Silhouette\": silhouette,\n",
    "            \"Davies-Bouldin\": db_index,\n",
    "            \"NMI\": nmi,\n",
    "            \"ARI\": ari,\n",
    "            \"V-Measure\": v_measure\n",
    "        })\n",
    "\n",
    "        print(f\" {name.upper()} | Silhouette={silhouette:.4f}, DB={db_index:.4f}, NMI={nmi:.4f}, ARI={ari:.4f}, V={v_measure:.4f}\")\n",
    "\n",
    "# Save results\n",
    "metrics_df = pd.DataFrame(metrics_results)\n",
    "metrics_path = os.path.join(save_dir, f\"clustering_metrics_combined_{dataset_name}_{split_folder.replace('-', '_')}_{SEED}.csv\")\n",
    "metrics_df.to_csv(metrics_path, index=False)\n",
    "print(f\"\\n Combined clustering metrics saved to: {metrics_path}\")\n",
    "print(metrics_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b064e5fa-8b3e-4423-8773-7d63112ffd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "# ----------------------------------------\n",
    "# Configuration\n",
    "# ----------------------------------------\n",
    "dataset_name = \"pubmed\"\n",
    "split_folder = \"30_70\"\n",
    "seeds = [42, 46, 123, 999, 2025]\n",
    "base_dir = f\"./pubmed_analysis_results/embeddings/{split_folder}/\"\n",
    "\n",
    "# Output file\n",
    "output_path = os.path.join(base_dir, f\"clustering_metrics_combined_{dataset_name}_{split_folder}_averaged.csv\")\n",
    "\n",
    "# ----------------------------------------\n",
    "# Load and combine all CSVs\n",
    "# ----------------------------------------\n",
    "dfs = []\n",
    "for seed in seeds:\n",
    "    filename = f\"clustering_metrics_combined_{dataset_name}_{split_folder}_{seed}.csv\"\n",
    "    filepath = os.path.join(base_dir, filename)\n",
    "    \n",
    "    if os.path.exists(filepath):\n",
    "        df = pd.read_csv(filepath)\n",
    "        df[\"Seed\"] = seed\n",
    "        dfs.append(df)\n",
    "        print(f\" Loaded: {filepath}\")\n",
    "    else:\n",
    "        print(f\" Missing file for seed {seed}: {filepath}\")\n",
    "\n",
    "if not dfs:\n",
    "    raise FileNotFoundError(\"No CSV files found — check your directory and filenames.\")\n",
    "\n",
    "# Merge all seed data\n",
    "combined_df = pd.concat(dfs, ignore_index=True)\n",
    "\n",
    "# ----------------------------------------\n",
    "# Average metrics across seeds\n",
    "# ----------------------------------------\n",
    "agg_df = (\n",
    "    combined_df\n",
    "    .groupby([\"Source\", \"Embedding\"])\n",
    "    .agg({\n",
    "        \"Silhouette\": [\"mean\", \"std\"],\n",
    "        \"Davies-Bouldin\": [\"mean\", \"std\"],\n",
    "        \"NMI\": [\"mean\", \"std\"],\n",
    "        \"ARI\": [\"mean\", \"std\"],\n",
    "        \"V-Measure\": [\"mean\", \"std\"]\n",
    "    })\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "# Flatten column names\n",
    "agg_df.columns = [\n",
    "    \"Source\", \"Embedding\",\n",
    "    \"Silhouette_Mean\", \"Silhouette_STD\",\n",
    "    \"DB_Mean\", \"DB_STD\",\n",
    "    \"NMI_Mean\", \"NMI_STD\",\n",
    "    \"ARI_Mean\", \"ARI_STD\",\n",
    "    \"VMeasure_Mean\", \"VMeasure_STD\"\n",
    "]\n",
    "\n",
    "# Save the averaged results\n",
    "agg_df.to_csv(output_path, index=False)\n",
    "print(f\"\\n Averaged metrics saved to: {output_path}\\n\")\n",
    "\n",
    "# Display summary\n",
    "print(agg_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ab1739c-4e36-44e4-b8b4-b389395f3884",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
