{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c8c2ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import networkx as nx\n",
    "from networkx import from_dict_of_lists, subgraph\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np \n",
    "import seaborn as sns \n",
    "import matplotlib.gridspec as gridspec\n",
    "import pickle \n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import random\n",
    "from scipy.sparse import coo_matrix, csr_matrix, csc_matrix\n",
    "import gc\n",
    "from tqdm import tqdm\n",
    "import csv\n",
    "from pathlib import Path\n",
    "\n",
    "#Print versions of pandas, networkx, matplotlib, numpy, and seaborn \n",
    "print(pd.__version__)\n",
    "print(nx.__version__)\n",
    "print(plt.matplotlib.__version__)\n",
    "print(np.__version__)\n",
    "print(sns.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61c736e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "base_path = Path(\".\")\n",
    "\n",
    "removed_ids_path = base_path / \"removed_paperIDs.csv\"\n",
    "removed_paper_ids = set()\n",
    "\n",
    "\n",
    "with open(removed_ids_path, mode='r') as csv_file:\n",
    "    reader = csv.DictReader(csv_file)\n",
    "    for row in reader:\n",
    "        removed_paper_ids.add(row['paper_id'].strip())\n",
    "\n",
    "\n",
    "# Load original graph data\n",
    "file_path = base_path / \"graphs_random_v2.pickle\"\n",
    "with open(file_path, \"rb\") as handle:\n",
    "    graphs = pickle.load(handle)\n",
    "    \n",
    "# Define color map and graph types \n",
    "color_map = {\n",
    "    'green': sns.color_palette('colorblind')[2],\n",
    "    'yellow': sns.color_palette('colorblind')[8],\n",
    "    'blue': sns.color_palette('colorblind')[0],\n",
    "    'orange': sns.color_palette('colorblind')[3],\n",
    "    'grey': sns.color_palette('colorblind')[7]\n",
    "}\n",
    "gpt_colors = [color_map['green'], color_map['yellow'], color_map['blue'], color_map['orange']]\n",
    "groundtruth_colors = [color_map['green'], color_map['grey'], color_map['blue']]\n",
    "blue_color = tuple(color_map['blue'])\n",
    "\n",
    "subgraphs = {}\n",
    "filtered_out_paper_ids = []\n",
    "\n",
    "for paper_id, graph_data in graphs.items():\n",
    "    paper_id = str(paper_id).strip()\n",
    "    if paper_id in removed_paper_ids:\n",
    "        continue\n",
    "\n",
    "    try:\n",
    "        nodes = graph_data['nodes']\n",
    "        edges = graph_data['edges']\n",
    "        node_colors = graph_data['node_color']\n",
    "\n",
    "        node_color_map = {\n",
    "            int(node): tuple(color) for node, color in zip(nodes, node_colors)\n",
    "        }\n",
    "\n",
    "        # Build full graph\n",
    "        G = nx.DiGraph()\n",
    "        G.add_nodes_from([int(node) for node in nodes])\n",
    "        G.add_edges_from([(int(u), int(v)) for u, v in edges])\n",
    "\n",
    "        # Extract ground truth subgraph\n",
    "        gt_nodes = [int(node) for node, color in zip(nodes, node_colors)\n",
    "                    if tuple(color) in groundtruth_colors]\n",
    "        groundtruth_graph = G.subgraph(gt_nodes).copy()\n",
    "\n",
    "        # Extract GPT-generated subgraph\n",
    "        gpt_nodes = [int(node) for node, color in zip(nodes, node_colors)\n",
    "                     if tuple(color) in gpt_colors]\n",
    "        gpt_generated_graph = G.subgraph(gpt_nodes).copy()\n",
    "\n",
    "        # Add GPT blue yellow/orange citation edges\n",
    "        for node, color in zip(nodes, node_colors):\n",
    "            node = int(node)\n",
    "            if tuple(color) == color_map['blue']:\n",
    "                for other_node, other_color in zip(nodes, node_colors):\n",
    "                    other_node = int(other_node)\n",
    "                    if tuple(other_color) in [color_map['yellow'], color_map['orange']]:\n",
    "                        gpt_generated_graph.add_edge(node, other_node)\n",
    "\n",
    "        # Balance node counts \n",
    "        gt_nodes_set = set(groundtruth_graph.nodes)\n",
    "        gpt_nodes_set = set(gpt_generated_graph.nodes)\n",
    "\n",
    "        if len(gt_nodes_set) != len(gpt_nodes_set):\n",
    "            random.seed(42)  # Set seed for reproducibility\n",
    "            gt_non_blue = [n for n in gt_nodes_set if node_color_map.get(n) != blue_color]\n",
    "            gpt_non_blue = [n for n in gpt_nodes_set if node_color_map.get(n) != blue_color]\n",
    "\n",
    "            while len(gt_nodes_set) > len(gpt_nodes_set) and gt_non_blue:\n",
    "                removed = random.choice(gt_non_blue)\n",
    "                groundtruth_graph.remove_node(removed)\n",
    "                gt_nodes_set.remove(removed)\n",
    "                gt_non_blue.remove(removed)\n",
    "\n",
    "            while len(gpt_nodes_set) > len(gt_nodes_set) and gpt_non_blue:\n",
    "                removed = random.choice(gpt_non_blue)\n",
    "                gpt_generated_graph.remove_node(removed)\n",
    "                gpt_nodes_set.remove(removed)\n",
    "                gpt_non_blue.remove(removed)\n",
    "\n",
    "            if len(gt_nodes_set) != len(gpt_nodes_set):\n",
    "                filtered_out_paper_ids.append(paper_id)\n",
    "                continue  \n",
    "\n",
    "        subgraphs[paper_id] = {\n",
    "            'Random': groundtruth_graph  \n",
    "        }\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing {paper_id}: {e}\")\n",
    "        filtered_out_paper_ids.append(paper_id)\n",
    "\n",
    "# Save the result\n",
    "output_path = base_path / \"random_graphs_filtered_v2.pickle\"\n",
    "with open(output_path, \"wb\") as handle:\n",
    "    pickle.dump(subgraphs, handle, protocol=pickle.HIGHEST_PROTOCOL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4951195",
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge gpt and ground truth graphs with random graphs\n",
    "\n",
    "# Load filtered GPT + Ground Truth graphs \n",
    "filtered_graphs_path = base_path / \"updated_subgraphs_v2.pickle\"\n",
    "with open(filtered_graphs_path, \"rb\") as f:\n",
    "    filtered_graphs = pickle.load(f)\n",
    "\n",
    "# Load random graphs and convert keys to int \n",
    "random_graphs_path = base_path / \"random_graphs_filtered_v2.pickle\"\n",
    "with open(random_graphs_path, \"rb\") as f:\n",
    "    raw_random_graphs = pickle.load(f)\n",
    "random_graphs = {int(k): v for k, v in raw_random_graphs.items()}\n",
    "\n",
    "# Load original graphs to recover node colors\n",
    "graphs_path = base_path / \"graphs_v2.pickle\"\n",
    "with open(graphs_path, \"rb\") as f:\n",
    "    full_graphs = pickle.load(f)\n",
    "\n",
    "# Merge and enrich with node color attributes\n",
    "merged_graphs = {}\n",
    "\n",
    "for paper_id in filtered_graphs:\n",
    "    paper_id_int = int(paper_id)\n",
    "\n",
    "    gt_graph = filtered_graphs[paper_id].get('groundtruth_graph')  \n",
    "    gpt_graph = filtered_graphs[paper_id].get('gpt_generated_graph')\n",
    "    random_graph = random_graphs.get(paper_id_int, {}).get('Random')\n",
    "\n",
    "    # Build node color map from original full graph\n",
    "    full_graph_data = full_graphs.get(paper_id_int)\n",
    "    if full_graph_data:\n",
    "        node_color_map = {\n",
    "            int(node): tuple(color)\n",
    "            for node, color in zip(full_graph_data['nodes'], full_graph_data['node_color'])\n",
    "        }\n",
    "    else:\n",
    "        node_color_map = {}\n",
    "\n",
    "    # Attach color to each graph's nodes\n",
    "    for graph in [gt_graph, gpt_graph, random_graph]:\n",
    "        if graph is not None:\n",
    "            for node in graph.nodes:\n",
    "                if node in node_color_map:\n",
    "                    graph.nodes[node]['color'] = node_color_map[node]\n",
    "\n",
    "    merged_graphs[paper_id_int] = {\n",
    "        'groundtruth_graph': gt_graph,\n",
    "        'gpt_generated_graph': gpt_graph,\n",
    "        'random_graph': random_graph\n",
    "    }\n",
    "\n",
    "merged_path = base_path / \"merged_graphs.pickle\"\n",
    "with open(merged_path, \"wb\") as f:\n",
    "    pickle.dump(merged_graphs, f, protocol=pickle.HIGHEST_PROTOCOL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c02f311",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Load merged graphs \n",
    "merged_path = base_path / \"merged_graphs.pickle\"\n",
    "with open(merged_path, \"rb\") as f:\n",
    "    merged_graphs = pickle.load(f)\n",
    "\n",
    "def clean_graph(g):\n",
    "    if g is None:\n",
    "        return None\n",
    "\n",
    "    # Convert to undirected\n",
    "    G = nx.Graph(g)\n",
    "\n",
    "    # Remove self-loops\n",
    "    G.remove_edges_from(nx.selfloop_edges(G))\n",
    "\n",
    "    if not nx.is_connected(G):\n",
    "        components = list(nx.connected_components(G))\n",
    "        largest_component_nodes = max(components, key=len)\n",
    "        G = G.subgraph(largest_component_nodes).copy()\n",
    "\n",
    "    # Remove color attribute from nodes\n",
    "    for node in G.nodes:\n",
    "        if 'color' in G.nodes[node]:\n",
    "            del G.nodes[node]['color']\n",
    "\n",
    "    return G\n",
    "\n",
    "#  Process all graphs \n",
    "cleaned_graphs = {}\n",
    "\n",
    "for paper_id, graphs in merged_graphs.items():\n",
    "    gt_graph = clean_graph(graphs.get('groundtruth_graph'))\n",
    "    gpt_graph = clean_graph(graphs.get('gpt_generated_graph'))\n",
    "    random_graph = clean_graph(graphs.get('random_graph'))\n",
    "\n",
    "    cleaned_graphs[paper_id] = {\n",
    "        'groundtruth_graph': gt_graph,\n",
    "        'gpt_generated_graph': gpt_graph,\n",
    "        'random_graph': random_graph\n",
    "    }\n",
    "\n",
    "cleaned_path = base_path / \"merged_graphs.pickle\"\n",
    "with open(cleaned_path, \"wb\") as f:\n",
    "    pickle.dump(cleaned_graphs, f, protocol=pickle.HIGHEST_PROTOCOL)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
