{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from datasets import DatasetDict, Dataset\n",
    "\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"/mnt/sdb1/hm-ind-agg/nyc-data/data/mind-org\" # TODO: Fill the path to the mind dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "col_news = ['NewsId', 'Category', 'SubCat', 'Title', 'Abstract', 'url', 'TitleEnt', 'AbstractEnt']\n",
    "train_news_df = pd.read_csv(f\"{data_path}/train/news.tsv\", sep='\\t', header=None, names=col_news)\n",
    "train_news_df = train_news_df.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_news_df = pd.read_csv(f\"{data_path}/test/news.tsv\", sep='\\t', header=None, names=col_news)\n",
    "test_news_df = test_news_df.dropna()\n",
    "valid_news_df = pd.read_csv(f\"{data_path}/valid/news.tsv\", sep='\\t', header=None, names=col_news)\n",
    "valid_news_df = valid_news_df.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NewsId</th>\n",
       "      <th>Category</th>\n",
       "      <th>SubCat</th>\n",
       "      <th>Title</th>\n",
       "      <th>Abstract</th>\n",
       "      <th>url</th>\n",
       "      <th>TitleEnt</th>\n",
       "      <th>AbstractEnt</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>N88753</td>\n",
       "      <td>lifestyle</td>\n",
       "      <td>lifestyleroyals</td>\n",
       "      <td>The Brands Queen Elizabeth, Prince Charles, an...</td>\n",
       "      <td>Shop the notebooks, jackets, and more that the...</td>\n",
       "      <td>https://assets.msn.com/labs/mind/AAGH0ET.html</td>\n",
       "      <td>[{\"Label\": \"Prince Philip, Duke of Edinburgh\",...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   NewsId   Category           SubCat  \\\n",
       "0  N88753  lifestyle  lifestyleroyals   \n",
       "\n",
       "                                               Title  \\\n",
       "0  The Brands Queen Elizabeth, Prince Charles, an...   \n",
       "\n",
       "                                            Abstract  \\\n",
       "0  Shop the notebooks, jackets, and more that the...   \n",
       "\n",
       "                                             url  \\\n",
       "0  https://assets.msn.com/labs/mind/AAGH0ET.html   \n",
       "\n",
       "                                            TitleEnt AbstractEnt  \n",
       "0  [{\"Label\": \"Prince Philip, Duke of Edinburgh\",...          []  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NewsId</th>\n",
       "      <th>SubCat</th>\n",
       "      <th>Title</th>\n",
       "      <th>Abstract</th>\n",
       "      <th>url</th>\n",
       "      <th>TitleEnt</th>\n",
       "      <th>AbstractEnt</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Category</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>sports</th>\n",
       "      <td>29625</td>\n",
       "      <td>29625</td>\n",
       "      <td>29625</td>\n",
       "      <td>29625</td>\n",
       "      <td>29625</td>\n",
       "      <td>29625</td>\n",
       "      <td>29625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>news</th>\n",
       "      <td>29363</td>\n",
       "      <td>29363</td>\n",
       "      <td>29363</td>\n",
       "      <td>29363</td>\n",
       "      <td>29363</td>\n",
       "      <td>29363</td>\n",
       "      <td>29363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>finance</th>\n",
       "      <td>5777</td>\n",
       "      <td>5777</td>\n",
       "      <td>5777</td>\n",
       "      <td>5777</td>\n",
       "      <td>5777</td>\n",
       "      <td>5777</td>\n",
       "      <td>5777</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>travel</th>\n",
       "      <td>4605</td>\n",
       "      <td>4605</td>\n",
       "      <td>4605</td>\n",
       "      <td>4605</td>\n",
       "      <td>4605</td>\n",
       "      <td>4605</td>\n",
       "      <td>4605</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>video</th>\n",
       "      <td>4562</td>\n",
       "      <td>4562</td>\n",
       "      <td>4562</td>\n",
       "      <td>4562</td>\n",
       "      <td>4562</td>\n",
       "      <td>4562</td>\n",
       "      <td>4562</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>foodanddrink</th>\n",
       "      <td>4319</td>\n",
       "      <td>4319</td>\n",
       "      <td>4319</td>\n",
       "      <td>4319</td>\n",
       "      <td>4319</td>\n",
       "      <td>4319</td>\n",
       "      <td>4319</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lifestyle</th>\n",
       "      <td>4255</td>\n",
       "      <td>4255</td>\n",
       "      <td>4255</td>\n",
       "      <td>4255</td>\n",
       "      <td>4255</td>\n",
       "      <td>4255</td>\n",
       "      <td>4255</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>weather</th>\n",
       "      <td>3820</td>\n",
       "      <td>3820</td>\n",
       "      <td>3820</td>\n",
       "      <td>3820</td>\n",
       "      <td>3820</td>\n",
       "      <td>3820</td>\n",
       "      <td>3820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>health</th>\n",
       "      <td>2815</td>\n",
       "      <td>2815</td>\n",
       "      <td>2815</td>\n",
       "      <td>2815</td>\n",
       "      <td>2815</td>\n",
       "      <td>2815</td>\n",
       "      <td>2815</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>autos</th>\n",
       "      <td>2756</td>\n",
       "      <td>2756</td>\n",
       "      <td>2756</td>\n",
       "      <td>2756</td>\n",
       "      <td>2756</td>\n",
       "      <td>2756</td>\n",
       "      <td>2756</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              NewsId  SubCat  Title  Abstract    url  TitleEnt  AbstractEnt\n",
       "Category                                                                   \n",
       "sports         29625   29625  29625     29625  29625     29625        29625\n",
       "news           29363   29363  29363     29363  29363     29363        29363\n",
       "finance         5777    5777   5777      5777   5777      5777         5777\n",
       "travel          4605    4605   4605      4605   4605      4605         4605\n",
       "video           4562    4562   4562      4562   4562      4562         4562\n",
       "foodanddrink    4319    4319   4319      4319   4319      4319         4319\n",
       "lifestyle       4255    4255   4255      4255   4255      4255         4255\n",
       "weather         3820    3820   3820      3820   3820      3820         3820\n",
       "health          2815    2815   2815      2815   2815      2815         2815\n",
       "autos           2756    2756   2756      2756   2756      2756         2756"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_df.groupby('Category').count().sort_values('NewsId', ascending=False).head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NewsId</th>\n",
       "      <th>Category</th>\n",
       "      <th>SubCat</th>\n",
       "      <th>Title</th>\n",
       "      <th>Abstract</th>\n",
       "      <th>url</th>\n",
       "      <th>TitleEnt</th>\n",
       "      <th>AbstractEnt</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>N88753</td>\n",
       "      <td>lifestyle</td>\n",
       "      <td>lifestyleroyals</td>\n",
       "      <td>The Brands Queen Elizabeth, Prince Charles, an...</td>\n",
       "      <td>Shop the notebooks, jackets, and more that the...</td>\n",
       "      <td>https://assets.msn.com/labs/mind/AAGH0ET.html</td>\n",
       "      <td>[{\"Label\": \"Prince Philip, Duke of Edinburgh\",...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   NewsId   Category           SubCat  \\\n",
       "0  N88753  lifestyle  lifestyleroyals   \n",
       "\n",
       "                                               Title  \\\n",
       "0  The Brands Queen Elizabeth, Prince Charles, an...   \n",
       "\n",
       "                                            Abstract  \\\n",
       "0  Shop the notebooks, jackets, and more that the...   \n",
       "\n",
       "                                             url  \\\n",
       "0  https://assets.msn.com/labs/mind/AAGH0ET.html   \n",
       "\n",
       "                                            TitleEnt AbstractEnt  \n",
       "0  [{\"Label\": \"Prince Philip, Duke of Edinburgh\",...          []  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Device set to use cuda:0\n",
      "/home/erfanl/anaconda3/envs/llm-up/lib/python3.10/site-packages/transformers/pipelines/text_classification.py:106: UserWarning: `return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# sentiment_classifier = pipeline(\n",
    "#     model=\"lxyuan/distilbert-base-multilingual-cased-sentiments-student\",\n",
    "#     return_all_scores=True,\n",
    "#     truncation=True,\n",
    "#     device=0,\n",
    "# )\n",
    "\n",
    "\n",
    "sentiment_classifier = pipeline(\n",
    "    model=\"siebert/sentiment-roberta-large-english\",\n",
    "    return_all_scores=True,\n",
    "    truncation=True,\n",
    "    device=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'label': 'NEGATIVE', 'score': 0.0010697843972593546},\n",
       "  {'label': 'POSITIVE', 'score': 0.9989301562309265}]]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentiment_classifier(\"I love this new movie! It's fantastic and the acting is superb. I would recommend it to everyone.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sentiment(text):\n",
    "    text = text if text is not None else \"\"\n",
    "    scores = sentiment_classifier([{\"text\": text}])\n",
    "    return scores[0][1]['score']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'label': 'NEGATIVE', 'score': 0.9974976181983948},\n",
       "  {'label': 'POSITIVE', 'score': 0.002502303570508957}]]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentiment_classifier([{\"text\": \"Noo!\"}])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_news_df = train_news_df[train_news_df['Category'] == 'sports']\n",
    "test_news_df = test_news_df[test_news_df['Category'] == 'sports']\n",
    "valid_news_df = valid_news_df[valid_news_df['Category'] == 'sports']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_news_df['Abstract_len'] = train_news_df['Abstract'].apply(lambda x: len(str(x).split()))\n",
    "test_news_df['Abstract_len'] = test_news_df['Abstract'].apply(lambda x: len(str(x).split()))\n",
    "valid_news_df['Abstract_len'] = valid_news_df['Abstract'].apply(lambda x: len(str(x).split()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_news_df = train_news_df[(train_news_df['Abstract_len'] >= 20) & (train_news_df['Abstract_len'] <= 120)]\n",
    "test_news_df = test_news_df[(test_news_df['Abstract_len'] >= 20) & (test_news_df['Abstract_len'] <= 120)]\n",
    "valid_news_df = valid_news_df[(valid_news_df['Abstract_len'] >= 20) & (valid_news_df['Abstract_len'] <= 120)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sentiments_abstract_batch(examples):\n",
    "    inputs = [{\"text\": abstract} for abstract in examples['Abstract']]\n",
    "    scores = sentiment_classifier(inputs)\n",
    "    return {'sentiment_abstract': [score[1]['score'] for score in scores]}\n",
    "    \n",
    "def get_sentiments_title_batch(examples):\n",
    "    inputs = [{\"text\": title} for title in examples['Title']]\n",
    "    scores = sentiment_classifier(inputs)\n",
    "    return {'sentiment_title': [score[1]['score'] for score in scores]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_news_ds = Dataset.from_pandas(train_news_df)\n",
    "test_news_ds = Dataset.from_pandas(test_news_df)\n",
    "valid_news_ds = Dataset.from_pandas(valid_news_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d61709d8451f4a938e5ee056f0a960fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/19105 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ce4787eaeb14f97acded870eed3a1a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/19105 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2be5eab20b6c4b678f86106743d919ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/22701 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "82c9a6f140a84e8c8f84584d891609e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/22701 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bacaae19cf074e09b4f09680ee4bb284",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/12744 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3220760783a845898ffe7c913adb95f3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/12744 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_news_ds = train_news_ds.map(\n",
    "    get_sentiments_abstract_batch,\n",
    "    batched=True,\n",
    "    batch_size=20,\n",
    ")\n",
    "\n",
    "train_news_ds = train_news_ds.map(\n",
    "    get_sentiments_title_batch,\n",
    "    batched=True,\n",
    "    batch_size=20,\n",
    ")\n",
    "\n",
    "test_news_ds = test_news_ds.map(\n",
    "    get_sentiments_abstract_batch,\n",
    "    batched=True,\n",
    "    batch_size=20,\n",
    ")\n",
    "\n",
    "test_news_ds = test_news_ds.map(\n",
    "    get_sentiments_title_batch,\n",
    "    batched=True,\n",
    "    batch_size=20,\n",
    ")\n",
    "\n",
    "valid_news_ds = valid_news_ds.map(\n",
    "    get_sentiments_abstract_batch,\n",
    "    batched=True,\n",
    "    batch_size=20,\n",
    ")\n",
    "\n",
    "valid_news_ds = valid_news_ds.map(\n",
    "    get_sentiments_title_batch,\n",
    "    batched=True,\n",
    "    batch_size=20,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_news_df = train_news_ds.to_pandas()\n",
    "test_news_df = test_news_ds.to_pandas()\n",
    "valid_news_df = valid_news_ds.to_pandas()\n",
    "\n",
    "train_news_df['abs_diff'] = train_news_df['sentiment_abstract'] - train_news_df['sentiment_title']\n",
    "train_news_df['abs_diff'] = train_news_df['abs_diff'].abs()\n",
    "\n",
    "test_news_df['abs_diff'] = test_news_df['sentiment_abstract'] - test_news_df['sentiment_title']\n",
    "test_news_df['abs_diff'] = test_news_df['abs_diff'].abs()\n",
    "\n",
    "valid_news_df['abs_diff'] = valid_news_df['sentiment_abstract'] - valid_news_df['sentiment_title']\n",
    "valid_news_df['abs_diff'] = valid_news_df['abs_diff'].abs()\n",
    "\n",
    "train_news_df = train_news_df[train_news_df['abs_diff'] < 0.95]\n",
    "test_news_df = test_news_df[test_news_df['abs_diff'] < 0.95]\n",
    "valid_news_df = valid_news_df[valid_news_df['abs_diff'] < 0.95]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "42969"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_news_df.shape[0] + test_news_df.shape[0] + train_news_df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "54550"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_ds.num_rows + test_news_ds.num_rows + valid_news_ds.num_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.787699358386801"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(valid_news_df.shape[0] + test_news_df.shape[0] + train_news_df.shape[0]) / (train_news_ds.num_rows + test_news_ds.num_rows + valid_news_ds.num_rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(15071, 17903, 9995)"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_df.shape[0], test_news_df.shape[0], valid_news_df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_count = train_news_df.shape[0] + test_news_df.shape[0] + valid_news_df.shape[0]\n",
    "sft_count = 10000\n",
    "reward_count = 10000\n",
    "reward_valid_count = 3000\n",
    "ppo_count = 10000\n",
    "test_count = total_count - sft_count - reward_count - reward_valid_count - ppo_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split validation to 1000 samples and the rest\n",
    "valid_news_df_3000 = valid_news_df.sample(n=3000, random_state=42)\n",
    "valid_news_df = valid_news_df.drop(valid_news_df_3000.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3000, 6995)"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_news_df_3000.shape[0], valid_news_df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "17903"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_news_df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9969"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split test to 3000 samples and the rest\n",
    "test_news_df_test = test_news_df.sample(n=test_count, random_state=42)\n",
    "test_news_df = test_news_df.drop(test_news_df_test.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(9969, 7934)"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_news_df_test.shape[0], test_news_df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the rest of valid and test to train\n",
    "train_news_df = pd.concat([train_news_df, valid_news_df, test_news_df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentiment_abstract</th>\n",
       "      <th>sentiment_title</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>sentiment_abstract</th>\n",
       "      <td>1.00000</td>\n",
       "      <td>0.95839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sentiment_title</th>\n",
       "      <td>0.95839</td>\n",
       "      <td>1.00000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    sentiment_abstract  sentiment_title\n",
       "sentiment_abstract             1.00000          0.95839\n",
       "sentiment_title                0.95839          1.00000"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_df[['sentiment_abstract', 'sentiment_title']].corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "# find the top 10 most frequent words in the titles\n",
    "# train_news_df['Title_lower'] = train_news_df['Title'].str.lower()\n",
    "# print(train_news_df['Title_lower'].str.split().explode().value_counts().head(200).index.tolist())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_news_df = train_news_df.drop('__index_level_0__', axis=1)\n",
    "test_news_df_test = test_news_df_test.drop('__index_level_0__', axis=1)\n",
    "valid_news_df_3000 = valid_news_df_3000.drop('__index_level_0__', axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30000"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_news_df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "sft_sample = train_news_df.sample(n=sft_count, random_state=42)\n",
    "train_news_df = train_news_df.drop(sft_sample.index)\n",
    "reward_sample = train_news_df.sample(n=reward_count, random_state=42)\n",
    "train_news_df = train_news_df.drop(reward_sample.index)\n",
    "ppo_sample = train_news_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NewsId</th>\n",
       "      <th>Category</th>\n",
       "      <th>SubCat</th>\n",
       "      <th>Title</th>\n",
       "      <th>Abstract</th>\n",
       "      <th>url</th>\n",
       "      <th>TitleEnt</th>\n",
       "      <th>AbstractEnt</th>\n",
       "      <th>Abstract_len</th>\n",
       "      <th>sentiment_abstract</th>\n",
       "      <th>sentiment_title</th>\n",
       "      <th>abs_diff</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>N75778</td>\n",
       "      <td>sports</td>\n",
       "      <td>football_nfl</td>\n",
       "      <td>John Dorsey admits talks with Washington, but it \"takes two to tango\"</td>\n",
       "      <td>Team officials in Washington \"emphatically\" denied a rumor of a Trent Williams trade to Cleveland, according to a report Tuesday. A day later, Browns General Manager John Dorsey admitted publicly he has talked to Washington president Bruce Allen. \"We've had a few conversations,\" Dorsey said, via Mary Kay Cabot of the Cleveland Plain Dealer. \"It [more]</td>\n",
       "      <td>https://assets.msn.com/labs/mind/AAISxPW.html</td>\n",
       "      <td>[{\"Label\": \"John Dorsey (American football)\", \"Type\": \"P\", \"WikidataId\": \"Q14950911\", \"Confidence\": 0.995, \"OccurrenceOffsets\": [0], \"SurfaceForms\": [\"John Dorsey\"]}]</td>\n",
       "      <td>[{\"Label\": \"John Dorsey (American football)\", \"Type\": \"P\", \"WikidataId\": \"Q14950911\", \"Confidence\": 0.995, \"OccurrenceOffsets\": [166, 280], \"SurfaceForms\": [\"John Dorsey\", \"Dorsey\"]}, {\"Label\": \"Cleveland Browns\", \"Type\": \"O\", \"WikidataId\": \"Q223527\", \"Confidence\": 1.0, \"OccurrenceOffsets\": [143], \"SurfaceForms\": [\"Browns\"]}, {\"Label\": \"Cleveland\", \"Type\": \"G\", \"WikidataId\": \"Q37320\", \"Confidence\": 0.967, \"OccurrenceOffsets\": [88], \"SurfaceForms\": [\"Cleveland\"]}, {\"Label\": \"The Plain Dealer\", \"Type\": \"M\", \"WikidataId\": \"Q286036\", \"Confidence\": 1.0, \"OccurrenceOffsets\": [319], \"SurfaceForms\": [\"Cleveland Plain Dealer\"]}]</td>\n",
       "      <td>56</td>\n",
       "      <td>0.003814</td>\n",
       "      <td>0.021764</td>\n",
       "      <td>0.01795</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   NewsId Category        SubCat  \\\n",
       "1  N75778   sports  football_nfl   \n",
       "\n",
       "                                                                   Title  \\\n",
       "1  John Dorsey admits talks with Washington, but it \"takes two to tango\"   \n",
       "\n",
       "                                                                                                                                                                                                                                                                                                                                                            Abstract  \\\n",
       "1  Team officials in Washington \"emphatically\" denied a rumor of a Trent Williams trade to Cleveland, according to a report Tuesday. A day later, Browns General Manager John Dorsey admitted publicly he has talked to Washington president Bruce Allen. \"We've had a few conversations,\" Dorsey said, via Mary Kay Cabot of the Cleveland Plain Dealer. \"It [more]   \n",
       "\n",
       "                                             url  \\\n",
       "1  https://assets.msn.com/labs/mind/AAISxPW.html   \n",
       "\n",
       "                                                                                                                                                                 TitleEnt  \\\n",
       "1  [{\"Label\": \"John Dorsey (American football)\", \"Type\": \"P\", \"WikidataId\": \"Q14950911\", \"Confidence\": 0.995, \"OccurrenceOffsets\": [0], \"SurfaceForms\": [\"John Dorsey\"]}]   \n",
       "\n",
       "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           AbstractEnt  \\\n",
       "1  [{\"Label\": \"John Dorsey (American football)\", \"Type\": \"P\", \"WikidataId\": \"Q14950911\", \"Confidence\": 0.995, \"OccurrenceOffsets\": [166, 280], \"SurfaceForms\": [\"John Dorsey\", \"Dorsey\"]}, {\"Label\": \"Cleveland Browns\", \"Type\": \"O\", \"WikidataId\": \"Q223527\", \"Confidence\": 1.0, \"OccurrenceOffsets\": [143], \"SurfaceForms\": [\"Browns\"]}, {\"Label\": \"Cleveland\", \"Type\": \"G\", \"WikidataId\": \"Q37320\", \"Confidence\": 0.967, \"OccurrenceOffsets\": [88], \"SurfaceForms\": [\"Cleveland\"]}, {\"Label\": \"The Plain Dealer\", \"Type\": \"M\", \"WikidataId\": \"Q286036\", \"Confidence\": 1.0, \"OccurrenceOffsets\": [319], \"SurfaceForms\": [\"Cleveland Plain Dealer\"]}]   \n",
       "\n",
       "   Abstract_len  sentiment_abstract  sentiment_title  abs_diff  \n",
       "1            56            0.003814         0.021764   0.01795  "
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.set_option('display.max_colwidth', None)\n",
    "\n",
    "train_news_df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "# restructure DatasetDict to train valid test\n",
    "data = DatasetDict({\n",
    "    \"sft\": Dataset.from_pandas(sft_sample),\n",
    "    \"reward\": Dataset.from_pandas(reward_sample),\n",
    "    \"ppo\": Dataset.from_pandas(ppo_sample),\n",
    "    \"test\": Dataset.from_pandas(test_news_df_test),\n",
    "    \"reward_valid\": Dataset.from_pandas(valid_news_df_3000)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/mnt/sdb1/finetuning/data/mind-sport-abs'"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = json.load(open(\"../configs/mind-sport-abs-entangled/data.json\"))\n",
    "output_path = f\"{config['base_dir']}/data/{config['data']}\"\n",
    "output_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc6c09c6f259459caea8b43750a44d84",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6068a95444a4c4ba253a33c2c1f38db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60cd096ed71d4084b0ab2f09c9ef5eb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9c1876f6af2491c8d1c478e463a40b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/9969 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1cd3df37a5a541eaa10bd7011bd5de7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.save_to_disk(output_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-up",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
