{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
    "import torch\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UserID</th>\n",
       "      <th>DocID</th>\n",
       "      <th>News_Body</th>\n",
       "      <th>Top_Keyphrases</th>\n",
       "      <th>model headline</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N79048</td>\n",
       "      <td>Curious just how far your dollar goes in Bellt...</td>\n",
       "      <td>[Fourth Ave, Ave, Elliott Ave, rental, Apartme...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N94939</td>\n",
       "      <td>Italian authorities are searching for Swiss fo...</td>\n",
       "      <td>[Swiss, Italy, Boys, Ismaili, Swiss club, Swis...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N101378</td>\n",
       "      <td>Etiquette standards have changed throughout th...</td>\n",
       "      <td>[kids, longer, manners, anymore, children, tau...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N75328</td>\n",
       "      <td>DALLAS (CBSDFW.COM)   A plane made an emergenc...</td>\n",
       "      <td>[sickness, Texas, Health, Health Presbyterian,...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N51667</td>\n",
       "      <td>All the pitching stats that's fit to print Yes...</td>\n",
       "      <td>[pitching, Astros Batting, Astros, Batting, bu...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15543</th>\n",
       "      <td>NT37</td>\n",
       "      <td>N67982</td>\n",
       "      <td>A 54-year-old man was arrested and charged aft...</td>\n",
       "      <td>[officer, police officer, impersonating office...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15544</th>\n",
       "      <td>NT37</td>\n",
       "      <td>N40427</td>\n",
       "      <td>Facebook announced plans for its own cryptocur...</td>\n",
       "      <td>[Business, Debbie, logo, currency, kind, desig...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15545</th>\n",
       "      <td>NT37</td>\n",
       "      <td>N39260</td>\n",
       "      <td>Medical device company Soliton Inc. (NASDAQ: S...</td>\n",
       "      <td>[device, tattoo, RAP, company RAP, removal, RA...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15546</th>\n",
       "      <td>NT37</td>\n",
       "      <td>N63210</td>\n",
       "      <td>RENTON, Wash. (AP) Shaquem Griffin is thankful...</td>\n",
       "      <td>[Seahawks, Griffin, season, Seattle Seahawks, ...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15547</th>\n",
       "      <td>NT37</td>\n",
       "      <td>N21484</td>\n",
       "      <td>Ronald Vermeulen is seen in a June 13, 2019, b...</td>\n",
       "      <td>[Los Angeles, Riverside Police, Street, Los, R...</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15548 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      UserID    DocID                                          News_Body  \\\n",
       "0      NT100   N79048  Curious just how far your dollar goes in Bellt...   \n",
       "1      NT100   N94939  Italian authorities are searching for Swiss fo...   \n",
       "2      NT100  N101378  Etiquette standards have changed throughout th...   \n",
       "3      NT100   N75328  DALLAS (CBSDFW.COM)   A plane made an emergenc...   \n",
       "4      NT100   N51667  All the pitching stats that's fit to print Yes...   \n",
       "...      ...      ...                                                ...   \n",
       "15543   NT37   N67982  A 54-year-old man was arrested and charged aft...   \n",
       "15544   NT37   N40427  Facebook announced plans for its own cryptocur...   \n",
       "15545   NT37   N39260  Medical device company Soliton Inc. (NASDAQ: S...   \n",
       "15546   NT37   N63210  RENTON, Wash. (AP) Shaquem Griffin is thankful...   \n",
       "15547   NT37   N21484  Ronald Vermeulen is seen in a June 13, 2019, b...   \n",
       "\n",
       "                                          Top_Keyphrases model headline  \n",
       "0      [Fourth Ave, Ave, Elliott Ave, rental, Apartme...                 \n",
       "1      [Swiss, Italy, Boys, Ismaili, Swiss club, Swis...                 \n",
       "2      [kids, longer, manners, anymore, children, tau...                 \n",
       "3      [sickness, Texas, Health, Health Presbyterian,...                 \n",
       "4      [pitching, Astros Batting, Astros, Batting, bu...                 \n",
       "...                                                  ...            ...  \n",
       "15543  [officer, police officer, impersonating office...                 \n",
       "15544  [Business, Debbie, logo, currency, kind, desig...                 \n",
       "15545  [device, tattoo, RAP, company RAP, removal, RA...                 \n",
       "15546  [Seahawks, Griffin, season, Seattle Seahawks, ...                 \n",
       "15547  [Los Angeles, Riverside Police, Street, Los, R...                 \n",
       "\n",
       "[15548 rows x 5 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file_path='BTier_MEGA_V1_Updated_user_document_keyphrases_10.pkl'\n",
    "with open(file_path, 'rb') as file:\n",
    "    data = pickle.load(file)\n",
    "\n",
    "df=pd.DataFrame(data)\n",
    "\n",
    "df[\"model headline\"] = \"\"\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Error during conversion: ChunkedEncodingError(ProtocolError('Response ended prematurely'))\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/bigbird-pegasus-large-bigpatent\")\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(\"google/bigbird-pegasus-large-bigpatent\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_sentences_with_phrases(text,phrases):\n",
    "    sentences = re.split(r'(?<=[.!?])\\s+',text)\n",
    "    matches=[]\n",
    "    for sent in sentences:\n",
    "        for phrase in phrases:\n",
    "            if phrase.lower() in sent.lower():\n",
    "                matches.append(sent.strip())\n",
    "                break\n",
    "\n",
    "    return matches\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "results=[]\n",
    "\n",
    "def generate_headlines(df):\n",
    "    for _,row in tqdm(df.iterrows(),total=len(df),desc=\"Generating headlines\"):\n",
    "        userid=row['UserID']\n",
    "        docid=row['DocID'] \n",
    "        text=row['News_Body']\n",
    "        keyphrases=row['Top_Keyphrases']\n",
    "\n",
    "        cues=extract_sentences_with_phrases(text,keyphrases)\n",
    "\n",
    "        if not cues:\n",
    "            headline = None\n",
    "        else:\n",
    "            cue_text=\" \".join(cues)\n",
    "            input_text = f\"{cue_text}\\n\\n{text}\"\n",
    "\n",
    "            inputs = tokenizer(input_text,return_tensors=\"pt\",truncation=True,max_length=4096)\n",
    "            inputs = {k:v.to(device) for k,v in inputs.items()}\n",
    "\n",
    "            summary_ids = model.generate(\n",
    "                inputs['input_ids'],\n",
    "                max_length=50,\n",
    "                num_beams=10,\n",
    "                early_stopping=True\n",
    "            )\n",
    "\n",
    "            headline = tokenizer.decode(summary_ids[0], skip_special_tokens=True)\n",
    "\n",
    "        results.append({\n",
    "            'UserID':userid,\n",
    "            'DocID':docid,\n",
    "            'Headlines': headline\n",
    "        })\n",
    "\n",
    "    return pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating headlines:   0%|          | 0/1000 [00:00<?, ?it/s]Input ids are automatically padded from 925 to 960 to be a multiple of `config.block_size`: 64\n",
      "Generating headlines:   0%|          | 1/1000 [00:03<1:00:10,  3.61s/it]Attention type 'block_sparse' is not possible if sequence_length: 408 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...\n",
      "Generating headlines: 100%|██████████| 1000/1000 [24:16<00:00,  1.46s/it]\n"
     ]
    }
   ],
   "source": [
    "headlines_df_0to1000 = generate_headlines(df.head(1000))\n",
    "headlines_df_0to1000.to_csv(\"headlines_df_0to1000.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>UserID</th>\n",
       "      <th>DocID</th>\n",
       "      <th>Headlines</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N79048</td>\n",
       "      <td>To help you find the shortest path to work, i....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N94939</td>\n",
       "      <td>A player, who plays for a football club, has g...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N101378</td>\n",
       "      <td>A method of teaching manners to a plurality of...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N75328</td>\n",
       "      <td>Decompression sickness results from the reduct...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NT100</td>\n",
       "      <td>N51667</td>\n",
       "      <td>A method of analyzing the pitching staff of a ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>NT12</td>\n",
       "      <td>N25977</td>\n",
       "      <td>In accordance with one or more embodiments of ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>996</th>\n",
       "      <td>NT12</td>\n",
       "      <td>N68272</td>\n",
       "      <td>A method of playing golf by hitting the ball a...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>997</th>\n",
       "      <td>NT12</td>\n",
       "      <td>N101591</td>\n",
       "      <td>A method of constructing a basketball team com...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>998</th>\n",
       "      <td>NT12</td>\n",
       "      <td>N51005</td>\n",
       "      <td>A method to quickly turn loyal supporters into...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>NT37</td>\n",
       "      <td>N25899</td>\n",
       "      <td>In an era where tariffs are all that matters, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1000 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    UserID    DocID                                          Headlines\n",
       "0    NT100   N79048  To help you find the shortest path to work, i....\n",
       "1    NT100   N94939  A player, who plays for a football club, has g...\n",
       "2    NT100  N101378  A method of teaching manners to a plurality of...\n",
       "3    NT100   N75328  Decompression sickness results from the reduct...\n",
       "4    NT100   N51667  A method of analyzing the pitching staff of a ...\n",
       "..     ...      ...                                                ...\n",
       "995   NT12   N25977  In accordance with one or more embodiments of ...\n",
       "996   NT12   N68272  A method of playing golf by hitting the ball a...\n",
       "997   NT12  N101591  A method of constructing a basketball team com...\n",
       "998   NT12   N51005  A method to quickly turn loyal supporters into...\n",
       "999   NT37   N25899  In an era where tariffs are all that matters, ...\n",
       "\n",
       "[1000 rows x 3 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "headlines_df_0to1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating headlines:   5%|▍         | 46/1000 [01:06<29:48,  1.87s/it]"
     ]
    }
   ],
   "source": [
    "headlines_df_1000to2000 = generate_headlines(df[1000:2000])\n",
    "headlines_df_1000to2000.to_csv(\"headlines_df_1000to2000.csv\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
