{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "343c3baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import networkx as nx\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "# Define base path as the current folder\n",
    "base_path = Path(\".\")\n",
    "\n",
    "# pickle file\n",
    "merged_pickle_path = base_path / \"merged_graphs.pickle\"\n",
    "with open(merged_pickle_path, \"rb\") as f:\n",
    "    merged_graphs = pickle.load(f)\n",
    "\n",
    "# embeddings\n",
    "focal_embeddings = np.load(base_path / \"focal_embeddings_openai.npy\")  \n",
    "ground_truth_embeddings = np.load(base_path / \"ground_truth_embeddings_openai.npy\") \n",
    "generated_embeddings = np.load(base_path / \"generated_embeddings_openai.npy\")  \n",
    "\n",
    "df_sample = pd.read_csv(base_path / \"SciSciNet_Sample_Journals_subset.csv\")\n",
    "df_ground_truth = pd.read_csv(base_path / \"ground_truth_references.csv\")\n",
    "df_generated = pd.read_csv(base_path / \"generated_references.csv\")\n",
    "\n",
    "# objects to create\n",
    "ground_truth_dict = {}\n",
    "generated_dict = {}\n",
    "random_dict = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4543233",
   "metadata": {},
   "outputs": [],
   "source": [
    "tqdm.pandas(\n",
    "    desc=\"Processing papers\",\n",
    "    total=len(df_sample),\n",
    "    bar_format=\"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]\",\n",
    ")\n",
    "for i in tqdm(range(len(df_sample))):\n",
    "    # get the paper id\n",
    "    focal_paper_id = df_sample.PaperID[i]\n",
    "\n",
    "    # if focal_paper_id is not in merged graphs, skip\n",
    "    if focal_paper_id not in merged_graphs:\n",
    "        continue\n",
    "\n",
    "    # select focal paper embedding\n",
    "    focal_embedding = focal_embeddings[df_sample.PaperID == focal_paper_id]\n",
    "\n",
    "    # create list of ground truth references ids and remove focal paper id\n",
    "    gt_nodes = list(merged_graphs[focal_paper_id][\"groundtruth_graph\"].nodes())\n",
    "    gt_nodes.remove(focal_paper_id)\n",
    "    gt_indices = df_ground_truth[\n",
    "        (df_ground_truth.Cited_PaperID.isin(gt_nodes))\n",
    "        & (df_ground_truth.Citing_PaperID == focal_paper_id) \n",
    "        # important to also match focal paper id\n",
    "    ].index\n",
    "    gt_embeddings = ground_truth_embeddings[gt_indices]\n",
    "    gt_dict = {k: v for k, v in zip(df_ground_truth.Cited_PaperID[gt_indices], gt_embeddings)}\n",
    "\n",
    "    ground_truth_dict.update(\n",
    "        {\n",
    "            focal_paper_id: {\n",
    "                \"focal_embedding\": focal_embedding,\n",
    "                \"reference_embeddings\": gt_dict,\n",
    "            }\n",
    "        }\n",
    "    )\n",
    "\n",
    "    # create list of generated references ids and remove focal paper id\n",
    "    gen_nodes = list(merged_graphs[focal_paper_id][\"gpt_generated_graph\"].nodes())\n",
    "    gen_nodes.remove(focal_paper_id)\n",
    "    gen_indices = df_generated[\n",
    "        (df_generated.PaperID.isin(gen_nodes))\n",
    "        & (df_generated.id == focal_paper_id)\n",
    "        # important to also match focal paper id\n",
    "    ].index\n",
    "    gen_embeddings = generated_embeddings[gen_indices]\n",
    "    gen_dict = {k: v for k, v in zip(df_generated.PaperID[gen_indices], gen_embeddings)}\n",
    "\n",
    "    generated_dict.update(\n",
    "        {\n",
    "            focal_paper_id: {\n",
    "                \"focal_embedding\": focal_embedding,\n",
    "                \"reference_embeddings\": gen_dict,\n",
    "            }\n",
    "        }\n",
    "    )\n",
    "\n",
    "    random_nodes = list(merged_graphs[focal_paper_id][\"random_graph\"].nodes())\n",
    "    random_nodes.remove(focal_paper_id)\n",
    "    random_indices = df_ground_truth[\n",
    "        df_ground_truth.Cited_PaperID.isin(random_nodes)\n",
    "    ].drop_duplicates(\n",
    "        # remove duplicates; but still same id in the end\n",
    "        subset=[\"Cited_PaperID\"]\n",
    "    ).index\n",
    "    random_embeddings = ground_truth_embeddings[random_indices]\n",
    "    rand_dict = {\n",
    "        k: v for k, v in zip(df_ground_truth.Cited_PaperID[random_indices], random_embeddings)\n",
    "    }\n",
    "\n",
    "    random_dict.update(\n",
    "        {\n",
    "            focal_paper_id: {\n",
    "                \"focal_embedding\": focal_embedding,\n",
    "                \"reference_embeddings\": rand_dict,\n",
    "            }\n",
    "        }\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a1fc4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('ground_truth_dictionary.pkl', 'wb') as f:\n",
    "    pickle.dump(ground_truth_dict, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "with open('generated_dictionary.pkl', 'wb') as f:\n",
    "    pickle.dump(generated_dict, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "with open('random_dictionary.pkl', 'wb') as f:\n",
    "    pickle.dump(random_dict, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53b03d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69255b6b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
