{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eed71663",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import from_networkx\n",
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import defaultdict\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from sklearn.ensemble import RandomForestClassifier\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5d3d40",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "file_path = base_path / \"random_graph_graphs_with_embedding.pickle\"\n",
    "with open(file_path, \"rb\") as f:\n",
    "    random_graph = pickle.load(f)\n",
    "    \n",
    "sample_id = list(random_graph.keys())[0]  \n",
    "G = random_graph[sample_id]\n",
    "\n",
    "print(f\"Sample ID: {sample_id}\")\n",
    "print(f\"Number of nodes: {G.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {G.number_of_edges()}\")\n",
    "\n",
    "for node in list(G.nodes())[:20]: \n",
    "    feature = G.nodes[node].get(\"feature\", None)\n",
    "    print(f\"Node ID: {node}\")\n",
    "    if feature is None:\n",
    "        print(\"  Feature: None\")\n",
    "    elif isinstance(feature, np.ndarray):\n",
    "        print(f\"  Feature shape: {feature.shape}\")\n",
    "        print(f\"  Feature sample (first 5 dims): {feature[:5]}\")\n",
    "    else:\n",
    "        print(f\"  Feature type: {type(feature)}\")\n",
    "        print(f\"  Feature preview: {feature}\")\n",
    "\n",
    "file_path = base_path / \"gpt_generated_graph_graphs_with_embedding.pickle\"\n",
    "with open(file_path, \"rb\") as f:\n",
    "    gpt_generated_graph = pickle.load(f)\n",
    "\n",
    "sample_id = list(gpt_generated_graph.keys())[0]  \n",
    "G = gpt_generated_graph[sample_id]\n",
    "\n",
    "print(f\"Sample ID: {sample_id}\")\n",
    "print(f\"Number of nodes: {G.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {G.number_of_edges()}\")\n",
    "\n",
    "for node in list(G.nodes())[:20]: \n",
    "    feature = G.nodes[node].get(\"feature\", None)\n",
    "    print(f\"Node ID: {node}\")\n",
    "    if feature is None:\n",
    "        print(\"  Feature: None\")\n",
    "    elif isinstance(feature, np.ndarray):\n",
    "        print(f\"  Feature shape: {feature.shape}\")\n",
    "        print(f\"  Feature sample (first 5 dims): {feature[:5]}\")\n",
    "    else:\n",
    "        print(f\"  Feature type: {type(feature)}\")\n",
    "        print(f\"  Feature preview: {feature}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "667ce1f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_data_list = []\n",
    "for sample_id, G in random_graph.items():\n",
    "    features = []\n",
    "    nodes = list(G.nodes())\n",
    "    for n in nodes:\n",
    "        feat = G.nodes[n].get(\"feature\", np.zeros(3072, dtype=np.float32))\n",
    "        features.append(torch.tensor(feat, dtype=torch.float32))\n",
    "    x = torch.stack(features)\n",
    "    H = G.copy()\n",
    "    for n in nodes:\n",
    "        del H.nodes[n]['feature']\n",
    "    data = from_networkx(H)\n",
    "    data.x = x\n",
    "    data.paper_id = sample_id\n",
    "    data.y = torch.tensor([1]) \n",
    "    data.graph_type = \"Random\" \n",
    "    random_data_list.append(data)\n",
    "\n",
    "gpt_data_list = []\n",
    "for sample_id, G in gpt_generated_graph.items():\n",
    "    features = []\n",
    "    nodes = list(G.nodes())\n",
    "    for n in nodes:\n",
    "        feat = G.nodes[n].get(\"feature\", np.zeros(3072, dtype=np.float32))\n",
    "        features.append(torch.tensor(feat, dtype=torch.float32))\n",
    "    x = torch.stack(features)\n",
    "    H = G.copy()\n",
    "    for n in nodes:\n",
    "        del H.nodes[n]['feature']\n",
    "    data = from_networkx(H)\n",
    "    data.x = x\n",
    "    data.paper_id = sample_id\n",
    "    data.y = torch.tensor([0])\n",
    "    data.graph_type = \"gpt\"\n",
    "    gpt_data_list.append(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd692125",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "all_graphs = random_data_list + gpt_data_list\n",
    "for g in all_graphs:\n",
    "    g.graph_id = g.paper_id  # set graph_id for grouping\n",
    "\n",
    "graph_id_groups = defaultdict(list)\n",
    "for data in all_graphs:\n",
    "    graph_id_groups[data.graph_id].append(data)\n",
    "\n",
    "all_graph_ids = list(graph_id_groups.keys())\n",
    "random_seed = 40\n",
    "\n",
    "train_ids, temp_ids = train_test_split(\n",
    "    all_graph_ids,\n",
    "    test_size=0.3,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "val_ids, test_ids = train_test_split(\n",
    "    temp_ids,\n",
    "    test_size=0.5,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "train_dataset = collect_by_ids(train_ids)\n",
    "val_dataset   = collect_by_ids(val_ids)\n",
    "test_dataset  = collect_by_ids(test_ids)\n",
    "\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i              # new 0…N-1 index\n",
    "        if hasattr(data, 'graph_id'):     # remove leak source\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):   # remove type attribute\n",
    "            del data.graph_type\n",
    "\n",
    "clean_and_index(train_dataset)\n",
    "clean_and_index(val_dataset)\n",
    "clean_and_index(test_dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3511431",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inspect_first_n(dataset_split, split_name, n=6):\n",
    "\n",
    "    print(f\"\\n--- First {n} samples from {split_name.upper()} ---\")\n",
    "    for i, data in enumerate(dataset_split[:n]):\n",
    "        print(f\"\\nSample #{i} in {split_name}:\")\n",
    "        # print attribute list\n",
    "        print(\"  Keys            :\", data.keys())\n",
    "        # print shapes\n",
    "        print(\"  x.shape         :\", tuple(data.x.shape))\n",
    "        print(\"  edge_index.shape:\", tuple(data.edge_index.shape))\n",
    "        print(\"  y (label)       :\", data.y.item())\n",
    "        print(\"  graph_index     :\", data.graph_index)\n",
    "        \n",
    "        num_nf = min(50, data.x.size(0))\n",
    "        print(f\"  Sample node features (first {num_nf} rows):\")\n",
    "        print(data.x[:num_nf])\n",
    "        \n",
    "        num_e = min(50, data.edge_index.size(1))\n",
    "        edges = data.edge_index[:, :num_e].t().tolist()\n",
    "        print(f\"  Sample edges (first {num_e}):\", edges)\n",
    "\n",
    "# Call for each split\n",
    "inspect_first_n(train_dataset, \"train\", n=6)\n",
    "inspect_first_n(val_dataset,   \"val\",   n=6)\n",
    "inspect_first_n(test_dataset,  \"test\",  n=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daaf6b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class0_feats = []\n",
    "class1_feats = []\n",
    "\n",
    "for data in train_dataset:\n",
    "    feat_mean = data.x.mean(dim=0).cpu().numpy()\n",
    "    if data.y.item() == 0:\n",
    "        class0_feats.append(feat_mean)\n",
    "    else:\n",
    "        class1_feats.append(feat_mean)\n",
    "\n",
    "class0_feats = np.vstack(class0_feats)\n",
    "class1_feats = np.vstack(class1_feats)\n",
    "\n",
    "print(\"Class 0 mean:\", class0_feats.mean(axis=0))\n",
    "print(\"Class 1 mean:\", class1_feats.mean(axis=0))\n",
    "print(\"Class 0 std:\", class0_feats.std(axis=0))\n",
    "print(\"Class 1 std:\", class1_feats.std(axis=0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec76f49f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# Prepare data\n",
    "train_X = [data.x.mean(dim=0).numpy() for data in train_dataset]\n",
    "train_y = [data.y.item() for data in train_dataset]\n",
    "val_X   = [data.x.mean(dim=0).numpy() for data in val_dataset]\n",
    "val_y   = [data.y.item() for data in val_dataset]\n",
    "\n",
    "# Train Random Forest\n",
    "rf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "rf.fit(train_X, train_y)\n",
    "pred = rf.predict(val_X)\n",
    "print(\"Val accuracy (Random Forest, mean features):\", accuracy_score(val_y, pred))\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "print(classification_report(val_y, pred))\n",
    "print(confusion_matrix(val_y, pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "848f3061",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.tree import plot_tree, _tree\n",
    "\n",
    "#  Print thresholds \n",
    "def print_split_thresholds(rf, num_trees=2, max_nodes=12, feature_names=None):\n",
    "\n",
    "    for t_idx, est in enumerate(rf.estimators_[:num_trees]):\n",
    "        tree = est.tree_\n",
    "        feat = tree.feature\n",
    "        thr  = tree.threshold\n",
    "        print(f\"\\n=== Tree #{t_idx} ===\")\n",
    "        shown = 0\n",
    "        for node_id in range(tree.node_count):\n",
    "            if feat[node_id] == _tree.TREE_UNDEFINED:  # leaves\n",
    "                continue\n",
    "            fname = feature_names[feat[node_id]] if feature_names is not None else f\"f{feat[node_id]}\"\n",
    "            print(f\"node {node_id:4d}:  {fname} <= {thr[node_id]:.6f}\")\n",
    "            shown += 1\n",
    "            if shown >= max_nodes:\n",
    "                break\n",
    "\n",
    "#  Plot a single tree (for better understanding) \n",
    "def plot_single_tree(rf, tree_idx=0, max_depth=3, feature_names=None, figsize=(16, 10)):\n",
    "\n",
    "    est = rf.estimators_[tree_idx]\n",
    "    plt.figure(figsize=figsize)\n",
    "    plot_tree(\n",
    "        est,\n",
    "        max_depth=max_depth,\n",
    "        feature_names=feature_names,\n",
    "        filled=True,\n",
    "        impurity=True,\n",
    "        proportion=True,\n",
    "        rounded=True,\n",
    "        fontsize=8\n",
    "    )\n",
    "    plt.title(f\"Random Forest - Tree #{tree_idx} (max_depth={max_depth})\")\n",
    "    plt.show()\n",
    "\n",
    "# Distribution of thresholds across the forest\n",
    "def collect_splits(rf):\n",
    "\n",
    "    f_idx, thresh = [], []\n",
    "    for est in rf.estimators_:\n",
    "        t = est.tree_\n",
    "        mask = t.feature != _tree.TREE_UNDEFINED\n",
    "        f_idx.append(t.feature[mask])\n",
    "        thresh.append(t.threshold[mask])\n",
    "    if len(f_idx) == 0:\n",
    "        return np.array([]), np.array([])\n",
    "    return np.concatenate(f_idx), np.concatenate(thresh)\n",
    "\n",
    "def plot_threshold_scatter(rf, feature_names=None, figsize=(12, 6)):\n",
    "\n",
    "    f_idx, thr = collect_splits(rf)\n",
    "    if f_idx.size == 0:\n",
    "        print(\"No internal splits found.\")\n",
    "        return\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.scatter(f_idx, thr, s=8, alpha=0.3)\n",
    "    plt.xlabel(\"Feature index\")\n",
    "    plt.ylabel(\"Threshold value\")\n",
    "    plt.title(\"All split thresholds across the forest\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "def plot_top_feature_split_counts(rf, top_k=20, figsize=(12, 6), feature_names=None):\n",
    "\n",
    "    f_idx, thr = collect_splits(rf)\n",
    "    if f_idx.size == 0:\n",
    "        print(\"No internal splits found.\")\n",
    "        return\n",
    "    max_feature = int(f_idx.max()) + 1\n",
    "    counts = np.bincount(f_idx, minlength=max_feature)\n",
    "    top_idx = np.argsort(counts)[-top_k:][::-1]\n",
    "    labels = [feature_names[i] if feature_names is not None else f\"f{i}\" for i in top_idx]\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.bar(range(len(top_idx)), counts[top_idx])\n",
    "    plt.xticks(range(len(top_idx)), labels, rotation=60, ha='right')\n",
    "    plt.ylabel(\"#Splits using feature\")\n",
    "    plt.title(f\"Top {top_k} features by split count\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# Decision path for a single sample \n",
    "def print_decision_path_for_sample(rf, x, tree_idx=0, feature_names=None):\n",
    "\n",
    "    est = rf.estimators_[tree_idx]\n",
    "    t = est.tree_\n",
    "    node_indicator = est.decision_path(x.reshape(1, -1))\n",
    "    leaf_id = est.apply(x.reshape(1, -1))[0]\n",
    "    print(f\"\\nDecision path on Tree #{tree_idx}, ends at leaf {leaf_id}\")\n",
    "    for node_id in node_indicator.indices:\n",
    "        if t.feature[node_id] == _tree.TREE_UNDEFINED:\n",
    "            continue\n",
    "        fid = t.feature[node_id]\n",
    "        thr = t.threshold[node_id]\n",
    "        val = x[fid]\n",
    "        fname = feature_names[fid] if feature_names is not None else f\"f{fid}\"\n",
    "        op = \"<=\" if val <= thr else \">\"\n",
    "        print(f\"node {node_id:4d}:  ({fname} = {val:.6f}) {op} {thr:.6f}\")\n",
    "\n",
    "try:\n",
    "    D = train_X.shape[1]\n",
    "except:\n",
    "    D = rf.n_features_in_\n",
    "feature_names = [f\"emb[{i}]\" for i in range(D)]\n",
    "\n",
    "print_split_thresholds(rf, num_trees=2, max_nodes=12, feature_names=feature_names)\n",
    "plot_single_tree(rf, tree_idx=0, max_depth=3, feature_names=feature_names)\n",
    "plot_threshold_scatter(rf, feature_names=feature_names)\n",
    "plot_top_feature_split_counts(rf, top_k=20, feature_names=feature_names)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8210907",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import accuracy_score\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "X_train = np.stack([d.x.mean(0).cpu().numpy() for d in train_dataset]).astype(np.float64)\n",
    "y_train = np.array([d.y.item() for d in train_dataset], dtype=np.int64)\n",
    "X_val   = np.stack([d.x.mean(0).cpu().numpy() for d in val_dataset]).astype(np.float64)\n",
    "y_val   = np.array([d.y.item() for d in val_dataset], dtype=np.int64)\n",
    "\n",
    "X0 = X_train[y_train == 0]\n",
    "X1 = X_train[y_train == 1]\n",
    "mu0 = X0.mean(axis=0); mu1 = X1.mean(axis=0)\n",
    "s0  = X0.std(axis=0, ddof=1) + 1e-12\n",
    "s1  = X1.std(axis=0, ddof=1) + 1e-12\n",
    "\n",
    "ratio = s0 / s1 \n",
    "frac_sigma1_smaller = (ratio > 1).mean()\n",
    "print(f\"(1) In {frac_sigma1_smaller*100:.1f}% of dimensions, σ1 < σ0 (class 1 is more compact).\")\n",
    "print(f\"    Variance ratios (median/percentiles): median={np.median(ratio):.2f}, P75={np.percentile(ratio,75):.2f}, P90={np.percentile(ratio,90):.2f}\")\n",
    "\n",
    "aucs = []\n",
    "for j in range(X_train.shape[1]):\n",
    "    auc = roc_auc_score(y_train, X_train[:, j])\n",
    "    aucs.append(max(auc, 1 - auc)) \n",
    "aucs = np.array(aucs)\n",
    "r, p = pearsonr(ratio, aucs)\n",
    "print(f\"(2) Pearson correlation between (σ0/σ1) and univariate AUC: r={r:.3f}, p={p:.2e}\")\n",
    "\n",
    "top_by_ratio = np.argsort(-ratio)[:10]\n",
    "print(\"\\nTop 10 dims by (σ0/σ1) — larger → class 1 more compact:\")\n",
    "for j in top_by_ratio:\n",
    "    print(f\"emb[{j:4d}]  σ0/σ1={ratio[j]:.2f}   mu0={mu0[j]: .5f}  mu1={mu1[j]: .5f}  AUC={aucs[j]:.3f}\")\n",
    "\n",
    "top_by_auc = np.argsort(-aucs)[:10]\n",
    "print(\"\\nTop 10 dims by univariate AUC:\")\n",
    "for j in top_by_auc:\n",
    "    print(f\"emb[{j:4d}]  AUC={aucs[j]:.3f}  σ0/σ1={ratio[j]:.2f}   mu0={mu0[j]: .5f}  mu1={mu1[j]: .5f}\")\n",
    "\n",
    "\n",
    "scaler = StandardScaler().fit(X_train) \n",
    "Xtr_z = scaler.transform(X_train)\n",
    "Xva_z = scaler.transform(X_val)\n",
    "\n",
    "rf_raw = RandomForestClassifier(n_estimators=300, random_state=42, n_jobs=-1)\n",
    "rf_raw.fit(X_train, y_train)\n",
    "acc_raw = accuracy_score(y_val, rf_raw.predict(X_val))\n",
    "\n",
    "rf_z = RandomForestClassifier(n_estimators=300, random_state=42, n_jobs=-1)\n",
    "rf_z.fit(Xtr_z, y_train)\n",
    "acc_z = accuracy_score(y_val, rf_z.predict(Xva_z))\n",
    "\n",
    "print(f\"\\n(3) RF accuracy without normalization: {acc_raw:.4f}\")\n",
    "print(f\"(4) RF accuracy with global Z-score: {acc_z:.4f}\")\n",
    "print(\"If acc_z is meaningfully lower, the model benefits from differences in spread/scale, not just mean shifts.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "452b99a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.tree import _tree\n",
    "\n",
    "est = rf.estimators_[0]\n",
    "t = est.tree_\n",
    "\n",
    "j_root  = t.feature[0]       \n",
    "thr_rf  = t.threshold[0]    \n",
    "\n",
    "vals = np.asarray(train_X)[:, j_root]   \n",
    "vals_sorted = np.sort(vals)\n",
    "\n",
    "k = np.searchsorted(vals_sorted, thr_rf)\n",
    "left_val  = vals_sorted[max(k-1, 0)]\n",
    "right_val = vals_sorted[min(k, len(vals_sorted)-1)]\n",
    "mid = (left_val + right_val) / 2\n",
    "\n",
    "print(f\"root feature = emb[{j_root}]\")\n",
    "print(f\"tree threshold = {thr_rf:.6f}\")\n",
    "print(f\"adjacent values around thr: {left_val:.6f} , {right_val:.6f}\")\n",
    "print(f\"their midpoint           : {mid:.6f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e98b5f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "N_RUNS = 10\n",
    "all_acc = []\n",
    "all_f1_macro = []\n",
    "all_f1_weighted = []\n",
    "\n",
    "all_graph_ids = list(graph_id_groups.keys())\n",
    "\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i\n",
    "        if hasattr(data, 'graph_id'):\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):\n",
    "            del data.graph_type\n",
    "\n",
    "for run_seed in range(N_RUNS):\n",
    "    # Split\n",
    "    train_ids, temp_ids = train_test_split(all_graph_ids, test_size=0.3, random_state=run_seed)\n",
    "    val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=run_seed)\n",
    "\n",
    "    train_dataset = collect_by_ids(train_ids)\n",
    "    val_dataset   = collect_by_ids(val_ids)\n",
    "\n",
    "    clean_and_index(train_dataset)\n",
    "    clean_and_index(val_dataset)\n",
    "\n",
    "    # Prepare data\n",
    "    train_X = [data.x.mean(dim=0).numpy() for data in train_dataset]\n",
    "    train_y = [data.y.item() for data in train_dataset]\n",
    "    val_X   = [data.x.mean(dim=0).numpy() for data in val_dataset]\n",
    "    val_y   = [data.y.item() for data in val_dataset]\n",
    "\n",
    "    rf = RandomForestClassifier(n_estimators=100, random_state=run_seed)\n",
    "    rf.fit(train_X, train_y)\n",
    "    pred = rf.predict(val_X)\n",
    "\n",
    "    acc = accuracy_score(val_y, pred)\n",
    "    f1_macro = f1_score(val_y, pred, average='macro')\n",
    "    f1_weighted = f1_score(val_y, pred, average='weighted')\n",
    "\n",
    "    all_acc.append(acc)\n",
    "    all_f1_macro.append(f1_macro)\n",
    "    all_f1_weighted.append(f1_weighted)\n",
    "\n",
    "print(f\"Random Forest (mean node features) - {N_RUNS} runs\")\n",
    "print(\"Val accuracy   : %.4f ± %.4f\" % (np.mean(all_acc), np.std(all_acc)))\n",
    "print(\"F1-macro       : %.4f ± %.4f\" % (np.mean(all_f1_macro), np.std(all_f1_macro)))\n",
    "print(\"F1-weighted    : %.4f ± %.4f\" % (np.mean(all_f1_weighted), np.std(all_f1_weighted)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b9daecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X = [data.x.mean(dim=0).numpy() for data in test_dataset]\n",
    "test_y = [data.y.item() for data in test_dataset]\n",
    "pred_test = rf.predict(test_X)\n",
    "print(\"Test accuracy (Random Forest):\", accuracy_score(test_y, pred_test))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23e5b050",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "all_node_embeddings = []\n",
    "all_labels = [] \n",
    "\n",
    "# For random graphs\n",
    "for g in random_data_list:\n",
    "    all_node_embeddings.append(g.x.cpu().numpy())\n",
    "    all_labels.extend([0] * g.x.shape[0])\n",
    "\n",
    "# For GPT-generated graphs\n",
    "for g in gpt_data_list:\n",
    "    all_node_embeddings.append(g.x.cpu().numpy())\n",
    "    all_labels.extend([1] * g.x.shape[0])\n",
    "\n",
    "all_node_embeddings = np.vstack(all_node_embeddings)\n",
    "all_labels = np.array(all_labels)\n",
    "\n",
    "print(f\"Total node embeddings: {all_node_embeddings.shape}\")  # (total_nodes, 3072)\n",
    "\n",
    "pca = PCA(n_components=2, random_state=42)\n",
    "embeddings_2d = pca.fit_transform(all_node_embeddings)\n",
    "\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.scatter(\n",
    "    embeddings_2d[all_labels == 0, 0], embeddings_2d[all_labels == 0, 1],\n",
    "    c='blue', alpha=0.3, label='Random Graphs', s=10\n",
    ")\n",
    "plt.scatter(\n",
    "    embeddings_2d[all_labels == 1, 0], embeddings_2d[all_labels == 1, 1],\n",
    "    c='red', alpha=0.3, label='GPT Graphs', s=10\n",
    ")\n",
    "plt.xlabel(\"PCA 1\")\n",
    "plt.ylabel(\"PCA 2\")\n",
    "plt.legend()\n",
    "plt.title(\"All Node Embeddings (PCA 2D)\\nRandom=Blue, GPT=Red\")\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "951da5e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# Prepare mean-pooled features\n",
    "all_X = np.stack([data.x.mean(dim=0).numpy() for data in train_dataset + val_dataset + test_dataset])\n",
    "all_y = np.array([data.y.item() for data in train_dataset + val_dataset + test_dataset])\n",
    "\n",
    "# PCA\n",
    "pca = PCA(n_components=2, random_state=42)\n",
    "mean_X_pca = pca.fit_transform(all_X)\n",
    "\n",
    "# Plot PCA\n",
    "plt.figure(figsize=(7,7))\n",
    "plt.scatter(mean_X_pca[all_y==0,0], mean_X_pca[all_y==0,1], c='blue', label='GPT', alpha=0.6)\n",
    "plt.scatter(mean_X_pca[all_y==1,0], mean_X_pca[all_y==1,1], c='red', label='Random', alpha=0.6)\n",
    "plt.xlabel(\"PCA 1\")\n",
    "plt.ylabel(\"PCA 2\")\n",
    "plt.title(\"PCA of Mean-Pooled Graph Embeddings\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9707fd4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Helper to compute mean-pooled features and their labels for a split\n",
    "def mean_pooled_X_y(split_dataset):\n",
    "    all_X = np.stack([data.x.mean(dim=0).cpu().numpy() for data in split_dataset])\n",
    "    all_y = np.array([data.y.item() for data in split_dataset])\n",
    "    return all_X, all_y\n",
    "\n",
    "splits = {\n",
    "    'Train': train_dataset,\n",
    "    'Validation': val_dataset,\n",
    "    'Test': test_dataset\n",
    "}\n",
    "\n",
    "for split_name, split_dataset in splits.items():\n",
    "    all_X, all_y = mean_pooled_X_y(split_dataset)\n",
    "    \n",
    "    if len(all_X) > 2:\n",
    "        pca = PCA(n_components=2, random_state=42)\n",
    "        mean_X_pca = pca.fit_transform(all_X)\n",
    "    else:\n",
    "        mean_X_pca = all_X  \n",
    "\n",
    "    plt.figure(figsize=(7, 7))\n",
    "    plt.scatter(mean_X_pca[all_y==0,0], mean_X_pca[all_y==0,1], c='blue', label='GPT', alpha=0.6)\n",
    "    plt.scatter(mean_X_pca[all_y==1,0], mean_X_pca[all_y==1,1], c='red', label='Random', alpha=0.6)\n",
    "    plt.xlabel(\"PCA 1\")\n",
    "    plt.ylabel(\"PCA 2\")\n",
    "    plt.title(f\"PCA of Mean-Pooled Graph Embeddings: {split_name} Split\")\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4186f929",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,4))\n",
    "plt.hist(mean_X_pca[all_y==0, 0], bins=40, alpha=0.5, color='blue', label='GPT', density=True)\n",
    "plt.hist(mean_X_pca[all_y==1, 0], bins=40, alpha=0.5, color='red', label='Random', density=True)\n",
    "plt.xlabel(\"PCA Component 1\")\n",
    "plt.ylabel(\"Density\")\n",
    "plt.title(\"Distribution of First PCA Component (Mean-Pooled Embeddings)\")\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5996d187",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "df = pd.DataFrame({\n",
    "    'pca1': mean_X_pca[:, 0],\n",
    "    'category': ['GPT' if y==0 else 'Random' for y in all_y]\n",
    "})\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "sns.violinplot(x='category', y='pca1', data=df, palette=['blue', 'red'])\n",
    "plt.title(\"Violin Plot of First PCA Component\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11fc14a5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b803f5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
