{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f76502",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pickle\n",
    "from matplotlib.lines import Line2D\n",
    "import matplotlib.image as mpimg\n",
    "from networkx.exception import PowerIterationFailedConvergence\n",
    "from matplotlib.ticker import MultipleLocator\n",
    "from pathlib import Path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc294c09",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "merged_pickle_path = base_path / \"merged_subfield_graphs.pickle\"\n",
    "\n",
    "with open(merged_pickle_path, \"rb\") as handle:\n",
    "    merged_graphs = pickle.load(handle)\n",
    "\n",
    "# Explore the content\n",
    "print(\"Type of loaded object:\", type(merged_graphs))\n",
    "\n",
    "if isinstance(merged_graphs, dict):\n",
    "    print(\"Number of keys:\", len(merged_graphs))\n",
    "    print(\"Sample keys:\", list(merged_graphs.keys())[:5])\n",
    "    first_item = next(iter(merged_graphs.values()))\n",
    "elif isinstance(merged_graphs, list):\n",
    "    print(\"Length of list:\", len(merged_graphs))\n",
    "    print(\"Type of first item:\", type(merged_graphs[0]))\n",
    "    first_item = merged_graphs[0]\n",
    "else:\n",
    "    print(\"Loaded object is of unrecognized type\")\n",
    "    first_item = merged_graphs\n",
    "\n",
    "print(\"\\nDetails of first item:\")\n",
    "if hasattr(first_item, '__dict__'):\n",
    "    for attr, val in vars(first_item).items():\n",
    "        print(f\"{attr}: {type(val)}\")\n",
    "else:\n",
    "    print(first_item)\n",
    "\n",
    "sample_id = list(merged_graphs.keys())[3]\n",
    "sample_graphs = merged_graphs[sample_id]\n",
    "\n",
    "for key, graph in sample_graphs.items():\n",
    "    print(f\"\\n--- {key} ---\")\n",
    "    print(\"Type:\", type(graph))\n",
    "    print(\"Number of nodes:\", graph.number_of_nodes())\n",
    "    print(\"Number of edges:\", graph.number_of_edges())\n",
    "    print(\"Sample nodes:\", list(graph.nodes(data=True))[:3])\n",
    "    print(\"Sample edges:\", list(graph.edges(data=True))[:3])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d99b7dee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Centrality Computation Functions \n",
    "def calculate_centrality_measures(G):\n",
    "    if G is None or G.number_of_nodes() == 0:\n",
    "        return {\n",
    "            'degree_centrality': {},\n",
    "            'betweenness_centrality': {},\n",
    "            'closeness_centrality': {},\n",
    "            'eigenvector_centrality': {}\n",
    "        }\n",
    "    \n",
    "    centralities = {\n",
    "        'degree_centrality': nx.degree_centrality(G),\n",
    "        'betweenness_centrality': nx.betweenness_centrality(G),\n",
    "        'closeness_centrality': nx.closeness_centrality(G),\n",
    "    }\n",
    "    try:\n",
    "        centralities['eigenvector_centrality'] = nx.eigenvector_centrality(G, max_iter=1000)\n",
    "    except nx.NetworkXException:\n",
    "        centralities['eigenvector_centrality'] = {}\n",
    "    \n",
    "    return centralities\n",
    "\n",
    "def compute_features(centralities):\n",
    "    features = {\n",
    "        'Mean Degree Centrality': np.mean(list(centralities['degree_centrality'].values())) if centralities['degree_centrality'] else 0,\n",
    "        'Mean Betweenness Centrality': np.mean(list(centralities['betweenness_centrality'].values())) if centralities['betweenness_centrality'] else 0,\n",
    "        'Mean Closeness Centrality': np.mean(list(centralities['closeness_centrality'].values())) if centralities['closeness_centrality'] else 0,\n",
    "    }\n",
    "    eig_centrality_values = list(centralities['eigenvector_centrality'].values())\n",
    "    features['Mean Eigenvector Centrality'] = np.mean(eig_centrality_values) if eig_centrality_values else 0.0\n",
    "    return features\n",
    "features_data = []\n",
    "\n",
    "for sample_id, graph_types in merged_graphs.items():\n",
    "    graph_mapping = {\n",
    "        'groundtruth_graph': 'Groundtruth',\n",
    "        'gpt_generated_graph': 'GPT',\n",
    "        'random_graph': 'Random'\n",
    "    }\n",
    "\n",
    "    for key, graph_type in graph_mapping.items():\n",
    "        G = graph_types.get(key)\n",
    "        if G is not None:\n",
    "            G = G.to_undirected()  \n",
    "\n",
    "            centralities = calculate_centrality_measures(G)\n",
    "            features = compute_features(centralities)\n",
    "            features['graph_type'] = graph_type\n",
    "            features['sample_id'] = sample_id\n",
    "            features_data.append(features)\n",
    "\n",
    "\n",
    "df_features = pd.DataFrame(features_data)\n",
    "print(df_features.columns.tolist())\n",
    "\n",
    "def compute_common_bins(df, feature_name, num_bins=90):\n",
    "    min_val = df[feature_name].min()\n",
    "    max_val = df[feature_name].max()\n",
    "    return np.linspace(min_val, max_val, num_bins)\n",
    "\n",
    "def plot_feature_distributions(df, feature_name):\n",
    "    common_bins = compute_common_bins(df, feature_name)  \n",
    "    plt.figure(figsize=(6, 3))\n",
    "\n",
    "    sns.histplot(data=df[df['graph_type'] == 'Random'], x=feature_name, color='grey', label='Random Graphs', bins=common_bins, kde=False, alpha=0.3)\n",
    "    sns.histplot(data=df[df['graph_type'] == 'Groundtruth'], x=feature_name, color='red', label='Groundtruth Graphs', bins=common_bins, kde=False, alpha=0.5)\n",
    "    sns.histplot(data=df[df['graph_type'] == 'GPT'], x=feature_name, color='blue', label='GPT Graphs', bins=common_bins, kde=False, alpha=0.5)\n",
    "\n",
    "    plt.xlabel(feature_name)\n",
    "    plt.ylabel('Frequency')  \n",
    "    plt.ylim(0, 1500) \n",
    "    plt.legend()\n",
    "    plt.grid(False)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "features_to_plot = ['Mean Degree Centrality', 'Mean Closeness Centrality', 'Mean Eigenvector Centrality']\n",
    "\n",
    "for feature in features_to_plot:\n",
    "    plot_feature_distributions(df_features, feature)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67eb573d",
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue\n",
    "        G_und = G.to_undirected()\n",
    "\n",
    "        # Centralities / clustering\n",
    "        deg = nx.degree_centrality(G_und)\n",
    "        close = nx.closeness_centrality(G_und) \n",
    "        cluster = nx.clustering(G_und)\n",
    "\n",
    "        if G_und.number_of_nodes() > 2 and G_und.number_of_edges() > 0:\n",
    "            try:\n",
    "                eig = nx.eigenvector_centrality(G_und, max_iter=1000, tol=1e-06)\n",
    "            except PowerIterationFailedConvergence:\n",
    "                eig = {n: np.nan for n in G_und.nodes()}\n",
    "        else:\n",
    "            eig = {n: np.nan for n in G_und.nodes()}\n",
    "\n",
    "        for n in G_und.nodes():\n",
    "            records.append(\n",
    "                {\n",
    "                    \"sample_id\": sample_id,\n",
    "                    \"graph_type\": gtype,\n",
    "                    \"node\": n,\n",
    "                    \"Degree Centrality\": deg.get(n, np.nan),\n",
    "                    \"Closeness Centrality\": close.get(n, np.nan),\n",
    "                    \"Eigenvector Centrality\": eig.get(n, np.nan),\n",
    "                    \"Clustering Coefficient\": cluster.get(n, np.nan),\n",
    "                }\n",
    "            )\n",
    "\n",
    "# Node-level DataFrame\n",
    "df_nodes = pd.DataFrame.from_records(records)\n",
    "\n",
    "GRAPH_TYPE_MAP = {\n",
    "    \"groundtruth_graph\": \"Ground truth\",\n",
    "    \"gpt_generated_graph\": \"Generated\",\n",
    "    \"random_graph\": \"Random\",\n",
    "}\n",
    "df_nodes[\"Graph Type\"] = df_nodes[\"graph_type\"].map(GRAPH_TYPE_MAP).fillna(df_nodes[\"graph_type\"])\n",
    "\n",
    "feature_names = [\n",
    "    \"Degree Centrality\",\n",
    "    \"Closeness Centrality\",\n",
    "    \"Eigenvector Centrality\",\n",
    "    \"Clustering Coefficient\",\n",
    "]\n",
    "df_graphs = (\n",
    "    df_nodes.groupby([\"sample_id\", \"Graph Type\"], as_index=False)[feature_names]\n",
    "    .mean()\n",
    ")\n",
    "\n",
    "df_melted_nodes = pd.melt(\n",
    "    df_nodes,\n",
    "    id_vars=[\"sample_id\", \"graph_type\", \"Graph Type\", \"node\"],\n",
    "    value_vars=feature_names,\n",
    "    var_name=\"Feature\",\n",
    "    value_name=\"Value\",\n",
    ")\n",
    "\n",
    "# Shorter facet labels for the x-axis\n",
    "feature_label_map = {\n",
    "    \"Degree Centrality\": \"Degree Centra.\",\n",
    "    \"Closeness Centrality\": \"Closeness Centra.\",\n",
    "    \"Eigenvector Centrality\": \"Eigenvector Centra.\",\n",
    "    \"Clustering Coefficient\": \"Cluster Coeff.\",\n",
    "}\n",
    "df_melted_nodes[\"Feature_plot\"] = df_melted_nodes[\"Feature\"].map(feature_label_map)\n",
    "\n",
    "# Ensure consistent categorical ordering across plots\n",
    "feature_order = [feature_label_map[f] for f in feature_names]\n",
    "graph_type_order = [\"Ground truth\", \"Generated\", \"Random\"]\n",
    "df_melted_nodes[\"Feature_plot\"] = pd.Categorical(df_melted_nodes[\"Feature_plot\"], categories=feature_order, ordered=True)\n",
    "df_melted_nodes[\"Graph Type\"] = pd.Categorical(df_melted_nodes[\"Graph Type\"], categories=graph_type_order, ordered=True)\n",
    "\n",
    "def plot_node_features(ax, data=df_melted_nodes):\n",
    "    custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "    label_font = {\"fontsize\": 3.5, \"fontfamily\": \"serif\"}\n",
    "    tick_fontsize = 2.8\n",
    "    legend_font = {\n",
    "        \"fontsize\": 2.8,\n",
    "        \"title_fontsize\": 2.8,\n",
    "        \"prop\": {\"family\": \"serif\", \"size\": 2.8},\n",
    "    }\n",
    "\n",
    "    def get_outliers(df):\n",
    "        outlier_rows = []\n",
    "        for (feature, gtype), group in df.groupby([\"Feature\", \"Graph Type\"]):\n",
    "            q1 = group[\"Value\"].quantile(0.25)\n",
    "            q3 = group[\"Value\"].quantile(0.75)\n",
    "            iqr = q3 - q1\n",
    "            lower, upper = q1 - 1.5 * iqr, q3 + 1.5 * iqr\n",
    "            outliers = group[(group[\"Value\"] < lower) | (group[\"Value\"] > upper)]\n",
    "            if not outliers.empty:\n",
    "                outlier_rows.append(outliers)\n",
    "        return pd.concat(outlier_rows) if outlier_rows else pd.DataFrame(columns=df.columns)\n",
    "\n",
    "    _ = get_outliers(data)\n",
    "\n",
    "    sns.boxplot(\n",
    "        data=data,\n",
    "        x=\"Feature_plot\",\n",
    "        y=\"Value\",\n",
    "        hue=\"Graph Type\",         \n",
    "        order=feature_order,\n",
    "        hue_order=[g for g in graph_type_order if g in data[\"Graph Type\"].unique()],\n",
    "        palette=custom_palette[: data[\"Graph Type\"].nunique()],\n",
    "        fliersize=0,\n",
    "        linewidth=0.2,\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    n_types = data[\"Graph Type\"].nunique()\n",
    "    legend = ax.legend(\n",
    "        handles[:n_types],\n",
    "        labels[:n_types],\n",
    "        loc=\"upper right\",\n",
    "        bbox_to_anchor=(0.2, 1.01),\n",
    "        frameon=True,\n",
    "        edgecolor=\"black\",\n",
    "        **legend_font,\n",
    "    )\n",
    "    legend.get_frame().set_linewidth(0.15)\n",
    "\n",
    "    ax.set_ylabel(\"\", fontdict=label_font)\n",
    "    ax.set_xlabel(\"\", fontdict=label_font)\n",
    "\n",
    "    for tick in ax.get_xticklabels():\n",
    "        tick.set_family(\"serif\")\n",
    "    for tick in ax.get_yticklabels():\n",
    "        tick.set_family(\"serif\")\n",
    "\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_linewidth(0.2)\n",
    "\n",
    "    ax.tick_params(width=0.16, length=1.9)\n",
    "    ax.tick_params(axis=\"x\", labelsize=tick_fontsize, pad=1)\n",
    "    ax.tick_params(axis=\"y\", labelsize=tick_fontsize, pad=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d9ec07",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(dpi=100)\n",
    "plot_node_features(ax) \n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e360cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "one_dc = df_nodes[np.isclose(df_nodes[\"Degree Centrality\"], 1.0)]\n",
    "\n",
    "print(\"How many nodes have degree centrality 1.0 by graph_type?\")\n",
    "print(one_dc.groupby(\"graph_type\").size())\n",
    "\n",
    "print(\"\\nExamples:\")\n",
    "print(one_dc.head(10)[[\"sample_id\",\"graph_type\",\"node\",\"Degree Centrality\"]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa4fe309",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "max_deg_rows = []\n",
    "\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue\n",
    "        G_und = G.to_undirected()\n",
    "        if G_und.number_of_nodes() == 0:\n",
    "            max_deg_rows.append({\n",
    "                \"sample_id\": sample_id,\n",
    "                \"graph_type\": gtype,\n",
    "                \"max_degree_centrality\": math.nan,\n",
    "                \"argmax_nodes\": [],\n",
    "                \"matches_graph_id\": False\n",
    "            })\n",
    "            continue\n",
    "\n",
    "        deg = nx.degree_centrality(G_und)\n",
    "        max_val = max(deg.values()) if deg else math.nan\n",
    "        max_nodes = [n for n, v in deg.items() if v == max_val]\n",
    "\n",
    "        match = any(str(n) == str(sample_id) for n in max_nodes)\n",
    "\n",
    "        max_deg_rows.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"max_degree_centrality\": max_val,\n",
    "            \"argmax_nodes\": max_nodes,\n",
    "            \"matches_graph_id\": match\n",
    "        })\n",
    "\n",
    "df_max_deg = pd.DataFrame(max_deg_rows).sort_values([\"graph_type\",\"sample_id\"])\n",
    "print(df_max_deg.head(10))\n",
    "\n",
    "print(\"\\nMatch counts by graph_type:\")\n",
    "print(df_max_deg.groupby(\"graph_type\")[\"matches_graph_id\"].value_counts().unstack(fill_value=0))\n",
    "\n",
    "print(\"\\nExamples where the max-centrality node is NOT the sample_id:\")\n",
    "print(df_max_deg[~df_max_deg[\"matches_graph_id\"]].head(10))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf5da092",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gather Number of Edges and Average Degree per graph\n",
    "avgdeg_data = []\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue\n",
    "        G_und = G.to_undirected()\n",
    "        n = G_und.number_of_nodes()\n",
    "        m = G_und.number_of_edges()\n",
    "        avg_deg = (2 * m / n) if n > 0 else np.nan\n",
    "        avgdeg_data.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"num_nodes\": n,\n",
    "            \"num_edges\": m,\n",
    "            \"avg_degree\": avg_deg\n",
    "        })\n",
    "\n",
    "# Create DataFrame\n",
    "df_avgdeg = pd.DataFrame(avgdeg_data)\n",
    "def plot_edges_vs_avg_degree(ax, df_avgdeg):\n",
    "    custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_avgdeg,\n",
    "        x='num_edges',\n",
    "        y='avg_degree',\n",
    "        hue='graph_type',\n",
    "        palette=custom_palette,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.1,\n",
    "        s=2.2,\n",
    "        ax=ax,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    n_vals = np.arange(2, df_avgdeg['num_nodes'].max() + 1)\n",
    "\n",
    "    tree_edges = n_vals - 1\n",
    "    tree_avg_deg = 2 * (n_vals - 1) / n_vals\n",
    "    tree_line, = ax.plot(tree_edges, tree_avg_deg, 'g--', label='Tree Graph',linewidth=0.3)\n",
    "\n",
    "    complete_edges = n_vals * (n_vals - 1) // 2\n",
    "    complete_avg_deg = n_vals - 1\n",
    "    complete_line, = ax.plot(complete_edges, complete_avg_deg, 'r--', label='Complete Graph',linewidth=0.3)\n",
    "\n",
    " \n",
    "    ax.set_xlabel('Number of Edges', fontsize=2, fontfamily='serif', labelpad=0.3)\n",
    "    ax.set_ylabel('Mean Degree', fontsize=2, fontfamily='serif', labelpad=0.2)\n",
    "    ax.set_xlim(1, 50)\n",
    "    ax.set_ylim(1, 8)\n",
    "    ax.tick_params(labelsize=2.4, pad=1)\n",
    "    for label in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "        spine.set_linewidth(0.2)\n",
    "\n",
    "\n",
    "    ax.tick_params(width=0.1, length=1.1)\n",
    "\n",
    "    legend = ax.legend(\n",
    "        handles=[tree_line, complete_line],\n",
    "        fontsize=4,\n",
    "        title_fontsize=4,\n",
    "        loc='upper left',\n",
    "        frameon=True,\n",
    "        edgecolor='black'\n",
    "    )\n",
    "    legend.get_frame().set_linewidth(0.06)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf5aa6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(dpi=1800)\n",
    "plot_edges_vs_avg_degree(ax, df_avgdeg)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c200624e",
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "def plot_edges_vs_avg_degree_joint(df_avgdeg):\n",
    "    graph_types = df_avgdeg['graph_type'].unique()\n",
    "    color_map = dict(zip(graph_types, custom_palette))\n",
    "\n",
    "    g = sns.JointGrid(\n",
    "        data=df_avgdeg,\n",
    "        x=\"avg_degree\",   \n",
    "        y=\"num_edges\",     \n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        space=0.05\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_avgdeg,\n",
    "        x='avg_degree',\n",
    "        y='num_edges',\n",
    "        hue='graph_type',\n",
    "        palette=color_map,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.3,\n",
    "        s=39,\n",
    "        ax=g.ax_joint,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    n_vals = np.arange(2, df_avgdeg['num_nodes'].max() + 1)\n",
    "    tree_avg_deg = 2 * (n_vals - 1) / n_vals\n",
    "    tree_edges = n_vals - 1\n",
    "    g.ax_joint.plot(tree_avg_deg, tree_edges, 'g--', label='Tree Graph', linewidth=0.9)\n",
    "\n",
    "    complete_avg_deg = n_vals - 1\n",
    "    complete_edges = n_vals * (n_vals - 1) // 2\n",
    "    g.ax_joint.plot(complete_avg_deg, complete_edges, 'r--', label='Complete Graph', linewidth=0.9)\n",
    "\n",
    "    g.ax_joint.set_xlabel('Mean Degree', fontsize=15, fontfamily='serif', labelpad=0.1)\n",
    "    g.ax_joint.set_ylabel('Number of Edges', fontsize=15, fontfamily='serif', labelpad=0.1)\n",
    "    g.ax_joint.set_xlim(1, 8)\n",
    "    g.ax_joint.set_ylim(1, 50)\n",
    "    g.ax_joint.tick_params(labelsize=12, width=0.9, length=1.3)\n",
    "\n",
    "    g.ax_joint.spines['bottom'].set_color('black')\n",
    "    g.ax_joint.spines['left'].set_color('black')\n",
    "    g.ax_joint.spines['bottom'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['left'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['top'].set_color('none')\n",
    "    g.ax_joint.spines['right'].set_color('none')\n",
    "\n",
    "    for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "\n",
    "    for gtype, color in zip(graph_types, custom_palette):\n",
    "        sub = df_avgdeg[df_avgdeg['graph_type'] == gtype]\n",
    "        sns.kdeplot(sub['avg_degree'], ax=g.ax_marg_x, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "        sns.kdeplot(y=sub['num_edges'], ax=g.ax_marg_y, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "\n",
    "    g.ax_marg_y.set_xlim(0, 0.06)\n",
    "    g.ax_marg_x.set_ylim(0, 3.18)\n",
    "\n",
    "\n",
    "    plt.tight_layout()\n",
    "    base_path = Path(\".\")\n",
    "\n",
    "    g.figure.savefig(base_path / \"avg_degree_vs_edges_jointplotsub.png\", dpi=600,bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "plot_edges_vs_avg_degree_joint(df_avgdeg)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1b75dbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "edges_nodes_data = []\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:\n",
    "            continue\n",
    "        G_und = G.to_undirected()\n",
    "        edges_nodes_data.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"num_edges\": G_und.number_of_edges(),\n",
    "            \"num_nodes\": G_und.number_of_nodes()\n",
    "        })\n",
    "\n",
    "df_edges_nodes = pd.DataFrame(edges_nodes_data)\n",
    "\n",
    "graph_types = df_edges_nodes['graph_type'].unique()\n",
    "\n",
    "color_map = dict(zip(graph_types, custom_palette))\n",
    "def plot_nodes_vs_edges(ax, df_edges_nodes):\n",
    "    custom_palette = [\"#76C1FA\", \"#F78FB3\", \"#A8E6CF\"]\n",
    "    graph_types = df_edges_nodes['graph_type'].unique()\n",
    "    color_map = dict(zip(graph_types, custom_palette))\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_edges_nodes,\n",
    "        x=\"num_nodes\",\n",
    "        y=\"num_edges\",\n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.1,\n",
    "        s=2.2,\n",
    "        ax=ax,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    x = np.arange(1, 51)\n",
    "    ax.plot(x, x - 1, 'g--', label='Tree (min edges)',linewidth=0.3)\n",
    "    ax.plot(x, x * (x - 1) // 2, 'r--', label='Complete (max edges)',linewidth=0.3)\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_color('black')\n",
    "        spine.set_linewidth(0.2)\n",
    "\n",
    "    \n",
    "    ax.tick_params(width=0.1, length=1.1)\n",
    "\n",
    "    ax.set_xlim(0, 50)\n",
    "    ax.set_ylim(0, 100)\n",
    "    ax.set_xlabel(\"Number of Nodes\", fontsize=2, fontfamily='serif', labelpad=0.2)\n",
    "    ax.set_ylabel(\"Number of Edges\", fontsize=2, fontfamily='serif', labelpad=0.008)\n",
    "\n",
    "    ax.tick_params(labelsize=6, pad=1)\n",
    "    for label in ax.get_xticklabels() + ax.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    unique_labels = dict(zip(labels, handles))\n",
    "    new_labels = ['Tree (min edges)', 'Complete (max edges)'] + list(graph_types)\n",
    "    new_handles = [unique_labels[l] for l in new_labels if l in unique_labels]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c99abf0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(6, 4), dpi=1800)\n",
    "plot_nodes_vs_edges(ax, df_edges_nodes)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b8a6e78",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_nodes_vs_edges_joint(df_edges_nodes):\n",
    "    graph_types = df_edges_nodes['graph_type'].unique()\n",
    "    color_map = dict(zip(graph_types, custom_palette))\n",
    "\n",
    "    g = sns.JointGrid(\n",
    "        data=df_edges_nodes,\n",
    "        x=\"num_nodes\",\n",
    "        y=\"num_edges\",\n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        space=0.05\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        data=df_edges_nodes,\n",
    "        x=\"num_nodes\",\n",
    "        y=\"num_edges\",\n",
    "        hue=\"graph_type\",\n",
    "        palette=color_map,\n",
    "        alpha=0.99,\n",
    "        edgecolor='white',\n",
    "        marker='o',\n",
    "        linewidth=0.5,\n",
    "        s=40,\n",
    "        ax=g.ax_joint,\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "    x = np.arange(1, 51)\n",
    "    g.ax_joint.plot(x, x - 1, 'g--', linewidth=0.9)\n",
    "    g.ax_joint.plot(x, x * (x - 1) // 2, 'r--', linewidth=0.9)\n",
    "    g.ax_joint.set_xlim(0, 50)\n",
    "    g.ax_joint.set_ylim(0, 100)\n",
    "    g.ax_joint.set_xlabel(\"Number of Nodes\", fontsize=15, fontfamily='serif', labelpad=0.9)\n",
    "    g.ax_joint.set_ylabel(\"Number of Edges\", fontsize=15, fontfamily='serif', labelpad=0.9)\n",
    "    g.ax_joint.tick_params(labelsize=12, width=0.9, length=1.3)\n",
    "\n",
    "    g.ax_joint.spines['bottom'].set_color('black')\n",
    "    g.ax_joint.spines['left'].set_color('black')\n",
    "    g.ax_joint.spines['bottom'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['left'].set_linewidth(0.9)\n",
    "    g.ax_joint.spines['top'].set_color('none')\n",
    "    g.ax_joint.spines['right'].set_color('none')\n",
    "\n",
    "    for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():\n",
    "        label.set_family('serif')\n",
    "\n",
    "    # KDE Marginals\n",
    "    for gtype, color in zip(graph_types, custom_palette):\n",
    "        sub = df_edges_nodes[df_edges_nodes['graph_type'] == gtype]\n",
    "        sns.kdeplot(sub['num_nodes'], ax=g.ax_marg_x, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "        sns.kdeplot(y=sub['num_edges'], ax=g.ax_marg_y, color=color, fill=True, alpha=0.3, linewidth=1)\n",
    "\n",
    "    tree_line = Line2D([0], [0], linestyle='--', color='green', linewidth=0.9, label='Tree graph')\n",
    "    complete_line = Line2D([0], [0], linestyle='--', color='red', linewidth=0.9, label='Complete graph')\n",
    "\n",
    "    scatter_handles = [\n",
    "        Line2D([0], [0], marker='o', color='white', label=gtype,\n",
    "               markerfacecolor=color_map[gtype], markersize=6,\n",
    "               markeredgewidth=0.3, markeredgecolor='white')\n",
    "        for gtype in graph_types\n",
    "    ]\n",
    "\n",
    "    legend_handles = [tree_line, complete_line]\n",
    "\n",
    "    g.ax_joint.legend(handles=legend_handles, fontsize=10, loc='upper left', frameon=True)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    g.figure.savefig(base_path / \"nodes_vs_edges_jointplotsub.png\", dpi=600, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "plot_nodes_vs_edges_joint(df_edges_nodes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e1b838",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df = df_graphs.dropna(subset=[\"Degree Centrality\", \"Clustering Coefficient\"])\n",
    "\n",
    "g = sns.jointplot(\n",
    "    data=plot_df,\n",
    "    x=\"Degree Centrality\",\n",
    "    y=\"Clustering Coefficient\",\n",
    "    hue=\"Graph Type\",                \n",
    "    palette=custom_palette,\n",
    "    kind=\"scatter\",\n",
    "    alpha=0.9,\n",
    "    edgecolor=\"white\",\n",
    "    marker=\"o\",\n",
    "    linewidth=0.5,\n",
    "    s=52,\n",
    ")\n",
    "\n",
    "# Labels and ticks\n",
    "g.ax_joint.set_xlabel(\"Mean Degree Centrality\", fontsize=15, fontfamily=\"serif\", labelpad=0.9)\n",
    "g.ax_joint.set_ylabel(\"Mean Clustering Coefficient\", fontsize=15, fontfamily=\"serif\", labelpad=0.9)\n",
    "g.ax_joint.tick_params(labelsize=13, width=0.9, length=1.3)\n",
    "\n",
    "g.ax_marg_y.set_xlim(0, 3)\n",
    "\n",
    "if getattr(g.ax_joint, \"legend_\", None):\n",
    "    g.ax_joint.legend_.set_title(\"\") \n",
    "\n",
    "plt.tight_layout()\n",
    "g.savefig(base_path / \"centrality_vs_clustering1sub.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5ae9b24",
   "metadata": {},
   "outputs": [],
   "source": [
    "engineer_data = []\n",
    "for sample_id, graphs_dict in merged_graphs.items():\n",
    "    for gtype, G in graphs_dict.items():\n",
    "        if G is None:       \n",
    "            continue \n",
    "        \n",
    "        degrees = [d for _, d in G.degree()] \n",
    "        if len(degrees) > 0:\n",
    "            max_deg = np.max(degrees)\n",
    "            mean_deg = np.mean(degrees)\n",
    "            if mean_deg > 0:\n",
    "                max_mean_ratio = max_deg / mean_deg\n",
    "            else:\n",
    "                max_mean_ratio = np.nan\n",
    "        else:\n",
    "            max_deg = mean_deg = max_mean_ratio = np.nan\n",
    "        engineer_data.append({\n",
    "            \"sample_id\": sample_id,\n",
    "            \"graph_type\": gtype,\n",
    "            \"max_degree\": max_deg,\n",
    "            \"Mean Degree\": mean_deg,\n",
    "            \"Max Mean Ratio\": max_mean_ratio\n",
    "        })\n",
    "\n",
    "df_deg = pd.DataFrame(engineer_data)\n",
    "\n",
    "g=sns.jointplot(\n",
    "    data=df_deg,\n",
    "    x=\"Mean Degree\",                 \n",
    "    y=\"Max Mean Ratio\",\n",
    "    hue=\"graph_type\",\n",
    "    palette=custom_palette,\n",
    "    kind=\"scatter\",\n",
    "    alpha=0.9,\n",
    "    edgecolor='white',\n",
    "    marker='o',\n",
    "    linewidth=0.5,\n",
    "    s=52,\n",
    "    ax=ax\n",
    ")\n",
    "\n",
    "g.ax_marg_x.set_ylim(0, 0.7)\n",
    "g.ax_joint.xaxis.set_major_locator(MultipleLocator(2))\n",
    "\n",
    "g.ax_joint.set_xlabel(\"Mean Degree\", fontsize=15, fontfamily='serif', labelpad=0.9)\n",
    "g.ax_joint.set_ylabel(\"Max Mean Ratio\", fontsize=15, fontfamily='serif', labelpad=0.9)\n",
    "g.ax_joint.tick_params(labelsize=13, width=0.9, length=1.3)\n",
    "g.ax_joint.legend_.remove()\n",
    "g.savefig(base_path / \"mean_vs_ratio2sub.png\", dpi=600, bbox_inches='tight') \n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "477ee543",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create figure\n",
    "fig = plt.figure(figsize=(5.5, 6), dpi=1800)\n",
    "\n",
    "ax_img        = fig.add_axes([0.015, 0.11, 0.44, 0.93])  \n",
    "ax_plot       = fig.add_axes([0.5,  0.503, 0.42, 0.18])   \n",
    "\n",
    "ax_centrality = fig.add_axes([0.014, 0.146, 0.24, 0.47])  \n",
    "ax_ratio      = fig.add_axes([0.235, 0.273, 0.249, 0.217])  \n",
    "ax_avg        = fig.add_axes([0.48, 0.232, 0.23, 0.296])  \n",
    "ax_node       = fig.add_axes([0.704, 0.28, 0.255, 0.204])  \n",
    "\n",
    "def add_label(ax, label, x=0.02, y=0.98):\n",
    "    ax.text(x, y, f'({label})',\n",
    "            transform=ax.transAxes, ha='left', va='top',\n",
    "            fontsize=9, fontweight='bold', zorder=10,\n",
    "            bbox=dict(facecolor='white', edgecolor='none',\n",
    "                      alpha=0.85, pad=1.0))\n",
    "\n",
    "add_label(ax_plot,       'a')   \n",
    "add_label(ax_centrality, 'b')\n",
    "add_label(ax_ratio,      'c')\n",
    "add_label(ax_avg,        'd')\n",
    "add_label(ax_node,       'e')\n",
    "\n",
    "# IMG\n",
    "img = mpimg.imread(base_path2 / \"Slide3.png\")\n",
    "ax_img.imshow(img)\n",
    "ax_img.axis(\"off\")\n",
    "\n",
    "# PLOT\n",
    "plot_node_features(ax_plot)\n",
    "\n",
    "img1 = mpimg.imread(base_path / \"centrality_vs_clustering1sub.png\")\n",
    "ax_centrality.imshow(img1)\n",
    "ax_centrality.axis(\"off\")\n",
    "\n",
    "img2 = mpimg.imread(base_path / \"mean_vs_ratio2sub.png\")\n",
    "ax_ratio.imshow(img2)\n",
    "ax_ratio.axis(\"off\")\n",
    "\n",
    "img_avg = mpimg.imread(base_path / \"avg_degree_vs_edges_jointplotsub.png\")\n",
    "ax_avg.imshow(img_avg)\n",
    "ax_avg.axis(\"off\")\n",
    "\n",
    "img_node = mpimg.imread(base_path / \"nodes_vs_edges_jointplotsub.png\")\n",
    "ax_node.imshow(img_node)\n",
    "ax_node.axis(\"off\")\n",
    "\n",
    "plt.savefig(base_path / \"high_res_output3sub.png\", dpi=600, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4256e3d5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
