{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk, Dataset, concatenate_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original nobel_prize_winner dataset size: 72965\n",
      "Upsampling nobel_prize_winner dataset\n",
      "Extra size: 11434\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (0/1 shards):   0%|          | 0/84399 [00:00<?, ? examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 84399/84399 [00:01<00:00, 57485.97 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New nobel_prize_winner dataset size: 84399\n",
      "Original verb dataset size: 22931\n",
      "Upsampling verb dataset\n",
      "Extra size: 61468\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 84399/84399 [00:00<00:00, 95022.97 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New verb dataset size: 84399\n",
      "Original physical_object dataset size: 42226\n",
      "Upsampling physical_object dataset\n",
      "Extra size: 42173\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 84399/84399 [00:00<00:00, 92287.13 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New physical_object dataset size: 84399\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "target_dataset_size = len(load_from_disk(\"data/city_train\"))\n",
    "\n",
    "for domain in [\"nobel_prize_winner\", \"verb\", \"physical_object\"]:\n",
    "    original_dataset = load_from_disk(f\"data/{domain}_train\")\n",
    "    print(f\"Original {domain} dataset size: {len(original_dataset)}\")\n",
    "    \n",
    "    if len(original_dataset) < target_dataset_size:\n",
    "        print(f\"Upsampling {domain} dataset\")\n",
    "        extra_size = target_dataset_size - len(original_dataset)\n",
    "        print(f\"Extra size: {extra_size}\")\n",
    "        shuffled_dataset = original_dataset.shuffle(seed=42)\n",
    "        shuffled_dataset = shuffled_dataset.select([i % len(original_dataset) for i in range(extra_size)])\n",
    "        new_dataset = concatenate_datasets([shuffled_dataset, original_dataset])\n",
    "        new_dataset.save_to_disk(f\"data/{domain}_train_large\")\n",
    "        print(f\"New {domain} dataset size: {len(new_dataset)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 421092/421092 [00:00<00:00, 531942.64 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 46629/46629 [00:00<00:00, 502659.33 examples/s]\n"
     ]
    }
   ],
   "source": [
    "all_train_sets = []\n",
    "all_test_sets = []\n",
    "\n",
    "for domain in [\"nobel_prize_winner\", \"verb\", \"physical_object\", \"occupation\"]:\n",
    "    all_train_sets.append(load_from_disk(f\"data/{domain}_train_large\"))\n",
    "    all_test_sets.append(load_from_disk(f\"data/{domain}_test\"))\n",
    "    \n",
    "all_train_sets.append(load_from_disk(\"data/city_train\"))\n",
    "all_test_sets.append(load_from_disk(\"data/city_test\"))\n",
    "\n",
    "train_dataset = concatenate_datasets(all_train_sets)\n",
    "test_dataset = concatenate_datasets(all_test_sets)\n",
    "\n",
    "train_dataset.save_to_disk(\"data/train_large\")\n",
    "\n",
    "# Randomly downsample the test set to 2000 examples\n",
    "test_dataset = test_dataset.shuffle(seed=42)\n",
    "test_dataset = test_dataset.select(range(2000))\n",
    "test_dataset.save_to_disk(\"data/test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
