{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import json\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('data/train.csv')\n",
    "df_val = pd.read_csv('data/valid.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Index(['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3',\n",
       "        'label'],\n",
       "       dtype='object'),\n",
       " Index(['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3',\n",
       "        'label'],\n",
       "       dtype='object'))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.columns, df_val.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "id          3Q9SPIIRWJKVQ8244310E8TUS6YWAC##34V1S5K3GTZMDU...\n",
       "context     Good Old War and person L : I saw both of thes...\n",
       "question    In the future , will this person go to see oth...\n",
       "answer0                           None of the above choices .\n",
       "answer1     This person likes music and likes to see the s...\n",
       "answer2     This person only likes Good Old War and Person...\n",
       "answer3     Other Bands is not on tour and this person can...\n",
       "label                                                       1\n",
       "Name: 0, dtype: object"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all answer candidates without the \"none of the above choices\" option\n",
    "all_choices = []\n",
    "for i in range(len(df_train)):\n",
    "    row = df_train.iloc[i]\n",
    "    chs = [row.answer0, row.answer1, row.answer2, row.answer3]\n",
    "    for ans in chs:\n",
    "        if \"none of the above choices\" in str(ans).lower():\n",
    "            continue\n",
    "        else:\n",
    "            all_choices.append(ans)\n",
    "\n",
    "for i in range(len(df_val)):\n",
    "    row = df_val.iloc[i]\n",
    "    chs = [row.answer0, row.answer1, row.answer2, row.answer3]\n",
    "    for ans in chs:\n",
    "        if \"none of the above choices\" in str(ans).lower():\n",
    "            continue\n",
    "        else:\n",
    "            all_choices.append(ans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "91728"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(all_choices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data_keep, new_data_to_sample = [], []\n",
    "def get_data(df):\n",
    "    for i in range(len(df)):\n",
    "        ex = {}\n",
    "        ex['source'] = 'CosmosQA'\n",
    "        ex['task'] = 'commonsense-based reading comprehension'\n",
    "        row = df.iloc[i]\n",
    "        ex['context'] = row.context\n",
    "        ex['question'] = row.question\n",
    "        ex['choices'] = {}\n",
    "        \n",
    "        chs = [row.answer0, row.answer1, row.answer2, row.answer3]\n",
    "        \n",
    "        # ignore questions whose answer is \"none of the above choices\"\n",
    "        if \"none of the above choices\" in str(chs[row.label]).lower():\n",
    "            continue\n",
    "        \n",
    "        # since we always add the option \"none of the above\", we replace the original \"none of the above choices\" option with a randomly selected answer\n",
    "        fg = False\n",
    "        for j in range(4):\n",
    "            if \"none of the above choices\" in str(chs[j]).lower():\n",
    "                while True:\n",
    "                    rp = random.sample(all_choices, 1)[0]\n",
    "                    if rp not in chs:\n",
    "                        chs[j] = rp\n",
    "                        fg = True\n",
    "                        break\n",
    "\n",
    "        chs.append(\"I don't know\")\n",
    "        chs.append(\"None of the above\")\n",
    "        letters = [\"A\", \"B\", \"C\", \"D\", \"E\", \"F\"]\n",
    "        for j in range(len(letters)):\n",
    "            ex['choices'][letters[j]] = chs[j]\n",
    "        ex['answer'] = letters[row.label]\n",
    "\n",
    "        if fg:\n",
    "            new_data_to_sample.append(ex)\n",
    "        else:\n",
    "            new_data_keep.append(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6189 17577\n",
      "6987 19505\n"
     ]
    }
   ],
   "source": [
    "get_data(df_train)\n",
    "print(len(new_data_keep), len(new_data_to_sample))\n",
    "get_data(df_val)\n",
    "print(len(new_data_keep), len(new_data_to_sample))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_sample = 10000 - len(new_data_keep)\n",
    "sampled_data = list(random.sample(new_data_to_sample, num_sample))\n",
    "sampled_data.extend(new_data_keep)\n",
    "random.shuffle(sampled_data)\n",
    "len(sampled_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, dd in enumerate(sampled_data):\n",
    "    dd['id'] = idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/cosmosqa_10k.json\", 'w') as f:\n",
    "    json.dump(sampled_data, f, indent=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
