{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f76502",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pickle\n",
    "import torch\n",
    "from torch_geometric.utils import from_networkx\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from torch_geometric.utils import to_networkx\n",
    "from collections import Counter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc294c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = Path(\".\")\n",
    "merged_pickle_path = base_path / \"merged_graphs.pickle\"\n",
    "with open(merged_pickle_path, 'rb') as handle:\n",
    "    merged_graphs = pickle.load(handle)\n",
    "\n",
    "# Explore the content\n",
    "print(\"Type of loaded object:\", type(merged_graphs))\n",
    "\n",
    "if isinstance(merged_graphs, dict):\n",
    "    print(\"Number of keys:\", len(merged_graphs))\n",
    "    print(\"Sample keys:\", list(merged_graphs.keys())[:5])\n",
    "    first_item = next(iter(merged_graphs.values()))\n",
    "elif isinstance(merged_graphs, list):\n",
    "    print(\"Length of list:\", len(merged_graphs))\n",
    "    print(\"Type of first item:\", type(merged_graphs[0]))\n",
    "    first_item = merged_graphs[0]\n",
    "else:\n",
    "    print(\"Loaded object is of unrecognized type\")\n",
    "    first_item = merged_graphs\n",
    "\n",
    "print(\"\\nDetails of first item:\")\n",
    "if hasattr(first_item, '__dict__'):\n",
    "    for attr, val in vars(first_item).items():\n",
    "        print(f\"{attr}: {type(val)}\")\n",
    "else:\n",
    "    print(first_item)\n",
    "import networkx as nx\n",
    "\n",
    "sample_id = list(merged_graphs.keys())[3]\n",
    "sample_graphs = merged_graphs[sample_id]\n",
    "\n",
    "for key, graph in sample_graphs.items():\n",
    "    print(f\"\\n--- {key} ---\")\n",
    "    print(\"Type:\", type(graph))\n",
    "    print(\"Number of nodes:\", graph.number_of_nodes())\n",
    "    print(\"Number of edges:\", graph.number_of_edges())\n",
    "    print(\"Sample nodes:\", list(graph.nodes(data=True))[:3])\n",
    "    print(\"Sample edges:\", list(graph.edges(data=True))[:3])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26589536",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_node_features(G: nx.Graph):\n",
    "    deg_cent = nx.degree_centrality(G)\n",
    "    close_cent = nx.closeness_centrality(G)\n",
    "    eigen_cent = nx.eigenvector_centrality_numpy(G, max_iter=500)\n",
    "    cluster_coef = nx.clustering(G)\n",
    "    num_edges = G.number_of_edges()\n",
    "\n",
    "    for node in G.nodes():\n",
    "        G.nodes[node]['x'] = [\n",
    "            deg_cent[node],\n",
    "            close_cent[node],\n",
    "            eigen_cent[node],\n",
    "            cluster_coef[node],\n",
    "            float(num_edges)\n",
    "        ]\n",
    "\n",
    "    return G\n",
    "\n",
    "def convert_to_pyg_data(graph: nx.Graph):\n",
    "    graph = graph.to_undirected()\n",
    "    graph = extract_node_features(graph)\n",
    "\n",
    "    pyg_data = from_networkx(graph)\n",
    "\n",
    "    x = [graph.nodes[n]['x'] for n in graph.nodes()]\n",
    "    pyg_data.x = torch.tensor(x, dtype=torch.float)\n",
    "\n",
    "    return pyg_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9efb98e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#convert the merge file which contains rndom and ground truth and gpt graphs to a file which contains gpt and ground truth graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b5b7ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_pyg_data(graph: nx.Graph, graph_type: str, graph_id: int):\n",
    "    graph = graph.to_undirected()\n",
    "    graph = extract_node_features(graph)\n",
    "    graph = nx.convert_node_labels_to_integers(graph)\n",
    "\n",
    "    pyg_data = from_networkx(graph)\n",
    "    pyg_data.x = torch.tensor([graph.nodes[n]['x'] for n in graph.nodes()], dtype=torch.float)\n",
    "    \n",
    "    # Add metadata\n",
    "    pyg_data.graph_type = 'groundtruth_graph' if graph_type == 'groundtruth_graph' else 'gpt_generated_graph'\n",
    "    pyg_data.graph_id = graph_id\n",
    "    return pyg_data\n",
    "\n",
    "filtered_data_list = []\n",
    "\n",
    "for sample_id, graph_dict in merged_graphs.items():\n",
    "    for graph_type in ['groundtruth_graph', 'gpt_generated_graph']:\n",
    "        if graph_type in graph_dict:\n",
    "            pyg_data = convert_to_pyg_data(graph_dict[graph_type], graph_type, graph_id=sample_id)\n",
    "            \n",
    "            # Assign labels\n",
    "            if pyg_data.graph_type == 'groundtruth_graph':\n",
    "                pyg_data.y = torch.tensor([0], dtype=torch.long)\n",
    "            elif pyg_data.graph_type == 'gpt_generated_graph':\n",
    "                pyg_data.y = torch.tensor([1], dtype=torch.long)\n",
    "\n",
    "            filtered_data_list.append(pyg_data)\n",
    "\n",
    "print(f\"Kept {len(filtered_data_list)} graphs (removed 'random_graph' category).\")\n",
    "print(f\"Labels assigned: groundtruth_graph → 0, gpt_generated_graph → 1\")\n",
    "\n",
    "class MyGraphDataset(InMemoryDataset):\n",
    "    def __init__(self, data_list=None):\n",
    "        super().__init__('.')\n",
    "        if data_list is not None:\n",
    "            self.data, self.slices = self.collate(data_list)\n",
    "\n",
    "filtered_dataset = MyGraphDataset(filtered_data_list)\n",
    "\n",
    "# Save to file\n",
    "new_path = base_path / \"my_graph_dataset_gpt_groundtruth.pt\"\n",
    "torch.save(filtered_dataset, new_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef28e95d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#convert the merge file to file with random and ground truth graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c093cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_pyg_data(graph: nx.Graph, graph_type: str, graph_id: int):\n",
    "    graph = graph.to_undirected()\n",
    "    graph = extract_node_features(graph)\n",
    "    graph = nx.convert_node_labels_to_integers(graph)\n",
    "\n",
    "    pyg_data = from_networkx(graph)\n",
    "    pyg_data.x = torch.tensor([graph.nodes[n]['x'] for n in graph.nodes()], dtype=torch.float)\n",
    "\n",
    "    # Assign metadata\n",
    "    pyg_data.graph_type = graph_type\n",
    "    pyg_data.graph_id = graph_id\n",
    "    return pyg_data\n",
    "\n",
    "filtered_data_list = []\n",
    "\n",
    "for sample_id, graph_dict in merged_graphs.items():\n",
    "    for graph_type in ['groundtruth_graph', 'random_graph']:\n",
    "        if graph_type in graph_dict:\n",
    "            pyg_data = convert_to_pyg_data(graph_dict[graph_type], graph_type, graph_id=sample_id)\n",
    "\n",
    "            # Assign labels\n",
    "            if graph_type == 'groundtruth_graph':\n",
    "                pyg_data.graph_type = 'groundtruth_graph'\n",
    "                pyg_data.y = torch.tensor([0], dtype=torch.long)\n",
    "            elif graph_type == 'random_graph':\n",
    "                pyg_data.graph_type = 'random_graph'\n",
    "                pyg_data.y = torch.tensor([1], dtype=torch.long)\n",
    "\n",
    "            filtered_data_list.append(pyg_data)\n",
    "\n",
    "print(f\"Kept {len(filtered_data_list)} graphs (removed 'gpt_generated_graph' category).\")\n",
    "print(f\"Labels assigned: groundtruth_graph → 0, random_graph → 1\")\n",
    "\n",
    "class MyGraphDataset(InMemoryDataset):\n",
    "    def __init__(self, data_list=None):\n",
    "        super().__init__('.')\n",
    "        if data_list is not None:\n",
    "            self.data, self.slices = self.collate(data_list)\n",
    "\n",
    "filtered_dataset = MyGraphDataset(filtered_data_list)\n",
    "new_path = base_path / \"my_graph_dataset_groundtruth_random.pt\"\n",
    "torch.save(filtered_dataset, new_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6852e8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "#convert the merge file to file with random and gpt graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "100781d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_pyg_data(graph: nx.Graph, graph_type: str, graph_id: int):\n",
    "    graph = graph.to_undirected()\n",
    "    graph = extract_node_features(graph)\n",
    "    graph = nx.convert_node_labels_to_integers(graph)\n",
    "\n",
    "    pyg_data = from_networkx(graph)\n",
    "    pyg_data.x = torch.tensor([graph.nodes[n]['x'] for n in graph.nodes()], dtype=torch.float)\n",
    "    \n",
    "    # Metadata\n",
    "    pyg_data.graph_type = graph_type\n",
    "    pyg_data.graph_id = graph_id\n",
    "    return pyg_data\n",
    "\n",
    "filtered_data_list = []\n",
    "\n",
    "for sample_id, graph_dict in merged_graphs.items():\n",
    "    for graph_type in ['gpt_generated_graph', 'random_graph']:\n",
    "        if graph_type in graph_dict:\n",
    "            pyg_data = convert_to_pyg_data(graph_dict[graph_type], graph_type, graph_id=sample_id)\n",
    "\n",
    "            # Assign labels\n",
    "            if graph_type == 'gpt_generated_graph':\n",
    "                pyg_data.graph_type = 'gpt_generated_graph'\n",
    "                pyg_data.y = torch.tensor([0], dtype=torch.long)\n",
    "            elif graph_type == 'random_graph':\n",
    "                pyg_data.graph_type = 'random_graph'\n",
    "                pyg_data.y = torch.tensor([1], dtype=torch.long)\n",
    "\n",
    "            filtered_data_list.append(pyg_data)\n",
    "\n",
    "print(f\"Kept {len(filtered_data_list)} graphs (removed 'groundtruth_graph' category).\")\n",
    "print(f\"Labels assigned: gpt_generated_graph → 0, random_graph → 1\")\n",
    "\n",
    "class MyGraphDataset(InMemoryDataset):\n",
    "    def __init__(self, data_list=None):\n",
    "        super().__init__('.')\n",
    "        if data_list is not None:\n",
    "            self.data, self.slices = self.collate(data_list)\n",
    "\n",
    "filtered_dataset = MyGraphDataset(filtered_data_list)\n",
    "\n",
    "new_path = base_path / \"my_graph_dataset_gpt_random.pt\"\n",
    "torch.save(filtered_dataset, new_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc55934e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
