{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c8c2ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import networkx as nx\n",
    "from networkx import from_dict_of_lists, subgraph\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np \n",
    "import seaborn as sns \n",
    "import matplotlib.gridspec as gridspec\n",
    "import pickle \n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import random\n",
    "from scipy.sparse import coo_matrix, csr_matrix, csc_matrix\n",
    "import gc\n",
    "from tqdm import tqdm\n",
    "import csv\n",
    "\n",
    "#Print versions of pandas, networkx, matplotlib, numpy, and seaborn \n",
    "print(pd.__version__)\n",
    "print(nx.__version__)\n",
    "print(plt.matplotlib.__version__)\n",
    "print(np.__version__)\n",
    "print(sns.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61c736e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Load removed paper IDs\n",
    "removed_ids_path = base_path / \"removed paperIDs.csv\"\n",
    "removed_paper_ids = set()\n",
    "\n",
    "with open(removed_ids_path, mode=\"r\") as csv_file:\n",
    "    reader = csv.DictReader(csv_file)\n",
    "    for row in reader:\n",
    "        removed_paper_ids.add(row[\"paper_id\"].strip())\n",
    "\n",
    "# Load original graph data\n",
    "file_path = base_path / \"graphs_random_subfield_v2.pickle\"\n",
    "with open(file_path, \"rb\") as handle:\n",
    "    graphs = pickle.load(handle)\n",
    "\n",
    "\n",
    "# Define color map and graph types \n",
    "color_map = {\n",
    "    'green': sns.color_palette('colorblind')[2],\n",
    "    'yellow': sns.color_palette('colorblind')[8],\n",
    "    'blue': sns.color_palette('colorblind')[0],\n",
    "    'orange': sns.color_palette('colorblind')[3],\n",
    "    'grey': sns.color_palette('colorblind')[7]\n",
    "}\n",
    "gpt_colors = [color_map['green'], color_map['yellow'], color_map['blue'], color_map['orange']]\n",
    "groundtruth_colors = [color_map['green'], color_map['grey'], color_map['blue']]\n",
    "blue_color = tuple(color_map['blue'])\n",
    "\n",
    "subgraphs = {}\n",
    "filtered_out_paper_ids = []\n",
    "\n",
    "for paper_id, graph_data in graphs.items():\n",
    "    paper_id = str(paper_id).strip()\n",
    "    if paper_id in removed_paper_ids:\n",
    "        continue\n",
    "\n",
    "    try:\n",
    "        nodes = graph_data['nodes']\n",
    "        edges = graph_data['edges']\n",
    "        node_colors = graph_data['node_color']\n",
    "\n",
    "        node_color_map = {\n",
    "            int(node): tuple(color) for node, color in zip(nodes, node_colors)\n",
    "        }\n",
    "\n",
    "        # Build full graph\n",
    "        G = nx.DiGraph()\n",
    "        G.add_nodes_from([int(node) for node in nodes])\n",
    "        G.add_edges_from([(int(u), int(v)) for u, v in edges])\n",
    "\n",
    "        # Extract ground truth subgraph\n",
    "        gt_nodes = [int(node) for node, color in zip(nodes, node_colors)\n",
    "                    if tuple(color) in groundtruth_colors]\n",
    "        groundtruth_graph = G.subgraph(gt_nodes).copy()\n",
    "\n",
    "        # Extract GPT-generated subgraph\n",
    "        gpt_nodes = [int(node) for node, color in zip(nodes, node_colors)\n",
    "                     if tuple(color) in gpt_colors]\n",
    "        gpt_generated_graph = G.subgraph(gpt_nodes).copy()\n",
    "\n",
    "        # Add GPT blue yellow/orange citation edges\n",
    "        for node, color in zip(nodes, node_colors):\n",
    "            node = int(node)\n",
    "            if tuple(color) == color_map['blue']:\n",
    "                for other_node, other_color in zip(nodes, node_colors):\n",
    "                    other_node = int(other_node)\n",
    "                    if tuple(other_color) in [color_map['yellow'], color_map['orange']]:\n",
    "                        gpt_generated_graph.add_edge(node, other_node)\n",
    "\n",
    "        # Balance node counts \n",
    "        gt_nodes_set = set(groundtruth_graph.nodes)\n",
    "        gpt_nodes_set = set(gpt_generated_graph.nodes)\n",
    "\n",
    "        if len(gt_nodes_set) != len(gpt_nodes_set):\n",
    "            random.seed(42)  \n",
    "            gt_non_blue = [n for n in gt_nodes_set if node_color_map.get(n) != blue_color]\n",
    "            gpt_non_blue = [n for n in gpt_nodes_set if node_color_map.get(n) != blue_color]\n",
    "\n",
    "            while len(gt_nodes_set) > len(gpt_nodes_set) and gt_non_blue:\n",
    "                removed = random.choice(gt_non_blue)\n",
    "                groundtruth_graph.remove_node(removed)\n",
    "                gt_nodes_set.remove(removed)\n",
    "                gt_non_blue.remove(removed)\n",
    "\n",
    "            while len(gpt_nodes_set) > len(gt_nodes_set) and gpt_non_blue:\n",
    "                removed = random.choice(gpt_non_blue)\n",
    "                gpt_generated_graph.remove_node(removed)\n",
    "                gpt_nodes_set.remove(removed)\n",
    "                gpt_non_blue.remove(removed)\n",
    "\n",
    "            if len(gt_nodes_set) != len(gpt_nodes_set):\n",
    "                filtered_out_paper_ids.append(paper_id)\n",
    "                continue  \n",
    "\n",
    "        subgraphs[paper_id] = {\n",
    "            'Random': groundtruth_graph  \n",
    "            # GPT graph is dropped as per your comment\n",
    "        }\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing {paper_id}: {e}\")\n",
    "        filtered_out_paper_ids.append(paper_id)\n",
    "\n",
    "\n",
    "# Save the result \n",
    "output_path = base_path / \"random_subfield_graphs_filtered_v2.pickle\"\n",
    "with open(output_path, \"wb\") as handle:\n",
    "    pickle.dump(subgraphs, handle, protocol=pickle.HIGHEST_PROTOCOL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4951195",
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge gpt and ground truth graphs with random graphs\n",
    "\n",
    "\n",
    "# Define base path as current folder\n",
    "base_path = Path(\".\")\n",
    "\n",
    "# Load filtered GPT + Ground Truth graphs \n",
    "filtered_graphs_path = base_path / \"updated_subgraphs_v2.pickle\"\n",
    "with open(filtered_graphs_path, \"rb\") as f:\n",
    "    filtered_graphs = pickle.load(f)\n",
    "\n",
    "# Load random graphs and convert keys to int \n",
    "random_graphs_path = base_path / \"random_subfield_graphs_filtered_v2.pickle\"\n",
    "with open(random_graphs_path, \"rb\") as f:\n",
    "    raw_random_graphs = pickle.load(f)\n",
    "random_graphs = {int(k): v for k, v in raw_random_graphs.items()}\n",
    "\n",
    "# Load original graphs to recover node colors\n",
    "graphs_path = base_path / \"graphs_v2.pickle\"\n",
    "with open(graphs_path, \"rb\") as f:\n",
    "    full_graphs = pickle.load(f)\n",
    "# Merge and enrich with node color attributes\n",
    "merged_graphs = {}\n",
    "\n",
    "for paper_id in filtered_graphs:\n",
    "    paper_id_int = int(paper_id)\n",
    "\n",
    "    gt_graph = filtered_graphs[paper_id].get('groundtruth_graph')  \n",
    "    gpt_graph = filtered_graphs[paper_id].get('gpt_generated_graph')\n",
    "    random_graph = random_graphs.get(paper_id_int, {}).get('Random')\n",
    "\n",
    "    # Build node color map from original full graph\n",
    "    full_graph_data = full_graphs.get(paper_id_int)\n",
    "    if full_graph_data:\n",
    "        node_color_map = {\n",
    "            int(node): tuple(color)\n",
    "            for node, color in zip(full_graph_data['nodes'], full_graph_data['node_color'])\n",
    "        }\n",
    "    else:\n",
    "        node_color_map = {}\n",
    "\n",
    "    # Attach color to each graph's nodes\n",
    "    for graph in [gt_graph, gpt_graph, random_graph]:\n",
    "        if graph is not None:\n",
    "            for node in graph.nodes:\n",
    "                if node in node_color_map:\n",
    "                    graph.nodes[node]['color'] = node_color_map[node]\n",
    "\n",
    "    merged_graphs[paper_id_int] = {\n",
    "        'groundtruth_graph': gt_graph,\n",
    "        'gpt_generated_graph': gpt_graph,\n",
    "        'random_graph': random_graph\n",
    "    }\n",
    "\n",
    "# Save the enriched merged graphs \n",
    "merged_path = base_path / \"merged_subfield_graphs.pickle\"\n",
    "with open(merged_path, \"wb\") as f:\n",
    "    pickle.dump(merged_graphs, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c02f311",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Load merged graphs \n",
    "merged_path = base_path / \"merged_subfield_graphs.pickle\"\n",
    "with open(merged_path, \"rb\") as f:\n",
    "    merged_graphs = pickle.load(f)\n",
    "def clean_graph(g):\n",
    "    if g is None:\n",
    "        return None\n",
    "\n",
    "    # Convert to undirected\n",
    "    G = nx.Graph(g)\n",
    "\n",
    "    # Remove self-loops\n",
    "    G.remove_edges_from(nx.selfloop_edges(G))\n",
    "\n",
    "    if not nx.is_connected(G):\n",
    "        components = list(nx.connected_components(G))\n",
    "        largest_component_nodes = max(components, key=len)\n",
    "        G = G.subgraph(largest_component_nodes).copy()\n",
    "\n",
    "    # Remove color attribute from nodes\n",
    "    for node in G.nodes:\n",
    "        if 'color' in G.nodes[node]:\n",
    "            del G.nodes[node]['color']\n",
    "\n",
    "    return G\n",
    "\n",
    "#  Process all graphs \n",
    "cleaned_graphs = {}\n",
    "\n",
    "for paper_id, graphs in merged_graphs.items():\n",
    "    gt_graph = clean_graph(graphs.get('groundtruth_graph'))\n",
    "    gpt_graph = clean_graph(graphs.get('gpt_generated_graph'))\n",
    "    random_graph = clean_graph(graphs.get('random_graph'))\n",
    "\n",
    "    cleaned_graphs[paper_id] = {\n",
    "        'groundtruth_graph': gt_graph,\n",
    "        'gpt_generated_graph': gpt_graph,\n",
    "        'random_graph': random_graph\n",
    "    }\n",
    "\n",
    "# Define base path as current folder\n",
    "base_path = Path(\".\")\n",
    "\n",
    "# Save cleaned graphs\n",
    "cleaned_path = base_path / \"merged_subfield_graphs.pickle\"\n",
    "with open(cleaned_path, \"wb\") as f:\n",
    "    pickle.dump(cleaned_graphs, f, protocol=pickle.HIGHEST_PROTOCOL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "040b1080",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_pickle_path = base_path / \"merged_subfield_graphs.pickle\"\n",
    "\n",
    "with open(merged_pickle_path, 'rb') as handle:\n",
    "    merged_graphs = pickle.load(handle)\n",
    "\n",
    "# Explore the content\n",
    "print(\"Type of loaded object:\", type(merged_graphs))\n",
    "\n",
    "if isinstance(merged_graphs, dict):\n",
    "    print(\"Number of keys:\", len(merged_graphs))\n",
    "    print(\"Sample keys:\", list(merged_graphs.keys())[:5])\n",
    "    first_item = next(iter(merged_graphs.values()))\n",
    "elif isinstance(merged_graphs, list):\n",
    "    print(\"Length of list:\", len(merged_graphs))\n",
    "    print(\"Type of first item:\", type(merged_graphs[0]))\n",
    "    first_item = merged_graphs[0]\n",
    "else:\n",
    "    print(\"Loaded object is of unrecognized type\")\n",
    "    first_item = merged_graphs\n",
    "\n",
    "print(\"\\nDetails of first item:\")\n",
    "if hasattr(first_item, '__dict__'):\n",
    "    for attr, val in vars(first_item).items():\n",
    "        print(f\"{attr}: {type(val)}\")\n",
    "else:\n",
    "    print(first_item)\n",
    "import networkx as nx\n",
    "\n",
    "sample_id = list(merged_graphs.keys())[3]\n",
    "sample_graphs = merged_graphs[sample_id]\n",
    "\n",
    "for key, graph in sample_graphs.items():\n",
    "    print(f\"\\n--- {key} ---\")\n",
    "    print(\"Type:\", type(graph))\n",
    "    print(\"Number of nodes:\", graph.number_of_nodes())\n",
    "    print(\"Number of edges:\", graph.number_of_edges())\n",
    "    print(\"Sample nodes:\", list(graph.nodes(data=True))[:3])\n",
    "    print(\"Sample edges:\", list(graph.edges(data=True))[:3])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "769b1337",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Check top-level structure\n",
    "print(f\"Type: {type(merged_graphs)}\")\n",
    "print(f\"Number of samples: {len(merged_graphs)}\")\n",
    "\n",
    "# Peek into the first few entries\n",
    "for i, (sample_id, graphs_dict) in enumerate(merged_graphs.items()):\n",
    "    print(f\"\\nSample ID: {sample_id}\")\n",
    "    print(f\"Type of graphs_dict: {type(graphs_dict)}\")\n",
    "    for gtype, graph in graphs_dict.items():\n",
    "        print(f\"  Graph type: {gtype}, Type: {type(graph)}\")\n",
    "    if i == 2:  # only print first 3 entries\n",
    "        break\n",
    "from collections import Counter\n",
    "\n",
    "# Initialize counter\n",
    "graph_type_counts = Counter()\n",
    "\n",
    "# Count non-None graphs per type\n",
    "for graphs_dict in merged_graphs.values():\n",
    "    for gtype in [\"random_graph\", \"groundtruth_graph\", \"gpt_generated_graph\"]:\n",
    "        if graphs_dict.get(gtype) is not None:\n",
    "            graph_type_counts[gtype] += 1\n",
    "\n",
    "# Print results\n",
    "print(\"Graph counts (non-None only):\")\n",
    "for gtype, count in graph_type_counts.items():\n",
    "    print(f\"{gtype}: {count}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f42343",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Centrality Computation Functions \n",
    "def calculate_centrality_measures(G):\n",
    "    if G is None or G.number_of_nodes() == 0:\n",
    "        return {\n",
    "            'degree_centrality': {},\n",
    "            'betweenness_centrality': {},\n",
    "            'closeness_centrality': {},\n",
    "            'eigenvector_centrality': {}\n",
    "        }\n",
    "    \n",
    "    centralities = {\n",
    "        'degree_centrality': nx.degree_centrality(G),\n",
    "        'betweenness_centrality': nx.betweenness_centrality(G),\n",
    "        'closeness_centrality': nx.closeness_centrality(G),\n",
    "    }\n",
    "    try:\n",
    "        centralities['eigenvector_centrality'] = nx.eigenvector_centrality(G, max_iter=1000)\n",
    "    except nx.NetworkXException:\n",
    "        centralities['eigenvector_centrality'] = {}\n",
    "    \n",
    "    return centralities\n",
    "\n",
    "def compute_features(centralities):\n",
    "    features = {\n",
    "        'Mean Degree Centrality': np.mean(list(centralities['degree_centrality'].values())) if centralities['degree_centrality'] else 0,\n",
    "        'Mean Betweenness Centrality': np.mean(list(centralities['betweenness_centrality'].values())) if centralities['betweenness_centrality'] else 0,\n",
    "        'Mean Closeness Centrality': np.mean(list(centralities['closeness_centrality'].values())) if centralities['closeness_centrality'] else 0,\n",
    "    }\n",
    "    eig_centrality_values = list(centralities['eigenvector_centrality'].values())\n",
    "    features['Mean Eigenvector Centrality'] = np.mean(eig_centrality_values) if eig_centrality_values else 0.0\n",
    "    return features\n",
    "features_data = []\n",
    "\n",
    "for sample_id, graph_types in merged_graphs.items():\n",
    "    graph_mapping = {\n",
    "        'groundtruth_graph': 'Groundtruth',\n",
    "        'gpt_generated_graph': 'GPT',\n",
    "        'random_graph': 'Random'\n",
    "    }\n",
    "\n",
    "    for key, graph_type in graph_mapping.items():\n",
    "        G = graph_types.get(key)\n",
    "        if G is not None:\n",
    "            G = G.to_undirected()  \n",
    "\n",
    "            centralities = calculate_centrality_measures(G)\n",
    "            features = compute_features(centralities)\n",
    "            features['graph_type'] = graph_type\n",
    "            features['sample_id'] = sample_id\n",
    "            features_data.append(features)\n",
    "\n",
    "\n",
    "df_features = pd.DataFrame(features_data)\n",
    "print(df_features.columns.tolist())\n",
    "\n",
    "def compute_common_bins(df, feature_name, num_bins=90):\n",
    "    min_val = df[feature_name].min()\n",
    "    max_val = df[feature_name].max()\n",
    "    return np.linspace(min_val, max_val, num_bins)\n",
    "\n",
    "def plot_feature_distributions(df, feature_name):\n",
    "    common_bins = compute_common_bins(df, feature_name)  \n",
    "    plt.figure(figsize=(6, 3))\n",
    "\n",
    "    sns.histplot(data=df[df['graph_type'] == 'Random'], x=feature_name, color='grey', label='Random Graphs', bins=common_bins, kde=False, alpha=0.3)\n",
    "    sns.histplot(data=df[df['graph_type'] == 'Groundtruth'], x=feature_name, color='red', label='Groundtruth Graphs', bins=common_bins, kde=False, alpha=0.5)\n",
    "    sns.histplot(data=df[df['graph_type'] == 'GPT'], x=feature_name, color='blue', label='GPT Graphs', bins=common_bins, kde=False, alpha=0.5)\n",
    "\n",
    "    plt.xlabel(feature_name)\n",
    "    plt.ylabel('Frequency')  \n",
    "    plt.ylim(0, 1500) \n",
    "    plt.legend()\n",
    "    plt.grid(False)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "features_to_plot = ['Mean Degree Centrality', 'Mean Closeness Centrality', 'Mean Eigenvector Centrality']\n",
    "\n",
    "for feature in features_to_plot:\n",
    "    plot_feature_distributions(df_features, feature)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15a8b8e8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = []\n",
    "\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue  \n",
    "\n",
    "        G_und = G.to_undirected()\n",
    "        deg = nx.degree_centrality(G_und)\n",
    "        close = nx.closeness_centrality(G_und)\n",
    "        cluster = nx.clustering(G_und)\n",
    "\n",
    "        if G_und.number_of_nodes() > 2:\n",
    "            try:\n",
    "                eig = nx.eigenvector_centrality(G_und, max_iter=1000)\n",
    "            except:\n",
    "                eig = {n: np.nan for n in G_und.nodes()}\n",
    "        else:\n",
    "            eig = {n: np.nan for n in G_und.nodes()}\n",
    "\n",
    "        for node in G_und.nodes():\n",
    "            data.append({\n",
    "                \"sample_id\": sample_id,\n",
    "                \"graph_type\": gtype,\n",
    "                \"node\": node,\n",
    "                \"Degree Centrality\": deg.get(node, np.nan),\n",
    "                \"Closeness Centrality\": close.get(node, np.nan),\n",
    "                \"Eigenvector Centrality\": eig.get(node, np.nan),\n",
    "                \"Clustering Coefficient\": cluster.get(node, np.nan),\n",
    "            })\n",
    "\n",
    "# Create node-level DataFrame\n",
    "df_nodes = pd.DataFrame(data)\n",
    "\n",
    "feature_names = [\n",
    "    \"Degree Centrality\",\n",
    "    \"Closeness Centrality\",\n",
    "    \"Eigenvector Centrality\",\n",
    "    \"Clustering Coefficient\",\n",
    "]\n",
    "\n",
    "df_graphs = df_nodes.groupby(['sample_id', 'graph_type'])[feature_names].mean().reset_index()\n",
    "# Capitalize labels before melting\n",
    "df_nodes[\"graph_type\"] = df_nodes[\"graph_type\"].replace({\n",
    "    \"Groundtruth_graph\": \"Ground truth graphs\",\n",
    "    \"gpt_generated_graph\": \"Generated graphs\",\n",
    "    \"random_graph\": \"Random graphs\"\n",
    "})\n",
    "\n",
    "# Now melt\n",
    "df_melted_nodes = pd.melt(\n",
    "    df_nodes,\n",
    "    id_vars=['sample_id', 'graph_type', 'node'],\n",
    "    value_vars=feature_names,\n",
    "    var_name='Feature',\n",
    "    value_name='Value'\n",
    ")\n",
    "\n",
    "\n",
    "def plot_node_features(ax):\n",
    "    custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "    label_font = {'fontsize': 2.6 , 'fontfamily': 'serif'}\n",
    "    tick_fontsize = 2\n",
    "    legend_font = {'fontsize': 1.6, 'title_fontsize': 1.8, 'prop': {'family': 'serif', 'size': 1.6}}\n",
    "    def get_outliers(df):\n",
    "        outlier_rows = []\n",
    "        grouped = df.groupby(['Feature', 'graph_type'])\n",
    "        for (feature, gtype), group in grouped:\n",
    "            q1 = group['Value'].quantile(0.25)\n",
    "            q3 = group['Value'].quantile(0.75)\n",
    "            iqr = q3 - q1\n",
    "            lower = q1 - 1.5 * iqr\n",
    "            upper = q3 + 1.5 * iqr\n",
    "            outliers = group[(group['Value'] < lower) | (group['Value'] > upper)]\n",
    "            outlier_rows.append(outliers)\n",
    "        return pd.concat(outlier_rows)\n",
    "\n",
    "    outlier_df = get_outliers(df_melted_nodes)\n",
    "\n",
    "    sns.boxplot(\n",
    "        data=df_melted_nodes,\n",
    "        x='Feature',\n",
    "        y='Value',\n",
    "        hue='graph_type',\n",
    "        palette=custom_palette,\n",
    "        fliersize=0,\n",
    "        linewidth=0.2,\n",
    "        ax=ax\n",
    "    )\n",
    "\n",
    "    \n",
    "\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    n_types = df_melted_nodes['graph_type'].nunique()\n",
    "    legend=ax.legend(\n",
    "        handles[:n_types],\n",
    "        labels[:n_types],\n",
    "        loc='upper right',            \n",
    "        bbox_to_anchor=(-0.04, 1.01),  \n",
    "        frameon=True,\n",
    "        edgecolor='black',\n",
    "        **legend_font\n",
    "    )\n",
    "    legend.get_frame().set_linewidth(0.1) \n",
    "\n",
    "\n",
    "    ax.set_ylabel('', fontdict=label_font)\n",
    "    ax.set_xlabel('', fontdict=label_font)\n",
    "\n",
    "    for label in ax.get_xticklabels():\n",
    "        label.set_family('serif')\n",
    "    for label in ax.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_linewidth(0.17)\n",
    "    ax.tick_params(width=0.1, length=1.2)\n",
    "    ax.tick_params(axis='x', labelsize=tick_fontsize, pad=1)\n",
    "    ax.tick_params(axis='y', labelsize=tick_fontsize, pad=1)\n",
    "fig, ax = plt.subplots(dpi=300)\n",
    "plot_node_features(ax) \n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96e6cd2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "avgdeg_data = []\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue \n",
    "\n",
    "\n",
    "        G_und = G.to_undirected()\n",
    "        n = G_und.number_of_nodes()\n",
    "        m = G_und.number_of_edges()\n",
    "        avg_deg = (2 * m / n) if n > 0 else np.nan\n",
    "        avgdeg_data.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"num_nodes\": n,\n",
    "            \"num_edges\": m,\n",
    "            \"avg_degree\": avg_deg\n",
    "        })\n",
    "\n",
    "# Create DataFrame\n",
    "df_avgdeg = pd.DataFrame(avgdeg_data)\n",
    "def plot_edges_vs_avg_degree(ax, df_avgdeg):\n",
    "    custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_avgdeg,\n",
    "        x='num_edges',\n",
    "        y='avg_degree',\n",
    "        hue='graph_type',\n",
    "        palette=custom_palette,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.1,\n",
    "        s=2.2,\n",
    "        ax=ax,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    n_vals = np.arange(2, df_avgdeg['num_nodes'].max() + 1)\n",
    "\n",
    "    tree_edges = n_vals - 1\n",
    "    tree_avg_deg = 2 * (n_vals - 1) / n_vals\n",
    "    tree_line, = ax.plot(tree_edges, tree_avg_deg, 'g--', label='Tree Graph',linewidth=0.3)\n",
    "\n",
    "    complete_edges = n_vals * (n_vals - 1) // 2\n",
    "    complete_avg_deg = n_vals - 1\n",
    "    complete_line, = ax.plot(complete_edges, complete_avg_deg, 'r--', label='Complete Graph',linewidth=0.3)\n",
    "\n",
    " \n",
    "    ax.set_xlabel('Number of Edges', fontsize=2, fontfamily='serif', labelpad=0.3)\n",
    "    ax.set_ylabel('Mean Degree', fontsize=2, fontfamily='serif', labelpad=0.2)\n",
    "    ax.set_xlim(1, 50)\n",
    "    ax.set_ylim(1, 8)\n",
    "    ax.tick_params(labelsize=2.4, pad=1)\n",
    "    for label in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "        spine.set_linewidth(0.2)\n",
    "\n",
    "\n",
    "    ax.tick_params(width=0.1, length=1.1)\n",
    "\n",
    "    legend = ax.legend(\n",
    "        handles=[tree_line, complete_line],\n",
    "        fontsize=1.9,\n",
    "        title_fontsize=1.9,\n",
    "        loc='upper left',\n",
    "        frameon=True,\n",
    "        edgecolor='black'\n",
    "    )\n",
    "    legend.get_frame().set_linewidth(0.06)\n",
    "fig, ax = plt.subplots(dpi=300)\n",
    "plot_edges_vs_avg_degree(ax, df_avgdeg)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6d8158",
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "def plot_edges_vs_avg_degree_joint(df_avgdeg):\n",
    "    graph_types = df_avgdeg['graph_type'].unique()\n",
    "    color_map = dict(zip(graph_types, custom_palette))\n",
    "\n",
    "    g = sns.JointGrid(\n",
    "        data=df_avgdeg,\n",
    "        x=\"avg_degree\",   \n",
    "        y=\"num_edges\",     \n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        space=0.05\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_avgdeg,\n",
    "        x='avg_degree',\n",
    "        y='num_edges',\n",
    "        hue='graph_type',\n",
    "        palette=color_map,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.3,\n",
    "        s=39,\n",
    "        ax=g.ax_joint,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    n_vals = np.arange(2, df_avgdeg['num_nodes'].max() + 1)\n",
    "    tree_avg_deg = 2 * (n_vals - 1) / n_vals\n",
    "    tree_edges = n_vals - 1\n",
    "    g.ax_joint.plot(tree_avg_deg, tree_edges, 'g--', label='Tree Graph', linewidth=0.7)\n",
    "\n",
    "    complete_avg_deg = n_vals - 1\n",
    "    complete_edges = n_vals * (n_vals - 1) // 2\n",
    "    g.ax_joint.plot(complete_avg_deg, complete_edges, 'r--', label='Complete Graph', linewidth=0.7)\n",
    "\n",
    "    g.ax_joint.set_xlabel('Mean Degree', fontsize=9, fontfamily='serif', labelpad=0.1)\n",
    "    g.ax_joint.set_ylabel('Number of Edges', fontsize=9, fontfamily='serif', labelpad=0.1)\n",
    "    g.ax_joint.set_xlim(1, 8)\n",
    "    g.ax_joint.set_ylim(1, 50)\n",
    "    g.ax_joint.tick_params(labelsize=7, width=0.9, length=1.3)\n",
    "\n",
    "    g.ax_joint.spines['bottom'].set_color('black')\n",
    "    g.ax_joint.spines['left'].set_color('black')\n",
    "    g.ax_joint.spines['bottom'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['left'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['top'].set_color('none')\n",
    "    g.ax_joint.spines['right'].set_color('none')\n",
    "\n",
    "    for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "\n",
    "    # KDE Marginals\n",
    "    for gtype, color in zip(graph_types, custom_palette):\n",
    "        sub = df_avgdeg[df_avgdeg['graph_type'] == gtype]\n",
    "        sns.kdeplot(sub['avg_degree'], ax=g.ax_marg_x, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "        sns.kdeplot(y=sub['num_edges'], ax=g.ax_marg_y, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "\n",
    "    # limit marginal axes\n",
    "    g.ax_marg_y.set_xlim(0, 0.06)\n",
    "    g.ax_marg_x.set_ylim(0, 3.18)\n",
    "\n",
    "    g.figure.savefig(\"avg_degree_vs_edges_subfield.png\", dpi=600, bbox_inches='tight')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "plot_edges_vs_avg_degree_joint(df_avgdeg)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22ce86b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "\n",
    "edges_nodes_data = []\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue \n",
    "        edges_nodes_data.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"num_edges\": G.number_of_edges(),\n",
    "            \"num_nodes\": G.number_of_nodes()\n",
    "        })\n",
    "\n",
    "df_edges_nodes = pd.DataFrame(edges_nodes_data)\n",
    "\n",
    "graph_types = df_edges_nodes['graph_type'].unique()\n",
    "\n",
    "color_map = dict(zip(graph_types, custom_palette))\n",
    "def plot_nodes_vs_edges(ax, df_edges_nodes):\n",
    "    custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "    graph_types = df_edges_nodes['graph_type'].unique()\n",
    "    color_map = dict(zip(graph_types, custom_palette))\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_edges_nodes,\n",
    "        x=\"num_nodes\",\n",
    "        y=\"num_edges\",\n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.1,\n",
    "        s=2.2,\n",
    "        ax=ax,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    x = np.arange(1, 51)\n",
    "    ax.plot(x, x - 1, 'g--', label='Tree (min edges)',linewidth=0.3)\n",
    "    ax.plot(x, x * (x - 1) // 2, 'r--', label='Complete (max edges)',linewidth=0.3)\n",
    "        # Make spines lighter\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "        spine.set_linewidth(0.2)\n",
    "\n",
    "    \n",
    "    ax.tick_params(width=0.1, length=1.1)\n",
    "\n",
    "    ax.set_xlim(0, 50)\n",
    "    ax.set_ylim(0, 100)\n",
    "    ax.set_xlabel(\"Number of Nodes\", fontsize=2, fontfamily='serif', labelpad=0.2)\n",
    "    ax.set_ylabel(\"Number of Edges\", fontsize=2, fontfamily='serif', labelpad=0.008)\n",
    "\n",
    "    ax.tick_params(labelsize=2.4, pad=1)\n",
    "    for label in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    unique_labels = dict(zip(labels, handles))\n",
    "    new_labels = ['Tree (min edges)', 'Complete (max edges)'] + list(graph_types)\n",
    "    new_handles = [unique_labels[l] for l in new_labels if l in unique_labels]\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 4), dpi=300)\n",
    "plot_nodes_vs_edges(ax, df_edges_nodes)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e46630",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_nodes_vs_edges_joint(df_edges_nodes):\n",
    "    graph_types = df_edges_nodes['graph_type'].unique()\n",
    "    color_map = dict(zip(graph_types, custom_palette))\n",
    "\n",
    "    g = sns.JointGrid(\n",
    "        data=df_edges_nodes,\n",
    "        x=\"num_nodes\",\n",
    "        y=\"num_edges\",\n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        space=0.05\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_edges_nodes,\n",
    "        x=\"num_nodes\",\n",
    "        y=\"num_edges\",\n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.5,\n",
    "        s=40,\n",
    "        ax=g.ax_joint,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    x = np.arange(1, 51)\n",
    "    g.ax_joint.plot(x, x - 1, 'g--', linewidth=0.7)\n",
    "    g.ax_joint.plot(x, x * (x - 1) // 2, 'r--', linewidth=0.7)\n",
    "    g.ax_joint.set_xlim(0, 50)\n",
    "    g.ax_joint.set_ylim(0, 100)\n",
    "    g.ax_joint.set_xlabel(\"Number of Nodes\", fontsize=9, fontfamily='serif', labelpad=0.9)\n",
    "    g.ax_joint.set_ylabel(\"Number of Edges\", fontsize=9, fontfamily='serif', labelpad=0.9)\n",
    "    g.ax_joint.tick_params(labelsize=7, width=0.9, length=1.3)\n",
    "\n",
    "    g.ax_joint.spines['bottom'].set_color('black')\n",
    "    g.ax_joint.spines['left'].set_color('black')\n",
    "    g.ax_joint.spines['bottom'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['left'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['top'].set_color('none')\n",
    "    g.ax_joint.spines['right'].set_color('none')\n",
    "\n",
    "    for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "\n",
    "    # KDE Marginals\n",
    "    for gtype, color in zip(graph_types, custom_palette):\n",
    "        sub = df_edges_nodes[df_edges_nodes['graph_type'] == gtype]\n",
    "        sns.kdeplot(sub['num_nodes'], ax=g.ax_marg_x, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "        sns.kdeplot(y=sub['num_edges'], ax=g.ax_marg_y, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "\n",
    "    tree_line = Line2D([0], [0], linestyle='--', color='green', linewidth=0.7, label='Tree graph')\n",
    "    complete_line = Line2D([0], [0], linestyle='--', color='red', linewidth=0.7, label='Complete graph')\n",
    "\n",
    "    scatter_handles = [\n",
    "        Line2D([0], [0], marker='o', color='white', label=gtype,\n",
    "               markerfacecolor=color_map[gtype], markersize=6,\n",
    "               markeredgewidth=0.3, markeredgecolor='white')\n",
    "        for gtype in graph_types\n",
    "    ]\n",
    "\n",
    "    legend_handles = [tree_line, complete_line]\n",
    "\n",
    "    g.ax_joint.legend(handles=legend_handles, fontsize=7, loc='upper left', frameon=True)\n",
    "    g.figure.savefig(\"nodes_vs_edges_subfield.png\", dpi=600, bbox_inches='tight')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "plot_nodes_vs_edges_joint(df_edges_nodes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec93e82",
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]  \n",
    "\n",
    "g = sns.jointplot(\n",
    "    data=df_graphs,\n",
    "    x=\"Degree Centrality\",\n",
    "    y=\"Clustering Coefficient\",\n",
    "    hue=\"graph_type\",\n",
    "    palette=custom_palette,\n",
    "    kind=\"scatter\",\n",
    "    alpha=0.9,\n",
    "    edgecolor='white',\n",
    "    marker='o',\n",
    "    linewidth=0.5,\n",
    "    s=52,\n",
    "    ax=ax\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "g.ax_marg_y.set_xlim(0, 3)\n",
    "plt.tight_layout()\n",
    "g.savefig(\"centrality_vs_clusteringsubfield.png\", dpi=600, bbox_inches='tight')  \n",
    "\n",
    "g.ax_joint.legend_.remove()\n",
    "\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f6f038",
   "metadata": {},
   "outputs": [],
   "source": [
    "engineer_data = []\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue \n",
    "\n",
    "        degrees = [d for n, d in G.degree()]\n",
    "        if len(degrees) > 0:\n",
    "            max_deg = np.max(degrees)\n",
    "            mean_deg = np.mean(degrees)\n",
    "            if mean_deg > 0:\n",
    "                max_mean_ratio = max_deg / mean_deg\n",
    "            else:\n",
    "                max_mean_ratio = np.nan\n",
    "        else:\n",
    "            max_deg = mean_deg = max_mean_ratio = np.nan\n",
    "        engineer_data.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"max_degree\": max_deg,\n",
    "            \"Mean Degree\": mean_deg,\n",
    "            \"Max Mean Ratio\": max_mean_ratio\n",
    "        })\n",
    "\n",
    "df_deg = pd.DataFrame(engineer_data)\n",
    "\n",
    "g=sns.jointplot(\n",
    "    data=df_deg,\n",
    "    x=\"Mean Degree\",                 \n",
    "    y=\"Max Mean Ratio\",\n",
    "    hue=\"graph_type\",\n",
    "    palette=custom_palette,\n",
    "    kind=\"scatter\",\n",
    "    alpha=0.9,\n",
    "    edgecolor='white',\n",
    "    marker='o',\n",
    "    linewidth=0.5,\n",
    "    s=52,\n",
    "    ax=ax\n",
    ")\n",
    "g.ax_marg_x.set_ylim(0, 0.7)\n",
    "g.ax_joint.legend_.remove()\n",
    "g.savefig(\"mean_vs_ratiosubfield.png\", dpi=600, bbox_inches='tight') \n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a1a9925",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "\n",
    "# Create figure\n",
    "fig = plt.figure(figsize=(5.5, 4), dpi=300)\n",
    "\n",
    "ax_img        = fig.add_axes([0.034, 0.30, 0.45, 0.7]) \n",
    "ax_plot       = fig.add_axes([0.5,  0.5, 0.42, 0.3])  \n",
    "\n",
    "ax_centrality = fig.add_axes([0.014, 0.095, 0.24, 0.487])  \n",
    "ax_ratio      = fig.add_axes([0.235, 0.18, 0.27, 0.316])  \n",
    "ax_avg        = fig.add_axes([0.48, 0.187, 0.25, 0.298])  \n",
    "ax_node       = fig.add_axes([0.7, 0.187, 0.25, 0.298])  \n",
    "\n",
    "img = mpimg.imread(base_path / \"Screenshot_2025-07-23_at_17.47.44.png\")\n",
    "ax_img.imshow(img)\n",
    "ax_img.axis(\"off\")\n",
    "\n",
    "# PLOT\n",
    "plot_node_features(ax_plot)\n",
    "\n",
    "# Centrality vs Clustering (Image)\n",
    "img1 = mpimg.imread(base_path / \"centrality_vs_clusteringsubfield.png\")\n",
    "ax_centrality.imshow(img1)\n",
    "ax_centrality.axis(\"off\")\n",
    "\n",
    "# Mean vs Ratio (Image)\n",
    "img2 = mpimg.imread(base_path / \"mean_vs_ratiosubfield.png\")\n",
    "ax_ratio.imshow(img2)\n",
    "ax_ratio.axis(\"off\")\n",
    "\n",
    "img_avg = mpimg.imread(base_path / \"avg_degree_vs_edges_subfield.png\")\n",
    "ax_avg.imshow(img_avg)\n",
    "ax_avg.axis(\"off\")\n",
    "\n",
    "img_node = mpimg.imread(base_path / \"nodes_vs_edges_subfield.png\")\n",
    "ax_node.imshow(img_node)\n",
    "ax_node.axis(\"off\")\n",
    "\n",
    "plt.savefig(base_path / \"high_res_outputsubfield.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24465165",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
