{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71a9c465",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import from_networkx\n",
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import defaultdict\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83f863fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_path = \"groundtruth_graph_graphs_with embedding.pickle\"\n",
    "\n",
    "with open(file_path, \"rb\") as f:\n",
    "    groundtruth_graph = pickle.load(f)\n",
    "\n",
    "    \n",
    "sample_id = list(groundtruth_graph.keys())[0]  \n",
    "G = groundtruth_graph[sample_id]\n",
    "\n",
    "print(f\"Sample ID: {sample_id}\")\n",
    "print(f\"Number of nodes: {G.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {G.number_of_edges()}\")\n",
    "\n",
    "for node in list(G.nodes())[:20]: \n",
    "    feature = G.nodes[node].get(\"feature\", None)\n",
    "    print(f\"Node ID: {node}\")\n",
    "    if feature is None:\n",
    "        print(\"  Feature: None\")\n",
    "    elif isinstance(feature, np.ndarray):\n",
    "        print(f\"  Feature shape: {feature.shape}\")\n",
    "        print(f\"  Feature sample (first 5 dims): {feature[:5]}\")\n",
    "    else:\n",
    "        print(f\"  Feature type: {type(feature)}\")\n",
    "        print(f\"  Feature preview: {feature}\")\n",
    "\n",
    "\n",
    "file_path = \"gpt_generated_graph_graphs_with embedding.pickle\"\n",
    "with open(file_path, \"rb\") as f:\n",
    "    gpt_generated_graph = pickle.load(f)\n",
    "sample_id = list(gpt_generated_graph.keys())[0]  \n",
    "G = gpt_generated_graph[sample_id]\n",
    "\n",
    "print(f\"Sample ID: {sample_id}\")\n",
    "print(f\"Number of nodes: {G.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {G.number_of_edges()}\")\n",
    "\n",
    "for node in list(G.nodes())[:20]: \n",
    "    feature = G.nodes[node].get(\"feature\", None)\n",
    "    print(f\"Node ID: {node}\")\n",
    "    if feature is None:\n",
    "        print(\"  Feature: None\")\n",
    "    elif isinstance(feature, np.ndarray):\n",
    "        print(f\"  Feature shape: {feature.shape}\")\n",
    "        print(f\"  Feature sample (first 5 dims): {feature[:5]}\")\n",
    "    else:\n",
    "        print(f\"  Feature type: {type(feature)}\")\n",
    "        print(f\"  Feature preview: {feature}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4e267e",
   "metadata": {},
   "outputs": [],
   "source": [
    "actual_data_list = []\n",
    "for sample_id, G in groundtruth_graph.items():\n",
    "    # Assume every node has a 3072-dim feature\n",
    "    features = []\n",
    "    nodes = list(G.nodes())\n",
    "    for n in nodes:\n",
    "        feat = G.nodes[n].get(\"feature\", np.zeros(3072, dtype=np.float32))\n",
    "        features.append(torch.tensor(feat, dtype=torch.float32))\n",
    "    x = torch.stack(features)\n",
    "    H = G.copy()\n",
    "    for n in nodes:\n",
    "        del H.nodes[n]['feature']\n",
    "    data = from_networkx(H)\n",
    "    data.x = x\n",
    "    data.paper_id = sample_id\n",
    "    data.y = torch.tensor([0]) \n",
    "    data.graph_type = \"groundtruth\" \n",
    "    actual_data_list.append(data)\n",
    "\n",
    "gpt_data_list = []\n",
    "for sample_id, G in gpt_generated_graph.items():\n",
    "    features = []\n",
    "    nodes = list(G.nodes())\n",
    "    for n in nodes:\n",
    "        feat = G.nodes[n].get(\"feature\", np.zeros(3072, dtype=np.float32))\n",
    "        features.append(torch.tensor(feat, dtype=torch.float32))\n",
    "    x = torch.stack(features)\n",
    "    H = G.copy()\n",
    "    for n in nodes:\n",
    "        del H.nodes[n]['feature']\n",
    "    data = from_networkx(H)\n",
    "    data.x = x\n",
    "    data.paper_id = sample_id\n",
    "    data.y = torch.tensor([1])\n",
    "    data.graph_type = \"gpt\"\n",
    "    gpt_data_list.append(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843cd50f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#  Combine all data and assign graph_id attribute\n",
    "all_graphs = actual_data_list + gpt_data_list\n",
    "for g in all_graphs:\n",
    "    g.graph_id = g.paper_id  # set graph_id for grouping\n",
    "\n",
    "#  Group all graphs by their graph_id\n",
    "graph_id_groups = defaultdict(list)\n",
    "for data in all_graphs:\n",
    "    graph_id_groups[data.graph_id].append(data)\n",
    "\n",
    "all_graph_ids = list(graph_id_groups.keys())\n",
    "random_seed = 40\n",
    "\n",
    "#  Split graph_ids into train (70%) and temp (30%)\n",
    "train_ids, temp_ids = train_test_split(\n",
    "    all_graph_ids,\n",
    "    test_size=0.3,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "#  Split temp into validation (15%) and test (15%)\n",
    "val_ids, test_ids = train_test_split(\n",
    "    temp_ids,\n",
    "    test_size=0.5,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "#  Helper to collect Data objects by ids\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "train_dataset = collect_by_ids(train_ids)\n",
    "val_dataset   = collect_by_ids(val_ids)\n",
    "test_dataset  = collect_by_ids(test_ids)\n",
    "\n",
    "# Assign simple graph_index and drop the original graph_id (optional cleanup)\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i              # new 0…N-1 index\n",
    "        if hasattr(data, 'graph_id'):     # remove leak source\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):   # remove type attribute\n",
    "            del data.graph_type\n",
    "\n",
    "clean_and_index(train_dataset)\n",
    "clean_and_index(val_dataset)\n",
    "clean_and_index(test_dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e49c48",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def inspect_first_n(dataset_split, split_name, n=6):\n",
    "\n",
    "    print(f\"\\n--- First {n} samples from {split_name.upper()} ---\")\n",
    "    for i, data in enumerate(dataset_split[:n]):\n",
    "        print(f\"\\nSample #{i} in {split_name}:\")\n",
    "        # print attribute list\n",
    "        print(\"  Keys            :\", data.keys())\n",
    "        # print shapes\n",
    "        print(\"  x.shape         :\", tuple(data.x.shape))\n",
    "        print(\"  edge_index.shape:\", tuple(data.edge_index.shape))\n",
    "        print(\"  y (label)       :\", data.y.item())\n",
    "        print(\"  graph_index     :\", data.graph_index)\n",
    "        \n",
    "        num_nf = min(50, data.x.size(0))\n",
    "        print(f\"  Sample node features (first {num_nf} rows):\")\n",
    "        print(data.x[:num_nf])\n",
    "        \n",
    "        num_e = min(50, data.edge_index.size(1))\n",
    "        edges = data.edge_index[:, :num_e].t().tolist()\n",
    "        print(f\"  Sample edges (first {num_e}):\", edges)\n",
    "\n",
    "# Call for each split\n",
    "inspect_first_n(train_dataset, \"train\", n=6)\n",
    "inspect_first_n(val_dataset,   \"val\",   n=6)\n",
    "inspect_first_n(test_dataset,  \"test\",  n=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daaf6b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class0_feats = []\n",
    "class1_feats = []\n",
    "\n",
    "for data in train_dataset:\n",
    "    feat_mean = data.x.mean(dim=0).cpu().numpy()\n",
    "    if data.y.item() == 0:\n",
    "        class0_feats.append(feat_mean)\n",
    "    else:\n",
    "        class1_feats.append(feat_mean)\n",
    "\n",
    "class0_feats = np.vstack(class0_feats)\n",
    "class1_feats = np.vstack(class1_feats)\n",
    "\n",
    "print(\"Class 0 mean:\", class0_feats.mean(axis=0))\n",
    "print(\"Class 1 mean:\", class1_feats.mean(axis=0))\n",
    "print(\"Class 0 std:\", class0_feats.std(axis=0))\n",
    "print(\"Class 1 std:\", class1_feats.std(axis=0))\n",
    "print(\"class0_feats shape:\", class0_feats.shape)\n",
    "print(\"class1_feats shape:\", class1_feats.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29bc4430",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def graph_to_vector(data):\n",
    "    if data.x.size(0) == 0:\n",
    "        return np.zeros((data.x.size(1),), dtype=np.float32)\n",
    "    return data.x.mean(dim=0).cpu().numpy()\n",
    "\n",
    "def build_xy(dataset_split):\n",
    "    X = np.stack([graph_to_vector(d) for d in dataset_split], axis=0)\n",
    "    y = np.array([int(d.y.item()) for d in dataset_split], dtype=np.int64)\n",
    "    return X, y\n",
    "\n",
    "X_train, y_train = build_xy(train_dataset)\n",
    "X_val,   y_val   = build_xy(val_dataset)\n",
    "X_test,  y_test  = build_xy(test_dataset)\n",
    "\n",
    "rf_final = RandomForestClassifier(\n",
    "    n_estimators=400,\n",
    "    random_state=42,\n",
    "    bootstrap=True,\n",
    "    oob_score=True,\n",
    "    n_jobs=-1,\n",
    "    max_depth=12,\n",
    "    min_samples_leaf=10,\n",
    "    min_samples_split=20,\n",
    "    max_features=55  # ~sqrt(3072)\n",
    ")\n",
    "\n",
    "rf_final.fit(X_train, y_train)\n",
    "\n",
    "val_pred  = rf_final.predict(X_val)\n",
    "val_proba = rf_final.predict_proba(X_val)[:, 1]\n",
    "print(\"OOB accuracy:\", getattr(rf_final, \"oob_score_\", None))\n",
    "print(\"Val accuracy:\", accuracy_score(y_val, val_pred))\n",
    "print(\"Val ROC-AUC:\", roc_auc_score(y_val, val_proba))\n",
    "print(\"\\nValidation classification report:\\n\", classification_report(y_val, val_pred))\n",
    "print(\"Validation confusion matrix:\\n\", confusion_matrix(y_val, val_pred))\n",
    "\n",
    "test_pred  = rf_final.predict(X_test)\n",
    "test_proba = rf_final.predict_proba(X_test)[:, 1]\n",
    "print(\"\\nTEST accuracy:\", accuracy_score(y_test, test_pred))\n",
    "print(\"TEST ROC-AUC:\", roc_auc_score(y_test, test_proba))\n",
    "print(\"\\nTest classification report:\\n\", classification_report(y_test, test_pred))\n",
    "print(\"Test confusion matrix:\\n\", confusion_matrix(y_test, test_pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "978a16c8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import gridspec\n",
    "from sklearn import tree as sktree\n",
    "\n",
    "\n",
    "def _node_depths_for_tree(estimator):\n",
    "    t = estimator.tree_\n",
    "    L, R = t.children_left, t.children_right\n",
    "    depth = np.zeros(t.node_count, dtype=int)\n",
    "    stack = [(0, 0)]\n",
    "    while stack:\n",
    "        nid, d = stack.pop()\n",
    "        depth[nid] = d\n",
    "        if L[nid] != -1: stack.append((L[nid], d + 1))\n",
    "        if R[nid] != -1: stack.append((R[nid], d + 1))\n",
    "    return depth\n",
    "\n",
    "def _avg_leaf_depths_over_forest(rf, X):\n",
    "\n",
    "    if X is None or len(X) == 0:\n",
    "        return np.array([])\n",
    "    per_tree_depths = []\n",
    "    for est in rf.estimators_:\n",
    "        leaf_ids = est.apply(X)\n",
    "        depth = _node_depths_for_tree(est)\n",
    "        per_tree_depths.append(depth[leaf_ids])\n",
    "    per_tree_depths = np.stack(per_tree_depths, axis=1)  \n",
    "    return per_tree_depths.mean(axis=1)\n",
    "\n",
    "def plot_sample_tree(\n",
    "    rf, *, tree_idx=0, max_depth=3, feature_names=None, example_x=None,\n",
    "    fontsize=9, ax=None\n",
    "):\n",
    "\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "    est = rf.estimators_[tree_idx]\n",
    "    # basic plot\n",
    "    sktree.plot_tree(\n",
    "        est,\n",
    "        max_depth=max_depth,\n",
    "        feature_names=feature_names,\n",
    "        filled=True,\n",
    "        impurity=False,\n",
    "        proportion=True,\n",
    "        rounded=True,\n",
    "        fontsize=fontsize,\n",
    "        ax=ax,\n",
    "    )\n",
    "    ax.set_title(f\"Sample tree #{tree_idx} (max_depth={max_depth})\", fontsize=fontsize+1)\n",
    "\n",
    "def plot_avg_leaf_depth_hist(\n",
    "    rf, X, y=None, *, bins=20, fontsize=9, ax=None\n",
    "):\n",
    "\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "    avg_depths = _avg_leaf_depths_over_forest(rf, X)\n",
    "    ax.hist(avg_depths, bins=bins)\n",
    "    ax.set_xlabel(\"Average leaf depth per sample\", fontsize=fontsize)\n",
    "    ax.set_ylabel(\"Count\", fontsize=fontsize)\n",
    "    ax.set_title(\"Validation: average leaf depth distribution\", fontsize=fontsize+1)\n",
    "    ax.grid(True, alpha=0.25, linewidth=0.6)\n",
    "\n",
    "def plot_cumulative_impurity(\n",
    "    rf, *, max_depth_for_imp=20, coverages=(0.90, 0.95), fontsize=9, ax=None\n",
    "):\n",
    "\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "\n",
    "    tbl = depthwise_gini_table(rf, max_depth=max_depth_for_imp)\n",
    "    depths = tbl[\"depths\"]\n",
    "    cum = tbl[\"cum_share\"]\n",
    "\n",
    "    if len(depths) == 0:\n",
    "        ax.text(0.5, 0.5, \"No internal splits found.\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
    "        return\n",
    "\n",
    "    ax.plot(depths, cum, marker=\"o\")\n",
    "    ax.set_ylim(0, 1.0)\n",
    "    ax.set_xlim(left=0, right=max(depths))\n",
    "    ax.set_xlabel(\"Depth\", fontsize=fontsize)\n",
    "    ax.set_ylabel(\"Cumulative share of ΔGini\", fontsize=fontsize)\n",
    "    ax.set_title(\"Cumulative ΔGini by depth\", fontsize=fontsize+1)\n",
    "    ax.grid(True, alpha=0.25, linewidth=0.6)\n",
    "\n",
    "    # Add vertical guides for requested coverage thresholds\n",
    "    for cov in coverages:\n",
    "        eff, _ = effective_depth_by_coverage(rf, coverage=cov, max_depth=max_depth_for_imp)\n",
    "        if eff is not None:\n",
    "            ax.axvline(eff, linestyle=\"--\")\n",
    "            ax.text(eff, 0.02, f\"{int(cov*100)}%→d={eff}\", rotation=90,\n",
    "                    va=\"bottom\", ha=\"right\", fontsize=fontsize)\n",
    "\n",
    "\n",
    "def plot_rf_custom_layout(\n",
    "    rf,\n",
    "    feature_names=None,\n",
    "    example_x=None,\n",
    "    X_val=None,\n",
    "    y_val=None,\n",
    "    *,\n",
    "    figsize=(19.5, 6), dpi=600,\n",
    "    left_width=7.9, right_width=1.9,    \n",
    "    right_top_frac=0.4, right_bottom_frac=0.4,  \n",
    "    wspace=0.1, hspace_right=0.30,       # spacings\n",
    "    margins=(0.06, 0.98, 0.05, 0.99),    # outer margins (bottom, top, left, right)\n",
    "    tree_idx=0, max_depth_tree=3, max_depth_for_imp=20,\n",
    "    fontsize_tree=9, fontsize_hist=9, fontsize_imp=9,\n",
    "):\n",
    "    total_frac = right_top_frac + right_bottom_frac\n",
    "    if total_frac <= 0:\n",
    "        raise ValueError(\"right_top_frac + right_bottom_frac must be > 0\")\n",
    "    f_top = right_top_frac / total_frac\n",
    "    f_bot = right_bottom_frac / total_frac\n",
    "\n",
    "    fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "\n",
    "    outer = gridspec.GridSpec(\n",
    "        nrows=1, ncols=2, figure=fig,\n",
    "        width_ratios=[left_width, right_width],\n",
    "        wspace=wspace\n",
    "    )\n",
    "\n",
    "    # Left panel\n",
    "    ax_left = fig.add_subplot(outer[0, 0])\n",
    "\n",
    "    # Right panel split into top/bottom\n",
    "    right = gridspec.GridSpecFromSubplotSpec(\n",
    "        2, 1, subplot_spec=outer[0, 1],\n",
    "        height_ratios=[f_top, f_bot],\n",
    "        hspace=hspace_right\n",
    "    )\n",
    "    ax_top    = fig.add_subplot(right[0, 0])  # histogram of avg leaf depth\n",
    "    ax_bottom = fig.add_subplot(right[1, 0])  # cumulative ΔGini\n",
    "\n",
    "    plot_sample_tree(\n",
    "        rf, tree_idx=tree_idx, max_depth=max_depth_tree,\n",
    "        feature_names=feature_names, example_x=example_x,\n",
    "        fontsize=fontsize_tree, ax=ax_left\n",
    "    )\n",
    "\n",
    "    plot_avg_leaf_depth_hist(\n",
    "        rf, X_val, y_val,\n",
    "        bins=20, fontsize=fontsize_hist, ax=ax_top\n",
    "    )\n",
    "\n",
    "    plot_cumulative_impurity(\n",
    "        rf, max_depth_for_imp=max_depth_for_imp,\n",
    "        coverages=(0.90, 0.95),\n",
    "        fontsize=fontsize_imp, ax=ax_bottom\n",
    "    )\n",
    "\n",
    "    b, t, l, r = margins\n",
    "    fig.subplots_adjust(left=l, right=r, bottom=b, top=t)\n",
    "\n",
    "    return fig, {\"left\": ax_left, \"top\": ax_top, \"bottom\": ax_bottom}\n",
    "\n",
    "\n",
    "try:\n",
    "    D = getattr(rf_final, \"n_features_in_\", np.asarray(X_val).shape[1])\n",
    "except Exception:\n",
    "    D = np.asarray(X_val).shape[1]\n",
    "feature_names = [f\"emb[{i}]\" for i in range(D)]\n",
    "example_x = np.asarray(X_val)[0] if X_val is not None and len(X_val) > 0 else None\n",
    "\n",
    "fig, axes = plot_rf_custom_layout(\n",
    "    rf_final,\n",
    "    feature_names=feature_names,\n",
    "    example_x=example_x,\n",
    "    X_val=X_val, y_val=y_val,    \n",
    "    figsize=(19.5, 7), dpi=300,\n",
    "    left_width=7.9, right_width=1.9,\n",
    "    right_top_frac=0.4, right_bottom_frac=0.4,\n",
    "    wspace=0.1, hspace_right=0.30,\n",
    "    margins=(0.06, 0.98, 0.05, 0.99),\n",
    "    tree_idx=0, max_depth_tree=3, max_depth_for_imp=20,\n",
    "    fontsize_tree=9, fontsize_hist=9, fontsize_imp=9,\n",
    ")\n",
    "\n",
    "plt.show()\n",
    "fig.savefig(\"RF.png\", dpi=600, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d9ecac4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import gridspec\n",
    "import re  \n",
    "\n",
    "def plot_rf_custom_layout(\n",
    "    rf,\n",
    "    feature_names=None,\n",
    "    example_x=None,\n",
    "    val_X=None,\n",
    "    val_y=None,\n",
    "    *,\n",
    "    figsize=(19.5, 12), dpi=900,\n",
    "    top_height=10.2, bottom_height=6.2,   # controls row heights\n",
    "    wspace_bottom=0.1, hspace_rows=0.1,\n",
    "    margins=(0.06, 0.98, 0.0, 0.99),    # (bottom, top, left, right)\n",
    "    tree_idx=0, max_depth_tree=3, max_depth_for_imp=20,\n",
    "    fontsize_tree=13, fontsize_hist=13, fontsize_imp=13,\n",
    "):\n",
    "\n",
    "    fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "\n",
    "    outer = gridspec.GridSpec(\n",
    "        nrows=2, ncols=2, figure=fig,\n",
    "        height_ratios=[top_height, bottom_height],\n",
    "        wspace=0.0,         # no col spacing needed for the top row (spans both)\n",
    "        hspace=hspace_rows  # space between the two rows\n",
    "    )\n",
    "\n",
    "    ax_tree = fig.add_subplot(outer[0, :])\n",
    "\n",
    "       bottom = gridspec.GridSpecFromSubplotSpec(\n",
    "        1, 2, subplot_spec=outer[1, :], wspace=wspace_bottom\n",
    "    )\n",
    "    ax_hist   = fig.add_subplot(bottom[0, 0])\n",
    "    ax_imp    = fig.add_subplot(bottom[0, 1])\n",
    "\n",
    "    ax_hist.set_prop_cycle(plt.cycler(color=['#357ABD']))\n",
    "    ax_imp.set_prop_cycle(plt.cycler(color=['#357ABD']))\n",
    "\n",
    "    plot_sample_tree(\n",
    "        rf, tree_idx=tree_idx, max_depth=max_depth_tree,\n",
    "        feature_names=feature_names, example_x=example_x,\n",
    "        fontsize=fontsize_tree, ax=ax_tree\n",
    "    )\n",
    "\n",
    "    for txt in getattr(ax_tree, \"texts\", []):\n",
    "        original = txt.get_text()\n",
    "        def _to_pct(m):\n",
    "            val = float(m.group(0))\n",
    "            return f\"{val * 100:.1f}%\"\n",
    "        updated = re.sub(r\"0?\\.\\d+\", _to_pct, original)\n",
    "        if updated != original:\n",
    "            txt.set_text(updated)\n",
    "\n",
    "    plot_avg_leaf_depth_hist(\n",
    "        rf, val_X=val_X, val_y=val_y,\n",
    "        bins=20, fontsize=fontsize_hist, ax=ax_hist\n",
    "    )\n",
    "\n",
    "    plot_cumulative_impurity(\n",
    "        rf, max_depth_for_imp=max_depth_for_imp,\n",
    "        fontsize=fontsize_imp, ax=ax_imp\n",
    "    )\n",
    "\n",
    "    # Margins\n",
    "    b, t, l, r = margins\n",
    "    fig.subplots_adjust(left=l, right=r, bottom=b, top=t)\n",
    "\n",
    "    return fig, {\"tree\": ax_tree, \"hist\": ax_hist, \"imp\": ax_imp}\n",
    "\n",
    "D = getattr(rf_final, \"n_features_in_\", np.asarray(val_X).shape[1])\n",
    "feature_names = [f\"emb[{i}]\" for i in range(D)]\n",
    "example_x = np.asarray(val_X)[0] if val_X is not None else None\n",
    "\n",
    "fig, axes = plot_rf_custom_layout(\n",
    "    rf_final,\n",
    "    feature_names=feature_names,\n",
    "    example_x=example_x,\n",
    "    val_X=val_X, val_y=val_y,\n",
    "    figsize=(19.5, 12), dpi=300,\n",
    "    top_height=10.2, bottom_height=6.2, \n",
    "    wspace_bottom=0.1, hspace_rows=0.1,\n",
    "    margins=(0.06, 0.98, 0.02, 0.99),\n",
    "    tree_idx=0, max_depth_tree=3, max_depth_for_imp=20,\n",
    "    fontsize_tree=13, fontsize_hist=13, fontsize_imp=13,\n",
    ")\n",
    "plt.show()\n",
    "fig.savefig(\"RF.png\", dpi=900, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25b0cdf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import RepeatedStratifiedKFold\n",
    "\n",
    "all_dataset = train_dataset + val_dataset\n",
    "\n",
    "X = np.array([d.x.mean(dim=0).cpu().numpy() for d in all_dataset])\n",
    "y = np.array([d.y.item() for d in all_dataset])\n",
    "\n",
    "groups = None\n",
    "\n",
    "rf_params = dict(\n",
    "    n_estimators=400,\n",
    "    random_state=42,\n",
    "    bootstrap=True,\n",
    "    oob_score=True,\n",
    "    n_jobs=-1,\n",
    "    max_depth=12,\n",
    "    min_samples_leaf=10,\n",
    "    min_samples_split=20,\n",
    "    max_features=55\n",
    ")\n",
    "\n",
    "accs, aucs = [], []\n",
    "\n",
    "if groups is None:\n",
    "    splitter = RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=42)\n",
    "    splits = splitter.split(X, y)\n",
    "else:\n",
    "    try:\n",
    "        from sklearn.model_selection import StratifiedGroupKFold\n",
    "        sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)\n",
    "        splits = sgkf.split(X, y, groups=groups)\n",
    "    except Exception:\n",
    "        from sklearn.model_selection import GroupKFold\n",
    "        gkf = GroupKFold(n_splits=5)\n",
    "        splits = gkf.split(X, groups=groups)  \n",
    "for fold, (tr, va) in enumerate(splits, 1):\n",
    "    rf = RandomForestClassifier(**rf_params)\n",
    "    rf.fit(X[tr], y[tr])\n",
    "    y_pred  = rf.predict(X[va])\n",
    "    y_proba = rf.predict_proba(X[va])[:, 1]\n",
    "    acc = accuracy_score(y[va], y_pred)\n",
    "    auc = roc_auc_score(y[va], y_proba)\n",
    "    accs.append(acc); aucs.append(auc)\n",
    "    print(f\"Fold {fold:02d}: Acc={acc:.4f}  AUC={auc:.4f}\")\n",
    "\n",
    "print(\"\\n=== Cross-Validation Summary ({} folds) ===\".format(len(accs)))\n",
    "print(\"Accuracy:  mean={:.4f}  std={:.4f}\".format(np.mean(accs), np.std(accs)))\n",
    "print(\"ROC-AUC :  mean={:.4f}  std={:.4f}\".format(np.mean(aucs), np.std(aucs)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea9ea62f",
   "metadata": {},
   "outputs": [],
   "source": [
    "importances = rf.feature_importances_\n",
    "plt.plot(importances)\n",
    "plt.title(\"Feature importances (Random Forest)\")\n",
    "plt.xlabel(\"Feature Index\")\n",
    "plt.ylabel(\"Importance\")\n",
    "plt.show()\n",
    "print(\"Top features:\", np.argsort(importances)[-10:][::-1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23e5b050",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "all_node_embeddings = []\n",
    "all_labels = [] \n",
    "\n",
    "# For actual graphs\n",
    "for g in actual_data_list:\n",
    "    all_node_embeddings.append(g.x.cpu().numpy())\n",
    "    all_labels.extend([0] * g.x.shape[0])\n",
    "\n",
    "# For GPT-generated graphs\n",
    "for g in gpt_data_list:\n",
    "    all_node_embeddings.append(g.x.cpu().numpy())\n",
    "    all_labels.extend([1] * g.x.shape[0])\n",
    "\n",
    "all_node_embeddings = np.vstack(all_node_embeddings)\n",
    "all_labels = np.array(all_labels)\n",
    "\n",
    "print(f\"Total node embeddings: {all_node_embeddings.shape}\")  # (total_nodes, 3072)\n",
    "\n",
    "# Project all embeddings to 2D using PCA\n",
    "pca = PCA(n_components=2, random_state=42)\n",
    "embeddings_2d = pca.fit_transform(all_node_embeddings)\n",
    "\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.scatter(\n",
    "    embeddings_2d[all_labels == 0, 0], embeddings_2d[all_labels == 0, 1],\n",
    "    c='blue', alpha=0.3, label='Actual Graphs', s=10\n",
    ")\n",
    "plt.scatter(\n",
    "    embeddings_2d[all_labels == 1, 0], embeddings_2d[all_labels == 1, 1],\n",
    "    c='red', alpha=0.3, label='GPT Graphs', s=10\n",
    ")\n",
    "plt.xlabel(\"PCA 1\")\n",
    "plt.ylabel(\"PCA 2\")\n",
    "plt.legend()\n",
    "plt.title(\"All Node Embeddings (PCA 2D)\\nActual=Blue, GPT=Red\")\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcefd660",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "def get_embeddings_and_labels(dataset, actual_data_list, gpt_data_list):\n",
    "    all_node_embeddings = []\n",
    "    all_labels = []\n",
    "\n",
    "    actual_ids = set(id(g) for g in actual_data_list)\n",
    "    gpt_ids = set(id(g) for g in gpt_data_list)\n",
    "\n",
    "    for g in dataset:\n",
    "        if id(g) in actual_ids:\n",
    "            label = 0\n",
    "        elif id(g) in gpt_ids:\n",
    "            label = 1\n",
    "        else:\n",
    "            label = -1  \n",
    "        all_node_embeddings.append(g.x.cpu().numpy())\n",
    "        all_labels.extend([label] * g.x.shape[0])\n",
    "\n",
    "    all_node_embeddings = np.vstack(all_node_embeddings)\n",
    "    all_labels = np.array(all_labels)\n",
    "    return all_node_embeddings, all_labels\n",
    "\n",
    "splits = {\n",
    "    'Train': train_dataset,\n",
    "    'Validation': val_dataset,\n",
    "    'Test': test_dataset\n",
    "}\n",
    "\n",
    "for split_name, split_dataset in splits.items():\n",
    "    embeddings, labels = get_embeddings_and_labels(split_dataset, actual_data_list, gpt_data_list)\n",
    "\n",
    "    # Project to 2D using PCA\n",
    "    if len(embeddings) > 2:\n",
    "        pca = PCA(n_components=2, random_state=42)\n",
    "        embeddings_2d = pca.fit_transform(embeddings)\n",
    "    else:\n",
    "        embeddings_2d = embeddings  \n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(8, 8))\n",
    "    plt.scatter(\n",
    "        embeddings_2d[labels == 0, 0], embeddings_2d[labels == 0, 1],\n",
    "        c='blue', alpha=0.3, label='Actual Graphs', s=10\n",
    "    )\n",
    "    plt.scatter(\n",
    "        embeddings_2d[labels == 1, 0], embeddings_2d[labels == 1, 1],\n",
    "        c='red', alpha=0.3, label='GPT Graphs', s=10\n",
    "    )\n",
    "    plt.xlabel(\"PCA 1\")\n",
    "    plt.ylabel(\"PCA 2\")\n",
    "    plt.legend()\n",
    "    plt.title(f\"{split_name} Node Embeddings (PCA 2D)\\nActual=Blue, GPT=Red\")\n",
    "    plt.grid(True)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "951da5e7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\n",
    "all_X = np.stack([data.x.mean(dim=0).numpy() for data in train_dataset + val_dataset + test_dataset])\n",
    "all_y = np.array([data.y.item() for data in train_dataset + val_dataset + test_dataset])\n",
    "\n",
    "pca = PCA(n_components=2, random_state=42)\n",
    "mean_X_pca = pca.fit_transform(all_X)\n",
    "plt.figure(figsize=(7,7))\n",
    "plt.scatter(mean_X_pca[all_y==0,0], mean_X_pca[all_y==0,1], c='blue', label='Actual', alpha=0.6)\n",
    "plt.scatter(mean_X_pca[all_y==1,0], mean_X_pca[all_y==1,1], c='red', label='GPT', alpha=0.6)\n",
    "plt.xlabel(\"PCA 1\")\n",
    "plt.ylabel(\"PCA 2\")\n",
    "plt.title(\"PCA of Mean-Pooled Graph Embeddings\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20913b15",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def mean_pooled_X_y(split_dataset):\n",
    "    all_X = np.stack([data.x.mean(dim=0).cpu().numpy() for data in split_dataset])\n",
    "    all_y = np.array([data.y.item() for data in split_dataset])\n",
    "    return all_X, all_y\n",
    "\n",
    "splits = {\n",
    "    'Train': train_dataset,\n",
    "    'Validation': val_dataset,\n",
    "    'Test': test_dataset\n",
    "}\n",
    "\n",
    "for split_name, split_dataset in splits.items():\n",
    "    all_X, all_y = mean_pooled_X_y(split_dataset)\n",
    "    \n",
    "    if len(all_X) > 2:\n",
    "        pca = PCA(n_components=2, random_state=42)\n",
    "        mean_X_pca = pca.fit_transform(all_X)\n",
    "    else:\n",
    "        mean_X_pca = all_X  \n",
    "\n",
    "    plt.figure(figsize=(7, 7))\n",
    "    plt.scatter(mean_X_pca[all_y==0,0], mean_X_pca[all_y==0,1], c='blue', label='Actual', alpha=0.6)\n",
    "    plt.scatter(mean_X_pca[all_y==1,0], mean_X_pca[all_y==1,1], c='red', label='GPT', alpha=0.6)\n",
    "    plt.xlabel(\"PCA 1\")\n",
    "    plt.ylabel(\"PCA 2\")\n",
    "    plt.title(f\"PCA of Mean-Pooled Graph Embeddings: {split_name} Split\")\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4186f929",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,4))\n",
    "plt.hist(mean_X_pca[all_y==0, 0], bins=40, alpha=0.5, color='blue', label='Actual', density=True)\n",
    "plt.hist(mean_X_pca[all_y==1, 0], bins=40, alpha=0.5, color='red', label='GPT', density=True)\n",
    "plt.xlabel(\"PCA Component 1\")\n",
    "plt.ylabel(\"Density\")\n",
    "plt.title(\"Distribution of First PCA Component (Mean-Pooled Embeddings)\")\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5996d187",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df = pd.DataFrame({\n",
    "    'pca1': mean_X_pca[:, 0],\n",
    "    'category': ['Actual' if y==0 else 'GPT' for y in all_y]\n",
    "})\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "sns.violinplot(x='category', y='pca1', data=df, palette=['blue', 'red'])\n",
    "plt.title(\"Violin Plot of First PCA Component\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11fc14a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "importances = rf.feature_importances_\n",
    "top_k = 30 \n",
    "\n",
    "indices = np.argsort(importances)[::-1][:top_k]\n",
    "\n",
    "print(\"Top important feature indices and their importances:\")\n",
    "for i in indices:\n",
    "    print(f\"Feature {i}: Importance = {importances[i]:.5f}\")\n",
    "\n",
    "plt.figure(figsize=(12,6))\n",
    "plt.bar(range(top_k), importances[indices])\n",
    "plt.xticks(range(top_k), indices, rotation=45)\n",
    "plt.title(\"Top 30 Feature Importances in Random Forest\")\n",
    "plt.xlabel(\"Feature Index\")\n",
    "plt.ylabel(\"Importance\")\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d9453f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(25, 8))\n",
    "box = plt.boxplot(\n",
    "    data_to_plot,\n",
    "    labels=labels,\n",
    "    patch_artist=True,\n",
    "    showfliers=False\n",
    ")\n",
    "\n",
    "for i, patch in enumerate(box['boxes']):\n",
    "    if i % 2 == 0:\n",
    "        patch.set_facecolor('#4f8df7')  \n",
    "    else:\n",
    "        patch.set_facecolor('#f95d6a')  \n",
    "\n",
    "for i, median in enumerate(box['medians']):\n",
    "    if i % 2 == 0:\n",
    "        median.set(color='#0050b3', linewidth=2)  \n",
    "    else:\n",
    "        median.set(color='#b22222', linewidth=2) \n",
    "\n",
    "plt.title(\"Boxplot of Important Features (Blue: Class 0, Red: Class 1)\")\n",
    "plt.ylabel(\"Feature Value\")\n",
    "plt.xlabel(\"Feature Index and Class\")\n",
    "plt.xticks(rotation=90)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6696acf7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
