{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook Purpose and Overview\n",
    "\n",
    "In this notebook, we show the pipeline to load 5 datasets (cora, pubmed, ogbn-arxiv, arxiv-2023 and ogbn-product) and make predicitons for node classificaiton tasks via openai API.\n",
    "\n",
    "credit: GPT4 helps code generation for this notebook as well as other utils functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import openai\n",
    "from utils.utils import process_and_compare_predictions, load_data, sample_test_nodes, map_arxiv_labels\n",
    "\n",
    "openai.api_key  = os.environ['OPENAI_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define dataset name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_name = \"arxiv_2023\"\n",
    "# dataset_name = \"pubmed\"\n",
    "dataset_name = \"cora\"\n",
    "# dataset_name = \"arxiv\"\n",
    "# dataset_name = \"product\"\n",
    "\n",
    "if dataset_name == \"arxiv\" or dataset_name == \"arxiv_2023\":\n",
    "    source = \"arxiv\"\n",
    "else:\n",
    "    source = dataset_name\n",
    "\n",
    "# use_ori_arxiv_label=False # only for using original Arxiv identifier in system prompting for ogbn-arxiv\n",
    "arxiv_style=\"subcategory\" # \"identifier\", \"natural language\"\n",
    "include_options = False # set to true to include options in the prompt for arxiv datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset and raw texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, text = load_data(dataset_name, use_text=True, seed=42)\n",
    "print(data)\n",
    "\n",
    "if source == \"arxiv\" and arxiv_style != \"subcategory\":\n",
    "    text = map_arxiv_labels(data, text, source, arxiv_style)\n",
    "\n",
    "options = set(text['label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample test indices. Default setting is full test set for cora and arxiv_2023. And 1000 for other datasets. For demonstration purpose, we set sample size to 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"arxiv_2023\" or dataset_name == \"cora\":\n",
    "    sample_size = len(data.test_id)\n",
    "else:\n",
    "    sample_size = 1000\n",
    "\n",
    "sample_size = 3\n",
    "\n",
    "node_indices = sample_test_nodes(data, text, sample_size, dataset_name)\n",
    "\n",
    "print(f\"{node_indices = }\")\n",
    "\n",
    "idx_list = list(range(sample_size))\n",
    "\n",
    "node_index_list = [node_indices[idx] for idx in idx_list]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check dataset splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.train_mask.sum(), data.val_mask.sum(), data.test_mask.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define max number for 1-hop and 2-hop neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"product\":\n",
    "    max_papers_1 = 40\n",
    "    max_papers_2 = 10\n",
    "else:\n",
    "    max_papers_1 = 20\n",
    "    max_papers_2 = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Below we test the loaded dataset on two context as outlined in the paper.\n",
    "- Rich textual context\n",
    "- Scarce textual context\n",
    "\n",
    "# Rich textual context: \n",
    "\n",
    "For the target node, title and abstract are given for cora, pubmed, ogbn-arxiv and arxiv-2023, product title and product content are given for ogbn-product."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"ego\"\n",
    "zero_shot_CoT=False\n",
    "hop=1\n",
    "few_shot=False\n",
    "include_abs=True\n",
    "include_label=False\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Few-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"ego\"\n",
    "zero_shot_CoT=False\n",
    "hop=1\n",
    "few_shot=False\n",
    "include_abs=True\n",
    "include_label=False\n",
    "few_shot=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero-shot CoT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"ego\"\n",
    "zero_shot_CoT=True\n",
    "hop=1\n",
    "few_shot=False\n",
    "include_abs=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1-hop title+label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop=1\n",
    "include_label=True\n",
    "include_abs=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1-hop title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop=1\n",
    "include_label=False\n",
    "include_abs=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-hop title+label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop=2\n",
    "include_label=True\n",
    "include_abs=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-hop title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop=2\n",
    "include_label=False\n",
    "include_abs=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1-hop attention:\n",
    "\n",
    "1-hop attention means attention extraction and prediction over 1-hop neighbors. The attentions for test nodes are given under `\\attention`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "zero_shot_CoT=False\n",
    "hop=1\n",
    "few_shot=False\n",
    "include_abs=True\n",
    "include_label=False\n",
    "use_attention=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, use_attention=use_attention, options=options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scarce textual context:\n",
    "\n",
    "only the title of each node is given\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero-shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"ego\"\n",
    "hop=1\n",
    "zero_shot_CoT=False\n",
    "few_shot=False\n",
    "include_abs=False\n",
    "include_label=False\n",
    "\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop, include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1-hop title+label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop = 1\n",
    "include_label = True\n",
    "\n",
    "\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1-hop title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop = 1\n",
    "include_label = False\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs,  hop=hop, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-hop title+label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop = 2\n",
    "include_label = True\n",
    "\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-hop title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop = 2\n",
    "include_label = False\n",
    "\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs,  hop=hop, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Returned accuracy:\", accuracy)\n",
    "print(\"Returned wrong indexes:\", wrong_indexes_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1-hop attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"neighbors\"\n",
    "hop=1\n",
    "include_abs=False\n",
    "include_label=False\n",
    "use_attention=True\n",
    "accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, use_attention=use_attention, options=options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm_nc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
