{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70eaaf03",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import networkx as nx\n",
    "from networkx import from_dict_of_lists, subgraph\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np \n",
    "import seaborn as sns \n",
    "import matplotlib.gridspec as gridspec\n",
    "import pickle \n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import random\n",
    "from scipy.stats import mannwhitneyu\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "import random\n",
    "import csv\n",
    "from pathlib import Path\n",
    "\n",
    "print(pd.__version__)\n",
    "print(nx.__version__)\n",
    "print(plt.matplotlib.__version__)\n",
    "print(np.__version__)\n",
    "print(sns.__version__)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f5e95ae",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "# Define base path as current folder\n",
    "base_path = Path(\".\")\n",
    "\n",
    "file_path = base_path / \"graphs_v2.pickle\"\n",
    "\n",
    "try:\n",
    "    with open(file_path, \"rb\") as handle:\n",
    "        graphs = pickle.load(handle)\n",
    "except FileNotFoundError:\n",
    "    pass\n",
    "\n",
    "    \n",
    "#seperate graphs to groung truth and GPT subgraphs\n",
    "\n",
    "# Define the color map directly for known color names\n",
    "color_map = {\n",
    "    'green': sns.color_palette('colorblind')[2],\n",
    "    'yellow': sns.color_palette('colorblind')[8],\n",
    "    'blue': sns.color_palette('colorblind')[0],\n",
    "    'orange': sns.color_palette('colorblind')[3],\n",
    "    'grey': sns.color_palette('colorblind')[7]\n",
    "}\n",
    "\n",
    "# Define the colors for the subgraphs\n",
    "gpt_colors = [color_map['green'], color_map['yellow'], color_map['blue'], color_map['orange']]\n",
    "groundtruth_colors = [color_map['green'], color_map['grey'], color_map['blue']]\n",
    "\n",
    "subgraphs = {}\n",
    "filtered_out_paper_ids = []\n",
    "for paper_id, graph_data in graphs.items():\n",
    "    try:\n",
    "        nodes = graph_data['nodes']\n",
    "        edges = graph_data['edges']\n",
    "        node_colors = graph_data['node_color']\n",
    "        \n",
    "        # Create NetworkX graph\n",
    "        G = nx.DiGraph()\n",
    "        G.add_nodes_from(nodes)\n",
    "        G.add_edges_from(edges)\n",
    "        \n",
    "        # Create groundtruth_graph subgraph\n",
    "        groundtruth_graph_nodes = [node for node, color in zip(nodes, node_colors) if tuple(color) in groundtruth_colors]\n",
    "        groundtruth_graph = G.subgraph(groundtruth_graph_nodes).copy()\n",
    "        \n",
    "        # Create gpt_generated_graph subgraph\n",
    "        gpt_graph_nodes = [node for node, color in zip(nodes, node_colors) if tuple(color) in gpt_colors]\n",
    "        gpt_generated_graph = G.subgraph(gpt_graph_nodes).copy()\n",
    "        \n",
    "        # Add edges between blue and yellow/orange nodes\n",
    "        for node, color in zip(nodes, node_colors):\n",
    "            if tuple(color) == color_map['blue']:\n",
    "                for other_node, other_color in zip(nodes, node_colors):\n",
    "                    if tuple(other_color) == color_map['yellow']:\n",
    "                        gpt_generated_graph.add_edge(node, other_node)\n",
    "                    elif tuple(other_color) == color_map['orange']:\n",
    "                        gpt_generated_graph.add_edge(node, other_node)\n",
    "        \n",
    "        #Remove if either graph has fewer than 3 nodes\n",
    "        if len(groundtruth_graph.nodes) < 3 or len(gpt_generated_graph.nodes) < 3:\n",
    "            filtered_out_paper_ids.append(paper_id)\n",
    "            continue\n",
    "            \n",
    "        # Store the subgraphs in the dictionary\n",
    "        subgraphs[paper_id] = {\n",
    "            'groundtruth_graph': groundtruth_graph,\n",
    "            'gpt_generated_graph': gpt_generated_graph\n",
    "        }\n",
    "    except Exception as e:\n",
    "        filtered_out_paper_ids.append(paper_id)\n",
    "\n",
    "        \n",
    "\n",
    "def clean_graph(graph):\n",
    "    \"\"\"Remove self-loops and keep largest connected component.\"\"\"\n",
    "    graph.remove_edges_from(nx.selfloop_edges(graph))\n",
    "    if isinstance(graph, nx.DiGraph):\n",
    "        components = list(nx.weakly_connected_components(graph))\n",
    "    else:\n",
    "        components = list(nx.connected_components(graph))\n",
    "    if components:\n",
    "        largest_cc = max(components, key=len)\n",
    "        graph = graph.subgraph(largest_cc).copy()\n",
    "    return graph\n",
    "\n",
    "# Save the filtered subgraphs\n",
    "output_main_path = base_path / \"filtered_subgraphs_v2.pickle\"\n",
    "with open(output_main_path, \"wb\") as handle:\n",
    "    pickle.dump(subgraphs, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "# Save the removed paper IDs\n",
    "filtered_ids_csv_path = base_path / \"removed_paperIDs.csv\"\n",
    "with open(filtered_ids_csv_path, mode=\"w\", newline=\"\") as csv_file:\n",
    "    writer = csv.writer(csv_file)\n",
    "    writer.writerow([\"paper_id\"])\n",
    "    for paper_id in filtered_out_paper_ids:\n",
    "        writer.writerow([str(paper_id).strip()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8e22a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the original full graph structure \n",
    "raphs_path = base_path / \"graphs_v2.pickle\"\n",
    "with open(graphs_path, \"rb\") as f:\n",
    "    graphs = pickle.load(f)\n",
    "# Define the color map directly for known color names\n",
    "color_map = {\n",
    "    'green': sns.color_palette('colorblind')[2],\n",
    "    'yellow': sns.color_palette('colorblind')[8],\n",
    "    'blue': sns.color_palette('colorblind')[0],\n",
    "    'orange': sns.color_palette('colorblind')[3],\n",
    "    'grey': sns.color_palette('colorblind')[7]\n",
    "}\n",
    "\n",
    "# Define the colors for the subgraphs\n",
    "gpt_colors = [color_map['green'], color_map['yellow'], color_map['blue'], color_map['orange']]\n",
    "groundtruth_colors = [color_map['green'], color_map['grey'], color_map['blue']]\n",
    "\n",
    "output_main_path = base_path / \"filtered_subgraphs_v2.pickle\"\n",
    "with open(output_main_path, \"wb\") as handle:\n",
    "    pickle.dump(subgraphs, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "    \n",
    "# Function to balance the nodes of two graphs while preserving blue nodes\n",
    "def balance_graph_nodes_and_record(subgraphs, graphs, color_map):\n",
    "    removed_nodes_list = [] \n",
    "    \n",
    "    for paper_id, subgraph_data in subgraphs.items():\n",
    "        groundtruth_graph = subgraph_data['groundtruth_graph']\n",
    "        gpt_graph = subgraph_data['gpt_generated_graph']\n",
    "\n",
    "        groundtruth_nodes = set(groundtruth_graph.nodes)\n",
    "        gpt_nodes = set(gpt_graph.nodes)\n",
    "\n",
    "        # Skip if both graphs already have the same number of nodes\n",
    "        if len(groundtruth_nodes) == len(gpt_nodes):\n",
    "            continue\n",
    "\n",
    "        # Identify non-blue nodes that can be removed\n",
    "        blue_color = tuple(color_map['blue'])\n",
    "\n",
    "        groundtruth_non_blue = [\n",
    "            node for node in groundtruth_nodes\n",
    "            if tuple(graphs[paper_id]['node_color'][graphs[paper_id]['nodes'].index(node)]) != blue_color\n",
    "        ]\n",
    "        gpt_non_blue = [\n",
    "            node for node in gpt_nodes\n",
    "            if tuple(graphs[paper_id]['node_color'][graphs[paper_id]['nodes'].index(node)]) != blue_color\n",
    "        ]\n",
    "\n",
    "        # Balance the number of nodes in both subgraphs\n",
    "        while len(groundtruth_nodes) > len(gpt_nodes) and groundtruth_non_blue:\n",
    "            removed_node = random.choice(groundtruth_non_blue)\n",
    "            groundtruth_non_blue.remove(removed_node)\n",
    "            groundtruth_nodes.remove(removed_node)\n",
    "            groundtruth_graph.remove_node(removed_node)  \n",
    "            removed_nodes_list.append((paper_id, removed_node, None))  \n",
    "\n",
    "        while len(gpt_nodes) > len(groundtruth_nodes) and gpt_non_blue:\n",
    "            removed_node = random.choice(gpt_non_blue)\n",
    "            gpt_non_blue.remove(removed_node)\n",
    "            gpt_nodes.remove(removed_node)\n",
    "            gpt_graph.remove_node(removed_node)  \n",
    "            removed_nodes_list.append((paper_id, None, removed_node)) \n",
    "\n",
    "    return subgraphs, removed_nodes_list\n",
    "\n",
    "# Apply the balancing function\n",
    "balanced_subgraphs, removed_nodes_list = balance_graph_nodes_and_record(subgraphs, graphs, color_map)\n",
    "\n",
    "# Save the updated subgraphs to a new pickle file\n",
    "new_pickle_path = base_path / \"updated_subgraphs_v2.pickle\"\n",
    "with open(new_pickle_path, \"wb\") as handle:\n",
    "    pickle.dump(balanced_subgraphs, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "# Save the removed nodes to a CSV file\n",
    "csv_path = base_path / \"removed_nodes_v2.csv\"\n",
    "removed_nodes_df = pd.DataFrame(\n",
    "    removed_nodes_list,\n",
    "    columns=[\"paper_id\", \"removed_node_groundtruth_id\", \"removed_node_gpt_id\"]\n",
    ")\n",
    "removed_nodes_df.to_csv(csv_path, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
