{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "741992ed-9816-4664-8342-e877246d8a89",
   "metadata": {},
   "source": [
    "### Python packages used in this code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5799dd7-2916-49d2-84eb-a5c717e1e1bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install datasets\n",
    "# !pip install sentencepiece\n",
    "# !pip install transformers\n",
    "# !pip install sentence-transformers\n",
    "# !pip install openai\n",
    "import datasets\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import re\n",
    "from IPython.display import clear_output\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from transformers import BertTokenizer, TFBertModel\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import openai\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fb96aa-e650-499a-ad91-6f13bc29f062",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = 'API-KEY'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "956982ef-0c77-4406-90ed-0ee61315da10",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77cb3825-154b-4209-998e-a40a28ff07b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = datasets.load_dataset('allenai/scirepeval', 'peer_review_score_hIndex')\n",
    "dataset_test = datasets.load_dataset('allenai/scirepeval_test', 'peer_review_score')\n",
    "(pd.DataFrame(dataset_test['test'])).to_csv('../10_Data/df_test.csv')\n",
    "(pd.DataFrame(dataset_test['train'])).to_csv('../10_Data/df_train.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8495c2b-8d89-47ed-9b48-de516619dbea",
   "metadata": {},
   "source": [
    "### Removing sentence that include URL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dfa4904-e8e9-45d4-9836-c6f3682cd2e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the regular expression pattern to match URLs\n",
    "url_pattern = re.compile(r'http\\S+')\n",
    "\n",
    "# Define a function to remove the sentence with URL from each element\n",
    "def remove_url_sentence(text):\n",
    "    # Split the text into sentences\n",
    "    sentences = text.split('. ')\n",
    "    \n",
    "    # Find the sentence that contains a URL\n",
    "    for i, sentence in enumerate(sentences):\n",
    "        if url_pattern.search(sentence):\n",
    "            # Remove the sentence with URL\n",
    "            sentences.pop(i)\n",
    "            break\n",
    "    \n",
    "    # Join the remaining sentences\n",
    "    return '. '.join(sentences)\n",
    "\n",
    "# Apply the function to each element in the Series\n",
    "df['abstract (wo URL)'] = df['abstract'].apply(remove_url_sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cbd6adc-408c-4cf7-a72e-de26da95a046",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('../10_Data/df_removeURL.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa2c4979-8f52-459a-9d14-59f0bc95adae",
   "metadata": {},
   "source": [
    "## Enbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c806633-b1dc-4e16-bce6-fbeb8177c050",
   "metadata": {},
   "source": [
    "### BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70553549-dd05-4ae8-befe-20f2ba3bafb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the BERT tokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "\n",
    "# Load the BERT model\n",
    "model = TFBertModel.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f927808f-3384-4c3f-930d-56411a3ffd34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# BERT\n",
    "docID_list = []\n",
    "Enc_list_CLS = []\n",
    "error_list = []\n",
    "for i in range(len(df['abstract (wo URL)'])):\n",
    "    # Encode the text\n",
    "    text = df['abstract (wo URL)'][i]\n",
    "    input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors='tf')\n",
    "\n",
    "    # Generate the encoding\n",
    "    try:\n",
    "        outputs = model(input_ids)\n",
    "        last_hidden_states = outputs.last_hidden_state\n",
    "\n",
    "        # Save\n",
    "        docID_list.append(df['doc_id'][i])\n",
    "        Enc_list_CLS.append(last_hidden_states[:,0,:].numpy().reshape(-1))\n",
    "    except:\n",
    "        error_list.append(i)\n",
    "\n",
    "df_enc = pd.DataFrame(Enc_list_CLS, index=docID_list)\n",
    "df_enc.to_csv('../10_Data/df_enc_BERT.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35cf695b-208d-4949-9cb1-d1d68fdc5d62",
   "metadata": {},
   "source": [
    "### SciBERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6da0287f-954e-4bc9-a67d-8e3517ef3d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load SciBERT tokenizer and model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\")\n",
    "model = AutoModel.from_pretrained(\"allenai/scibert_scivocab_uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e414e3cc-40ca-4add-b4af-11e8138a8bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "docID_list = []\n",
    "Enc_list_CLS = []\n",
    "error_list = []\n",
    "for i in range(len(df['abstract (wo URL)'])):\n",
    "    # Encode the text\n",
    "    text = df['abstract (wo URL)'][i]\n",
    "    encoded_input = tokenizer(text, return_tensors='pt')\n",
    "\n",
    "    # Generate the encoding\n",
    "    try:\n",
    "        with torch.no_grad():\n",
    "        embeddings = model(encoded_input['input_ids'], encoded_input['attention_mask'])\n",
    "        last_hidden_states = embeddings.last_hidden_state\n",
    "\n",
    "        # Save\n",
    "        docID_list.append(df['doc_id'][i])\n",
    "        Enc_list.append(last_hidden_states)\n",
    "        Enc_list_CLS.append(last_hidden_states[:,0,:].numpy().reshape(-1))\n",
    "    except:\n",
    "        error_list.append(i)\n",
    "\n",
    "df_enc = pd.DataFrame(Enc_list_CLS, index=docID_list)\n",
    "df_enc.to_csv('../10_Data/df_enc_SciBERT.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "865ec606-edf9-4156-9e79-ebbee60c235f",
   "metadata": {},
   "source": [
    "### GPT-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5979fc75-41c7-4e2f-b485-112eb1a86230",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embedding(text, engine=\"text-embedding-ada-002\"):\n",
    "    text = text.replace(\"\\n\", \" \")\n",
    "    return openai.Embedding.create(input = [text], model=engine)['data'][0]['embedding']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bef4e4d7-7423-4ebb-81a3-184155134d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "docID_list = []\n",
    "Enc_list_CLS = []\n",
    "error_list = []\n",
    "for i in range(len(df['abstract (wo URL)'])):\n",
    "    # Encode the text\n",
    "    text = df['abstract (wo URL)'][i]\n",
    "\n",
    "    # Generate the encoding\n",
    "    Enc_list.append(get_embedding(text))\n",
    "    docID_list.append(df['doc_id'][i])\n",
    "\n",
    "df_enc = pd.DataFrame(Enc_list, index=docID_list)\n",
    "df_enc.to_csv('../10_Data/df_enc_GPT-3.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "693f31d7-706b-42c7-b305-f686235c839b",
   "metadata": {},
   "source": [
    "## T5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afcdd253-b9f2-4fb1-88d7-0feea250b6e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('sentence-transformers/sentence-t5-base')\n",
    "embeddings = model.encode(df['abstract (wo URL)'].values.tolist())\n",
    "df_enc = pd.DataFrame(embeddings, index=df['doc_id'])\n",
    "df_enc.to_csv('../10_Data/df_enc_T5.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
