{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2cea7be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import networkx as nx\n",
    "from networkx import from_dict_of_lists, subgraph\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np \n",
    "import seaborn as sns \n",
    "import matplotlib.gridspec as gridspec\n",
    "import pickle \n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import random\n",
    "from scipy.sparse import coo_matrix, csr_matrix, csc_matrix\n",
    "import gc\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "001105b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Print versions of pandas, networkx, matplotlib, numpy, and seaborn \n",
    "print(pd.__version__)\n",
    "print(nx.__version__)\n",
    "print(plt.matplotlib.__version__)\n",
    "print(np.__version__)\n",
    "print(sns.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b360ab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Settings for the notebook\n",
    "engine='_gpt-4'\n",
    "HUE_ORDER=['in paper', 'non-isolated', 'isolated']\n",
    "HUE_COLOR=[sns.color_palette('colorblind')[2], \n",
    "           sns.color_palette('colorblind')[8], \n",
    "           sns.color_palette('colorblind')[3]]\n",
    "#Font\n",
    "plt.rcParams[\"font.family\"] = \"Arial\"\n",
    "plt.rcParams[\"text.usetex\"] = True\n",
    "\n",
    "#Dictionary for adapting colors \n",
    "color_dic={'b': sns.color_palette('colorblind')[0],\n",
    "                  'g': sns.color_palette('colorblind')[2],\n",
    "                  'y': sns.color_palette('colorblind')[8],\n",
    "                  'r': sns.color_palette('colorblind')[3],\n",
    "                  'k': sns.color_palette('colorblind')[7]\n",
    "                  }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47933e10",
   "metadata": {},
   "source": [
    "# Load and preprocess data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b80167c",
   "metadata": {},
   "source": [
    "Load the adjacency matrix of SciSciNet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40e90ce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Define base path as current folder\n",
    "base_path = Path(\"data/SciSciNet\")\n",
    "\n",
    "# Load the SciSciNet adjacency matrix\n",
    "papers = np.load(base_path / \"SciSciNet_node_list.npy\")\n",
    "rows = np.load(base_path / \"SciSciNet_adjacency_matrix_rows.npy\")\n",
    "cols = np.load(base_path / \"SciSciNet_adjacency_matrix_cols.npy\")\n",
    "\n",
    "m=len(rows)\n",
    "n=len(papers)\n",
    "A=coo_matrix((np.ones(m), (rows, cols)), shape=(n, n), dtype=np.int32)\n",
    "\n",
    "del rows, cols\n",
    "gc.collect()\n",
    "\n",
    "#Convert A to a csr_matrix\n",
    "A=A.tocsr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "275107f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create a dictionary to converse between papers and entry in the adjacency matrix\n",
    "paper_dict={papers[i]: i for i in range(n)}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f1ba991",
   "metadata": {},
   "source": [
    "Load the ground truth references"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7056a08e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load the ground_truth_references\n",
    "ground_truth=pd.read_csv('ground_truth_references.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d4e6bf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data\n",
    "df_sample = pd.read_csv(\"SciSciNet_Sample_Journals_subset.csv\")\n",
    "df_ground_truth = pd.read_csv(\"ground_truth_references.csv\")\n",
    "df_generated = pd.read_csv(\"generated_references.csv\")\n",
    "df_fields = pd.read_csv(\"SciSciNet_Fields.tsv\", sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af42843",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create column with top_fields for each paper in sample\n",
    "sub_fields = df_fields.loc[\n",
    "    df_fields[\"Field_Type\"] == \"Sub\", [\"FieldID\", \"Field_Name\"]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "232bc151",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Rename Field_Name to sub_field_name\n",
    "sub_fields = sub_fields.rename(columns={\"Field_Name\": \"sub_field_name\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b79787",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sub_field(row):\n",
    "    if pd.isnull(row[\"FieldID\"]) or pd.isnull(row[\"C_f\"]):\n",
    "        return None\n",
    "\n",
    "    field_ids = row[\"FieldID\"].split(\",\")\n",
    "    cf_strings = row[\"C_f\"].split(\",\")\n",
    "    field_types = row['Field_Type'].split(\",\")\n",
    "\n",
    "    #Remove empty spaces from field_ids, cf_strings, and field_types\n",
    "    field_ids = [field_id.strip() for field_id in field_ids if field_id.strip()]\n",
    "    cf_strings = [cf.strip() for cf in cf_strings if cf.strip()]\n",
    "    field_types = [field_type.strip() for field_type in field_types if field_type.strip()]\n",
    "\n",
    "    #Convert to float\n",
    "    try:\n",
    "        cf_values = [float(cf.strip()) for cf in cf_strings]\n",
    "    except ValueError:\n",
    "        return None  \n",
    "        \n",
    "    if len(field_ids) != len(cf_values):\n",
    "        return None  \n",
    "\n",
    "    #Subset field_ids and cf_values based on field_types=='Sub'\n",
    "    sub_field_ids = [field_id for field_id, field_type in zip(field_ids, field_types) if field_type == 'Sub']\n",
    "    sub_cf_values = [cf_value for cf_value, field_type in zip(cf_values, field_types) if field_type == 'Sub']\n",
    "\n",
    "    if not sub_field_ids or not sub_cf_values:\n",
    "        return None\n",
    "\n",
    "    max_index = sub_cf_values.index(max(sub_cf_values))\n",
    "    sub_field_id = sub_field_ids[max_index]\n",
    "\n",
    "    return sub_field_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "198e953c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sample[\"sub_field\"] = df_sample.apply(get_sub_field, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aa0cda9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Merge df_sample with sub_fields to get the sub_field names\n",
    "#Convert sub_field column to int64\n",
    "df_sample['sub_field'] = df_sample['sub_field'].astype('Int64')\n",
    "sub_fields['FieldID'] = sub_fields['FieldID'].astype('Int64')\n",
    "df_sample = df_sample.merge(sub_fields, left_on='sub_field', right_on='FieldID', how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6dd96d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sample.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2139b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sample.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "667ddb73",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Find out the subfields which have more than 30 papers in the sample\n",
    "subfield_counts = df_sample['sub_field_name'].value_counts()\n",
    "subfields_to_keep = subfield_counts[subfield_counts > 10]\n",
    "print(subfields_to_keep)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dc68389",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Count the fraction of papers in the sample that are in subfields_to_keep\n",
    "fraction_in_subfields = df_sample['sub_field_name'].isin(subfields_to_keep.index).mean()\n",
    "print(f\"Fraction of papers in subfields with more than 10 papers: {fraction_in_subfields:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c516f9b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Subset df_sample to only include papers in subfields_to_keep\n",
    "df_sample = df_sample[df_sample['sub_field_name'].isin(subfields_to_keep.index)]\n",
    "df_sample.reset_index(drop=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c441a89",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sample.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f229661",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ground_truth.shape, df_generated.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cab0766",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Subset df_ground_truth and df_generated to only include PaperID which are in df_sample\n",
    "df_ground_truth = df_ground_truth[df_ground_truth['Citing_PaperID'].isin(df_sample['PaperID'])]\n",
    "df_generated = df_generated[df_generated['id'].isin(df_sample['PaperID'])]\n",
    "#Reset index for df_ground_truth and df_generated\n",
    "df_ground_truth.reset_index(drop=True, inplace=True)\n",
    "df_generated.reset_index(drop=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd25111e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ground_truth.shape, df_generated.shape #reduction in references is neglectable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c3df67d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge ground truth and generated references with sample data\n",
    "df_ground_truth = df_ground_truth.merge(\n",
    "    df_sample[\n",
    "        [\n",
    "            \"PaperID\",\n",
    "            \"sub_field_name\",\n",
    "            \"Year\",\n",
    "            \"Citation_Count\",\n",
    "            \"Reference_Count\",\n",
    "            \"Team_Size\",\n",
    "        ]\n",
    "    ],\n",
    "    how=\"left\",\n",
    "    left_on=\"Citing_PaperID\",\n",
    "    right_on=\"PaperID\",\n",
    ")\n",
    "\n",
    "df_generated = df_generated.merge(\n",
    "    df_sample[\n",
    "        [\n",
    "            \"PaperID\",\n",
    "            \"sub_field_name\",\n",
    "            \"Year\",\n",
    "            \"Citation_Count\",\n",
    "            \"Reference_Count\",\n",
    "            \"Team_Size\",\n",
    "        ]\n",
    "    ],\n",
    "    how=\"left\",\n",
    "    left_on=\"id\",\n",
    "    right_on=\"PaperID\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a420c907",
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth=df_ground_truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d209b686",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Randomly reshuffle Cited_PaperID while fixing top_field in ground_truth\n",
    "np.random.seed(0)\n",
    "ground_truth['Cited_PaperID_random']=ground_truth.groupby('sub_field_name')['Cited_PaperID'].transform(lambda x: x.sample(frac=1).values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9edf66",
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth[['Citing_PaperID', 'Cited_PaperID_random']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c1e76b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check how many times Citing_PaperID and Cited_PaperID_random are the same\n",
    "mask=ground_truth['Citing_PaperID']==ground_truth['Cited_PaperID_random']\n",
    "mask.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6c6425",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check per Citing_PaperID how many times Cited_PaperID_random is the same\n",
    "duplicate_df=ground_truth.groupby('Citing_PaperID')['Cited_PaperID_random'].unique().apply(lambda x: len(x))\n",
    "duplicate_df_2=ground_truth.groupby('Citing_PaperID')['Cited_PaperID_random'].apply(lambda x: len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56211770",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Compare duplicate_df and duplicate_df_2 --> still not much of an issue\n",
    "(duplicate_df!=duplicate_df_2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24a804bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create a dictionary with the references for each focal paper\n",
    "full_references_dict={}\n",
    "for id in ground_truth.Citing_PaperID.unique():\n",
    "    mask=ground_truth.Citing_PaperID==id\n",
    "    full_references_dict[id]=ground_truth[mask].Cited_PaperID_random.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dcb160f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Do consistency checks: is the ith row of A equal to the full_references_dict[i]\n",
    "full_references_dict[9673109]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df39b3a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Extract the ith row of paper_dict[9673109]\n",
    "refs=A.getrow(paper_dict[9673109]).nonzero()[1]\n",
    "#Transform the entries back to the paper_ids\n",
    "papers[refs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03141ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create a dictionary of dictionaries where for each key in the full_references_dict, the references of the focal paper form\n",
    "#a dictionary with the corresponding row of A as value, i.e., the references of the references\n",
    "paper_references_dict={}\n",
    "for key in full_references_dict.keys():\n",
    "    #Create the dictionary to store the references of the references\n",
    "    refs_of_ref={}\n",
    "    for ref in full_references_dict[key]:\n",
    "        refs=A.getrow(paper_dict[ref]).nonzero()[1]\n",
    "        refs_of_ref[ref]=papers[refs]\n",
    "    paper_references_dict[key]=refs_of_ref"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c5b8acc",
   "metadata": {},
   "source": [
    "Load the generated references"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6480149a",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated=pd.read_csv('generated_references.csv')\n",
    "#Only keep the rows where id is in df_sample.PaperID\n",
    "generated=generated[generated['id'].isin(df_sample['PaperID'])]\n",
    "#Reset index for generated\n",
    "generated.reset_index(drop=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6edf8ed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create a dictionary with the generated references for each focal paper\n",
    "gen_full_references_dict={}\n",
    "for id in generated.id.unique():\n",
    "    mask1=generated.id==id\n",
    "    mask2=generated.Exists==1.0\n",
    "    mask=mask1 & mask2\n",
    "    gen_full_references_dict[id]=generated[mask].PaperID.values.astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49f4e73d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create a dictionary of dictionaries where for each key in the gen_full_references_dict, the generated references of the focal paper form\n",
    "#a dictionary with the corresponding row of A as value, i.e., the references of the generated references\n",
    "gen_paper_references_dict={}\n",
    "not_in_adjacency_matrix=[]\n",
    "for key in gen_full_references_dict.keys():\n",
    "    #Create the dictionary to store the references of the references\n",
    "    refs_of_ref={}\n",
    "    for ref in gen_full_references_dict[key]:\n",
    "        try:\n",
    "            refs=A.getrow(paper_dict[ref]).nonzero()[1]\n",
    "            refs_of_ref[ref]=papers[refs]\n",
    "        except:\n",
    "            not_in_adjacency_matrix.append(ref)\n",
    "            continue\n",
    "    gen_paper_references_dict[key]=refs_of_ref"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad195383",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Amount of generated references that do not appearin the adjacency matrix, it means they don't have recorded references or citations\n",
    "len(not_in_adjacency_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40f90e2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Delete the adjacency matrix to free up memory\n",
    "del A, papers\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fb64ef5",
   "metadata": {},
   "source": [
    "## Create the graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a750dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "results=generated.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fe7f8d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Initialize the dataframe for the graph analysis\n",
    "\n",
    "#Add the graph results to the info file, but only keep columns on title and id\n",
    "results_graph=pd.DataFrame()\n",
    "results_graph['id']=np.array(list(gen_paper_references_dict.keys()))\n",
    "\n",
    "#Add columns for the graph invariants we calculate\n",
    "\n",
    "#_true: graph invariants of the graph that contains the ground truth references with the focal paper removed\n",
    "#_pred: graph invariants of the graph that contains generated references with the focal paper removed \n",
    "#_true_pred: graph invariants of the graph that contains both the ground truth and generated references\n",
    "#_true_pred_2: graph invariants of the graph that contains both the ground truth and generated references with isolated generations/predictions removed\n",
    "\n",
    "#SUMMARY STATISTICS \n",
    "\n",
    "#Self-citation flag that indicates if GPT predicted the paper to cite itself \n",
    "results_graph.loc[:,'self_citation']=np.nan\n",
    "\n",
    "#Total number of correct generations/ predictions: we exclude self-citations in the count and count duplicate predictions once\n",
    "results_graph.loc[:,'total_predictions']=np.nan\n",
    "\n",
    "#Total number of ground truth references of the focal paper\n",
    "#We only count references that have a PaperID in the adjacency matrix\n",
    "results_graph.loc[:,'total_references']=np.nan \n",
    "\n",
    "#Total number of predictions/ generated references that appear in the focal paper (green nodes)\n",
    "results_graph.loc[:,'in_paper_predictions']=np.nan\n",
    "\n",
    "#Total number of predictions/ generated references that don't appear in the focal paper and are not isolated (yellow nodes)\n",
    "#Note: yellow nodes can also be only connected within each other (i.e. isolated from the ground truth references)\n",
    "results_graph.loc[:,'wrong_predictions']=np.nan \n",
    "\n",
    "#Total number of isolated predictions/ generated references (red nodes)\n",
    "results_graph.loc[:,'isolated_predictions']=np.nan\n",
    "\n",
    "#Total number of ground truth references (green and grey nodes) that are only connected to original paper \n",
    "#We only count references that have a PaperID in the adjacency matrix\n",
    "results_graph.loc[:,'isolated_references']=np.nan\n",
    "\n",
    "#----------------------------------------------------------\n",
    "\n",
    "#GRAPH INVARIANTS\n",
    "\n",
    "#Average clustering coefficient of the directed graph of predicted/generated vs. ground truth references (we exclude the focal paper)\n",
    "#https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.cluster.clustering.html\n",
    "\n",
    "results_graph.loc[:,'avg_clustering_pred']=np.nan #count_zero=False (don't count isolated nodes)\n",
    "results_graph.loc[:,'avg_clustering_true']=np.nan #count_zero=False (don't count isolated nodes)\n",
    "\n",
    "#Average clustering coefficient of the undirected graph of predicted/generated vs. ground truth references (we exclude the focal paper)\n",
    "results_graph.loc[:,'avg_clustering_pred_undir']=np.nan #count_zero=False (don't count isolated nodes)\n",
    "results_graph.loc[:,'avg_clustering_true_undir']=np.nan #count_zero=False (don't count isolated nodes)\n",
    "\n",
    "#----------------------------------------------------------\n",
    "\n",
    "#RATE PLAUSIBILITY OF PREDICTIONS/ GENERATED REFERENCES\n",
    "#wrong predictions/genereations that are connected to the ground truth references (yellow nodes)\n",
    "\n",
    "#Calculate the number of edges between the non-isolated wrong predictions with the ground truth references, \n",
    "#divided by the total number of possible edges\n",
    "results_graph.loc[:,'edge_density']=np.nan \n",
    "\n",
    "#Calculate fraction of non-isolated wrong predictions that are connected to at least one paper in the paper \n",
    "results_graph.loc[:,'edge_density_bool']=np.nan \n",
    "\n",
    "#Calculate average jaccard coefficient between non-isolated wrong predictions and original focal paper (we exclude the green nodes for now)\n",
    "#https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.link_prediction.jaccard_coefficient.html\n",
    "results_graph.loc[:,'avg_jaccard_coeff']=np.nan \n",
    "\n",
    "#Save also individual Jaccard coefficient of non-isolated yellow nodes \n",
    "results.loc[:,'jaccard_coeff']=np.nan \n",
    "\n",
    "#Calculate edge expansion between non-isolated yellow nodes and their complement in the graph H (excludes isolated predictions and green nodes)\n",
    "#https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.cuts.edge_expansion.html\n",
    "#Intuition: average number of edges between yellow nodes and actual introduction references we have to remove to disconnect the two sets \n",
    "results_graph.loc[:,'edge_expansion']=np.nan \n",
    "\n",
    "#----------------------------------------------------------\n",
    "\n",
    "#DESCRIPTION OF PREDICTIONS/ GENERATED REFERENCES\n",
    "\n",
    "#Description to see where generated reference appears in graph of ground truth references\n",
    "results.loc[:,'description']='not available'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de884f0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create the graphs for all focal papers \n",
    "\n",
    "#Dictionary to store the graphs\n",
    "graphs={}\n",
    "\n",
    "#Option to plot graphs and save them  results_graph.shape[0]\n",
    "PLOT=False \n",
    "\n",
    "for focal_index in tqdm(range(results_graph.shape[0])):\n",
    "    focal_id=results_graph.at[focal_index, 'id']\n",
    "\n",
    "    #Create a dictionary of lists for the ground truth references \n",
    "    dictionary_paper_true={}\n",
    "    #Add the references of the focal paper as a first entry \n",
    "    try:\n",
    "        reference_list=full_references_dict[focal_id]\n",
    "        #Remove nan entries from the reference list \n",
    "        reference_list=[x for x in reference_list if not (isinstance(x, (float, np.float64)) and np.isnan(x))] #note: we can only test np.isnan for a float\n",
    "        dictionary_paper_true[focal_id]=reference_list\n",
    "    except:\n",
    "        #focal_id has no reference list\n",
    "        continue\n",
    "        \n",
    "    #Add all the references as keys plus their respective references \n",
    "      \n",
    "    try:\n",
    "        for index, id in enumerate(list(full_references_dict[focal_id])):\n",
    "            try:\n",
    "                #Extract the reference list of the reference from the paper_references_dict\n",
    "                reference_list=paper_references_dict[focal_id][id]\n",
    "                reference_list=list(reference_list)\n",
    "                #Remove None entries from the reference list \n",
    "                reference_list=[x for x in reference_list if x is not None]\n",
    "                dictionary_paper_true[id]=reference_list\n",
    "            except:\n",
    "                #Add an empty list here because we use .keys() later to retrieve information of all the nodes\n",
    "                #Only add the entry if id is not nan\n",
    "                if not (isinstance(id, (float, np.float64)) and np.isnan(id)):\n",
    "                        dictionary_paper_true[id]=[]\n",
    "                #The reference id of focal_id has no reference list\n",
    "                continue\n",
    "    except:\n",
    "        #focal_id has no information on references\n",
    "        continue\n",
    "        \n",
    "    #Create a dictionary of lists for the generated/predicted references, the focal paper is not stored as a key\n",
    "    #to avoid overwrite issues when we update dictionary_paper_true \n",
    "    dictionary_paper_pred={}\n",
    "    #Add all the generated/ predicted references as keys plus their respective references \n",
    "    #Note: Only add a generated/ predicted reference to the dictionary if it exists \n",
    "    \n",
    "    try:\n",
    "        for index, id in enumerate(list(gen_full_references_dict[focal_id])):\n",
    "            try:\n",
    "                #Extract the reference list of the gen_reference from the gen_paper_references_dict\n",
    "                reference_list=gen_paper_references_dict[focal_id][id]\n",
    "                reference_list=list(reference_list)\n",
    "                #Remove None entries from the reference list \n",
    "                reference_list=[x for x in reference_list if x is not None]\n",
    "                dictionary_paper_pred[id]=reference_list\n",
    "            except:\n",
    "                #Add an empty list here because we use .keys() later to retrieve information of all the nodes\n",
    "                dictionary_paper_pred[id]=[]\n",
    "                #The generated/ predicted reference id of focal has no reference list\n",
    "                continue\n",
    "    except:\n",
    "        #focal_id has no generated/ predicted introduction references.\n",
    "        continue\n",
    "    \n",
    "    #Merge the two dictionaries to create a dictionary for the whole graph of generations/ predictions and true references\n",
    "    dictionary=dictionary_paper_true.copy()\n",
    "    #Since the focal paper is not a key in dictionary_paper_pred we have no overwrite issues here\n",
    "    #UPDATE: There might be issues, if GPT predicted the focal paper to cite itself, we should remove the focal_id\n",
    "    #Check whether GPT predicted the focal paper to cite itself and remove from  \n",
    "    #from dictionary_paper_pred.keys() to avoid overwrite issues \n",
    "    self_citation_flag=(focal_id in dictionary_paper_pred.keys())\n",
    "    if self_citation_flag:\n",
    "        #Remove the focal paper from the dictionary_paper_pred.keys()\n",
    "        dictionary_paper_pred.pop(focal_id)\n",
    "    results_graph.at[focal_index, 'self_citation']=self_citation_flag * 1.0\n",
    "    dictionary.update(dictionary_paper_pred)\n",
    "    \n",
    "    #Calculate how many references predicted/generated by GPT were actually cited from the focal paper\n",
    "    \n",
    "    #Ground truth references of the focal paper \n",
    "    B=set(list(dictionary_paper_true.keys()))\n",
    "    B.remove(focal_id)\n",
    "    #All the predicted/generated references that exist\n",
    "    #Note: We do remove duplicate predictions in this step \n",
    "    C=set(list(dictionary_paper_pred.keys()))\n",
    "    \n",
    "    results_graph.at[focal_index, 'total_predictions']=len(list(C))\n",
    "    results_graph.at[focal_index, 'total_references']=len(list(B))\n",
    "        \n",
    "    #GRAPH CREATION \n",
    "\n",
    "    #List of nodes \n",
    "    nodes=[focal_id]+list(set.union(B,C))\n",
    "\n",
    "    #First create the whole graph of true and generated/ predicted references  \n",
    "    G_all=nx.from_dict_of_lists(dictionary, create_using=nx.DiGraph)\n",
    "    #Subset the whole graph by focal paper and it's ground truth and generated references.\n",
    "    G=subgraph(G_all, nodes)\n",
    "    G_undir=nx.Graph(G)\n",
    "    \n",
    "    #Calculate the subgraph H which does not contain isolated nodes \n",
    "    #--> we will use it for subsequent analysis below\n",
    "\n",
    "    H=G.copy()\n",
    "    H.remove_nodes_from(list(nx.isolates(H)))\n",
    "    H=nx.Graph(H)\n",
    "\n",
    "    #Calculate the graph that corresponds to the ground truth references\n",
    "    G_true_0=nx.from_dict_of_lists(dictionary_paper_true, create_using=nx.DiGraph)\n",
    "    #Subset by the papers that appear in the paper \n",
    "    G_true_0=subgraph(G_true_0, list(dictionary_paper_true.keys()))  #Note: this step might now be redundant\n",
    "    #Remove the focal paper from the node list \n",
    "    G_true=G_true_0.copy() #otherwise we can't remove the node \n",
    "    G_true.remove_nodes_from([focal_id])\n",
    "    G_true_undir=nx.Graph(G_true)\n",
    "    \n",
    "    #Calculate the graph that corresponds to the predicted/generated references that exist\n",
    "    G_pred=nx.from_dict_of_lists(dictionary_paper_pred, create_using=nx.DiGraph)\n",
    "    #Note: we use the set C here since we already removed the original conference paper from dictionary_paper_pred.keys() if necessary\n",
    "    G_pred=subgraph(G_pred, list(C))  #Note: this step might now be redundant\n",
    "    G_pred_undir=nx.Graph(G_pred)\n",
    "    \n",
    "    #Note: We don't have to remove the focal paper from the node list, since it is only \n",
    "    #part if GPT predicted the paper to cite itself (only happens when focal paper is purple in drawings)\n",
    "    \n",
    "    #COLOURING OF NODES IN THE GRAPH DRAWINGS\n",
    "    #Blue (b): focal paper\n",
    "    #Green (g): generated references that appear in the paper\n",
    "    #Yellow (y): generated references that don't appear in the paper but are somehow connected to the ground truth references or other generated references\n",
    "    #Orange (r): isolated generated references\n",
    "    #Grey (k): ground truth references not predicted by GPT\n",
    "\n",
    "    #We calculate the nodelist and color list for the graph drawing below\n",
    "    nodes=[focal_id]\n",
    "    color=['b']\n",
    "    #Correctly predicted citations in the paper\n",
    "    correct=list(B&C)\n",
    "    nodes+=correct\n",
    "    color+='g'*len(correct)\n",
    "    results_graph.at[focal_index, 'in_paper_predictions']=len(correct)\n",
    "    #Wrongly predicted citations that are not isolated \n",
    "    wrong=list((C-B)-set(nx.isolates(G)))\n",
    "    nodes+=wrong\n",
    "    color+='y'*len(wrong)\n",
    "    results_graph.at[focal_index, 'wrong_predictions']=len(wrong)\n",
    "    #Isolated predictions\n",
    "    isolated_predictions=list(set(nx.isolates(G)))\n",
    "    nodes+=isolated_predictions\n",
    "    color+='r'*len(isolated_predictions)\n",
    "    results_graph.at[focal_index, 'isolated_predictions']=len(isolated_predictions)\n",
    "    #References not predicted by GPT 4\n",
    "    not_predict=list(B-C)\n",
    "    nodes+=not_predict\n",
    "    color+='k'*len(not_predict)\n",
    "    \n",
    "    #Calculate isolated paper references \n",
    "    results_graph.at[focal_index, 'isolated_references']=len(set(nx.isolates(G_true)))\n",
    "        \n",
    "    #Graph invariants \n",
    "    \n",
    "    #1st analysis: generations/ predictions vs. ground truth\n",
    "    \n",
    "    # Calculate the average clustering coefficient for directed graphs\n",
    "    \n",
    "    try: #avoid division by zero error \n",
    "        avg_clustering_pred= nx.average_clustering(G_pred, count_zeros=False)\n",
    "    except:\n",
    "        avg_clustering_pred=np.nan\n",
    "    results_graph.at[focal_index, 'avg_clustering_pred']=avg_clustering_pred\n",
    "    \n",
    "    try:\n",
    "        avg_clustering_true= nx.average_clustering(G_true, count_zeros=False)\n",
    "    except:\n",
    "        avg_clustering_true=np.nan\n",
    "    results_graph.at[focal_index, 'avg_clustering_true']=avg_clustering_true\n",
    "    \n",
    "    # Calculate the average clustering coefficient for undirected graphs\n",
    "    \n",
    "    try: #avoid division by zero error \n",
    "        avg_clustering_pred_undir=nx.average_clustering(G_pred_undir, count_zeros=False)\n",
    "    except:\n",
    "        avg_clustering_pred_undir=np.nan\n",
    "    results_graph.at[focal_index, 'avg_clustering_pred_undir']=avg_clustering_pred_undir\n",
    "    \n",
    "    try:\n",
    "        avg_clustering_true_undir= nx.average_clustering(G_true_undir, count_zeros=False)\n",
    "    except:\n",
    "        avg_clustering_true_undir= np.nan\n",
    "    results_graph.at[focal_index, 'avg_clustering_true_undir']=avg_clustering_true_undir\n",
    "    \n",
    "    #2nd analysis: rate plausibility of predictions (focus on non-isolated orange nodes, calculated above)\n",
    "    \n",
    "    #Calculate the Jaccard coefficient between each node in wrong and the original conference paper \n",
    "    #Use it to calculate the average Jaccard coefficient between the non-isolated yellow nodes and\n",
    "    #fraction of nodes in wrong that are connected to at least reference of the focal paper\n",
    "    edge_density_bool=0\n",
    "    avg_jaccard_coeff=0\n",
    "    #Avoid division by zero error\n",
    "    if len(wrong)!=0:\n",
    "        for node in wrong:\n",
    "            try:\n",
    "                jaccard_coeff=list(nx.jaccard_coefficient(H, [(node, focal_id)]))[0][2]\n",
    "                edge_density_bool+=(jaccard_coeff>0)\n",
    "                avg_jaccard_coeff+=jaccard_coeff\n",
    "                #Store the jaccard coefficient also in the results dataframe to get an overall summary statistic and not grouped by prompt\n",
    "                #Get the index of the node in the results data frame to store the Jaccard coefficient\n",
    "                #retrieve as index 2 because index is stored for results_graph\n",
    "                mask1=results['PaperID']==node \n",
    "                mask2=results['id']==focal_id\n",
    "                index2=results[mask1&mask2].index[0]\n",
    "                results.at[index2, 'jaccard_coeff']=jaccard_coeff\n",
    "            except:\n",
    "                print('Error in calculating Jaccard coefficient for {0}'.format(focal_id))\n",
    "                continue\n",
    "            \n",
    "        edge_density_bool/=len(wrong)\n",
    "        avg_jaccard_coeff/=len(wrong)\n",
    "    \n",
    "        results_graph.at[focal_index, 'edge_density_bool']=edge_density_bool\n",
    "        results_graph.at[focal_index, 'avg_jaccard_coeff']=avg_jaccard_coeff\n",
    "\n",
    "            \n",
    "    #Calculate edge expansion of wrong and their complement in H (note that H is undirected)\n",
    "    edge_expansion=0 #set to zero here, so every title contains it, will only be added to results if it can be calculated\n",
    "    if len(wrong)!=0:\n",
    "        try:\n",
    "            edge_expansion=nx.edge_expansion(H, wrong)\n",
    "            results_graph.at[focal_index, 'edge_expansion']=edge_expansion\n",
    "        except:\n",
    "            print('Error in calculating edge expansion for {0}'.format(focal_id))\n",
    "            continue        \n",
    "\n",
    "\n",
    "    #Add description of generated references to results dataframe \n",
    "\n",
    "    for Id in list(G_pred.nodes()):\n",
    "        mask1=results['id']==focal_id\n",
    "        mask2=results['PaperID']==Id\n",
    "        if Id in correct:\n",
    "            results.loc[mask1&mask2, 'description']='in paper'\n",
    "        elif Id in wrong:\n",
    "            results.loc[mask1&mask2, 'description']='non-isolated'\n",
    "        else:\n",
    "            results.loc[mask1&mask2, 'description']='isolated'\n",
    "\n",
    "    #Adapt colors with the color dictionary\n",
    "    new_color=[color_dic[x] for x in color]\n",
    "    \n",
    "    #Add legend \n",
    "    legend_elements = [\n",
    "        Line2D([0], [0], marker='s', color='w', label='Paper', markerfacecolor=color_dic['b'],markeredgecolor='black', markersize=10), \n",
    "        Line2D([0], [0], marker='s', color='w', label='In paper', markerfacecolor=color_dic['g'],markeredgecolor='black',  markersize=10),\n",
    "        Line2D([0], [0], marker='s', color='w', label='Non-isolated', markerfacecolor=color_dic['y'],markeredgecolor='black',  markersize=10),\n",
    "        Line2D([0], [0], marker='s', color='w', label='Isolated', markerfacecolor=color_dic['r'],markeredgecolor='black',  markersize=10),\n",
    "        Line2D([0], [0], marker='s', color='w', label='Rest intro', markerfacecolor=color_dic['k'],markeredgecolor='black',  markersize=10),\n",
    "    ]\n",
    "\n",
    "\n",
    "    #Make focal paper bigger (150) and rest intro prediction (50) smaller in size \n",
    "    size_dic={'b': 150,\n",
    "                  'g': 100,\n",
    "                  'y': 100,\n",
    "                  'r': 100,\n",
    "                  'k': 50\n",
    "                  }\n",
    "\n",
    "    size=[size_dic[x] for x in color]\n",
    "\n",
    "    if PLOT:\n",
    "        nx.draw(G, \n",
    "            nodelist=nodes, \n",
    "            pos=nx.kamada_kawai_layout(G), \n",
    "            node_color=new_color, \n",
    "            edgecolors='black',\n",
    "            arrowsize=5, \n",
    "            width=0.4, \n",
    "            alpha=0.8, \n",
    "            node_size=size,\n",
    "            node_shape='s'\n",
    "            )\n",
    "    \n",
    "    #Create a dictionary of the respective graph and details how to draw it \n",
    "    dictionary_graph={'nodes': nodes,\n",
    "                      'edges': np.array(G.edges()),\n",
    "                      'node_color': new_color,\n",
    "                      'node_size': size\n",
    "                      }\n",
    "    \n",
    "    #Store the dictionary in the graphs dictionary\n",
    "    graphs[focal_id]=dictionary_graph\n",
    "    \n",
    "    if PLOT:\n",
    "        plt.legend(handles=legend_elements, \n",
    "               loc='best',\n",
    "               frameon=False,\n",
    "               ) \n",
    "        \n",
    "        plt.title('Edge density bool: {0:.2f}, edge expansion: {1:.2f}'.format(edge_density_bool, edge_expansion))\n",
    "        plt.savefig('graphs_figures/'+str(focal_id)+'.jpeg', \n",
    "                    dpi=300, \n",
    "                   bbox_inches='tight') \n",
    "    \n",
    "        plt.show()    \n",
    "\n",
    "#Save the graphs dictionary as a pickle file\n",
    "with open('graphs_results/graphs_random_subfield_v2.pickle', 'wb') as handle:\n",
    "    pickle.dump(graphs, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "#Save results_graph dataframe as csv file\n",
    "results_graph.to_csv('graphs_results/results_graph_random_subfield_v2.csv')\n",
    "\n",
    "#Save results dataframe with added metadata on citation graphs as csv file\n",
    "results.to_csv('graphs_results/results_graph_metadata_random_subfield_v2.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9e37137",
   "metadata": {},
   "source": [
    "## Data for generation of figures "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be4f48b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load files for generating the figures\n",
    "with open('graphs_results/graphs_random_subfield_v2.pickle', 'rb') as handle:\n",
    "    graphs = pickle.load(handle)\n",
    "\n",
    "results_graph=pd.read_csv('graphs_results/results_graph_random_subfield_v2.csv', index_col=0)\n",
    "results=pd.read_csv('graphs_results/results_graph_metadata_random_subfield_v2.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad0d2010",
   "metadata": {},
   "source": [
    "# Figures"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54cfe524",
   "metadata": {},
   "source": [
    "Graphs displayed to explain pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772aeecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#print version of matplotlib\n",
    "print(plt.matplotlib.__version__)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7451e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplot_mosaic(\n",
    "                            [\n",
    "                            ['A']\n",
    "                            ],\n",
    "                            figsize=(7.08,6.69/2),\n",
    "                            gridspec_kw = {'width_ratios':[1],'height_ratios': [1]},\n",
    "                            dpi = 600\n",
    "                            )\n",
    "\n",
    "#Set up the list of original paper ids\n",
    "original_papers=list(graphs.keys())\n",
    "\n",
    "#Choose the focal paper to be displayed and extract its citation graph\n",
    "paper=2106566585\n",
    "dic=graphs[paper]\n",
    "G=nx.DiGraph()\n",
    "G.add_nodes_from(dic['nodes'])\n",
    "G.add_edges_from(dic['edges'])\n",
    "\n",
    "#Create a GridSpec within 'A'\n",
    "gs = gridspec.GridSpecFromSubplotSpec(\n",
    "    2, 4,\n",
    "    subplot_spec=ax['A'].get_subplotspec()\n",
    ")\n",
    "\n",
    "#Create subplots within 'A'\n",
    "\n",
    "#Plot the central large graph\n",
    "ax_large = fig.add_subplot(gs[0:2, 0:3])  \n",
    "positions=nx.kamada_kawai_layout(G)\n",
    "nx.draw(G, \n",
    "        nodelist=dic['nodes'], \n",
    "        pos=positions, \n",
    "        node_color=dic['node_color'], \n",
    "        edgecolors='black',\n",
    "        arrowsize=5, \n",
    "        width=0.3, \n",
    "        alpha=1, \n",
    "        node_size=[0.25*x for x in dic['node_size']],\n",
    "        node_shape='s',\n",
    "        ax=ax_large\n",
    "        )\n",
    "\n",
    "#Set margins close to 0 to avoid whitespace\n",
    "ax_large.margins(x=0.05, y=0)\n",
    "\n",
    "#Add legend to ax on the middle left \n",
    "\n",
    "color_dic={'b': sns.color_palette('colorblind')[0],\n",
    "                  'g': sns.color_palette('colorblind')[2],\n",
    "                  'y': sns.color_palette('colorblind')[8],\n",
    "                  'r': sns.color_palette('colorblind')[3],\n",
    "                  'k': sns.color_palette('colorblind')[7]\n",
    "                  }\n",
    "\n",
    "ax_legend=fig.add_subplot(gs[0:2, 3])\n",
    "markersize=5\n",
    "legend_elements = [\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Focal paper}', markerfacecolor=color_dic['b'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1), \n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{In paper}', markerfacecolor=color_dic['g'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1),\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Non-isolated}', markerfacecolor=color_dic['y'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1),\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Isolated}', markerfacecolor=color_dic['r'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1),\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Rest}', markerfacecolor=color_dic['k'],\n",
    "                markeredgecolor='black', \n",
    "                markersize=markersize, alpha=1),\n",
    "    ]\n",
    "\n",
    "#Set fontsize of legend\n",
    "ax_legend.legend(handles=legend_elements,\n",
    "                loc='right',\n",
    "                frameon=False,\n",
    "                handlelength=0.5,\n",
    "                handletextpad=0.5,  \n",
    "                prop={'size': 8}\n",
    "                )\n",
    "\n",
    "ax_legend.set_axis_off()\n",
    "\n",
    "#Despine figure\n",
    "sns.despine(fig, left=True, bottom=True)\n",
    "#Remove xticks and yticks\n",
    "ax['A'].set_xticks([])\n",
    "ax['A'].set_yticks([])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0707689c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#For the graph we display in the main paper, extract the subgraph of the ground truth references plus the focal paper\n",
    "#drawn with the same positions as the original graph\n",
    "\n",
    "fig, ax = plt.subplot_mosaic(\n",
    "                            [\n",
    "                            ['A']\n",
    "                            ],\n",
    "                            figsize=(7.08,6.69/2.5),\n",
    "                            gridspec_kw = {'width_ratios':[1],'height_ratios': [1]},\n",
    "                            dpi = 600\n",
    "                            )\n",
    "\n",
    "#Create a GridSpec within 'A'\n",
    "gs = gridspec.GridSpecFromSubplotSpec(\n",
    "    2, 4,\n",
    "    subplot_spec=ax['A'].get_subplotspec()\n",
    ")\n",
    "\n",
    "color_dic={'b': sns.color_palette('colorblind')[0],\n",
    "                  'g': sns.color_palette('colorblind')[2],\n",
    "                  'y': sns.color_palette('colorblind')[8],\n",
    "                  'r': sns.color_palette('colorblind')[3],\n",
    "                  'k': sns.color_palette('colorblind')[7]\n",
    "                  }\n",
    "\n",
    "#Filter for nodes in the actual introduction, i.e. for blue, green and gray nodes \n",
    "mask=[x in [color_dic['b'], color_dic['g'], color_dic['k']] for x in dic['node_color']]\n",
    "\n",
    "#Extract node color of intro nodes\n",
    "node_color_intro=[dic['node_color'][i] for i, x in enumerate(mask) if x]\n",
    "#If color is green, make it grey\n",
    "node_color_intro=[color_dic['k'] if x==color_dic['g'] else x for x in node_color_intro]\n",
    "\n",
    "#Extract intro node sizes\n",
    "node_size_intro=[dic['node_size'][i] for i, x in enumerate(mask) if x]\n",
    "#If node size is 100, make it 50\n",
    "node_size_intro=[50 if x==100 else x for x in node_size_intro]\n",
    "\n",
    "#Extract intro nodes\n",
    "nodes_intro=[dic['nodes'][i] for i, x in enumerate(mask) if x]\n",
    "#Extract the positions of the nodes in the introduction\n",
    "positions_intro={k: positions[k] for k in nodes_intro}\n",
    "\n",
    "#Get the subgraph of the introduction nodes and edges\n",
    "H=G.subgraph(nodes_intro)\n",
    "\n",
    "#Plot the central large graph\n",
    "ax_large = fig.add_subplot(gs[0:2, 0:3])  \n",
    "\n",
    "#Draw the introduction nodes and edges \n",
    "nx.draw(H, \n",
    "        nodelist=nodes_intro, \n",
    "        pos=positions_intro, \n",
    "        node_color=node_color_intro, \n",
    "        node_size=node_size_intro, \n",
    "        edgecolors='black',\n",
    "        arrowsize=5, \n",
    "        width=0.3, \n",
    "        alpha=1, \n",
    "        node_shape='s',\n",
    "        ax=ax_large\n",
    "        )\n",
    "\n",
    "#Add legend to ax on the middle left \n",
    "\n",
    "color_dic={'b': sns.color_palette('colorblind')[0],\n",
    "                  'k': sns.color_palette('colorblind')[7]\n",
    "                  }\n",
    "\n",
    "ax_legend=fig.add_subplot(gs[0:2, 3])\n",
    "markersize=5\n",
    "legend_elements = [\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Focal paper}', markerfacecolor=color_dic['b'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1), \n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{References}', markerfacecolor=color_dic['k'],\n",
    "                markeredgecolor='black', \n",
    "                markersize=markersize, alpha=1),\n",
    "    ]\n",
    "\n",
    "#Set fontsize of legend\n",
    "ax_legend.legend(handles=legend_elements,\n",
    "                loc='right',\n",
    "                frameon=False,\n",
    "                handlelength=0.5,\n",
    "                handletextpad=0.5,  \n",
    "                prop={'size': 8}\n",
    "                )\n",
    "\n",
    "ax_legend.set_axis_off()\n",
    "\n",
    "#Despine figure\n",
    "sns.despine(fig, left=True, bottom=True)\n",
    "#Remove xticks and yticks\n",
    "ax['A'].set_xticks([])\n",
    "ax['A'].set_yticks([])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a07de818",
   "metadata": {},
   "source": [
    "## Main Graph Analysis Figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "993d6d1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Figure 4 with reference count instead of fields of study \n",
    "fig, ax = plt.subplot_mosaic(\n",
    "                            [\n",
    "                            ['A', 'A', 'A', '.', 'B'],\n",
    "                            ['A', 'A', 'A', '.', '.'],\n",
    "                            ['A','A', 'A', '.', 'C'],\n",
    "                            ['A', 'A', 'A', '.', '.'],\n",
    "                            ['D', '.', 'E', '.', 'F'],\n",
    "                            ['.', '.', '.', '.', '.']\n",
    "                            ],\n",
    "                            figsize=(7.08,6.69),\n",
    "                            gridspec_kw = {'width_ratios':[1, 0.1 , 1, 0.2, 1],'height_ratios': [1, 0, 1, 0.05, 1, 0.3]},\n",
    "                            dpi = 600\n",
    "                            )\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Arial\"\n",
    "plt.rcParams['legend.title_fontsize'] = 7\n",
    "plt.rcParams['text.usetex'] = True\n",
    "\n",
    "# annotations\n",
    "annotations = {\n",
    "    'A': r'\\textbf{a}',\n",
    "    'B': r'\\textbf{b}',\n",
    "    'C': r'\\textbf{c}',\n",
    "    'D': r'\\textbf{d}',\n",
    "    'E': r'\\textbf{e}',\n",
    "    'F': r'\\textbf{f}',\n",
    "}\n",
    " \n",
    "# Annotate each subplot\n",
    "for key, axis in ax.items():\n",
    "    axis.text(-0.035, 1.10, annotations[key], transform=axis.transAxes,\n",
    "            fontsize=7, fontweight='bold', va='top', ha='left')\n",
    "    \n",
    "#For each axis in ax set the fontsize of the x and y labels to 8\n",
    "for key, axis in ax.items():\n",
    "    axis.tick_params(axis='both', which='major', labelsize=8)\n",
    "    axis.tick_params(axis='both', which='minor', labelsize=8)\n",
    "    axis.set_xlabel(axis.get_xlabel(), fontsize=8)\n",
    "    axis.set_ylabel(axis.get_ylabel(), fontsize=8)\n",
    "\n",
    "#-----------------------------------A-----------------------------------\n",
    "#CITATION GRAPHS \n",
    "\n",
    "#Create a GridSpec within 'A'\n",
    "gs = gridspec.GridSpecFromSubplotSpec(\n",
    "    4, 4,\n",
    "    subplot_spec=ax['A'].get_subplotspec()\n",
    ")\n",
    "\n",
    "#Create subplots within 'A'\n",
    "\n",
    "#Set up the list of original paper ids\n",
    "original_papers=list(graphs.keys())\n",
    "\n",
    "#Choose the focal paper to be displayed and extract its citation graph\n",
    "paper=2151817970\n",
    "dic=graphs[paper]\n",
    "G=nx.DiGraph()\n",
    "G.add_nodes_from(dic['nodes'])\n",
    "G.add_edges_from(dic['edges'])\n",
    "\n",
    "#Plot the central large graph\n",
    "ax_large = fig.add_subplot(gs[1:3, 0:3])  \n",
    "nx.draw(G, \n",
    "        nodelist=dic['nodes'], \n",
    "        pos=nx.kamada_kawai_layout(G), \n",
    "        node_color=dic['node_color'], \n",
    "        edgecolors='black',\n",
    "        arrowsize=5, \n",
    "        width=0.3, \n",
    "        alpha=1, \n",
    "        node_size=[0.25*x for x in dic['node_size']],\n",
    "        node_shape='s',\n",
    "        ax=ax_large\n",
    "        )\n",
    "\n",
    "#Set margins close to 0 to avoid whitespace\n",
    "ax_large.margins(x=0.05, y=0)\n",
    "\n",
    "#Add legend to ax on the middle left \n",
    "\n",
    "color_dic={'b': sns.color_palette('colorblind')[0],\n",
    "                  'g': sns.color_palette('colorblind')[2],\n",
    "                  'y': sns.color_palette('colorblind')[8],\n",
    "                  'r': sns.color_palette('colorblind')[3],\n",
    "                  'k': sns.color_palette('colorblind')[7]\n",
    "                  }\n",
    "\n",
    "ax_legend=fig.add_subplot(gs[1:3, 3])\n",
    "markersize=5\n",
    "legend_elements = [\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Focal paper}', markerfacecolor=color_dic['b'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1), \n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{In paper}', markerfacecolor=color_dic['g'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1),\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Non-isolated}', markerfacecolor=color_dic['y'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1),\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Isolated}', markerfacecolor=color_dic['r'], \n",
    "               markeredgecolor='black', \n",
    "               markersize=markersize, alpha=1),\n",
    "        Line2D([0], [0], marker='s', color='w', label=r'\\textbf{Rest}', markerfacecolor=color_dic['k'],\n",
    "                markeredgecolor='black', \n",
    "                markersize=markersize, alpha=1),\n",
    "    ]\n",
    "\n",
    "#Set fontsize of legend\n",
    "ax_legend.legend(handles=legend_elements,\n",
    "                loc='right',\n",
    "                frameon=False,\n",
    "                handlelength=0.5,\n",
    "                handletextpad=0.5,  \n",
    "                prop={'size': 8}\n",
    "                )\n",
    "\n",
    "ax_legend.set_axis_off()\n",
    "\n",
    "#Smaller graphs\n",
    "#List of positions for the smaller graphs\n",
    "positions=[(0, [0,2]), (0, [2,4]), (3, [0,2]), (3, [2,4])]\n",
    "\n",
    "#Choose 4 representative papers to display \n",
    "papers_to_display=[\n",
    "                   2149226648,\n",
    "                   2165905899,\n",
    "                   2462155872, \n",
    "                   1875152935\n",
    "                   ]\n",
    "\n",
    "for pos, paper in zip(positions, papers_to_display):\n",
    "    ax_small = fig.add_subplot(gs[pos[0], pos[1][0]:pos[1][1]])\n",
    "    dic=graphs[paper]\n",
    "    G=nx.DiGraph()\n",
    "    G.add_nodes_from(dic['nodes'])\n",
    "    G.add_edges_from(dic['edges'])\n",
    "    nx.draw(G, \n",
    "            nodelist=dic['nodes'], \n",
    "            pos=nx.kamada_kawai_layout(G), \n",
    "            node_color=dic['node_color'], \n",
    "            edgecolors='black',\n",
    "            linewidths=0.5,\n",
    "            arrowsize=3, \n",
    "            width=0.2, \n",
    "            alpha=1, \n",
    "            node_size=[0.075*x for x in dic['node_size']],\n",
    "            node_shape='s',\n",
    "            ax=ax_small\n",
    "            )\n",
    "    #Set margins to 0 to avoid whitespace\n",
    "    ax_small.margins(x=0, y=0)\n",
    "    \n",
    "#Remove plot area of axis A\n",
    "ax['A'].set_axis_off()\n",
    "\n",
    "#-----------------------------------B-----------------------------------\n",
    "#FRACTIONS OF GENERATIONS PER FOCAL PAPER \n",
    "\n",
    "results_graph['isolated_predictions_fraction']=results_graph['isolated_predictions']/results_graph['total_predictions']\n",
    "results_graph['in_paper_predictions_fraction']=(results_graph['in_paper_predictions'])/results_graph['total_predictions']\n",
    "results_graph['non_isolated_predictions_fraction']=results_graph['wrong_predictions']/results_graph['total_predictions']\n",
    "\n",
    "#Convert the added columns above to long format \n",
    "results_graph_2=pd.melt(results_graph, \n",
    "                      id_vars=['id', 'total_predictions'], \n",
    "                      value_vars=['isolated_predictions_fraction', 'in_paper_predictions_fraction', 'non_isolated_predictions_fraction'], \n",
    "                      var_name='prediction_type', \n",
    "                      value_name='fraction')\n",
    "\n",
    "#Rename the prediction types to 'in paper', 'non-isolated', 'isolated'\n",
    "\n",
    "results_graph_2['prediction_type']=results_graph_2['prediction_type'].replace('isolated_predictions_fraction', 'isolated')\n",
    "results_graph_2['prediction_type']=results_graph_2['prediction_type'].replace('in_paper_predictions_fraction', 'in paper')\n",
    "results_graph_2['prediction_type']=results_graph_2['prediction_type'].replace('non_isolated_predictions_fraction', 'non-isolated')\n",
    "\n",
    "sns.boxenplot(data=results_graph_2, \n",
    "             x='prediction_type',\n",
    "             y='fraction',\n",
    "             hue='prediction_type',\n",
    "             order=HUE_ORDER,\n",
    "             hue_order=HUE_ORDER,\n",
    "             palette=HUE_COLOR,\n",
    "             showfliers=True,\n",
    "             edgecolor='black',  \n",
    "             legend=None,\n",
    "             ax=ax['B']\n",
    "             )\n",
    "\n",
    "#Remove x-axis labels\n",
    "ax['B'].set_xticklabels([])\n",
    "ax['B'].set_xlabel('')\n",
    "\n",
    "#Set y-axis label\n",
    "ax['B'].set_ylabel(r'$\\frac{\\mathrm{existing~generations}}{\\mathrm{focal~paper}}$')\n",
    "\n",
    "#Set y-axis ticks\n",
    "ax['B'].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1], \n",
    "           ['0.0','0.2','0.4', '0.6', '0.8', '1.0']\n",
    "          )\n",
    "\n",
    "#Remove axis on right, top and bottom \n",
    "ax['B'].spines['right'].set_visible(False)\n",
    "ax['B'].spines['top'].set_visible(False)\n",
    "\n",
    "\n",
    "#-----------------------------------C-----------------------------------\n",
    "#CITATION COUNT OF GENERATED REFERENCES PER DESCRIPTION CATEGORY\n",
    "\n",
    "mask=results['Exists']==1.0\n",
    "results['Citation Count_log']=results['Citations']+0.1\n",
    "\n",
    "sns.boxenplot(data=results[mask],\n",
    "            y='Citation Count_log',\n",
    "            x='description',\n",
    "            log_scale=True,\n",
    "            hue='description',\n",
    "            order=HUE_ORDER,\n",
    "            hue_order=HUE_ORDER,\n",
    "            palette=HUE_COLOR,\n",
    "            showfliers=True,\n",
    "            edgecolor='black', \n",
    "            legend=None,\n",
    "            ax=ax['C']\n",
    "            )\n",
    " \n",
    "#Remove x-axis labels\n",
    "ax['C'].set_xticklabels([])\n",
    "ax['C'].set_xlabel('')\n",
    "\n",
    "#Set y-axis label\n",
    "ax['C'].set_ylabel('Citation count')\n",
    "\n",
    "#Set y-axis ticks\n",
    "ax['C'].set_yticks([0.1, 1, 10, 100, 1000, 10000, 100000], \n",
    "           ['0',r'$10^0$',r'$10^1$', r'$10^2$', r'$10^3$', r'$10^4$', r'$10^5$']\n",
    "          )\n",
    "\n",
    "#Remove axis on right, top and bottom \n",
    "ax['C'].spines['right'].set_visible(False)\n",
    "ax['C'].spines['top'].set_visible(False)\n",
    "\n",
    "#-----------------------------------F-----------------------------------\n",
    "#REFERENCE COUNT OF GENERATED REFERENCES PER DESCRIPTION CATEGORY\n",
    "\n",
    "mask=results['Exists']==1.0\n",
    "results['Reference Count_log']=results['References']+0.1\n",
    "\n",
    "sns.boxenplot(data=results[mask],\n",
    "            y='Reference Count_log',\n",
    "            x='description',\n",
    "            log_scale=True,\n",
    "            hue='description',\n",
    "            order=HUE_ORDER,\n",
    "            hue_order=HUE_ORDER,\n",
    "            palette=HUE_COLOR,\n",
    "            showfliers=True,\n",
    "            edgecolor='black',  \n",
    "            legend=None,\n",
    "            ax=ax['F']\n",
    "            )\n",
    "\n",
    "\n",
    "#Remove x-axis labels\n",
    "ax['F'].set_xlabel('')\n",
    "ax['F'].set_xticklabels([])\n",
    "\n",
    "#Show x-axis ticks with angle and relabel them\n",
    "ax['F'].tick_params(axis='x', \n",
    "                    rotation=45\n",
    "                    )\n",
    "\n",
    "#Rename x-axis ticks so they start with an uppercase letter\n",
    "ax['F'].set_xticklabels(['In paper', 'Non-isolated', 'Isolated'])\n",
    "\n",
    "#Set y-axis label\n",
    "ax['F'].set_ylabel('Reference count')\n",
    "\n",
    "#Set y-axis ticks\n",
    "ax['F'].set_yticks([0.1, 1, 10, 100, 1000], \n",
    "           ['0', r'$10^0$',r'$10^1$', r'$10^2$', r'$10^3$']\n",
    "          )\n",
    "\n",
    "#Remove axis on right, top and bottom \n",
    "ax['F'].spines['right'].set_visible(False)\n",
    "ax['F'].spines['top'].set_visible(False)\n",
    "\n",
    "#-----------------------------------D-----------------------------------\n",
    "#Plot the average clustering coefficient of the undirected graphs as histogram  \n",
    "\n",
    "mask1=results_graph['avg_clustering_pred_undir'].isna()\n",
    "mask2=results_graph['avg_clustering_true_undir'].isna()\n",
    "generations=results_graph['avg_clustering_pred_undir'][~mask1]\n",
    "ground_truth=results_graph['avg_clustering_true_undir'][~mask2]\n",
    "\n",
    "generations = pd.DataFrame({'Average clustering coefficient': generations.values})\n",
    "ground_truth = pd.DataFrame({'Average clustering coefficient': ground_truth.values})\n",
    "\n",
    "color_palette = {\n",
    "    \"Generations\": sns.color_palette('colorblind')[1],\n",
    "    \"Ground truth\": sns.color_palette('colorblind')[0]\n",
    "}\n",
    "\n",
    "# Combine the two dataframes with an identifier column\n",
    "df_combined = pd.concat([generations.assign(dataset='Generations'), ground_truth.assign(dataset='Ground truth')], ignore_index=True)\n",
    "\n",
    "# Plot\n",
    "sns.histplot(data=df_combined, \n",
    "             x='Average clustering coefficient', \n",
    "             common_norm=False, \n",
    "             hue='dataset', \n",
    "             multiple=\"layer\", \n",
    "             stat='density', \n",
    "             palette=color_palette, \n",
    "             alpha=0.5,\n",
    "             edgecolor=None,\n",
    "             shrink=0.95,\n",
    "             linewidth=1,\n",
    "             ax=ax['D']\n",
    "            )\n",
    "\n",
    "#Set xlim \n",
    "ax['D'].set_xlim(0, 1)\n",
    "\n",
    "#Set ylim to make room for legend \n",
    "ax['D'].set_ylim(0, 7)\n",
    "\n",
    "#Set y-axis ticks\n",
    "ax['D'].set_yticks(np.arange(0, 4.1, 1))\n",
    "\n",
    "#Set legend\n",
    "legend = ax['D'].get_legend()\n",
    "legend.set_title('')\n",
    "legend.get_frame().set_linewidth(0.0)\n",
    "#Set legend labels\n",
    "legend_labels=[r'\\textbf{Generations}', r'\\textbf{Ground truth}']\n",
    "for i, text in enumerate(legend.texts):\n",
    "    text.set_text(legend_labels[i])\n",
    "#Reduce fontsize of legend\n",
    "plt.setp(legend.get_texts(), fontsize='8')\n",
    "\n",
    "#Remove axis on top and right \n",
    "ax['D'].spines['right'].set_visible(False)\n",
    "ax['D'].spines['top'].set_visible(False)\n",
    "\n",
    "#-----------------------------------E-----------------------------------\n",
    "#Plot a histogram of the edge_density bool for the non-isolated generations \n",
    "#Fraction of non-isolated generations that are connected to at least one introduction reference per paper/prompt\n",
    "\n",
    "mask=(results_graph['edge_density_bool'].isna())\n",
    "sns.histplot(results_graph['edge_density_bool'][~mask], \n",
    "             bins=10, \n",
    "             label=r'\\textbf{Non-isolated}', \n",
    "             edgecolor=None,\n",
    "             legend=True,\n",
    "             color=sns.color_palette('colorblind')[8], \n",
    "             shrink=0.95,\n",
    "             alpha=1,\n",
    "             ax=ax['E'],\n",
    "             )\n",
    "\n",
    "#Set x-label\n",
    "ax['E'].set_xlabel('Edge density bool')\n",
    "\n",
    "#Set y-label\n",
    "ax['E'].set_ylabel('Count')\n",
    "\n",
    "#Show legend with smaller fontsize\n",
    "ax['E'].legend(frameon=False,\n",
    "               handlelength=1,\n",
    "               handletextpad=0.5,\n",
    "               prop={'size': 8, 'weight': 'bold'},\n",
    "               )\n",
    "\n",
    "#Remove axis on top and right \n",
    "ax['E'].spines['right'].set_visible(False)\n",
    "ax['E'].spines['top'].set_visible(False)\n",
    "\n",
    "inset_ax = inset_axes(ax['E'], \n",
    "                      width=\"100%\", height=\"100%\", \n",
    "                      bbox_to_anchor=(0.28, 0.52, 0.5, 0.3), \n",
    "                      bbox_transform=ax['E'].transAxes)\n",
    "\n",
    "#Plot a histogram of the edge expansion per focal paper  \n",
    "mask=(results_graph['edge_expansion'].isna())\n",
    "sns.histplot(results_graph['edge_expansion'][~mask], \n",
    "             label='non-isolated predictions',\n",
    "             binwidth=0.5,\n",
    "             color=sns.color_palette('colorblind')[8], \n",
    "             alpha=1,\n",
    "             edgecolor=None,\n",
    "             shrink=0.9,\n",
    "             legend=None,\n",
    "             ax=inset_ax,\n",
    "             )     \n",
    "\n",
    "#Remove y-axis label\n",
    "inset_ax.set_ylabel('')\n",
    "\n",
    "#Set x-axis label with smaller fontsize\n",
    "inset_ax.set_xlabel('Edge expansion', fontsize=7)\n",
    "\n",
    "#Set x-axis ticks\n",
    "inset_ax.set_xticks(np.arange(0, 12, 2))\n",
    "#Set x-lim\n",
    "inset_ax.set_xlim(0, 10)\n",
    "#Set y-axis ticks\n",
    "#inset_ax.set_yticks(np.arange(0, 40, 10))\n",
    "#Set fontsize of xticks \n",
    "plt.setp(inset_ax.get_xticklabels(), fontsize=7)\n",
    "plt.setp(inset_ax.get_yticklabels(), fontsize=7)\n",
    "\n",
    "#Remove axis on top and right \n",
    "inset_ax.spines['right'].set_visible(False)\n",
    "inset_ax.spines['top'].set_visible(False)\n",
    "\n",
    "#Save figure as pdf\n",
    "plt.savefig('graphs_results/fig_graphs_random_subfield_v2.pdf', dpi=1000, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcaec1ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.boxenplot(data=results_graph_2, \n",
    "             x='prediction_type',\n",
    "             y='fraction',\n",
    "             hue='prediction_type',\n",
    "             order=HUE_ORDER,\n",
    "             hue_order=HUE_ORDER,\n",
    "             palette=HUE_COLOR,\n",
    "             showfliers=True,\n",
    "             edgecolor='black',  \n",
    "             legend=None,\n",
    "             )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
