{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5d3d40",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import from_networkx\n",
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import defaultdict\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from sklearn.ensemble import RandomForestClassifier\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d1fd27b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "file_path = base_path / \"groundtruth_graph_graphs_with_embedding.pickle\"\n",
    "with open(file_path, \"rb\") as f:\n",
    "    groundtruth_graph = pickle.load(f)\n",
    "    \n",
    "sample_id = list(groundtruth_graph.keys())[0]  \n",
    "G = groundtruth_graph[sample_id]\n",
    "\n",
    "print(f\"Sample ID: {sample_id}\")\n",
    "print(f\"Number of nodes: {G.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {G.number_of_edges()}\")\n",
    "\n",
    "for node in list(G.nodes())[:20]: \n",
    "    feature = G.nodes[node].get(\"feature\", None)\n",
    "    print(f\"Node ID: {node}\")\n",
    "    if feature is None:\n",
    "        print(\"  Feature: None\")\n",
    "    elif isinstance(feature, np.ndarray):\n",
    "        print(f\"  Feature shape: {feature.shape}\")\n",
    "        print(f\"  Feature sample (first 5 dims): {feature[:5]}\")\n",
    "    else:\n",
    "        print(f\"  Feature type: {type(feature)}\")\n",
    "        print(f\"  Feature preview: {feature}\")\n",
    "\n",
    "file_path = base_path / \"random_graph_graphs_with_embedding.pickle\"\n",
    "with open(file_path, \"rb\") as f:\n",
    "    random_graph = pickle.load(f)\n",
    "    \n",
    "sample_id = list(random_graph.keys())[0]  \n",
    "G = random_graph[sample_id]\n",
    "\n",
    "print(f\"Sample ID: {sample_id}\")\n",
    "print(f\"Number of nodes: {G.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {G.number_of_edges()}\")\n",
    "\n",
    "for node in list(G.nodes())[:20]: \n",
    "    feature = G.nodes[node].get(\"feature\", None)\n",
    "    print(f\"Node ID: {node}\")\n",
    "    if feature is None:\n",
    "        print(\"  Feature: None\")\n",
    "    elif isinstance(feature, np.ndarray):\n",
    "        print(f\"  Feature shape: {feature.shape}\")\n",
    "        print(f\"  Feature sample (first 5 dims): {feature[:5]}\")\n",
    "    else:\n",
    "        print(f\"  Feature type: {type(feature)}\")\n",
    "        print(f\"  Feature preview: {feature}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02d1336",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_data_list = []\n",
    "for sample_id, G in random_graph.items():\n",
    "    features = []\n",
    "    nodes = list(G.nodes())\n",
    "    for n in nodes:\n",
    "        feat = G.nodes[n].get(\"feature\", np.zeros(3072, dtype=np.float32))\n",
    "        features.append(torch.tensor(feat, dtype=torch.float32))\n",
    "    x = torch.stack(features)\n",
    "    H = G.copy()\n",
    "    for n in nodes:\n",
    "        del H.nodes[n]['feature']\n",
    "    data = from_networkx(H)\n",
    "    data.x = x\n",
    "    data.paper_id = sample_id\n",
    "    data.y = torch.tensor([1]) \n",
    "    data.graph_type = \"Random\" \n",
    "    random_data_list.append(data)\n",
    "    \n",
    "actual_data_list = []\n",
    "for sample_id, G in groundtruth_graph.items():\n",
    "    features = []\n",
    "    nodes = list(G.nodes())\n",
    "    for n in nodes:\n",
    "        feat = G.nodes[n].get(\"feature\", np.zeros(3072, dtype=np.float32))\n",
    "        features.append(torch.tensor(feat, dtype=torch.float32))\n",
    "    x = torch.stack(features)\n",
    "    H = G.copy()\n",
    "    for n in nodes:\n",
    "        del H.nodes[n]['feature']\n",
    "    data = from_networkx(H)\n",
    "    data.x = x\n",
    "    data.paper_id = sample_id\n",
    "    data.y = torch.tensor([0]) \n",
    "    data.graph_type = \"groundtruth\" \n",
    "    actual_data_list.append(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843cd50f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "all_graphs = actual_data_list + random_data_list\n",
    "for g in all_graphs:\n",
    "    g.graph_id = g.paper_id  # set graph_id for grouping\n",
    "\n",
    "graph_id_groups = defaultdict(list)\n",
    "for data in all_graphs:\n",
    "    graph_id_groups[data.graph_id].append(data)\n",
    "\n",
    "all_graph_ids = list(graph_id_groups.keys())\n",
    "random_seed = 40\n",
    "\n",
    "train_ids, temp_ids = train_test_split(\n",
    "    all_graph_ids,\n",
    "    test_size=0.3,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "val_ids, test_ids = train_test_split(\n",
    "    temp_ids,\n",
    "    test_size=0.5,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "train_dataset = collect_by_ids(train_ids)\n",
    "val_dataset   = collect_by_ids(val_ids)\n",
    "test_dataset  = collect_by_ids(test_ids)\n",
    "\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i              # new 0…N-1 index\n",
    "        if hasattr(data, 'graph_id'):     # remove leak source\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):   # remove type attribute\n",
    "            del data.graph_type\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e49c48",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def inspect_first_n(dataset_split, split_name, n=6):\n",
    "\n",
    "    print(f\"\\n--- First {n} samples from {split_name.upper()} ---\")\n",
    "    for i, data in enumerate(dataset_split[:n]):\n",
    "        print(f\"\\nSample #{i} in {split_name}:\")\n",
    "        # print attribute list\n",
    "        print(\"  Keys            :\", data.keys())\n",
    "        # print shapes\n",
    "        print(\"  x.shape         :\", tuple(data.x.shape))\n",
    "        print(\"  edge_index.shape:\", tuple(data.edge_index.shape))\n",
    "        print(\"  y (label)       :\", data.y.item())\n",
    "        print(\"  graph_index     :\", data.graph_index)\n",
    "        \n",
    "        num_nf = min(50, data.x.size(0))\n",
    "        print(f\"  Sample node features (first {num_nf} rows):\")\n",
    "        print(data.x[:num_nf])\n",
    "        \n",
    "        num_e = min(50, data.edge_index.size(1))\n",
    "        edges = data.edge_index[:, :num_e].t().tolist()\n",
    "        print(f\"  Sample edges (first {num_e}):\", edges)\n",
    "\n",
    "# Call for each split\n",
    "inspect_first_n(train_dataset, \"train\", n=6)\n",
    "inspect_first_n(val_dataset,   \"val\",   n=6)\n",
    "inspect_first_n(test_dataset,  \"test\",  n=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6696acf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class0_feats = []\n",
    "class1_feats = []\n",
    "\n",
    "for data in train_dataset:\n",
    "    feat_mean = data.x.mean(dim=0).cpu().numpy()\n",
    "    if data.y.item() == 0:\n",
    "        class0_feats.append(feat_mean)\n",
    "    else:\n",
    "        class1_feats.append(feat_mean)\n",
    "\n",
    "class0_feats = np.vstack(class0_feats)\n",
    "class1_feats = np.vstack(class1_feats)\n",
    "\n",
    "print(\"Class 0 mean:\", class0_feats.mean(axis=0))\n",
    "print(\"Class 1 mean:\", class1_feats.mean(axis=0))\n",
    "print(\"Class 0 std:\", class0_feats.std(axis=0))\n",
    "print(\"Class 1 std:\", class1_feats.std(axis=0))\n",
    "print(\"class0_feats shape:\", class0_feats.shape)\n",
    "print(\"class1_feats shape:\", class1_feats.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f60ac6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# Prepare data\n",
    "train_X = [data.x.mean(dim=0).numpy() for data in train_dataset]\n",
    "train_y = [data.y.item() for data in train_dataset]\n",
    "val_X   = [data.x.mean(dim=0).numpy() for data in val_dataset]\n",
    "val_y   = [data.y.item() for data in val_dataset]\n",
    "\n",
    "# Train Random Forest\n",
    "rf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "rf.fit(train_X, train_y)\n",
    "pred = rf.predict(val_X)\n",
    "print(\"Val accuracy (Random Forest, mean features):\", accuracy_score(val_y, pred))\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "print(classification_report(val_y, pred))\n",
    "print(confusion_matrix(val_y, pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8850060f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import f1_score\n",
    "\n",
    "N_RUNS = 10\n",
    "all_acc = []\n",
    "all_f1_macro = []\n",
    "all_f1_weighted = []\n",
    "\n",
    "all_graph_ids = list(graph_id_groups.keys())\n",
    "\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i\n",
    "        if hasattr(data, 'graph_id'):\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):\n",
    "            del data.graph_type\n",
    "\n",
    "for run_seed in range(N_RUNS):\n",
    "    # Split\n",
    "    train_ids, temp_ids = train_test_split(all_graph_ids, test_size=0.3, random_state=run_seed)\n",
    "    val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=run_seed)\n",
    "\n",
    "    train_dataset = collect_by_ids(train_ids)\n",
    "    val_dataset   = collect_by_ids(val_ids)\n",
    "\n",
    "    clean_and_index(train_dataset)\n",
    "    clean_and_index(val_dataset)\n",
    "\n",
    "    # Prepare data\n",
    "    train_X = [data.x.mean(dim=0).numpy() for data in train_dataset]\n",
    "    train_y = [data.y.item() for data in train_dataset]\n",
    "    val_X   = [data.x.mean(dim=0).numpy() for data in val_dataset]\n",
    "    val_y   = [data.y.item() for data in val_dataset]\n",
    "\n",
    "    rf = RandomForestClassifier(n_estimators=100, random_state=run_seed)\n",
    "    rf.fit(train_X, train_y)\n",
    "    pred = rf.predict(val_X)\n",
    "\n",
    "    acc = accuracy_score(val_y, pred)\n",
    "    f1_macro = f1_score(val_y, pred, average='macro')\n",
    "    f1_weighted = f1_score(val_y, pred, average='weighted')\n",
    "\n",
    "    all_acc.append(acc)\n",
    "    all_f1_macro.append(f1_macro)\n",
    "    all_f1_weighted.append(f1_weighted)\n",
    "\n",
    "print(f\"Random Forest (mean node features) - {N_RUNS} runs\")\n",
    "print(\"Val accuracy   : %.4f ± %.4f\" % (np.mean(all_acc), np.std(all_acc)))\n",
    "print(\"F1-macro       : %.4f ± %.4f\" % (np.mean(all_f1_macro), np.std(all_f1_macro)))\n",
    "print(\"F1-weighted    : %.4f ± %.4f\" % (np.mean(all_f1_weighted), np.std(all_f1_weighted)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f92cfb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X = [data.x.mean(dim=0).numpy() for data in test_dataset]\n",
    "test_y = [data.y.item() for data in test_dataset]\n",
    "pred_test = rf.predict(test_X)\n",
    "print(\"Test accuracy (Random Forest):\", accuracy_score(test_y, pred_test))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd24c111",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "all_node_embeddings = []\n",
    "all_labels = [] \n",
    "\n",
    "# For actual graphs\n",
    "for g in actual_data_list:\n",
    "    all_node_embeddings.append(g.x.cpu().numpy())\n",
    "    all_labels.extend([0] * g.x.shape[0])\n",
    "\n",
    "# For graphs\n",
    "for g in random_data_list:\n",
    "    all_node_embeddings.append(g.x.cpu().numpy())\n",
    "    all_labels.extend([1] * g.x.shape[0])\n",
    "\n",
    "all_node_embeddings = np.vstack(all_node_embeddings)\n",
    "all_labels = np.array(all_labels)\n",
    "\n",
    "print(f\"Total node embeddings: {all_node_embeddings.shape}\")  # (total_nodes, 3072)\n",
    "\n",
    "# Project all embeddings to 2D using PCA\n",
    "pca = PCA(n_components=2, random_state=42)\n",
    "embeddings_2d = pca.fit_transform(all_node_embeddings)\n",
    "\n",
    "# Plot with color by label\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.scatter(\n",
    "    embeddings_2d[all_labels == 0, 0], embeddings_2d[all_labels == 0, 1],\n",
    "    c='blue', alpha=0.3, label='Actual Graphs', s=10\n",
    ")\n",
    "plt.scatter(\n",
    "    embeddings_2d[all_labels == 1, 0], embeddings_2d[all_labels == 1, 1],\n",
    "    c='red', alpha=0.3, label='random graphs', s=10\n",
    ")\n",
    "plt.xlabel(\"PCA 1\")\n",
    "plt.ylabel(\"PCA 2\")\n",
    "plt.legend()\n",
    "plt.title(\"All Node Embeddings (PCA 2D)\\nActual=Blue, GPT=Red\")\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "761527f1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def mean_pooled_X_y(split_dataset):\n",
    "    all_X = np.stack([data.x.mean(dim=0).cpu().numpy() for data in split_dataset])\n",
    "    all_y = np.array([data.y.item() for data in split_dataset])\n",
    "    return all_X, all_y\n",
    "\n",
    "splits = {\n",
    "    'Train': train_dataset,\n",
    "    'Validation': val_dataset,\n",
    "    'Test': test_dataset\n",
    "}\n",
    "\n",
    "for split_name, split_dataset in splits.items():\n",
    "    all_X, all_y = mean_pooled_X_y(split_dataset)\n",
    "    \n",
    "    if len(all_X) > 2:\n",
    "        pca = PCA(n_components=2, random_state=42)\n",
    "        mean_X_pca = pca.fit_transform(all_X)\n",
    "    else:\n",
    "        mean_X_pca = all_X  # fallback if only 1-2 samples\n",
    "\n",
    "    plt.figure(figsize=(7, 7))\n",
    "    plt.scatter(mean_X_pca[all_y==0,0], mean_X_pca[all_y==0,1], c='blue', label='Ground truth', alpha=0.6)\n",
    "    plt.scatter(mean_X_pca[all_y==1,0], mean_X_pca[all_y==1,1], c='red', label='Random', alpha=0.6)\n",
    "    plt.xlabel(\"PCA 1\")\n",
    "    plt.ylabel(\"PCA 2\")\n",
    "    plt.title(f\"PCA of Mean-Pooled Graph Embeddings: {split_name} Split\")\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "614f8b28",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2b10195",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
