{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b6af0ece-a4cc-4e4b-ab66-23598a0ed9c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/llm-risk-control/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Imports\n",
    "import json\n",
    "import pandas as pd\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4bd3cbc0-3952-4508-a7a2-fc058fba9781",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_balanced_dataset(df):\n",
    "    # Balance the dataset\n",
    "    balanced = []\n",
    "    n = df['label'].value_counts().min()\n",
    "    for label, group in df.groupby('label'):\n",
    "        group = group.sample(frac=1)[:n]\n",
    "        balanced.append(group)\n",
    "    \n",
    "    return pd.concat(balanced).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b628c0ca-c1ea-42a4-ad66-a969ed57631b",
   "metadata": {},
   "outputs": [],
   "source": [
    "unprocessed_data_dir = \"./datasets/icl/\"\n",
    "data_output_dir = \"./processed_data/icl/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f6284e8-6ad6-4be6-838b-43d990b97294",
   "metadata": {},
   "source": [
    "# Text-based Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "654ae91b-1c8f-4935-88eb-9dd3c9341357",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AG News (AG_News)\n",
    "df = pd.read_csv(unprocessed_data_dir + 'ag_news/test.csv')\n",
    "# Append the title and description to get the text\n",
    "df['text'] = [t + \": \" + d for t,d in zip(df['Title'], df['Description'])]\n",
    "df['label'] = df['Class Index'].replace(1, 'world').replace(2, 'sports').replace(3, 'business')\n",
    "df['label'] = df['label'].replace(4, 'science/technology')\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'ag_news.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "bbab5a7b-68ca-40d0-913d-1151106d541d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Text Retrieval Conference (TREC)\n",
    "# https://huggingface.co/datasets/CogComp/trec\n",
    "df = pd.read_csv(unprocessed_data_dir + 'trec/train.csv')\n",
    "df['label'] = df['label-coarse'].replace(0, 'abbreviation').replace(1, 'entity').replace(2, 'description')\n",
    "df['label'] = df['label'].replace(3, 'human').replace(4, 'location').replace(5, 'numeric')\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'trec.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "f7d85c2c-cf67-4c9a-b4d2-e8afd79cc54b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Spam detection (SpamAssassin)\n",
    "# no = not spam, yes = spam\n",
    "df = pd.read_csv(unprocessed_data_dir + 'spam/spam_assassin.csv')\n",
    "df['label'] = df['target'].replace(0, 'no').replace(1, 'yes')\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'spam.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c04f9245-5b05-4c46-b8e2-603fde8930af",
   "metadata": {},
   "source": [
    "# Paraphrase & Entailment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "90f6eb3e-56f1-406e-9147-952f86ce2a4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Medical Question Pairs (MQP)\n",
    "# Medical Question Pairs dataset by McCreery et al (2020) contains pairs of medical questions and paraphrased versions \n",
    "# of the question prepared by medical professional. Paraphrased versions were labelled as similar (syntactically dissimilar \n",
    "# but contextually similar) or dissimilar (syntactically may look similar but contextually dissimilar). \n",
    "# Labels 1: similar, 0: dissimilar\n",
    "ds = load_dataset(\"bigbio/mqp\", \"mqp_source\")\n",
    "df = pd.DataFrame(data=ds['train'], columns=ds['train'].features)\n",
    "df['text'] = ['Text 1: ' + t1 + ' Text 2: ' + t2 for t1,t2 in zip(df['text_1'], df['text_2'])]\n",
    "df['label'] = df['label'].replace('0', 'dissimilar').replace('1', 'similar')\n",
    "df = df[['text', 'label']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'mqp.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6b8763a7-601d-41c6-bc03-e82c58eb47ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Microsoft Research Paraphrase Corpus (MRP)\n",
    "# labels: 1 - equivalent, 0 - different\n",
    "ds = load_dataset('glue', 'mrpc', split='train')\n",
    "df = pd.DataFrame(data=ds, columns=ds.features)\n",
    "df['text'] = ['Sentence 1: ' + t1 + ' Sentence 2: ' + t2 for t1,t2 in zip(df['sentence1'], df['sentence2'])]\n",
    "df['label'] = df['label'].replace(0, 'different').replace(1, 'equivalent')\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'mrp.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "480157da-1923-4fd3-9545-e3e18d7d8707",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    }
   ],
   "source": [
    "# Winograd NLI (WNLI)\n",
    "ds = load_dataset(\"SetFit/wnli\")\n",
    "df = pd.DataFrame(data=ds['train'], columns=ds['train'].features)\n",
    "df['text'] = ['Text 1: ' + t1 + ' Text 2: ' + t2 for t1,t2 in zip(df['text1'], df['text2'])]\n",
    "df['label'] = df['label'].replace(0, 'is entailment').replace(1, 'not entailment')\n",
    "df = df[['text', 'label']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'wnli.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afccc590-97d5-4a48-924c-fd5a52de82f4",
   "metadata": {},
   "source": [
    "# Sentiment Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5b10e19d-2b18-4223-b25a-5a1e498c00f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FinancialPhrasebank\n",
    "df = pd.read_csv(unprocessed_data_dir + 'financial_phrasebank/financial_phrasebank_processed.csv')\n",
    "df['label'] = df['Sentiment']\n",
    "df['text'] = df['Text']\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'financial_phrasebank.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "42ed1f24-b944-4ff7-b5bc-02cc2ab35f63",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    }
   ],
   "source": [
    "# SST2\n",
    "ds = load_dataset(\"SetFit/sst2\")\n",
    "df = pd.DataFrame(data=ds['train'], columns=ds['train'].features)[['text', 'label_text']] \n",
    "df['label'] = df['label_text']\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'sst2.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8dc8cec7-ec3a-48b0-9941-31aa8308a71b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TweetEval-Atheism\n",
    "ds = load_dataset(\"cardiffnlp/tweet_eval\", \"stance_atheism\")\n",
    "df = pd.DataFrame(data=ds['train'], columns=ds['train'].features)[['text', 'label']] \n",
    "df['label'] = df['label'].replace(0, 'neither').replace(1, 'no').replace(2, 'yes')\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'tweeteval_atheism.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "86bce967-91da-4936-aff0-172ee226f8d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TweetEval-Feminist\n",
    "ds = load_dataset(\"cardiffnlp/tweet_eval\", \"stance_feminist\")\n",
    "df = pd.DataFrame(data=ds['train'], columns=ds['train'].features)[['text', 'label']] \n",
    "df['label'] = df['label'].replace(0, 'neither').replace(1, 'no').replace(2, 'yes')\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'tweeteval_feminist.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "307bea5a-101c-4640-9c4a-fad940000264",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TweetEval-Hate\n",
    "ds = load_dataset(\"cardiffnlp/tweet_eval\", \"hate\")\n",
    "df = pd.DataFrame(data=ds['train'], columns=ds['train'].features)[['text', 'label']] \n",
    "df['label'] = df['label'].replace(0, 'favor').replace(1, 'against')\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'tweeteval_hate.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "851a7861-21f7-4b0c-be63-fff19b882bd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unnatural\n",
    "df = pd.read_csv(unprocessed_data_dir + 'unnatural/unnatural.csv')\n",
    "df = df[['label', 'text']]\n",
    "df = get_balanced_dataset(df)\n",
    "df.to_csv(data_output_dir + 'unnatural.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4538d5f5-e5b6-45fa-a66d-fe21367f7d11",
   "metadata": {},
   "source": [
    "# BigBench-Hard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a4a999a2-3916-4d91-a0b4-c6052ecfe1e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_balanced_bbh_dataset(data):\n",
    "    text = [item['input'] for item in data['examples']]\n",
    "    labels = [item['target'].lower() for item in data['examples']]\n",
    "    \n",
    "    df = pd.DataFrame()\n",
    "    df['text'] = text\n",
    "    df['label'] = labels\n",
    "    return get_balanced_dataset(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "443e326d-d061-49c8-8022-b7a0aee46604",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Boolean\n",
    "with open(unprocessed_data_dir + 'boolean/boolean_expressions.json', 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "df_balanced = get_balanced_bbh_dataset(data)\n",
    "df_balanced.to_csv(data_output_dir + 'boolean.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f018c70d-9cfc-486d-9b51-6c73b3d6aac3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Navigation\n",
    "with open(unprocessed_data_dir + 'navigation/navigate.json', 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "df_balanced = get_balanced_bbh_dataset(data)\n",
    "df_balanced.to_csv(data_output_dir + 'navigation.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "6252097c-96a4-42ec-bcab-83eeec95cc18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sports Understanding\n",
    "with open(unprocessed_data_dir + 'sports/sports_understanding.json', 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "df_balanced = get_balanced_bbh_dataset(data)\n",
    "df_balanced.to_csv(data_output_dir + 'sports.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "c62c4b1e-09e8-4e6c-b0f1-4a6e7cfcd69d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Web of Lies\n",
    "with open(unprocessed_data_dir + 'web_of_lies/web_of_lies.json', 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "df_balanced = get_balanced_bbh_dataset(data)\n",
    "df_balanced.to_csv(data_output_dir + 'web_of_lies.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "623f9822-d6d1-482f-84ef-825d39903b90",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-risk-control",
   "language": "python",
   "name": "llm-risk-control"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
