{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e79e8ea1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# import torch\n",
    "# from torch_geometric.data import InMemoryDataset\n",
    "\n",
    "# class MyGraphDataset(InMemoryDataset):\n",
    "#     def __init__(self, data_list=None):\n",
    "#         super().__init__('.')\n",
    "#         if data_list is not None:\n",
    "#             self.data, self.slices = self.collate(data_list)\n",
    "\n",
    "# paths ---\n",
    "# orig_ds_path = base_path / \"my_graph_dataset_groundtruth_random.pt\"\n",
    "# csv_path     = base_path / \"small_graph_ids.csv\"\n",
    "# out_path     = base_path / \"my_graph_dataset_groundtruth_randomfiltered.pt\"\n",
    "\n",
    "# # load original dataset ---\n",
    "# dataset = torch.load(orig_ds_path)\n",
    "\n",
    "# # read IDs to remove ---\n",
    "# ids_to_remove = set(pd.read_csv(csv_path)['graph_id'].tolist())\n",
    "\n",
    "# #  filter out any Data objects whose .graph_id is in that set ---\n",
    "# filtered_list = [data for data in dataset \n",
    "#                  if getattr(data, 'graph_id', None) not in ids_to_remove]\n",
    "\n",
    "# #  wrap back into your dataset class and save ---\n",
    "# filtered_ds = MyGraphDataset(filtered_list)\n",
    "# torch.save(filtered_ds, out_path)\n",
    "\n",
    "# print(f\"Original size: {len(dataset)}\")\n",
    "# print(f\"Filtered size: {len(filtered_ds)}\")\n",
    "# print(f\"Saved filtered dataset to {out_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9850b9d0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import to_networkx\n",
    "import torch\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from collections import Counter\n",
    "import numpy as np\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import defaultdict\n",
    "\n",
    "class MyGraphDataset(InMemoryDataset):\n",
    "    def __init__(self, data_list=None):\n",
    "        super().__init__('.')\n",
    "        if data_list is not None:\n",
    "            self.data, self.slices = self.collate(data_list)\n",
    "\n",
    "new_path = base_path / \"my_graph_dataset_groundtruth_random.pt\"\n",
    "dataset = torch.load(new_path)\n",
    "\n",
    "graph_types = [data.graph_type for data in dataset]\n",
    "type_counts = Counter(graph_types)\n",
    "\n",
    "print(\"\\nGraph categories in the dataset:\")\n",
    "for category, count in type_counts.items():\n",
    "    print(f\"  {category}: {count}\")\n",
    "print(f\"Total graphs: {len(dataset)}\")\n",
    "\n",
    "sample = dataset[401]\n",
    "print(f\"Feature shape of graph: {sample.x.shape}\")\n",
    "print(f\"\\nGraph category (type): {sample.graph_type}\")\n",
    "print(f\"Graph ID: {sample.graph_id}\")\n",
    "\n",
    "if hasattr(sample, 'y'):\n",
    "    print(f\"Graph Label (y): {sample.y.item()}\")\n",
    "else:\n",
    "    print(\"Graph Label (y): Not assigned\")\n",
    "\n",
    "print(\"\\nNode features:\")\n",
    "for i, feat in enumerate(sample.x):\n",
    "    print(f\"Node {i}: {feat.tolist()}\")\n",
    "\n",
    "# G = to_networkx(sample, to_undirected=True)\n",
    "# plt.figure(figsize=(6, 6))\n",
    "# nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500)\n",
    "# plt.title(\"Sample Graph from Loaded Dataset\")\n",
    "# plt.show()\n",
    "\n",
    "sample = dataset[400]\n",
    "print(f\"Feature shape of graph: {sample.x.shape}\")\n",
    "print(f\"\\nGraph category (type): {sample.graph_type}\")\n",
    "print(f\"Graph ID: {sample.graph_id}\")\n",
    "\n",
    "if hasattr(sample, 'y'):\n",
    "    print(f\"Graph Label (y): {sample.y.item()}\")\n",
    "else:\n",
    "    print(\"Graph Label (y): Not assigned\")\n",
    "\n",
    "print(\"\\nNode features:\")\n",
    "for i, feat in enumerate(sample.x):\n",
    "    print(f\"Node {i}: {feat.tolist()}\")\n",
    "\n",
    "# G = to_networkx(sample, to_undirected=True)\n",
    "# plt.figure(figsize=(6, 6))\n",
    "# nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500)\n",
    "# plt.title(\"Sample Graph from Loaded Dataset\")\n",
    "# plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f20aa2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import defaultdict\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from torch_geometric.loader import DataLoader \n",
    "class MyGraphDataset(InMemoryDataset):\n",
    "    def __init__(self, data_list=None):\n",
    "        super().__init__('.')\n",
    "        if data_list is not None:\n",
    "            self.data, self.slices = self.collate(data_list)\n",
    "\n",
    "\n",
    "#Group all graphs by their original graph_id\n",
    "graph_id_groups = defaultdict(list)\n",
    "for data in dataset:\n",
    "    graph_id_groups[data.graph_id].append(data)\n",
    "\n",
    "all_graph_ids = list(graph_id_groups.keys())\n",
    "\n",
    "random_seed = 42\n",
    "\n",
    "# Split graph_ids into train (70%) and temp (30%)\n",
    "train_ids, temp_ids = train_test_split(\n",
    "    all_graph_ids,\n",
    "    test_size=0.3,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "# Split temp into validation (15%) and test (15%)\n",
    "val_ids, test_ids = train_test_split(\n",
    "    temp_ids,\n",
    "    test_size=0.5,\n",
    "    random_state=random_seed\n",
    ")\n",
    "\n",
    "# Helper to collect your Data objects back into lists\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "train_dataset = collect_by_ids(train_ids)\n",
    "val_dataset   = collect_by_ids(val_ids)\n",
    "test_dataset  = collect_by_ids(test_ids)\n",
    "\n",
    "# Assign simple graph_index and drop the original graph_id\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i              # new 0…N-1 index\n",
    "        if hasattr(data, 'graph_id'):     # remove leak source\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):   # remove type attribute\n",
    "            del data.graph_type\n",
    "            \n",
    "clean_and_index(train_dataset)\n",
    "clean_and_index(val_dataset)\n",
    "clean_and_index(test_dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a2bd679",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def inspect_first_n(dataset_split, split_name, n=3):\n",
    "    \"\"\"\n",
    "    Print keys, shapes, labels, and a small sample of node features\n",
    "    and edges for the first n Data objects in the split.\n",
    "    \"\"\"\n",
    "    print(f\"\\n--- First {n} samples from {split_name.upper()} ---\")\n",
    "    for i, data in enumerate(dataset_split[:n]):\n",
    "        print(f\"\\nSample #{i} in {split_name}:\")\n",
    "        # print attribute list\n",
    "        print(\"  Keys            :\", data.keys())\n",
    "        # print shapes\n",
    "        print(\"  x.shape         :\", tuple(data.x.shape))\n",
    "        print(\"  edge_index.shape:\", tuple(data.edge_index.shape))\n",
    "        print(\"  y (label)       :\", data.y.item())\n",
    "        print(\"  graph_index     :\", data.graph_index)\n",
    "        \n",
    "        # print first 5 node-features rows\n",
    "        num_nf = min(50, data.x.size(0))\n",
    "        print(f\"  Sample node features (first {num_nf} rows):\")\n",
    "        print(data.x[:num_nf])\n",
    "        \n",
    "        # print first 5 edges as (source, target) tuples\n",
    "        num_e = min(50, data.edge_index.size(1))\n",
    "        edges = data.edge_index[:, :num_e].t().tolist()\n",
    "        print(f\"  Sample edges (first {num_e}):\", edges)\n",
    "\n",
    "# Call for each split\n",
    "inspect_first_n(train_dataset, \"train\", n=3)\n",
    "inspect_first_n(val_dataset,   \"val\",   n=3)\n",
    "inspect_first_n(test_dataset,  \"test\",  n=3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7227c7a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class0_feats = []\n",
    "class1_feats = []\n",
    "\n",
    "for data in train_dataset:\n",
    "    feat_mean = data.x.mean(dim=0).cpu().numpy()\n",
    "    if data.y.item() == 0:\n",
    "        class0_feats.append(feat_mean)\n",
    "    else:\n",
    "        class1_feats.append(feat_mean)\n",
    "\n",
    "class0_feats = np.vstack(class0_feats)\n",
    "class1_feats = np.vstack(class1_feats)\n",
    "\n",
    "print(\"Class 0 mean:\", class0_feats.mean(axis=0))\n",
    "print(\"Class 1 mean:\", class1_feats.mean(axis=0))\n",
    "print(\"Class 0 std:\", class0_feats.std(axis=0))\n",
    "print(\"Class 1 std:\", class1_feats.std(axis=0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "006bc152",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_features = class0_feats.shape[1]\n",
    "plt.figure(figsize=(num_features*3, 5))\n",
    "\n",
    "for i in range(num_features):\n",
    "    plt.subplot(1, num_features, i+1)\n",
    "    plt.boxplot([class0_feats[:, i], class1_feats[:, i]], labels=['Class 0', 'Class 1'])\n",
    "    plt.title(f'Feature {i}')\n",
    "    plt.ylabel('Feature value')\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1478c88",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_features = class0_feats.shape[1]\n",
    "plt.figure(figsize=(num_features*3, 5))\n",
    "\n",
    "for i in range(num_features):\n",
    "    plt.subplot(1, num_features, i+1)\n",
    "    plt.hist(class0_feats[:, i], bins=15, alpha=0.5, label='Class 0')\n",
    "    plt.hist(class1_feats[:, i], bins=15, alpha=0.5, label='Class 1')\n",
    "    plt.title(f'Feature {i}')\n",
    "    plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efeef12d",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.vstack([class0_feats, class1_feats])\n",
    "y = np.array([0]*len(class0_feats) + [1]*len(class1_feats))\n",
    "\n",
    "pca = PCA(n_components=2)\n",
    "X_pca = pca.fit_transform(X)\n",
    "\n",
    "plt.figure(figsize=(6,6))\n",
    "plt.scatter(X_pca[y==0, 0], X_pca[y==0, 1], label='Class 0', alpha=0.6)\n",
    "plt.scatter(X_pca[y==1, 0], X_pca[y==1, 1], label='Class 1', alpha=0.6)\n",
    "plt.title('PCA of Graph Mean Features')\n",
    "plt.xlabel('PC1')\n",
    "plt.ylabel('PC2')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50c4c247",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# For all graphs in train/val set:\n",
    "train_X = []\n",
    "train_y = []\n",
    "for data in train_dataset:\n",
    "    # Mean of all features across nodes\n",
    "    train_X.append(data.x.mean(dim=0).numpy())\n",
    "    train_y.append(data.y.item())\n",
    "\n",
    "val_X = []\n",
    "val_y = []\n",
    "for data in val_dataset:\n",
    "    val_X.append(data.x.mean(dim=0).numpy())\n",
    "    val_y.append(data.y.item())\n",
    "\n",
    "clf = LogisticRegression().fit(train_X, train_y)\n",
    "pred = clf.predict(val_X)\n",
    "print(\"Val accuracy (all features):\", accuracy_score(val_y, pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d93ba315",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# Prepare data\n",
    "train_X = [data.x.mean(dim=0).numpy() for data in train_dataset]\n",
    "train_y = [data.y.item() for data in train_dataset]\n",
    "val_X   = [data.x.mean(dim=0).numpy() for data in val_dataset]\n",
    "val_y   = [data.y.item() for data in val_dataset]\n",
    "\n",
    "# Train Random Forest\n",
    "rf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "rf.fit(train_X, train_y)\n",
    "pred = rf.predict(val_X)\n",
    "print(\"Val accuracy (Random Forest, mean features):\", accuracy_score(val_y, pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01327c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "all_acc = []\n",
    "all_f1_macro = []\n",
    "all_f1_weighted = []\n",
    "\n",
    "N_RUNS = 10  \n",
    "\n",
    "all_graph_ids = list(set([data.graph_id for data in dataset]))\n",
    "graph_id_groups = defaultdict(list)\n",
    "for data in dataset:\n",
    "    graph_id_groups[data.graph_id].append(data)\n",
    "\n",
    "def collect_by_ids(graph_ids):\n",
    "    return [data for gid in graph_ids for data in graph_id_groups[gid]]\n",
    "\n",
    "def clean_and_index(dataset_split):\n",
    "    for i, data in enumerate(dataset_split):\n",
    "        data.graph_index = i\n",
    "        if hasattr(data, 'graph_id'):\n",
    "            del data.graph_id\n",
    "        if hasattr(data, 'graph_type'):\n",
    "            del data.graph_type\n",
    "\n",
    "for run_seed in range(N_RUNS):\n",
    "    # Split\n",
    "    train_ids, temp_ids = train_test_split(\n",
    "        all_graph_ids, test_size=0.3, random_state=run_seed\n",
    "    )\n",
    "    val_ids, test_ids = train_test_split(\n",
    "        temp_ids, test_size=0.5, random_state=run_seed\n",
    "    )\n",
    "\n",
    "    train_dataset = collect_by_ids(train_ids)\n",
    "    val_dataset   = collect_by_ids(val_ids)\n",
    "    test_dataset  = collect_by_ids(test_ids)\n",
    "\n",
    "    clean_and_index(train_dataset)\n",
    "    clean_and_index(val_dataset)\n",
    "    clean_and_index(test_dataset)\n",
    "\n",
    "    # Prepare data (mean node features)\n",
    "    train_X = [data.x.mean(dim=0).numpy() for data in train_dataset]\n",
    "    train_y = [data.y.item() for data in train_dataset]\n",
    "    val_X   = [data.x.mean(dim=0).numpy() for data in val_dataset]\n",
    "    val_y   = [data.y.item() for data in val_dataset]\n",
    "\n",
    "    rf = RandomForestClassifier(n_estimators=100, random_state=run_seed)\n",
    "    rf.fit(train_X, train_y)\n",
    "    pred = rf.predict(val_X)\n",
    "\n",
    "    acc = accuracy_score(val_y, pred)\n",
    "    f1_macro = f1_score(val_y, pred, average='macro')\n",
    "    f1_weighted = f1_score(val_y, pred, average='weighted')\n",
    "\n",
    "    all_acc.append(acc)\n",
    "    all_f1_macro.append(f1_macro)\n",
    "    all_f1_weighted.append(f1_weighted)\n",
    "\n",
    "print(f\"Random Forest (mean node features) - {N_RUNS} runs\")\n",
    "print(\"Val accuracy   : %.4f ± %.4f\" % (np.mean(all_acc), np.std(all_acc)))\n",
    "print(\"F1-macro       : %.4f ± %.4f\" % (np.mean(all_f1_macro), np.std(all_f1_macro)))\n",
    "print(\"F1-weighted    : %.4f ± %.4f\" % (np.mean(all_f1_weighted), np.std(all_f1_weighted)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f62e288",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
