{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4914c6c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "import pickle\n",
    "from transformers import T5Tokenizer, T5EncoderModel\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "283ae549",
   "metadata": {},
   "outputs": [],
   "source": [
    "news_df=pd.read_csv(\"pens_news (2).csv\")\n",
    "news_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ff8b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load T5-base encoder and tokenizer\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n",
    "model = T5EncoderModel.from_pretrained(\"t5-base\")\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0549628",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embedding(text, tokenizer, model, device):\n",
    "    \"\"\"Encode a single text into T5 embedding using mean pooling\"\"\"\n",
    "    inputs = tokenizer(\n",
    "        text,\n",
    "        return_tensors=\"pt\",\n",
    "        truncation=True,\n",
    "        padding=\"max_length\",\n",
    "        max_length=512  # limit for T5-base\n",
    "    ).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "        token_embeddings = outputs.last_hidden_state  # [1, seq_len, hidden_dim]\n",
    "        sentence_embedding = token_embeddings.mean(dim=1).squeeze().cpu().numpy()\n",
    "    return sentence_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b07e69a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Store embeddings\n",
    "headline_embeddings = {}\n",
    "newsbody_embeddings = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e91d0bfe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating headline embeddings...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  3.23it/s]\n"
     ]
    }
   ],
   "source": [
    "print(\"Generating headline embeddings...\")\n",
    "for i, row in tqdm(news_df[:5].iterrows(), total=len(news_df[:5])):\n",
    "    headline_embeddings[row[\"NewsID\"]] = get_embedding(\n",
    "        str(row[\"Headline\"]), tokenizer, model, device\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4ab5cb63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating news body embeddings...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  3.22it/s]\n"
     ]
    }
   ],
   "source": [
    "print(\"Generating news body embeddings...\")\n",
    "for i, row in tqdm(news_df[:5].iterrows(), total=len(news_df[:5])):\n",
    "    newsbody_embeddings[row[\"NewsID\"]] = get_embedding(\n",
    "        str(row[\"NewsBody\"]), tokenizer, model, device\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dd324913",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Embeddings saved: headline_T5.pkl and newsbody_T5.pkl\n"
     ]
    }
   ],
   "source": [
    "# Save as pickle\n",
    "with open(\"headline_T5.pkl\", \"wb\") as f:\n",
    "    pickle.dump(headline_embeddings, f)\n",
    "\n",
    "with open(\"newsbody_T5.pkl\", \"wb\") as f:\n",
    "    pickle.dump(newsbody_embeddings, f)\n",
    "\n",
    "print(\"✅ Embeddings saved: headline_T5.pkl and newsbody_T5.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "24af8d02",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"headline_T5.pkl\", \"rb\") as f:\n",
    "    headline_embeddings = pickle.load(f)\n",
    "\n",
    "# Load newsbody embeddings\n",
    "with open(\"newsbody_T5.pkl\", \"rb\") as f:\n",
    "    newsbody_embeddings = pickle.load(f)\n",
    "\n",
    "# Convert to DataFrames\n",
    "headline_df = pd.DataFrame(list(headline_embeddings.items()), columns=[\"NewsID\", \"Headline_Embedding\"])\n",
    "newsbody_df = pd.DataFrame(list(newsbody_embeddings.items()), columns=[\"NewsID\", \"NewsBody_Embedding\"])\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "07749777",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NewsID</th>\n",
       "      <th>NewsBody_Embedding</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>N10000</td>\n",
       "      <td>[-0.08127328, 0.019974899, -0.07350032, 0.0337...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>N10001</td>\n",
       "      <td>[-0.16276242, -0.0029857457, 0.010285702, 0.01...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>N10002</td>\n",
       "      <td>[-0.055686325, 0.061170243, -0.31549126, 0.350...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>N10003</td>\n",
       "      <td>[-0.3972332, 0.17620325, -0.16000667, 0.093594...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>N10004</td>\n",
       "      <td>[-0.22680107, -0.033510517, -0.09148933, 0.046...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   NewsID                                 NewsBody_Embedding\n",
       "0  N10000  [-0.08127328, 0.019974899, -0.07350032, 0.0337...\n",
       "1  N10001  [-0.16276242, -0.0029857457, 0.010285702, 0.01...\n",
       "2  N10002  [-0.055686325, 0.061170243, -0.31549126, 0.350...\n",
       "3  N10003  [-0.3972332, 0.17620325, -0.16000667, 0.093594...\n",
       "4  N10004  [-0.22680107, -0.033510517, -0.09148933, 0.046..."
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "newsbody_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c0051b3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type: <class 'numpy.ndarray'>\n",
      "Shape: (768,)\n",
      "Type: <class 'numpy.ndarray'>\n",
      "Shape: (768,)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Check one headline embedding\n",
    "sample_headline = headline_df[\"Headline_Embedding\"].iloc[0]\n",
    "print(\"Type:\", type(sample_headline))\n",
    "print(\"Shape:\", np.array(sample_headline).shape)\n",
    "\n",
    "# Check one newsbody embedding\n",
    "sample_newsbody = newsbody_df[\"NewsBody_Embedding\"].iloc[0]\n",
    "print(\"Type:\", type(sample_newsbody))\n",
    "print(\"Shape:\", np.array(sample_newsbody).shape)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
