{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c1ef1f-2372-44c0-9703-b8ac660cce10",
   "metadata": {},
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4441684d-fa5d-4f8d-9967-84753fd5a4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import time \n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import umap.umap_ as umap\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\n",
    "from node2vec import Node2Vec\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.metrics import CategoricalAccuracy\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import spektral\n",
    "from spektral.layers import GCNConv, GATConv\n",
    "from spektral.layers import GraphSageConv\n",
    "from spektral.data import Graph, Dataset, BatchLoader\n",
    "from scipy.sparse import csr_matrix\n",
    "from spektral.datasets import Cora\n",
    "from torch_geometric.nn import DeepGraphInfomax, VGAE\n",
    "from torch_geometric.utils import from_networkx\n",
    "import scipy.sparse as sp\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
    "from scipy.sparse.csgraph import laplacian\n",
    "from scipy.sparse.linalg import eigsh\n",
    "from collections import Counter\n",
    "from sklearn.preprocessing import normalize\n",
    "from joblib import Parallel, delayed\n",
    "from torch_geometric.nn import GCNConv as PyG_GCNConv, VGAE as PyG_VGAE\n",
    "from torch_geometric.data import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e18457-6dbb-4928-adad-4b279400977c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 46\n",
    "\n",
    "# Set seed for Python's built-in random module\n",
    "random.seed(SEED)\n",
    "\n",
    "# Set seed for NumPy\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# Set seed for TensorFlow\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Set seed for PyTorch\n",
    "torch.manual_seed(SEED)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b95802d-543f-4b3e-803a-cdbdb518595b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "from scipy.sparse import lil_matrix\n",
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class PubMedDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Planetoid(root=\".\", name=\"PubMed\")  # Load PubMed dataset\n",
    "        data = dataset[0]  # Access the first graph\n",
    "        \n",
    "        # Convert Torch tensors to NumPy\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1  # Number of classes\n",
    "        y_one_hot = np.eye(num_classes)[y]  # One-hot encoding\n",
    "\n",
    "        # Convert edge_index to a sparse adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = lil_matrix((num_nodes, num_nodes), dtype=np.float32)\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected graph\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90c18215-d57c-47a9-957d-5f98fa330e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dimensionality=150"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37b67b75-8102-41b3-b7d9-faac9a95ca1d",
   "metadata": {},
   "source": [
    "## Extracting modularity embedding and using it for classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e46cf19-a136-404f-9c2a-a2d8404b874f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Laplacian Eigenmaps Embedding\n",
    "def deepwalk_embedding(G, k=2, walk_length=10, num_walks=80, workers=4):\n",
    "    node2vec = Node2Vec(G, dimensions=k, walk_length=walk_length, num_walks=num_walks, workers=workers)\n",
    "    model = node2vec.fit(window=10, min_count=1, batch_words=4)\n",
    "    return np.array([model.wv[str(node)] for node in G.nodes()])\n",
    "\n",
    "# Node2Vec Embedding\n",
    "def node2vec_embedding(G, k=2, seed=SEED):\n",
    "    node2vec = Node2Vec(G, dimensions=k, walk_length=10, num_walks=100, workers=2, seed=seed)\n",
    "    model = node2vec.fit(window=10, min_count=1, batch_words=4)\n",
    "    return np.array([model.wv[str(node)] for node in G.nodes()])\n",
    "\n",
    "\n",
    "# VGAE Embedding \n",
    "class VGAEEncoder(torch.nn.Module):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super().__init__()\n",
    "        self.conv1 = PyG_GCNConv(in_channels, 2 * out_channels)  # Use PyG_GCNConv\n",
    "        self.conv_mu = PyG_GCNConv(2 * out_channels, out_channels)  # Separate layer for mu\n",
    "        self.conv_logstd = PyG_GCNConv(2 * out_channels, out_channels)  # Separate layer for logstd\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = torch.relu(self.conv1(x, edge_index))\n",
    "        mu = self.conv_mu(x, edge_index)\n",
    "        logstd = self.conv_logstd(x, edge_index)\n",
    "        return mu, logstd\n",
    "\n",
    "def vgae_embedding(data, k=128):\n",
    "    # Use one-hot encoded node IDs as features\n",
    "    num_nodes = data.num_nodes\n",
    "    x = torch.eye(num_nodes)\n",
    "\n",
    "    in_channels = x.shape[1]  # Feature dimension is equal to the number of nodes\n",
    "    model = PyG_VGAE(VGAEEncoder(in_channels, k))\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "    \n",
    "    for _ in tqdm(range(200)):\n",
    "        optimizer.zero_grad()\n",
    "        z = model.encode(x, data.edge_index)  # Use one-hot encoded features\n",
    "        loss = model.recon_loss(z, data.edge_index) + (1 / data.num_nodes) * model.kl_loss()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    return model.encode(x, data.edge_index).detach().numpy()\n",
    "\n",
    "# DGI Embedding\n",
    "def dgi_embedding(data, k=128):\n",
    "    class GCNEncoder(torch.nn.Module):\n",
    "        def __init__(self, in_channels, out_channels):\n",
    "            super().__init__()\n",
    "            self.conv1 = PyG_GCNConv(in_channels, 2 * out_channels)\n",
    "            self.conv2 = PyG_GCNConv(2 * out_channels, out_channels)\n",
    "\n",
    "        def forward(self, x, edge_index):\n",
    "            x = torch.relu(self.conv1(x, edge_index))\n",
    "            return self.conv2(x, edge_index)\n",
    "\n",
    "    # Initialize node features as random Gaussian matrix\n",
    "    num_nodes = data.num_nodes\n",
    "    x = torch.tensor(np.random.randn(num_nodes, k), dtype=torch.float)  \n",
    "\n",
    "    in_channels = x.shape[1]  # Now equal to k\n",
    "    model = DeepGraphInfomax(\n",
    "        hidden_channels=k,\n",
    "        encoder=GCNEncoder(in_channels, k),\n",
    "        summary=lambda z, *args, **kwargs: z.mean(dim=0),\n",
    "        corruption=lambda x, edge_index: (x[torch.randperm(x.size(0))], edge_index)\n",
    "    )\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "    for _ in tqdm(range(200)):\n",
    "        optimizer.zero_grad()\n",
    "        pos_z, neg_z, summary = model(x, data.edge_index)\n",
    "        loss = model.loss(pos_z, neg_z, summary)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    return pos_z.detach().numpy()\n",
    "\n",
    "# Unsupervised gradient ascent for modularity maximization\n",
    "def gradient_ascent_modularity_unsupervised(G, k=2, eta=0.01, iterations=1000, seed=SEED):\n",
    "    np.random.seed(seed)  # Ensure deterministic initialization\n",
    "\n",
    "    A = nx.to_numpy_array(G)\n",
    "    l = A.sum(axis=1)\n",
    "    m = np.sum(l) / 2\n",
    "    B = A - np.outer(l, l) / (2 * m)\n",
    "    n = B.shape[0]\n",
    "\n",
    "    S = np.random.randn(n, k)  # Random Initialization\n",
    "    S, _ = np.linalg.qr(S)  # Ensure initial orthonormality\n",
    "\n",
    "    for i in tqdm(range(iterations), desc=\"Gradient Ascent Progress\"):\n",
    "        grad = (1 / (2 * m)) * B @ S\n",
    "        S += eta * grad\n",
    "        S, _ = np.linalg.qr(S)  # Orthonormalize using QR decomposition\n",
    "\n",
    "    return S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08173103-6488-4269-b6a4-517a08cd583c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unsupervised gradient ascent for modularity maximization\n",
    "def gradient_ascent_modularity_unsupervised(G, k=2, eta=0.01, iterations=1000, seed=SEED):\n",
    "    np.random.seed(seed)  # Ensure deterministic initialization\n",
    "\n",
    "    A = nx.to_numpy_array(G)\n",
    "    l = A.sum(axis=1)\n",
    "    m = np.sum(l) / 2\n",
    "    B = A - np.outer(l, l) / (2 * m)\n",
    "    n = B.shape[0]\n",
    "\n",
    "    S = np.random.randn(n, k)  # Random Initialization\n",
    "    S, _ = np.linalg.qr(S)  # Ensure initial orthonormality\n",
    "\n",
    "    for i in tqdm(range(iterations), desc=\"Gradient Ascent Progress\"):\n",
    "        grad = (1 / (2 * m)) * B @ S\n",
    "        S += eta * grad\n",
    "        S, _ = np.linalg.qr(S)  # Orthonormalize using QR decomposition\n",
    "\n",
    "    return S\n",
    "    \n",
    "def perform_labeled_random_walks(G, label_mask, labels, num_walks, walk_length, walk_length_labelled=3):\n",
    "    walks = {node: [] for node in G.nodes()}\n",
    "    for node in G.nodes():\n",
    "        for _ in range(num_walks):\n",
    "            walk = [node]\n",
    "            labeled_count = 0\n",
    "            for _ in range(walk_length - 1):\n",
    "                cur = walk[-1]\n",
    "                neighbors = list(G.neighbors(cur))\n",
    "                if not neighbors:\n",
    "                    break\n",
    "                labeled_neighbors = [n for n in neighbors if label_mask[n]]\n",
    "                if labeled_neighbors and labeled_count < walk_length_labelled:\n",
    "                    next_node = random.choice(labeled_neighbors)\n",
    "                    labeled_count += 1\n",
    "                else:\n",
    "                    next_node = random.choice(neighbors)\n",
    "                walk.append(next_node)\n",
    "            walks[node].extend([n for n in walk if label_mask[n]])\n",
    "    return walks\n",
    "\n",
    "def compute_attention_weights(S, labeled_nodes):\n",
    "    weights = {}\n",
    "    for node, labeled in labeled_nodes.items():\n",
    "        if labeled:\n",
    "            similarities = {n: np.dot(S[node], S[n]) for n in labeled}\n",
    "            exp_sims = {n: np.exp(sim) for n, sim in similarities.items()}\n",
    "            total = sum(exp_sims.values())\n",
    "            weights[node] = {n: exp_sims[n] / total for n in labeled}\n",
    "    return weights\n",
    "\n",
    "def semi_supervised_gradient_ascent_modularity(G, labels, label_mask, k=2, eta=0.01, lambda_supervised=1.0, \n",
    "                                                      lambda_semi=2.0, iterations=5000, initialization='random',\n",
    "                                                      num_walks=10, walk_length=5, walk_length_labelled=3):\n",
    "    # Convert graph to sparse adjacency matrix\n",
    "    A = csr_matrix(nx.to_scipy_sparse_array(G, format='csr'))\n",
    "    degrees = np.array(A.sum(axis=1)).flatten()\n",
    "    m = G.number_of_edges()\n",
    "    n = A.shape[0]\n",
    "\n",
    "    # Initialize embeddings\n",
    "    if initialization == 'random':\n",
    "        S = np.random.randn(n, k)\n",
    "    S, _ = np.linalg.qr(S)\n",
    "\n",
    "    # Compute labeled random walks and attention weights\n",
    "    labeled_walks = perform_labeled_random_walks(G, label_mask, labels, num_walks, walk_length, walk_length_labelled)\n",
    "    attention_weights = compute_attention_weights(S, labeled_walks)\n",
    "\n",
    "    for _ in tqdm(range(iterations), desc=\"Gradient Ascent with Linear Modularity\"):\n",
    "        # Compute modularity gradient using linear approximation\n",
    "        neighbor_agg = A @ S  # Efficient aggregation of neighbor embeddings\n",
    "        global_correction = (degrees[:, None] / (2 * m)) * S.sum(axis=0)\n",
    "        grad_modularity = (1 / (2 * m)) * (neighbor_agg - global_correction)\n",
    "\n",
    "        # Compute supervised gradient\n",
    "        grad_supervised = np.zeros_like(S)\n",
    "        unique_labels = np.unique(labels[label_mask])\n",
    "        for label in unique_labels:\n",
    "            mask = (labels == label) & label_mask\n",
    "            mean_embedding = np.mean(S[mask], axis=0, keepdims=True)\n",
    "            grad_supervised[mask] = S[mask] - mean_embedding\n",
    "\n",
    "        # Compute semi-supervised gradient using adaptive attention\n",
    "        grad_semi_supervised = np.zeros_like(S)\n",
    "        for i in range(n):\n",
    "            if not label_mask[i] and i in attention_weights:\n",
    "                weighted_embedding = sum(weight * S[n] for n, weight in attention_weights[i].items())\n",
    "                grad_semi_supervised[i] = S[i] - weighted_embedding\n",
    "\n",
    "        # Update embeddings\n",
    "        grad_total = grad_modularity - lambda_supervised * grad_supervised - lambda_semi * grad_semi_supervised\n",
    "        S += eta * grad_total\n",
    "        S, _ = np.linalg.qr(S)\n",
    "\n",
    "    return S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d1fa1ef-526c-4ecb-a4e5-6c344f0412c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_networkx(A):\n",
    "    return nx.from_scipy_sparse_array(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fbba5b-3f7c-4821-9af5-d9d6e982b9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = PubMedDataset()\n",
    "ground_truth_labels = dataset[0].y\n",
    "labels=np.argmax(ground_truth_labels,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c92de7a7-bab8-4ee8-9601-e4f95f0386e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_to_be_masked=np.random.choice(np.arange(len(labels)),int(len(labels)*.7),replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0cf50e3-56e5-4187-b4ef-9f05cd3bc76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_labels=[]\n",
    "for i in np.arange(len(labels)):\n",
    "    if i in labels_to_be_masked:\n",
    "        masked_labels.append(-1)\n",
    "    else:\n",
    "        masked_labels.append(labels[i])\n",
    "masked_labels=np.array(masked_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77562c2a-8c15-4b37-a4d3-674776a0edf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_mask = masked_labels != -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "517544ea-62e6-495c-bce9-6733a551c959",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = dataset[0].x\n",
    "A = dataset[0].a\n",
    "G = convert_to_networkx(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17de4d41-86a4-40c3-bf4d-217511591a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Adjacency Matrix Shape:\", A.shape)\n",
    "print(\"Graph Nodes:\", G.number_of_nodes())\n",
    "print(\"Graph Edges:\", G.number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0346f826-0826-4b89-8328-054b3295fdc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert your preprocessed data into a PyTorch Geometric Data object\n",
    "X_py = Data(\n",
    "    x=torch.tensor(X, dtype=torch.float),  # Node features\n",
    "    edge_index=torch.tensor(np.array(A.nonzero()), dtype=torch.long),  # Edge indices\n",
    "    y=torch.tensor(labels, dtype=torch.long)  # Labels\n",
    ")\n",
    "\n",
    "# Ensure edge_index is in the correct shape (2, num_edges)\n",
    "X_py.edge_index = X_py.edge_index.to(torch.long)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81b4f8f4-20eb-4804-a2e7-b04c6e80d08d",
   "metadata": {},
   "source": [
    "## Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cada09-7de5-438e-b980-7d5e703615d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "# Dictionary for embeddings\n",
    "embedding_dict = {}\n",
    "execution_times = []  # List to store execution times\n",
    "\n",
    "def record_time(model_name, func, *args, **kwargs):\n",
    "    print(f\"Computing {model_name} embedding...\")\n",
    "    start_time = time.time()\n",
    "    result = func(*args, **kwargs)\n",
    "    end_time = time.time()\n",
    "    elapsed_time = end_time - start_time\n",
    "    print(f\"{model_name} embedding computed in {elapsed_time:.2f} seconds.\")\n",
    "    return result, elapsed_time\n",
    "\n",
    "def to_tf_tensor(x):\n",
    "    if isinstance(x, torch.Tensor):\n",
    "        x = x.detach().cpu().numpy()\n",
    "    return tf.convert_to_tensor(x, dtype=tf.float32)\n",
    "\n",
    "# DeepWalk\n",
    "X_deepwalk, elapsed = record_time(\"DeepWalk\", deepwalk_embedding, G, k=embedding_dimensionality)\n",
    "embedding_dict['deepwalk'] = to_tf_tensor(X_deepwalk)\n",
    "execution_times.append((\"DeepWalk\", elapsed))\n",
    "\n",
    "# VGAE\n",
    "X_vgae, elapsed = record_time(\"VGAE\", vgae_embedding, X_py, k=embedding_dimensionality)\n",
    "embedding_dict['vgae'] = to_tf_tensor(X_vgae)\n",
    "execution_times.append((\"VGAE\", elapsed))\n",
    "\n",
    "# DGI\n",
    "X_dgi, elapsed = record_time(\"DGI\", dgi_embedding, X_py, k=embedding_dimensionality)\n",
    "embedding_dict['dgi'] = to_tf_tensor(X_dgi)\n",
    "execution_times.append((\"DGI\", elapsed))\n",
    "\n",
    "X_modularity, elapsed = record_time(\"Modularity\", semi_supervised_gradient_ascent_modularity,\n",
    "                           G, labels, label_mask, k=embedding_dimensionality,\n",
    "                           eta=0.05, lambda_supervised=1.0, lambda_semi=2.0, iterations=200, initialization='random')\n",
    "embedding_dict['modularity'] = to_tf_tensor(X_modularity)\n",
    "execution_times.append((\"Modularity\", elapsed))\n",
    "\n",
    "# Node2Vec\n",
    "X_node2vec, elapsed = record_time(\"Node2Vec\", node2vec_embedding, G, k=embedding_dimensionality)\n",
    "embedding_dict['node2vec'] = to_tf_tensor(X_node2vec)\n",
    "execution_times.append((\"Node2Vec\", elapsed))\n",
    "\n",
    "# Random\n",
    "print(\"Generating Random embedding...\")\n",
    "start_time = time.time()\n",
    "shape = (len(ground_truth_labels), embedding_dimensionality)\n",
    "X_random = np.random.randn(*shape)\n",
    "embedding_dict['random'] = tf.convert_to_tensor(X_random, dtype=tf.float32)\n",
    "end_time = time.time()\n",
    "execution_times.append((\"Random\", end_time - start_time))\n",
    "print(f\"Random embedding generated in {end_time - start_time:.2f} seconds.\")\n",
    "\n",
    "# Given (original features)\n",
    "embedding_dict['given'] = to_tf_tensor(X)\n",
    "\n",
    "print(\"All embeddings computed and stored in the dictionary successfully.\")\n",
    "\n",
    "# Save execution times\n",
    "execution_df = pd.DataFrame(execution_times, columns=[\"Model\", \"Time (seconds)\"])\n",
    "filename = f\"./pubmed_analysis_results/embedding_execution_times_FUSE_pubmed_{SEED}.csv\"\n",
    "execution_df.to_csv(filename, index=False)\n",
    "print(f\"\\nExecution times saved to: {filename}\")\n",
    "print(execution_df)\n",
    "\n",
    "# Directory to save embeddings\n",
    "save_dir = f\"./pubmed_analysis_results/embeddings_pubmed_{SEED}\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "# Save embeddings as pickle files\n",
    "for name, emb in embedding_dict.items():\n",
    "    save_path = os.path.join(save_dir, f\"{name}_embedding.pkl\")\n",
    "    with open(save_path, \"wb\") as f:\n",
    "        pickle.dump(emb.numpy(), f)  # store as numpy array\n",
    "    print(f\"Saved {name} embedding to {save_path}\")\n",
    "\n",
    "print(\"\\n All embeddings saved as pickle files successfully.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a98e7ea-fd83-440b-a479-bed7c46185a4",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7297774-e5c4-4b87-9cc9-cab07deb7978",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_all_embeddings(all_embeddings, labels, label_mask):\n",
    "    \"\"\"\n",
    "    Visualize all embeddings in a grid with 4 columns per row using UMAP.\n",
    "\n",
    "    Parameters:\n",
    "    - all_embeddings: Dictionary where keys are embedding methods, and values are embeddings.\n",
    "    - labels: Labels (numpy array of shape [n_nodes]).\n",
    "    - label_mask: Boolean array indicating known labels (True for known, False for unknown).\n",
    "    \"\"\"\n",
    "    num_embeddings = len(all_embeddings)\n",
    "    num_rows = (num_embeddings + 3) // 4  # Ensure enough rows for all embeddings\n",
    "    fig, axes = plt.subplots(num_rows, 4, figsize=(8.27, 11.69))  # A4 size\n",
    "\n",
    "    for i, (embedding_type, embedding) in tqdm(enumerate(all_embeddings.items()), \n",
    "                                               total=num_embeddings, desc=\"Visualizing embeddings\"):\n",
    "        row, col = divmod(i, 4)\n",
    "        ax = axes[row, col] if num_rows > 1 else axes[col]  # Adjust for single-row case\n",
    "\n",
    "        # Ensure embedding is a NumPy array\n",
    "        if isinstance(embedding, tf.Tensor):\n",
    "            embedding = embedding.numpy()\n",
    "\n",
    "        # Reduce dimensionality using UMAP\n",
    "        reducer = umap.UMAP(n_components=2)\n",
    "        embedding_2d = reducer.fit_transform(embedding)\n",
    "\n",
    "        # Known labels\n",
    "        ax.scatter(embedding_2d[label_mask, 0], embedding_2d[label_mask, 1], \n",
    "                   c=labels[label_mask], cmap=\"Set1\", s=3, alpha=0.7, label=\"Known Labels\",\n",
    "                   edgecolors='none')\n",
    "\n",
    "        # Unknown labels\n",
    "        ax.scatter(embedding_2d[~label_mask, 0], embedding_2d[~label_mask, 1], \n",
    "                   c=labels[~label_mask], cmap=\"Set1\", s=5, alpha=0.7, \n",
    "                   label=\"Unknown Labels\", edgecolors='black', linewidths=0.2)\n",
    "\n",
    "        # Title with smaller font size\n",
    "        ax.set_title(embedding_type.upper(), fontsize=8, pad=2)\n",
    "\n",
    "        # Remove axis labels, ticks, and frames\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_frame_on(False)\n",
    "\n",
    "    # Remove empty subplots if num_embeddings is not a multiple of 4\n",
    "    for j in range(i + 1, num_rows * 4):\n",
    "        row, col = divmod(j, 4)\n",
    "        fig.delaxes(axes[row, col])\n",
    "\n",
    "    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=0.2, hspace=0.2)  # Adjust margins\n",
    "    save_path = \"./pubmed_analysis_results/embedding_grid_plot_pubmed.png\"\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Visualization saved to {save_path}\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73add6f8-ad97-49f5-8f66-88df3b717390",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(true_labels, predicted_labels):\n",
    "    \"\"\"\n",
    "    Evaluate the model's performance using accuracy, F1-score, and confusion matrix.\n",
    "\n",
    "    Args:\n",
    "        true_labels (np.array): Ground truth labels (integers).\n",
    "        predicted_labels (np.array): Predicted labels (integers).\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing accuracy, F1-score, and confusion matrix.\n",
    "    \"\"\"\n",
    "    # Compute accuracy\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    \n",
    "    # Compute F1-score (macro-averaged)\n",
    "    f1 = f1_score(true_labels, predicted_labels, average='macro')\n",
    "    \n",
    "    # Compute confusion matrix\n",
    "    cm = confusion_matrix(true_labels, predicted_labels)\n",
    "\n",
    "    #\n",
    "    print(cm)\n",
    "    \n",
    "    # Return results as a dictionary\n",
    "    return {\n",
    "        'accuracy': accuracy,\n",
    "        'f1_score': f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc2b387-1e6d-49a3-b28f-23f59c1c789c",
   "metadata": {},
   "source": [
    "## Classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f934d79-76c2-4d0e-81f4-486a1dc34d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoMaskGCNConv(GCNConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard mask\n",
    "        return super().call(inputs, mask=None)\n",
    "        \n",
    "class GCN(tf.keras.Model):\n",
    "    def __init__(self, n_labels, seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "        self.conv1 = NoMaskGCNConv(16, activation='relu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGCNConv(n_labels, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c02975-3867-49ee-bfd3-f74cd6d8dbe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spektral.layers import GATConv\n",
    "import tensorflow as tf\n",
    "\n",
    "# Define a custom wrapper for GATConv that avoids mask issues\n",
    "class NoMaskGATConv(GATConv):\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return None\n",
    "\n",
    "    def call(self, inputs, training=None, mask=None):\n",
    "        # Explicitly discard the mask argument\n",
    "        return super().call(inputs, mask=None)\n",
    "\n",
    "\n",
    "# Define the GAT model using the NoMaskGATConv\n",
    "class GAT(tf.keras.Model):\n",
    "    def __init__(self, n_labels, num_heads=8, seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        # Use the custom NoMaskGATConv instead of the original GATConv\n",
    "        self.conv1 = NoMaskGATConv(16, attn_heads=num_heads, concat_heads=True, activation='elu', kernel_initializer=initializer)\n",
    "        self.conv2 = NoMaskGATConv(n_labels, attn_heads=1, concat_heads=False, activation='softmax', kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1805e34a-bed9-410b-a719-49a9df09c377",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the GraphSAGE model\n",
    "class GraphSAGE(tf.keras.Model):\n",
    "    def __init__(self, n_labels, hidden_dim=16, aggregator='mean', seed=42):\n",
    "        super().__init__()\n",
    "        initializer = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "\n",
    "        self.conv1 = GraphSageConv(hidden_dim, activation='relu', aggregator=aggregator, kernel_initializer=initializer)\n",
    "        self.conv2 = GraphSageConv(n_labels, activation='softmax', aggregator=aggregator, kernel_initializer=initializer)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, a = inputs\n",
    "        intermediate_embeddings = self.conv1([x, a])  # Store intermediate embeddings\n",
    "        x = self.conv2([intermediate_embeddings, a])\n",
    "        return x, intermediate_embeddings  # Return both final output and intermediate embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7378d826-fa06-4a7d-8891-f46888320d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifiers=['gcn','gat','graphsage']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad91add-83d2-4b32-a97b-a7b48487c7fe",
   "metadata": {},
   "source": [
    "## Classification using different node embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f36b9cb-be80-465d-852e-ded40e6f7eaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(embedding_dict, embedding, classifier, ground_truth_labels=ground_truth_labels, masked_labels=masked_labels):\n",
    "    \"the labels have to be one hot encoded\"\n",
    "    \"model take values: gcn, gat, graphsage\"\n",
    "    print('embedding: ' + embedding.upper())\n",
    "    print('model: ' + classifier.upper())\n",
    "\n",
    "    X = embedding_dict[embedding]\n",
    "\n",
    "    print(\"Processing...\")\n",
    "    # Create boolean mask for training\n",
    "    train_mask = masked_labels != -1\n",
    "\n",
    "    # Split the data into training and prediction sets\n",
    "    X_train = X[train_mask]  # Training node features\n",
    "    Y_train = ground_truth_labels[train_mask]  # Training labels (one-hot encoded)\n",
    "    Y_train = tf.cast(Y_train, dtype='int32')\n",
    "    \n",
    "    # Reduce the adjacency matrix to only include training nodes\n",
    "    A_train = A[train_mask, :][:, train_mask]  # Correctly reduce the adjacency matrix\n",
    "    \n",
    "    # Convert sparse adjacency matrix to COO format\n",
    "    A_coo = A_train.tocoo()\n",
    "    indices = np.column_stack((A_coo.row, A_coo.col))  # Corrected indices format\n",
    "    values = A_coo.data\n",
    "    shape = A_coo.shape  # Shape: (num_nodes, num_nodes)\n",
    "    \n",
    "    # Create a sparse tensor for the adjacency matrix\n",
    "    A_train_tensor = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=shape)\n",
    "    \n",
    "    # Ensure the sparse tensor is ordered correctly\n",
    "    A_train_tensor = tf.sparse.reorder(A_train_tensor)\n",
    "\n",
    "    print(\"Training...\")\n",
    "    # Initialize the model\n",
    "    if classifier == 'gcn':\n",
    "        n_labels = ground_truth_labels.shape[1]  # Number of classes\n",
    "        model = GCN(n_labels)\n",
    "    elif classifier == 'gat':\n",
    "        n_labels = ground_truth_labels.shape[1]  # Number of classes\n",
    "        model = GAT(n_labels)\n",
    "    elif classifier == 'graphsage':\n",
    "        n_labels = ground_truth_labels.shape[1]  # Number of classes\n",
    "        model = GraphSAGE(n_labels)\n",
    "    \n",
    "    # Compile the model (not strictly necessary when using GradientTape, but useful for metrics)\n",
    "    model.compile(\n",
    "        optimizer=Adam(learning_rate=0.01),\n",
    "        loss=CategoricalCrossentropy(),\n",
    "        metrics=[CategoricalAccuracy()]\n",
    "    )\n",
    "    \n",
    "    # Print shapes for debugging\n",
    "    print(f\"Shape of X_train: {X_train.shape}\")\n",
    "    print(f\"Shape of A_train_tensor: {A_train_tensor.shape}\")\n",
    "    print(f\"Shape of Y_train: {Y_train.shape}\")\n",
    "    \n",
    "    # Define the optimizer and loss function\n",
    "    optimizer = Adam(learning_rate=0.01)\n",
    "    loss_fn = CategoricalCrossentropy()\n",
    "    \n",
    "    # Training loop with GradientTape\n",
    "    epochs = 200\n",
    "    for epoch in range(epochs):\n",
    "        with tf.GradientTape() as tape:\n",
    "            # Forward pass\n",
    "            predictions, intermediate_embeddings = model([X_train, A_train_tensor])  # Unpack both outputs\n",
    "                \n",
    "            # Compute supervised loss (cross-entropy)\n",
    "            supervised_loss = loss_fn(Y_train, predictions)\n",
    "            \n",
    "        # Compute gradients\n",
    "        gradients = tape.gradient(supervised_loss, model.trainable_variables)\n",
    "        \n",
    "        # Update weights\n",
    "        optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "        \n",
    "        # Print loss and accuracy for monitoring\n",
    "        if epoch % 10 == 0:\n",
    "            accuracy = CategoricalAccuracy()(Y_train, predictions)\n",
    "            print(f\"Epoch {epoch + 1}, Loss: {supervised_loss.numpy()}, Accuracy: {accuracy.numpy()}\")\n",
    "\n",
    "    print(\"Predicting...\")\n",
    "    # Prepare the full graph for prediction\n",
    "    X_full = X  # Full node features\n",
    "    A_full = A  # Full adjacency matrix\n",
    "    \n",
    "    # Convert the full adjacency matrix to COO format\n",
    "    A_full_coo = A_full.tocoo()\n",
    "    indices_full = np.column_stack((A_full_coo.row, A_full_coo.col))\n",
    "    values_full = A_full_coo.data\n",
    "    shape_full = A_full_coo.shape\n",
    "    \n",
    "    # Create a sparse tensor for the full adjacency matrix\n",
    "    A_full_tensor = tf.sparse.SparseTensor(indices=indices_full, values=values_full, dense_shape=shape_full)\n",
    "    A_full_tensor = tf.sparse.reorder(A_full_tensor)\n",
    "    \n",
    "    # Make predictions for all nodes\n",
    "    predictions, emb = model([X_full, A_full_tensor])  # Shape: [num_nodes, n_labels]\n",
    "\n",
    "    # Convert predictions to class labels (integers)\n",
    "    predicted_labels = tf.argmax(predictions, axis=1).numpy()  # Shape: [num_nodes]\n",
    "    \n",
    "    # Extract predictions for the masked nodes\n",
    "    predicted_labels_masked = predicted_labels[labels_to_be_masked]\n",
    "\n",
    "    # True labels for the masked nodes\n",
    "    true_labels_masked = labels[labels_to_be_masked]\n",
    "    \n",
    "    # Predicted labels for the masked nodes\n",
    "    predicted_labels_masked = predicted_labels[labels_to_be_masked]\n",
    "    \n",
    "    # Evaluate the model's performance\n",
    "    results = evaluate_model(true_labels_masked, predicted_labels_masked)\n",
    "    \n",
    "    # Print the results\n",
    "    print(f\"Accuracy: {results['accuracy'] * 100:.2f}%\")\n",
    "    print(f\"F1-Score: {results['f1_score']:.4f}\")\n",
    "\n",
    "    results['model'] = classifier\n",
    "    results['embedding'] = embedding\n",
    "\n",
    "    # Return results and intermediate embeddings for visualization\n",
    "    return results, emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b95ec7-e23c-4d6e-8157-a2264ef0c7a8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_results=[]\n",
    "graph_embeddings_dict={}\n",
    "for emb in embedding_dict.keys():\n",
    "    for clf in classifiers:\n",
    "        results, embedding_matrix = train_and_evaluate(embedding_dict, emb, clf)\n",
    "        all_results.append(results)\n",
    "        key_string= emb + ' with ' + clf\n",
    "        graph_embeddings_dict[key_string]=embedding_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17527c3b-ef18-4b1a-abb6-48bcea7dc4a3",
   "metadata": {},
   "source": [
    "## Saving aggregate results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa504238-b2df-4efd-b68c-90532d749d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to DataFrame\n",
    "df = pd.DataFrame(all_results)\n",
    "\n",
    "# Define dataset name and seed\n",
    "dataset_name = \"pubmed\"\n",
    "seed_value = SEED\n",
    "\n",
    "# Save as CSV file without sorting\n",
    "filename = f\"{dataset_name}_seed{seed_value}_results.csv\"\n",
    "filename='./pubmed_analysis_results/'+filename\n",
    "df.to_csv(filename, index=False)\n",
    "\n",
    "print(f\"Results saved as {filename}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9875500-5fdf-4d56-9a05-5d5d3cbceec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_embeddings= embedding_dict | graph_embeddings_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "602c48bd-1f98-47c9-993a-10ec0ec2966d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reorder_dict(original_dict, key_order):\n",
    "    \"\"\"\n",
    "    Reorders a dictionary based on a given list of keys.\n",
    "\n",
    "    Parameters:\n",
    "    - original_dict (dict): The dictionary to reorder.\n",
    "    - key_order (list): The list specifying the desired key order.\n",
    "\n",
    "    Returns:\n",
    "    - dict: A new dictionary with keys ordered as per key_order.\n",
    "    \"\"\"\n",
    "    return {key: original_dict[key] for key in key_order if key in original_dict}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0939b3-7a5a-4149-9130-86f5d6ca5141",
   "metadata": {},
   "outputs": [],
   "source": [
    "key_order = ['random', 'random with gcn', 'random with gat', 'random with graphsage', 'deepwalk', 'deepwalk with gcn', 'deepwalk with gat', 'deepwalk with graphsage', 'node2vec','node2vec with gcn', 'node2vec with gat', 'node2vec with graphsage', 'vgae', 'vgae with gcn', 'vgae with gat', 'vgae with graphsage', 'dgi', 'dgi with gcn', 'dgi with gat', 'dgi with graphsage', 'modularity', 'modularity with gcn', 'modularity with gat', 'modularity with graphsage', 'given', 'given with gcn', 'given with gat', 'given with graphsage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "267d6703-3449-4e03-a5ed-105a5eba4eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_embeddings = reorder_dict(all_embeddings, key_order)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9db961-2936-4f72-b644-8282bf37c445",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_all_embeddings(all_embeddings, labels, label_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feadb072-eb24-4b56-95b7-9973fa46566c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import optuna\n",
    "\n",
    "def objective(trial):\n",
    "    # Hyperparameter search space\n",
    "    eta = trial.suggest_float(\"eta\", 0.01, 1.0, step=0.01)\n",
    "    lambda_supervised = trial.suggest_float(\"lambda_supervised\", 0.5, 2.5, step=0.1)\n",
    "    lambda_semi = trial.suggest_float(\"lambda_semi\", 0.5, 2.5, step=0.1)\n",
    "    num_walks = trial.suggest_int(\"num_walks\", 5, 20, step=1)        # r\n",
    "    walk_length = trial.suggest_int(\"walk_length\", 1, 10, step=1)    # L\n",
    "    walk_length_labelled = trial.suggest_int(\"walk_length_labelled\", 1, 10, step=1)  # L'\n",
    "\n",
    "    # Measure runtime\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Generate embeddings with given lambdas\n",
    "    S = semi_supervised_gradient_ascent_modularity(\n",
    "        G,\n",
    "        labels=labels,\n",
    "        label_mask=label_mask,\n",
    "        k=150,\n",
    "        eta=eta,\n",
    "        lambda_supervised=lambda_supervised,\n",
    "        lambda_semi=lambda_semi,\n",
    "        iterations=200,\n",
    "        num_walks=num_walks,\n",
    "        walk_length=walk_length,\n",
    "        walk_length_labelled=walk_length_labelled\n",
    "    )\n",
    "\n",
    "    # Update embeddings\n",
    "    embedding_dict[\"modularity\"] = S\n",
    "\n",
    "    # Only evaluate with GraphSAGE\n",
    "    results, _ = train_and_evaluate(embedding_dict, \"modularity\", \"graphsage\")\n",
    "\n",
    "    # Compute runtime\n",
    "    runtime = time.time() - start_time\n",
    "\n",
    "    # Save runtime in trial's user attrs\n",
    "    trial.set_user_attr(\"runtime\", runtime)\n",
    "\n",
    "    # Return accuracy to maximize\n",
    "    return results[\"accuracy\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e92ea2a7-2231-411a-b1aa-1e2409231cfe",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Create and run study\n",
    "study = optuna.create_study(direction=\"maximize\")\n",
    "study.optimize(objective, n_trials=50)   # adjust n_trials as needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61bb1969-8568-4162-b6c1-f62abc55232c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract best trial\n",
    "best_trial = study.best_trial\n",
    "print(\"Best Accuracy:\", best_trial.value)\n",
    "print(\"Best Params:\", best_trial.params)\n",
    "print(\"Time for Best Trial:\", best_trial.user_attrs[\"runtime\"], \"seconds\")\n",
    "# Find the minimum runtime trial while still keeping accuracy close to best\n",
    "accuracies = [(t.number, t.value, t.user_attrs[\"runtime\"]) for t in study.trials if t.value is not None]\n",
    "accuracies_sorted = sorted(accuracies, key=lambda x: x[2])  # sort by runtime\n",
    "\n",
    "# Example: keep accuracy within 95% of best\n",
    "best_acc = best_trial.value\n",
    "good_acc_threshold = 0.95 * best_acc\n",
    "fastest_good = [a for a in accuracies_sorted if a[1] >= good_acc_threshold]\n",
    "\n",
    "if fastest_good:\n",
    "    trial_id, acc, runtime = fastest_good[0]\n",
    "    print(f\"Fastest trial with ≥95% of best acc: Trial {trial_id}, Acc={acc:.4f}, Time={runtime:.2f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37a313d9-ed5d-410f-8621-e1a4c0405e07",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.io as pio\n",
    "pio.renderers.default = \"notebook\"   # keep for interactive notebook display\n",
    "\n",
    "import optuna\n",
    "from optuna.visualization import (\n",
    "    plot_optimization_history,\n",
    "    plot_parallel_coordinate,\n",
    "    plot_slice,\n",
    "    plot_contour,\n",
    "    plot_param_importances\n",
    ")\n",
    "\n",
    "# Create output folder for saving plots\n",
    "import os\n",
    "output_dir = \"./pubmed_analysis_results/optuna_plots\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# 1. Optimization History\n",
    "fig1 = plot_optimization_history(study)\n",
    "fig1.show()\n",
    "fig1.write_html(f\"{output_dir}/optimization_history.html\")\n",
    "fig1.write_image(f\"{output_dir}/optimization_history.png\")\n",
    "\n",
    "# 2. Parallel Coordinates Plot\n",
    "fig2 = plot_parallel_coordinate(study)\n",
    "fig2.show()\n",
    "fig2.write_html(f\"{output_dir}/parallel_coordinates.html\")\n",
    "fig2.write_image(f\"{output_dir}/parallel_coordinates.png\")\n",
    "\n",
    "# 3. Slice Plot\n",
    "fig3 = plot_slice(study)\n",
    "fig3.show()\n",
    "fig3.write_html(f\"{output_dir}/slice_plot.html\")\n",
    "fig3.write_image(f\"{output_dir}/slice_plot.png\")\n",
    "\n",
    "# 4. Contour Plot\n",
    "fig4 = plot_contour(study)\n",
    "fig4.show()\n",
    "fig4.write_html(f\"{output_dir}/contour_plot.html\")\n",
    "fig4.write_image(f\"{output_dir}/contour_plot.png\")\n",
    "\n",
    "# 5. Hyperparameter Importance\n",
    "fig5 = plot_param_importances(study)\n",
    "fig5.show()\n",
    "fig5.write_html(f\"{output_dir}/param_importances.html\")\n",
    "fig5.write_image(f\"{output_dir}/param_importances.png\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3235ab-a0d8-4f02-9828-2438bd21a68f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# Convert trials to a DataFrame\n",
    "df = study.trials_dataframe(attrs=(\"number\", \"value\", \"params\", \"datetime_start\", \"datetime_complete\"))\n",
    "print(df.columns)\n",
    "\n",
    "# Compute elapsed time (in seconds)\n",
    "df[\"elapsed_time\"] = (df[\"datetime_complete\"] - df[\"datetime_start\"]).dt.total_seconds()\n",
    "\n",
    "# Sort by best objective value (assuming higher is better)\n",
    "df_sorted = df.sort_values(by=\"value\", ascending=False)\n",
    "\n",
    "# Collect all parameter columns dynamically\n",
    "param_cols = [c for c in df.columns if c.startswith(\"params_\")]\n",
    "\n",
    "# Show top trials with number, value, elapsed time, and params\n",
    "cols_to_display = [\"number\", \"value\", \"elapsed_time\"] + param_cols\n",
    "print(df_sorted[cols_to_display])\n",
    "\n",
    "# Display in notebook as a nice table\n",
    "# 2nd and 3rd best trials\n",
    "best_2_and_3 = df_sorted.iloc[1:3][cols_to_display]\n",
    "best_2_and_3\n",
    "from IPython.display import display\n",
    "\n",
    "display(best_2_and_3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40fd63cf-0b16-4813-9336-5d1950e1fb5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get 2nd and 3rd best trials\n",
    "second_best = df_sorted.iloc[1]\n",
    "third_best = df_sorted.iloc[2]\n",
    "\n",
    "print(\"Second Best Trial:\")\n",
    "print(second_best[cols_to_display])\n",
    "\n",
    "print(\"\\nThird Best Trial:\")\n",
    "print(third_best[cols_to_display])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64a5553a-ee8f-43c5-810e-8bcd900f5665",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tf-gpu)",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
