{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('../..')\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_from_disk, Dataset\n",
    "import torch\n",
    "from src.hyperdas.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from datasets import Dataset\n",
    "\n",
    "def _get_causal_size(dataset):\n",
    "    \n",
    "    causal_dataset = [d for d in dataset if d[\"attribute_type\"] == \"causal\"]\n",
    "    return len(causal_dataset)\n",
    "\n",
    "def _split_by_attribute_type(dataset):\n",
    "    \n",
    "    causal_dataset, isolate_dataset = [], []\n",
    "    for d in dataset:\n",
    "        if d[\"attribute_type\"] == \"causal\":\n",
    "            causal_dataset.append(d)\n",
    "        else:\n",
    "            isolate_dataset.append(d)\n",
    "            \n",
    "    return causal_dataset, isolate_dataset\n",
    "\n",
    "\n",
    "def _enlarge_dataset(dataset, target_causal_size, target_isolate_size, target_attribute):\n",
    "    \n",
    "    all_attributes = [\"Country\", \"Continent\", \"Language\", \"Latitude\", \"Longitude\", \"Timezone\"]\n",
    "    other_attributes = [attribute for attribute in all_attributes if attribute != target_attribute]\n",
    "    \n",
    "    causal_dataset, isolate_dataset = _split_by_attribute_type(dataset)\n",
    "    \n",
    "    num_causal_sample = target_causal_size - len(causal_dataset)\n",
    "    \n",
    "    new_isolate_dataset = []\n",
    "    \n",
    "    for attribute in other_attributes:\n",
    "        \n",
    "        attribute_isolate_dataset = [d for d in isolate_dataset if d[\"attribute\"] == attribute]\n",
    "        \n",
    "        num_isolate_sample = target_isolate_size - len(attribute_isolate_dataset)\n",
    "        \n",
    "        if num_isolate_sample > 0:\n",
    "            attribute_isolate_dataset = attribute_isolate_dataset + [random.choice(attribute_isolate_dataset) for _ in range(num_isolate_sample)]\n",
    "        else:\n",
    "            attribute_isolate_dataset = random.sample(attribute_isolate_dataset, target_isolate_size)\n",
    "\n",
    "        new_isolate_dataset.extend(attribute_isolate_dataset)\n",
    "        \n",
    "    enlarged_dataset = causal_dataset + new_isolate_dataset\n",
    "    enlarged_dataset = Dataset.from_list(enlarged_dataset)\n",
    "    \n",
    "    return enlarged_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for split in ['train', 'test']:\n",
    "    \n",
    "    \n",
    "    merged_dataset = []\n",
    "    \n",
    "    data_dict = {}\n",
    "    \n",
    "    for attribute in [\"Country\", \"Continent\", \"Language\", \"Latitude\", \"Longitude\", \"Timezone\"]:\n",
    "        dataset = load_from_disk(f'./data/city_{attribute.lower()}_{split}')\n",
    "        data_dict[attribute] = dataset\n",
    "    \n",
    "    target_dataset_size = max([_get_causal_size(dataset) for dataset in data_dict.values()])\n",
    "    target_isolate_dataset_size = target_dataset_size // 5\n",
    "    \n",
    "    for attribute in [\"Country\", \"Continent\", \"Language\", \"Latitude\", \"Longitude\", \"Timezone\"]:\n",
    "        \n",
    "        dataset = data_dict[attribute]\n",
    "        enlarged_dataset = _enlarge_dataset(dataset, target_dataset_size, target_isolate_dataset_size, attribute)\n",
    "        merged_dataset.extend(enlarged_dataset)\n",
    "    \n",
    "    merged_dataset = Dataset.from_list(merged_dataset)\n",
    "    merged_dataset.save_to_disk(f'./data/city_{split}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Country_causal': 3929,\n",
       " 'Continent_causal': 4161,\n",
       " 'Language_causal': 3151,\n",
       " 'Latitude_causal': 1887,\n",
       " 'Longitude_causal': 1737,\n",
       " 'Timezone_causal': 2636,\n",
       " 'Country_isolate': 3905,\n",
       " 'Continent_isolate': 4178,\n",
       " 'Language_isolate': 3196,\n",
       " 'Latitude_isolate': 1937,\n",
       " 'Longitude_isolate': 1662,\n",
       " 'Timezone_isolate': 2685}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stat_dict = {}\n",
    "for data in city_train:\n",
    "    key = data[\"attribute\"] + \"_\" + data[\"attribute_type\"]\n",
    "    if key not in stat_dict:\n",
    "        stat_dict[key] = 0\n",
    "    \n",
    "    stat_dict[key] += 1\n",
    "stat_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Country 6852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 6852/6852 [00:00<00:00, 731188.68 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Continent 7120\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 7120/7120 [00:00<00:00, 756285.47 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Language 6078\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 6078/6078 [00:00<00:00, 733526.49 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Longitude 3474\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 3474/3474 [00:00<00:00, 582467.70 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Latitude 3774\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 3774/3774 [00:00<00:00, 613586.45 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Timezone 5272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 5272/5272 [00:00<00:00, 704260.48 examples/s]\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "for attribute in [\"Country\", \"Continent\", \"Language\", \"Longitude\", \"Latitude\", \"Timezone\"]:\n",
    "    \n",
    "    attribute_dataset = []\n",
    "    isolate_data = []\n",
    "    \n",
    "    for data in city_train:\n",
    "        if data[\"attribute\"] == attribute and data[\"attribute_type\"] == \"causal\":\n",
    "            attribute_dataset.append(data)\n",
    "        elif attribute in data[\"edit_instruction\"] and data[\"attribute_type\"] == \"isolate\":\n",
    "            isolate_data.append(data)\n",
    "            \n",
    "    if len(isolate_data) < len(attribute_dataset):\n",
    "        attribute_dataset.extend(isolate_data)\n",
    "    else:\n",
    "        attribute_dataset.extend(random.sample(isolate_data, len(attribute_dataset)))\n",
    "    \n",
    "    print(attribute, len(attribute_dataset))\n",
    "    attribute_dataset = Dataset.from_list(attribute_dataset)\n",
    "    attribute_dataset.save_to_disk(f\"./data/ravel_city_{attribute}_train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Country 685\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 685/685 [00:00<00:00, 208846.28 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Continent 703\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 703/703 [00:00<00:00, 203099.31 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Language 621\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 621/621 [00:00<00:00, 188143.80 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Longitude 324\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 324/324 [00:00<00:00, 99462.38 examples/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Latitude"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 370\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 370/370 [00:00<00:00, 113583.58 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Timezone 571\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 571/571 [00:00<00:00, 166825.55 examples/s]\n"
     ]
    }
   ],
   "source": [
    "for attribute in [\"Country\", \"Continent\", \"Language\", \"Longitude\", \"Latitude\", \"Timezone\"]: \n",
    "    \n",
    "    attribute_dataset = []\n",
    "    isolate_data = []\n",
    "    \n",
    "    for data in city_test:\n",
    "        if data[\"attribute\"] == attribute and data[\"attribute_type\"] == \"causal\":\n",
    "            attribute_dataset.append(data)\n",
    "        elif attribute in data[\"edit_instruction\"] and data[\"attribute_type\"] == \"isolate\":\n",
    "            isolate_data.append(data)\n",
    "            \n",
    "    if len(isolate_data) < len(attribute_dataset):\n",
    "        attribute_dataset.extend(isolate_data)\n",
    "    else:\n",
    "        attribute_dataset.extend(random.sample(isolate_data, len(attribute_dataset)))\n",
    "    \n",
    "    print(attribute, len(attribute_dataset))\n",
    "    attribute_dataset = Dataset.from_list(attribute_dataset)\n",
    "    attribute_dataset.save_to_disk(f\"./data/ravel_city_{attribute}_test\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hypernet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
