{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "debe8409",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.utils import JIGSAW_DATASET_DIR_RAW, JIGSAW_DATASET_DIR_PROCESSED\n",
    "path_to_save_raw_dataset = JIGSAW_DATASET_DIR_RAW\n",
    "path_to_save_processed_dataset = JIGSAW_DATASET_DIR_PROCESSED\n",
    "%pip install kaggle\n",
    "# create a kaggle.json file with your API token at ~/.kaggle/kaggle.json\n",
    "!kaggle competitions download -c jigsaw-unintended-bias-in-toxicity-classification   \\\n",
    "    --path {path_to_save_raw_dataset}\n",
    "\n",
    "!unzip -d {path_to_save_raw_dataset} {path_to_save_raw_dataset}/jigsaw-unintended-bias-in-toxicity-classification.zip\n",
    "!rm -rf {path_to_save_raw_dataset}/jigsaw-unintended-bias-in-toxicity-classification.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "179cc049",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "dataset_name = \"google/jigsaw_unintended_bias\"\n",
    "# generate dataset from csv files\n",
    "dataset = load_dataset(dataset_name, data_dir=path_to_save_raw_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db34159c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# rename test_public_leaderboard to val and test_private_leaderboard to test\n",
    "dataset[\"val\"] = dataset.pop(\"test_public_leaderboard\")\n",
    "dataset[\"test\"] = dataset.pop(\"test_private_leaderboard\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cba292a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filter:   0%|          | 0/1804874 [00:00<?, ? examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filter: 100%|██████████| 1804874/1804874 [00:55<00:00, 32383.08 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32408.55 examples/s]\n",
      "Filter: 100%|██████████| 1804874/1804874 [00:55<00:00, 32470.92 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:02<00:00, 32585.52 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32112.60 examples/s]\n",
      "Filter: 100%|██████████| 1804874/1804874 [00:55<00:00, 32608.48 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 31835.57 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:02<00:00, 32522.64 examples/s]\n",
      "Filter: 100%|██████████| 1804874/1804874 [00:54<00:00, 32951.93 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32002.95 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32293.38 examples/s]\n",
      "Filter: 100%|██████████| 1804874/1804874 [00:55<00:00, 32698.91 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32365.61 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32368.38 examples/s]\n",
      "Filter: 100%|██████████| 1804874/1804874 [00:55<00:00, 32437.67 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 31679.37 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 31399.31 examples/s]\n",
      "Filter: 100%|██████████| 1804874/1804874 [00:55<00:00, 32790.53 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:03<00:00, 32205.39 examples/s]\n",
      "Filter: 100%|██████████| 97320/97320 [00:02<00:00, 32610.12 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# extract black/white, female/male, christian/muslim/jewish datasets\n",
    "train_dataset_black = dataset[\"train\"].filter(lambda x: x[\"black\"] > 0.5)\n",
    "val_dataset_black = dataset[\"val\"].filter(lambda x: x[\"black\"] > 0.5)\n",
    "test_dataset_black = dataset[\"test\"].filter(lambda x: x[\"black\"] > 0.5)\n",
    "\n",
    "train_dataset_white = dataset[\"train\"].filter(lambda x: x[\"white\"] > 0.5)\n",
    "val_dataset_white = dataset[\"val\"].filter(lambda x: x[\"white\"] > 0.5)\n",
    "test_dataset_white = dataset[\"test\"].filter(lambda x: x[\"white\"] > 0.5)\n",
    "\n",
    "train_dataset_male = dataset[\"train\"].filter(lambda x: x[\"male\"] > 0.5)\n",
    "val_dataset_male = dataset[\"val\"].filter(lambda x: x[\"male\"] > 0.5)\n",
    "test_dataset_male = dataset[\"test\"].filter(lambda x: x[\"male\"] > 0.5)\n",
    "\n",
    "train_dataset_female = dataset[\"train\"].filter(lambda x: x[\"female\"] > 0.5)\n",
    "val_dataset_female = dataset[\"val\"].filter(lambda x: x[\"female\"] > 0.5)\n",
    "test_dataset_female = dataset[\"test\"].filter(lambda x: x[\"female\"] > 0.5)\n",
    "\n",
    "train_dataset_christian = dataset[\"train\"].filter(lambda x: x[\"christian\"] > 0.5)\n",
    "val_dataset_christian = dataset[\"val\"].filter(lambda x: x[\"christian\"] > 0.5)\n",
    "test_dataset_christian = dataset[\"test\"].filter(lambda x: x[\"christian\"] > 0.5)\n",
    "\n",
    "train_dataset_muslim = dataset[\"train\"].filter(lambda x: x[\"muslim\"] > 0.5)\n",
    "val_dataset_muslim = dataset[\"val\"].filter(lambda x: x[\"muslim\"] > 0.5)\n",
    "test_dataset_muslim = dataset[\"test\"].filter(lambda x: x[\"muslim\"] > 0.5)\n",
    "\n",
    "train_dataset_jewish = dataset[\"train\"].filter(lambda x: x[\"jewish\"] > 0.5)\n",
    "val_dataset_jewish = dataset[\"val\"].filter(lambda x: x[\"jewish\"] > 0.5)\n",
    "test_dataset_jewish = dataset[\"test\"].filter(lambda x: x[\"jewish\"] > 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e9196c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a new dataset, the dataset consists of subsets of the above groups, and each subset consists of train, val, and test split\n",
    "\n",
    "new_dataset = {\"black\": {\"train\": train_dataset_black, \"val\": val_dataset_black, \"test\": test_dataset_black},\n",
    "              \"white\": {\"train\": train_dataset_white, \"val\": val_dataset_white, \"test\": test_dataset_white},\n",
    "              \"female\": {\"train\": train_dataset_female, \"val\": val_dataset_female, \"test\": test_dataset_female},\n",
    "              \"male\": {\"train\": train_dataset_male, \"val\": val_dataset_male, \"test\": test_dataset_male},\n",
    "              \"christian\": {\"train\": train_dataset_christian, \"val\": val_dataset_christian, \"test\": test_dataset_christian},\n",
    "              \"muslim\": {\"train\": train_dataset_muslim, \"val\": val_dataset_muslim, \"test\": test_dataset_muslim},\n",
    "              \"jewish\": {\"train\": train_dataset_jewish, \"val\": val_dataset_jewish, \"test\": test_dataset_jewish}\n",
    "                }\n",
    "\n",
    "# convert to a huggingface dataset\n",
    "from datasets import DatasetDict\n",
    "new_dataset = DatasetDict(new_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "ffe70071",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:   0%|          | 0/13869 [00:00<?, ? examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 13869/13869 [00:06<00:00, 2165.21 examples/s]\n",
      "Map: 100%|██████████| 699/699 [00:00<00:00, 1671.28 examples/s]\n",
      "Map: 100%|██████████| 696/696 [00:00<00:00, 2121.86 examples/s]\n",
      "Map: 100%|██████████| 23852/23852 [00:10<00:00, 2190.29 examples/s]\n",
      "Map: 100%|██████████| 1204/1204 [00:00<00:00, 2238.36 examples/s]\n",
      "Map: 100%|██████████| 1130/1130 [00:00<00:00, 1971.14 examples/s]\n",
      "Map: 100%|██████████| 50548/50548 [00:22<00:00, 2224.31 examples/s]\n",
      "Map: 100%|██████████| 2392/2392 [00:01<00:00, 2164.83 examples/s]\n",
      "Map: 100%|██████████| 2442/2442 [00:01<00:00, 2018.43 examples/s]\n",
      "Map: 100%|██████████| 40036/40036 [00:17<00:00, 2234.55 examples/s]\n",
      "Map: 100%|██████████| 2053/2053 [00:00<00:00, 2176.74 examples/s]\n",
      "Map: 100%|██████████| 1917/1917 [00:00<00:00, 1973.26 examples/s]\n",
      "Map: 100%|██████████| 35507/35507 [00:15<00:00, 2238.82 examples/s]\n",
      "Map: 100%|██████████| 1828/1828 [00:00<00:00, 2133.99 examples/s]\n",
      "Map: 100%|██████████| 1876/1876 [00:00<00:00, 2112.98 examples/s]\n",
      "Map: 100%|██████████| 19666/19666 [00:08<00:00, 2202.27 examples/s]\n",
      "Map: 100%|██████████| 914/914 [00:00<00:00, 2159.16 examples/s]\n",
      "Map: 100%|██████████| 977/977 [00:00<00:00, 1884.72 examples/s]\n",
      "Map: 100%|██████████| 7239/7239 [00:03<00:00, 2158.50 examples/s]\n",
      "Map: 100%|██████████| 405/405 [00:00<00:00, 1926.25 examples/s]\n",
      "Map: 100%|██████████| 387/387 [00:00<00:00, 1868.28 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# keep only one label: if target is greater than 0.5, then the label is 1, otherwise 0\n",
    "def convert_to_binary(example):\n",
    "    example[\"label\"] = 1 if example[\"target\"] > 0.5 else 0\n",
    "    return example\n",
    "\n",
    "for subset in new_dataset:\n",
    "    for split in new_dataset[subset]:\n",
    "        new_dataset[subset][split] = new_dataset[subset][split].map(convert_to_binary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "ac1dc23d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# rename comment_text to text, keep only text and label columns\n",
    "for subset in new_dataset:\n",
    "    for split in new_dataset[subset]:\n",
    "        new_dataset[subset][split] = new_dataset[subset][split].rename_column(\"comment_text\", \"text\")\n",
    "        new_dataset[subset][split] = new_dataset[subset][split].select_columns([\"text\", \"label\"])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b02f2f23",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (0/1 shards):   0%|          | 0/13869 [00:00<?, ? examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 13869/13869 [00:00<00:00, 50079.38 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 699/699 [00:00<00:00, 22781.49 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 696/696 [00:00<00:00, 23947.59 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 23852/23852 [00:00<00:00, 74927.90 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 1204/1204 [00:00<00:00, 40859.45 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 1130/1130 [00:00<00:00, 33236.77 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 50548/50548 [00:00<00:00, 75823.16 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 2392/2392 [00:00<00:00, 49096.04 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 2442/2442 [00:00<00:00, 49397.35 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 40036/40036 [00:00<00:00, 80586.26 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 2053/2053 [00:00<00:00, 53107.52 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 1917/1917 [00:00<00:00, 53203.82 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 35507/35507 [00:00<00:00, 83203.80 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 1828/1828 [00:00<00:00, 49761.73 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 1876/1876 [00:00<00:00, 52672.37 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 19666/19666 [00:00<00:00, 88422.39 examples/s] \n",
      "Saving the dataset (1/1 shards): 100%|██████████| 914/914 [00:00<00:00, 32185.86 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 977/977 [00:00<00:00, 34099.47 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 7239/7239 [00:00<00:00, 77203.24 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 405/405 [00:00<00:00, 16567.28 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 387/387 [00:00<00:00, 15987.04 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# save the new dataset\n",
    "import os\n",
    "for subset_name, split_dict in new_dataset.items():\n",
    "    dataset_dict = DatasetDict(split_dict)\n",
    "    dataset_dict.save_to_disk(os.path.join(path_to_save_processed_dataset, subset_name))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bcos",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
