{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c1ef1f-2372-44c0-9703-b8ac660cce10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4441684d-fa5d-4f8d-9967-84753fd5a4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import time \n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import umap.umap_ as umap\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\n",
    "from node2vec import Node2Vec\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.metrics import CategoricalAccuracy\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import spektral\n",
    "from spektral.layers import GCNConv, GATConv\n",
    "from spektral.layers import GraphSageConv\n",
    "from spektral.data import Graph, Dataset, BatchLoader\n",
    "from scipy.sparse import csr_matrix\n",
    "from spektral.datasets import Cora\n",
    "from torch_geometric.nn import DeepGraphInfomax, VGAE\n",
    "from torch_geometric.utils import from_networkx\n",
    "import scipy.sparse as sp\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
    "from scipy.sparse.csgraph import laplacian\n",
    "from scipy.sparse.linalg import eigsh\n",
    "from collections import Counter\n",
    "from sklearn.preprocessing import normalize\n",
    "from joblib import Parallel, delayed\n",
    "from torch_geometric.nn import GCNConv as PyG_GCNConv, VGAE as PyG_VGAE\n",
    "from torch_geometric.data import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e18457-6dbb-4928-adad-4b279400977c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 46\n",
    "\n",
    "# Set seed for Python's built-in random module\n",
    "random.seed(SEED)\n",
    "\n",
    "# Set seed for NumPy\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# Set seed for TensorFlow\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Set seed for PyTorch\n",
    "torch.manual_seed(SEED)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b95802d-543f-4b3e-803a-cdbdb518595b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class CiteSeerDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Planetoid(root=\".\", name=\"CiteSeer\")  # Load CiteSeer dataset\n",
    "        data = dataset[0]  # Access the first graph\n",
    "        \n",
    "        # Convert Torch tensors to NumPy\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1  # Number of classes\n",
    "        y_one_hot = np.eye(num_classes)[y]  # One-hot encoding\n",
    "\n",
    "        # Convert edge_index to a sparse adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = csr_matrix((num_nodes, num_nodes))  # Initialize sparse matrix\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected graph\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90c18215-d57c-47a9-957d-5f98fa330e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dimensionality=150"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37b67b75-8102-41b3-b7d9-faac9a95ca1d",
   "metadata": {},
   "source": [
    "## Extracting modularity embedding and using it for classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e46cf19-a136-404f-9c2a-a2d8404b874f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unsupervised gradient ascent for modularity maximization\n",
    "def gradient_ascent_modularity_unsupervised(G, k=2, eta=0.01, iterations=1000, seed=42):\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    A = nx.to_numpy_array(G)\n",
    "    l = np.array(A.sum(axis=1)).flatten()\n",
    "    m = np.sum(l) / 2\n",
    "    n = A.shape[0]\n",
    "\n",
    "    S = np.random.randn(n, k)\n",
    "    S, _ = np.linalg.qr(S)\n",
    "\n",
    "    modularity_scores = []\n",
    "\n",
    "    for i in tqdm(range(iterations), desc=\"Gradient Ascent Progress\"):\n",
    "        neighbor_agg = A @ S\n",
    "        global_correction = (l[:, None] / (2 * m)) * S.sum(axis=0)\n",
    "        grad_modularity = (1 / (2 * m)) * (neighbor_agg - global_correction)\n",
    "        S += eta * grad_modularity\n",
    "        S, _ = np.linalg.qr(S)\n",
    "\n",
    "        # ---- Modularity Score: Q = (1/2m) * Tr(S^T B S)\n",
    "        B_S = neighbor_agg - np.outer(l, l @ S) / (2 * m)\n",
    "        Q = (1 / (2 * m)) * np.trace(S.T @ B_S)\n",
    "        modularity_scores.append(Q)\n",
    "\n",
    "    return S, modularity_scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08173103-6488-4269-b6a4-517a08cd583c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_labeled_random_walks(G, label_mask, labels, num_walks, walk_length, walk_length_labelled=3):\n",
    "    walks = {node: [] for node in G.nodes()}\n",
    "    for node in G.nodes():\n",
    "        for _ in range(num_walks):\n",
    "            walk = [node]\n",
    "            labeled_count = 0\n",
    "            for _ in range(walk_length - 1):\n",
    "                cur = walk[-1]\n",
    "                neighbors = list(G.neighbors(cur))\n",
    "                if not neighbors:\n",
    "                    break\n",
    "                labeled_neighbors = [n for n in neighbors if label_mask[n]]\n",
    "                if labeled_neighbors and labeled_count < walk_length_labelled:\n",
    "                    next_node = random.choice(labeled_neighbors)\n",
    "                    labeled_count += 1\n",
    "                else:\n",
    "                    next_node = random.choice(neighbors)\n",
    "                walk.append(next_node)\n",
    "            walks[node].extend([n for n in walk if label_mask[n]])\n",
    "    return walks\n",
    "\n",
    "def compute_attention_weights(S, labeled_nodes):\n",
    "    weights = {}\n",
    "    for node, labeled in labeled_nodes.items():\n",
    "        if labeled:\n",
    "            similarities = {n: np.dot(S[node], S[n]) for n in labeled}\n",
    "            exp_sims = {n: np.exp(sim) for n, sim in similarities.items()}\n",
    "            total = sum(exp_sims.values())\n",
    "            weights[node] = {n: exp_sims[n] / total for n in labeled}\n",
    "    return weights\n",
    "\n",
    "def modularity_unsupervised(G, k=2, eta=0.01,  lambda_unsupervised=1e3, iterations=5000, initialization='random',\n",
    "                                                      num_walks=10, walk_length=5, walk_length_labelled=3):\n",
    "    # Convert graph to sparse adjacency matrix\n",
    "    A = csr_matrix(nx.to_scipy_sparse_array(G, format='csr'))\n",
    "    degrees = np.array(A.sum(axis=1)).flatten()\n",
    "    m = G.number_of_edges()\n",
    "    n = A.shape[0]\n",
    "\n",
    "    # Initialize embeddings\n",
    "    if initialization == 'random':\n",
    "        S = np.random.randn(n, k)\n",
    "    S, _ = np.linalg.qr(S)\n",
    "\n",
    "    for _ in tqdm(range(iterations), desc=\"Gradient Ascent with Linear Modularity\"):\n",
    "        # Compute modularity gradient using linear approximation\n",
    "        neighbor_agg = A @ S  # Efficient aggregation of neighbor embeddings\n",
    "        global_correction = (degrees[:, None] / (2 * m)) * S.sum(axis=0)\n",
    "        grad_modularity = (1 / (2 * m)) * (neighbor_agg - global_correction) * lambda_unsupervised\n",
    "\n",
    "        # Update embeddings\n",
    "        grad_total = lambda_unsupervised * grad_modularity\n",
    "        S += eta * grad_total\n",
    "        S, _ = np.linalg.qr(S)\n",
    "\n",
    "    return S\n",
    "\n",
    "\n",
    "\n",
    "def modularity_attention_unsupervised_semisupervised(G, labels, label_mask, k=2, eta=0.01,  lambda_unsupervised=1e3,\n",
    "                                                      lambda_semi=2.0, iterations=5000, initialization='random',\n",
    "                                                      num_walks=10, walk_length=5, walk_length_labelled=3):\n",
    "    # Convert graph to sparse adjacency matrix\n",
    "    A = csr_matrix(nx.to_scipy_sparse_array(G, format='csr'))\n",
    "    degrees = np.array(A.sum(axis=1)).flatten()\n",
    "    m = G.number_of_edges()\n",
    "    n = A.shape[0]\n",
    "\n",
    "    # Initialize embeddings\n",
    "    if initialization == 'random':\n",
    "        S = np.random.randn(n, k)\n",
    "    S, _ = np.linalg.qr(S)\n",
    "\n",
    "    # Compute labeled random walks and attention weights\n",
    "    labeled_walks = perform_labeled_random_walks(G, label_mask, labels, num_walks, walk_length, walk_length_labelled)\n",
    "    attention_weights = compute_attention_weights(S, labeled_walks)\n",
    "\n",
    "    for _ in tqdm(range(iterations), desc=\"Gradient Ascent with Linear Modularity\"):\n",
    "        # Compute modularity gradient using linear approximation\n",
    "        neighbor_agg = A @ S  # Efficient aggregation of neighbor embeddings\n",
    "        global_correction = (degrees[:, None] / (2 * m)) * S.sum(axis=0)\n",
    "        grad_modularity = (1 / (2 * m)) * (neighbor_agg - global_correction) * lambda_unsupervised\n",
    "\n",
    "        # Compute semi-supervised gradient using adaptive attention\n",
    "        grad_semi_supervised = np.zeros_like(S)\n",
    "        for i in range(n):\n",
    "            if not label_mask[i] and i in attention_weights:\n",
    "                weighted_embedding = sum(weight * S[n] for n, weight in attention_weights[i].items())\n",
    "                grad_semi_supervised[i] = S[i] - weighted_embedding\n",
    "\n",
    "        # Update embeddings\n",
    "        grad_total = lambda_unsupervised * grad_modularity - lambda_semi * grad_semi_supervised\n",
    "        S += eta * grad_total\n",
    "        S, _ = np.linalg.qr(S)\n",
    "\n",
    "    return S\n",
    "\n",
    "\n",
    "def attention_semisupervised(G, labels, label_mask,  k=2, eta=0.01,\n",
    "                                                      lambda_semi=2.0, iterations=5000, initialization='random',\n",
    "                                                      num_walks=10, walk_length=5, walk_length_labelled=3):\n",
    "    # Convert graph to sparse adjacency matrix\n",
    "    A = csr_matrix(nx.to_scipy_sparse_array(G, format='csr'))\n",
    "    degrees = np.array(A.sum(axis=1)).flatten()\n",
    "    m = G.number_of_edges()\n",
    "    n = A.shape[0]\n",
    "\n",
    "    # Initialize embeddings\n",
    "    if initialization == 'random':\n",
    "        S = np.random.randn(n, k)\n",
    "    S, _ = np.linalg.qr(S)\n",
    "\n",
    "    # Compute labeled random walks and attention weights\n",
    "    labeled_walks = perform_labeled_random_walks(G, label_mask, labels, num_walks, walk_length, walk_length_labelled)\n",
    "    attention_weights = compute_attention_weights(S, labeled_walks)\n",
    "\n",
    "    for _ in tqdm(range(iterations), desc=\"Gradient Ascent without Linear Modularity\"):\n",
    "\n",
    "\n",
    "        # Compute semi-supervised gradient using adaptive attention\n",
    "        grad_semi_supervised = np.zeros_like(S)\n",
    "        for i in range(n):\n",
    "            if not label_mask[i] and i in attention_weights:\n",
    "                weighted_embedding = sum(weight * S[n] for n, weight in attention_weights[i].items())\n",
    "                grad_semi_supervised[i] = S[i] - weighted_embedding\n",
    "\n",
    "        # Update embeddings\n",
    "        grad_total = - lambda_semi * grad_semi_supervised\n",
    "        S += eta * grad_total\n",
    "        S, _ = np.linalg.qr(S)\n",
    "\n",
    "    return S\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d1fa1ef-526c-4ecb-a4e5-6c344f0412c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_networkx(A):\n",
    "    return nx.from_scipy_sparse_array(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fbba5b-3f7c-4821-9af5-d9d6e982b9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = CiteSeerDataset()\n",
    "ground_truth_labels = dataset[0].y\n",
    "labels=np.argmax(ground_truth_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c92de7a7-bab8-4ee8-9601-e4f95f0386e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#mask_file = \"/Users/sujan/Modularity based semi supervised learning/masks/CiteSeer/70_30/CiteSeer_70_30_masked_indices_seed46.npy\"\n",
    "mask_file = \"/Users/sujan/Modularity based semi supervised learning/masks/CiteSeer/30_70/CiteSeer_30_70_masked_indices_seed46.npy\"\n",
    "\n",
    "labels_to_be_masked = np.load(mask_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0cf50e3-56e5-4187-b4ef-9f05cd3bc76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_labels=[]\n",
    "for i in np.arange(len(labels)):\n",
    "    if i in labels_to_be_masked:\n",
    "        masked_labels.append(-1)\n",
    "    else:\n",
    "        masked_labels.append(labels[i])\n",
    "masked_labels=np.array(masked_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77562c2a-8c15-4b37-a4d3-674776a0edf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_mask = masked_labels != -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "517544ea-62e6-495c-bce9-6733a551c959",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = dataset[0].x\n",
    "A = dataset[0].a\n",
    "G = convert_to_networkx(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17de4d41-86a4-40c3-bf4d-217511591a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Adjacency Matrix Shape:\", A.shape)\n",
    "print(\"Graph Nodes:\", G.number_of_nodes())\n",
    "print(\"Graph Edges:\", G.number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0346f826-0826-4b89-8328-054b3295fdc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert your preprocessed data into a PyTorch Geometric Data object\n",
    "X_py = Data(\n",
    "    x=torch.tensor(X, dtype=torch.float),  # Node features\n",
    "    edge_index=torch.tensor(np.array(A.nonzero()), dtype=torch.long),  # Edge indices\n",
    "    y=torch.tensor(labels, dtype=torch.long)  # Labels\n",
    ")\n",
    "\n",
    "# Ensure edge_index is in the correct shape (2, num_edges)\n",
    "X_py.edge_index = X_py.edge_index.to(torch.long)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81b4f8f4-20eb-4804-a2e7-b04c6e80d08d",
   "metadata": {},
   "source": [
    "## Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0433eb15-9557-49a3-9390-4ca2836303fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import umap\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.colors import ListedColormap\n",
    "\n",
    "def plot_embedding_with_labels_umap(X, y, title=\"UMAP Embedding\"):\n",
    "    \"\"\"\n",
    "    Plots 2D UMAP embeddings using a custom 6-class color palette inspired by the uploaded modularity image.\n",
    "\n",
    "    Parameters:\n",
    "    - X: High-dimensional data (shape: [n_samples, n_features])\n",
    "    - y: Integer labels (shape: [n_samples], values 0 through 5)\n",
    "    - title: Title for the plot\n",
    "    \"\"\"\n",
    "    # Step 1: Reduce dimensionality to 2D using UMAP\n",
    "    reducer = umap.UMAP(n_components=2, random_state=42)\n",
    "    X_2d = reducer.fit_transform(X)\n",
    "\n",
    "    # Step 2: Custom 6-class color palette\n",
    "    custom_colors = [\n",
    "        \"#e41a1c\",  # Red\n",
    "        \"#377eb8\",  # Blue\n",
    "        \"#984ea3\",  # Purple\n",
    "        \"#ffff33\",  # Yellow\n",
    "        \"#f781bf\",  # Pink\n",
    "        \"#999999\"   # Gray\n",
    "    ]\n",
    "    cmap = ListedColormap(custom_colors)\n",
    "\n",
    "    # Step 3: Create the plot\n",
    "    plt.figure(figsize=(8, 8))\n",
    "    scatter = plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y, cmap=cmap,\n",
    "                          s=8, alpha=0.75, edgecolors='k', linewidths=0.2)\n",
    "    plt.colorbar(scatter, ticks=np.unique(y), label='Class')\n",
    "\n",
    "    plt.title(title)\n",
    "    plt.axis(\"off\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cada09-7de5-438e-b980-7d5e703615d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dictionary for embeddings\n",
    "embedding_dict = {}\n",
    "execution_times = []  # List to store execution times\n",
    "\n",
    "# Compute embeddings and store them with time tracking\n",
    "def record_time(model_name, func, *args, **kwargs):\n",
    "    print(f\"Computing {model_name} embedding...\")\n",
    "    start_time = time.time()\n",
    "    result = func(*args, **kwargs)\n",
    "    end_time = time.time()\n",
    "    elapsed_time = end_time - start_time\n",
    "    execution_times.append((model_name, elapsed_time))\n",
    "    print(f\"{model_name} embedding computed in {elapsed_time:.2f} seconds.\")\n",
    "    return result\n",
    "\n",
    "X_modularity_unsupervised = record_time(\"Modularity Unsupervised\", modularity_unsupervised,\n",
    "                           G, k=embedding_dimensionality,\n",
    "                           eta=0.05, lambda_unsupervised=1e5, iterations=200, initialization='random')\n",
    "embedding_dict['modularity unsupervised'] = X_modularity_unsupervised\n",
    "\n",
    "plot_embedding_with_labels_umap(X_modularity_unsupervised, labels, title=\"Modularity Unsupervised Embedding\")\n",
    "\n",
    "X_attention_semisupervised = record_time(\"Attention Semi-Supervised\", attention_semisupervised,\n",
    "                           G, labels, label_mask, k=embedding_dimensionality,\n",
    "                           eta=0.05, lambda_semi=2.0, iterations=200, initialization='random')\n",
    "embedding_dict['attention semisupervised'] = X_attention_semisupervised\n",
    "\n",
    "plot_embedding_with_labels_umap(X_attention_semisupervised, labels, title=\"Attention Semisupervised Embedding\")\n",
    "\n",
    "\n",
    "X_modularity_attention_unsupervised_semisupervised = record_time(\"Modularity Attention Un-Semi-Supervised\", modularity_attention_unsupervised_semisupervised,\n",
    "                           G, labels, label_mask, k=embedding_dimensionality,\n",
    "                           eta=0.05, lambda_semi=2.0, iterations=200, initialization='random')\n",
    "embedding_dict['attention unsemisupervised'] = X_modularity_attention_unsupervised_semisupervised\n",
    "\n",
    "plot_embedding_with_labels_umap(X_modularity_attention_unsupervised_semisupervised, labels, title=\"Modularity Attention Un-Semi-Supervised Embedding\")\n",
    "\n",
    "print(\"All embeddings computed and stored in the dictionary successfully.\")\n",
    "\n",
    "# Store execution times in a DataFrame and save\n",
    "execution_df = pd.DataFrame(execution_times, columns=[\"Model\", \"Time (seconds)\"])\n",
    "#execution_df.to_csv(\"/Users/sujan/Modularity based semi supervised learning/Ablation_study/CiteSeer/70_30/embedding_execution_times_citeseer_70_30_\"+str(SEED)+\".csv\", index=False)\n",
    "execution_df.to_csv(\"/Users/sujan/Modularity based semi supervised learning/Ablation_study/CiteSeer/30_70/embedding_execution_times_citeseer_30_70_\"+str(SEED)+\".csv\", index=False)\n",
    "\n",
    "print(\"\\nExecution times saved to 'embedding_execution_times.csv'.\")\n",
    "print(execution_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a98e7ea-fd83-440b-a479-bed7c46185a4",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7297774-e5c4-4b87-9cc9-cab07deb7978",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import umap\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "def visualize_all_embeddings(all_embeddings, labels, label_mask, key_order=None, save_path=\"./unsupervised_analysis_results/embedding_grid_plot_citeseer_unsupervised_seed46.png\"):\n",
    "    \"\"\"\n",
    "    Visualize all embeddings in a grid with 4 columns per row using UMAP.\n",
    "\n",
    "    Parameters:\n",
    "    - all_embeddings: Dictionary with embedding method names as keys and embeddings as values.\n",
    "    - labels: Ground truth labels (numpy array of shape [n_nodes]).\n",
    "    - label_mask: Boolean array (True for known labels, False for masked).\n",
    "    - key_order: Optional list to enforce embedding visualization order.\n",
    "    - save_path: Where to save the final figure.\n",
    "    \"\"\"\n",
    "    keys = key_order if key_order is not None else list(all_embeddings.keys())\n",
    "    num_embeddings = len(keys)\n",
    "    num_cols = 4\n",
    "    num_rows = (num_embeddings + num_cols - 1) // num_cols  # Ceiling division\n",
    "\n",
    "    fig, axes = plt.subplots(num_rows, num_cols, figsize=(11.69, 8.27))  # Landscape A4\n",
    "    axes = np.array(axes).reshape(num_rows, num_cols)  # Ensure 2D\n",
    "\n",
    "    for i, embedding_type in tqdm(enumerate(keys), total=num_embeddings, desc=\"Visualizing embeddings\"):\n",
    "        embedding = all_embeddings.get(embedding_type)\n",
    "        if embedding is None:\n",
    "            print(f\" Skipping missing embedding: {embedding_type}\")\n",
    "            continue\n",
    "\n",
    "        row, col = divmod(i, num_cols)\n",
    "        ax = axes[row, col]\n",
    "\n",
    "        if isinstance(embedding, tf.Tensor):\n",
    "            embedding = embedding.numpy()\n",
    "\n",
    "        reducer = umap.UMAP(n_components=2, random_state=42)\n",
    "        embedding_2d = reducer.fit_transform(embedding)\n",
    "\n",
    "        # Plot known labels\n",
    "        ax.scatter(embedding_2d[label_mask, 0], embedding_2d[label_mask, 1],\n",
    "                   c=labels[label_mask], cmap=\"Set1\", s=3, alpha=0.7,\n",
    "                   label=\"Known Labels\", edgecolors='none')\n",
    "\n",
    "        # Plot unknown labels\n",
    "        ax.scatter(embedding_2d[~label_mask, 0], embedding_2d[~label_mask, 1],\n",
    "                   c=labels[~label_mask], cmap=\"Set1\", s=5, alpha=0.6,\n",
    "                   label=\"Masked Labels\", edgecolors='black', linewidths=0.2)\n",
    "\n",
    "        ax.set_title(embedding_type.upper(), fontsize=8, pad=2)\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_frame_on(False)\n",
    "\n",
    "    # Remove unused subplots\n",
    "    for j in range(num_embeddings, num_rows * num_cols):\n",
    "        row, col = divmod(j, num_cols)\n",
    "        fig.delaxes(axes[row, col])\n",
    "\n",
    "    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05,\n",
    "                        wspace=0.2, hspace=0.2)\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\" Visualization saved to: {save_path}\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73add6f8-ad97-49f5-8f66-88df3b717390",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(true_labels, predicted_labels):\n",
    "    \"\"\"\n",
    "    Evaluate the model's performance using accuracy, F1-score, and confusion matrix.\n",
    "\n",
    "    Args:\n",
    "        true_labels (np.array): Ground truth labels (integers).\n",
    "        predicted_labels (np.array): Predicted labels (integers).\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing accuracy, F1-score, and confusion matrix.\n",
    "    \"\"\"\n",
    "    # Compute accuracy\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    \n",
    "    # Compute F1-score (macro-averaged)\n",
    "    f1 = f1_score(true_labels, predicted_labels, average='macro')\n",
    "    \n",
    "    # Compute confusion matrix\n",
    "    cm = confusion_matrix(true_labels, predicted_labels)\n",
    "\n",
    "    #\n",
    "    print(cm)\n",
    "    \n",
    "    # Return results as a dictionary\n",
    "    return {\n",
    "        'accuracy': accuracy,\n",
    "        'f1_score': f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc2b387-1e6d-49a3-b28f-23f59c1c789c",
   "metadata": {},
   "source": [
    "## Classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f934d79-76c2-4d0e-81f4-486a1dc34d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoMaskGCNConv(GCNConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard mask\n",
    "        return super().call(inputs, mask=None)\n",
    "        \n",
    "class GCN(tf.keras.Model):\n",
    "    def __init__(self, n_labels, seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "        self.conv1 = NoMaskGCNConv(16, activation='relu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGCNConv(n_labels, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c02975-3867-49ee-bfd3-f74cd6d8dbe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spektral.layers import GATConv\n",
    "import tensorflow as tf\n",
    "\n",
    "# Define a custom wrapper for GATConv that avoids mask issues\n",
    "class NoMaskGATConv(GATConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard the mask argument\n",
    "        return super().call(inputs, mask=None)\n",
    "\n",
    "\n",
    "# Define the GAT model using the NoMaskGATConv\n",
    "class GAT(tf.keras.Model):\n",
    "    def __init__(self, n_labels, num_heads=8, seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        # Use the custom NoMaskGATConv instead of the original GATConv\n",
    "        self.conv1 = NoMaskGATConv(16, attn_heads=num_heads, concat_heads=True, activation='elu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGATConv(n_labels, attn_heads=1, concat_heads=False, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1805e34a-bed9-410b-a719-49a9df09c377",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the GraphSAGE model\n",
    "class GraphSAGE(tf.keras.Model):\n",
    "    def __init__(self, n_labels, hidden_dim=16, aggregator='mean', seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        self.conv1 = GraphSageConv(hidden_dim, activation='relu', aggregator=aggregator, kernel_initializer=initializer)\n",
    "        self.conv2 = GraphSageConv(n_labels, activation='softmax', aggregator=aggregator, kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7378d826-fa06-4a7d-8891-f46888320d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifiers=['gcn','gat','graphsage']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad91add-83d2-4b32-a97b-a7b48487c7fe",
   "metadata": {},
   "source": [
    "## Classification using different node embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f36b9cb-be80-465d-852e-ded40e6f7eaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(embedding_dict, embedding, classifier, ground_truth_labels=ground_truth_labels, masked_labels=masked_labels):\n",
    "    \"the labels have to be one hot encoded\"\n",
    "    \"model take values: gcn, gat, graphsage\"\n",
    "    print('embedding: ' + embedding.upper())\n",
    "    print('model: ' + classifier.upper())\n",
    "\n",
    "    X = embedding_dict[embedding]\n",
    "\n",
    "    print(\"Processing...\")\n",
    "    # Create boolean mask for training\n",
    "    train_mask = masked_labels != -1\n",
    "\n",
    "    # Split the data into training and prediction sets\n",
    "    X_train = X[train_mask]  # Training node features\n",
    "    Y_train = ground_truth_labels[train_mask]  # Training labels (one-hot encoded)\n",
    "    Y_train = tf.cast(Y_train, dtype='int32')\n",
    "    \n",
    "    # Reduce the adjacency matrix to only include training nodes\n",
    "    A_train = A[train_mask, :][:, train_mask]  # Correctly reduce the adjacency matrix\n",
    "    \n",
    "    # Convert sparse adjacency matrix to COO format\n",
    "    A_coo = A_train.tocoo()\n",
    "    indices = np.column_stack((A_coo.row, A_coo.col))  # Corrected indices format\n",
    "    values = A_coo.data\n",
    "    shape = A_coo.shape  # Shape: (num_nodes, num_nodes)\n",
    "    \n",
    "    # Create a sparse tensor for the adjacency matrix\n",
    "    A_train_tensor = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=shape)\n",
    "    \n",
    "    # Ensure the sparse tensor is ordered correctly\n",
    "    A_train_tensor = tf.sparse.reorder(A_train_tensor)\n",
    "\n",
    "    print(\"Training...\")\n",
    "    # Initialize the model\n",
    "    if classifier == 'gcn':\n",
    "        n_labels = ground_truth_labels.shape[1]  # Number of classes\n",
    "        model = GCN(n_labels)\n",
    "    elif classifier == 'gat':\n",
    "        n_labels = ground_truth_labels.shape[1]  # Number of classes\n",
    "        model = GAT(n_labels)\n",
    "    elif classifier == 'graphsage':\n",
    "        n_labels = ground_truth_labels.shape[1]  # Number of classes\n",
    "        model = GraphSAGE(n_labels)\n",
    "    \n",
    "    # Compile the model (not strictly necessary when using GradientTape, but useful for metrics)\n",
    "    model.compile(\n",
    "        optimizer=Adam(learning_rate=0.01),\n",
    "        loss=CategoricalCrossentropy(),\n",
    "        metrics=[CategoricalAccuracy()]\n",
    "    )\n",
    "    \n",
    "    # Print shapes for debugging\n",
    "    print(f\"Shape of X_train: {X_train.shape}\")\n",
    "    print(f\"Shape of A_train_tensor: {A_train_tensor.shape}\")\n",
    "    print(f\"Shape of Y_train: {Y_train.shape}\")\n",
    "    \n",
    "    # Define the optimizer and loss function\n",
    "    optimizer = Adam(learning_rate=0.01)\n",
    "    loss_fn = CategoricalCrossentropy()\n",
    "    \n",
    "    # Training loop with GradientTape\n",
    "    epochs = 200\n",
    "    for epoch in range(epochs):\n",
    "        with tf.GradientTape() as tape:\n",
    "            # Forward pass\n",
    "            predictions, intermediate_embeddings = model([X_train, A_train_tensor])  # Unpack both outputs\n",
    "                \n",
    "            # Compute supervised loss (cross-entropy)\n",
    "            supervised_loss = loss_fn(Y_train, predictions)\n",
    "            \n",
    "        # Compute gradients\n",
    "        gradients = tape.gradient(supervised_loss, model.trainable_variables)\n",
    "        \n",
    "        # Update weights\n",
    "        optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "        \n",
    "        # Print loss and accuracy for monitoring\n",
    "        if epoch % 10 == 0:\n",
    "            accuracy = CategoricalAccuracy()(Y_train, predictions)\n",
    "            print(f\"Epoch {epoch + 1}, Loss: {supervised_loss.numpy()}, Accuracy: {accuracy.numpy()}\")\n",
    "\n",
    "    print(\"Predicting...\")\n",
    "    # Prepare the full graph for prediction\n",
    "    X_full = X  # Full node features\n",
    "    A_full = A  # Full adjacency matrix\n",
    "    \n",
    "    # Convert the full adjacency matrix to COO format\n",
    "    A_full_coo = A_full.tocoo()\n",
    "    indices_full = np.column_stack((A_full_coo.row, A_full_coo.col))\n",
    "    values_full = A_full_coo.data\n",
    "    shape_full = A_full_coo.shape\n",
    "    \n",
    "    # Create a sparse tensor for the full adjacency matrix\n",
    "    A_full_tensor = tf.sparse.SparseTensor(indices=indices_full, values=values_full, dense_shape=shape_full)\n",
    "    A_full_tensor = tf.sparse.reorder(A_full_tensor)\n",
    "    \n",
    "    # Make predictions for all nodes\n",
    "    predictions, emb = model([X_full, A_full_tensor])  # Shape: [num_nodes, n_labels]\n",
    "\n",
    "    # Convert predictions to class labels (integers)\n",
    "    predicted_labels = tf.argmax(predictions, axis=1).numpy()  # Shape: [num_nodes]\n",
    "    \n",
    "    # Extract predictions for the masked nodes\n",
    "    predicted_labels_masked = predicted_labels[labels_to_be_masked]\n",
    "\n",
    "    # True labels for the masked nodes\n",
    "    true_labels_masked = labels[labels_to_be_masked]\n",
    "    \n",
    "    # Predicted labels for the masked nodes\n",
    "    predicted_labels_masked = predicted_labels[labels_to_be_masked]\n",
    "    \n",
    "    # Evaluate the model's performance\n",
    "    results = evaluate_model(true_labels_masked, predicted_labels_masked)\n",
    "    \n",
    "    # Print the results\n",
    "    print(f\"Accuracy: {results['accuracy'] * 100:.2f}%\")\n",
    "    print(f\"F1-Score: {results['f1_score']:.4f}\")\n",
    "\n",
    "    results['model'] = classifier\n",
    "    results['embedding'] = embedding\n",
    "\n",
    "    # Return results and intermediate embeddings for visualization\n",
    "    return results, emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b95ec7-e23c-4d6e-8157-a2264ef0c7a8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_results=[]\n",
    "graph_embeddings_dict={}\n",
    "for emb in embedding_dict.keys():\n",
    "    for clf in classifiers:\n",
    "        results, embedding_matrix = train_and_evaluate(embedding_dict, emb, clf)\n",
    "        all_results.append(results)\n",
    "        key_string= emb + ' with ' + clf\n",
    "        graph_embeddings_dict[key_string]=embedding_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17527c3b-ef18-4b1a-abb6-48bcea7dc4a3",
   "metadata": {},
   "source": [
    "## Saving aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa504238-b2df-4efd-b68c-90532d749d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to DataFrame\n",
    "df = pd.DataFrame(all_results)\n",
    "\n",
    "# Define dataset name and seed\n",
    "dataset_name = \"citeseer\"\n",
    "seed_value = SEED\n",
    "\n",
    "# Save as CSV file without sorting\n",
    "#filename = f\"{dataset_name}_seed{seed_value}_results_ablation_70_30.csv\"\n",
    "#filename='/Users/sujan/Modularity based semi supervised learning/Ablation_study/CiteSeer/70_30/'+filename\n",
    "filename = f\"{dataset_name}_seed{seed_value}_results_ablation_30_70.csv\"\n",
    "filename='/Users/sujan/Modularity based semi supervised learning/Ablation_study/CiteSeer/30_70/'+filename\n",
    "df.to_csv(filename, index=False)\n",
    "\n",
    "print(f\"Results saved as {filename}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fbf8080-e7d6-428c-b80d-558dbc304180",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tf-gpu)",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
