{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('../..')\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_from_disk\n",
    "import torch\n",
    "from src.hyperdas.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset\n",
    "\n",
    "from transformers import AutoTokenizer, LlamaForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"/scr-ssd/sjd24/llama3-8b\")\n",
    "model = LlamaForCausalLM.from_pretrained(\"/scr-ssd/sjd24/llama3-8b\", torch_dtype=torch.bfloat16)\n",
    "model = model.cuda()\n",
    "\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Country in split train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1250it [00:53, 23.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.66765; filtered out 6647 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "835it [00:35, 23.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9979030929379166; filtered out 28 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "833it [00:35, 23.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9985741088180112; filtered out 19 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 13306/13306 [00:00<00:00, 128939.22 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Country in split test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "250it [00:10, 23.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.663; filtered out 1348 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "166it [00:07, 23.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9939668174962293; filtered out 16 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "165it [00:06, 23.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9946889226100152; filtered out 14 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 2622/2622 [00:00<00:00, 112314.15 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Continent in split train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1250it [00:53, 23.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.67855; filtered out 6429 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "849it [00:36, 23.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.998526269250608; filtered out 20 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "847it [00:35, 23.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9983765035790717; filtered out 22 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 13529/13529 [00:00<00:00, 128630.00 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Continent in split test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "250it [00:10, 23.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.6785; filtered out 1286 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "170it [00:07, 23.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.994473102431835; filtered out 15 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "169it [00:07, 23.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9948128936643201; filtered out 14 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 2685/2685 [00:00<00:00, 111550.64 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Language in split train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1250it [00:54, 22.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.6004; filtered out 7992 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "751it [00:31, 23.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9961692205196535; filtered out 46 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "748it [00:31, 23.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.99849523491055; filtered out 18 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 11944/11944 [00:00<00:00, 131895.73 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Language in split test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "250it [00:11, 22.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.6095; filtered out 1562 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "153it [00:06, 22.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9967186218211649; filtered out 8 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "152it [00:06, 22.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9934156378600824; filtered out 16 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 2414/2414 [00:00<00:00, 109001.60 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Latitude in split train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1250it [00:52, 23.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.50515; filtered out 9897 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "632it [00:26, 23.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9966346629713946; filtered out 34 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "630it [00:26, 23.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9985102790743867; filtered out 15 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 10054/10054 [00:00<00:00, 122917.08 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Latitude in split test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "250it [00:10, 23.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.503; filtered out 1988 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "126it [00:05, 23.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9885685884691849; filtered out 23 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "125it [00:05, 23.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9964806435394671; filtered out 7 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 1982/1982 [00:00<00:00, 106912.79 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Longitude in split train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1250it [00:53, 23.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.488; filtered out 10240 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "610it [00:25, 23.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9962090163934426; filtered out 37 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "608it [00:25, 23.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9978401727861771; filtered out 21 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 9702/9702 [00:00<00:00, 129376.37 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Longitude in split test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "250it [00:10, 23.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.48525; filtered out 2059 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "122it [00:05, 23.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.994332818134982; filtered out 11 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "121it [00:05, 23.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.994300518134715; filtered out 11 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 1919/1919 [00:00<00:00, 107061.31 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Timezone in split train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1250it [00:51, 24.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.57055; filtered out 8589 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "714it [00:29, 24.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9964946104635878; filtered out 40 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "711it [00:29, 24.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9982411397414476; filtered out 20 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 11351/11351 [00:00<00:00, 129042.04 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating data for Timezone in split test\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "250it [00:10, 24.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.58575; filtered out 1657 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "147it [00:06, 24.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.990183525394793; filtered out 23 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "145it [00:06, 23.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9991379310344828; filtered out 2 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 2318/2318 [00:00<00:00, 107248.48 examples/s]\n"
     ]
    }
   ],
   "source": [
    "all_attributes = [\"Country\", \"Continent\", \"Language\", \"Latitude\", \"Longitude\", \"Timezone\"]\n",
    "\n",
    "for attribute in all_attributes:\n",
    "    for split, size in [(\"train\", 20000), (\"test\", 4000)]:\n",
    "        \n",
    "        print(f\"Generating data for {attribute} in split {split}\")\n",
    "        all_other_attributes = [a for a in all_attributes if a != attribute]\n",
    "        \n",
    "        dataset = generate_ravel_dataset(\n",
    "            size,\n",
    "            root_path=\"/home/ubuntu/HyperDAS/data/ravel\",\n",
    "            target_attributes=[attribute],\n",
    "            isolate_attributes=all_other_attributes,\n",
    "            template_split=split,\n",
    "            entity_split=\"both\",\n",
    "        )\n",
    "        \n",
    "        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)\n",
    "        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)\n",
    "        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)\n",
    "\n",
    "        dataset.save_to_disk(f\"/home/ubuntu/HyperDAS/experiments/ravel/data/city_{attribute.lower()}_{split}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "625it [00:26, 23.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.6655; filtered out 3345 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "416it [00:17, 23.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9959429000751315; filtered out 27 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "415it [00:17, 23.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.998038624019312; filtered out 13 examples\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "test_dataset = generate_ravel_dataset(\n",
    "    10000,\n",
    "    root_path=\"/home/ubuntu/HyperDAS/data/ravel\",\n",
    "    target_attributes=[\"Country\"],\n",
    "    isolate_attributes=[\"Continent\", \"Language\", \"Latitude\", \"Longitude\", \"Timezone\"],\n",
    "    template_split=\"test\",\n",
    "    entity_split=\"both\",\n",
    ")\n",
    "\n",
    "test_dataset = filter_dataset(model, tokenizer, test_dataset, batch_size=16)\n",
    "test_dataset = filter_dataset(model, tokenizer, test_dataset, batch_size=16)\n",
    "test_dataset = filter_dataset(model, tokenizer, test_dataset, batch_size=16)\n",
    "\n",
    "# test_dataset.save_to_disk(\"./data/ravel_sanity_check_test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6615, 3923, 0)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_country = [d for d in test_dataset if d[\"attribute\"] == \"Country\"]\n",
    "all_country_with_isolation = [d for d in test_dataset if d[\"attribute\"] == \"Country\" and d[\"attribute_type\"] == \"isolate\"]\n",
    "len(test_dataset), len(all_country), len(all_country_with_isolation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Country': 3923,\n",
       " 'Continent': 829,\n",
       " 'Language': 629,\n",
       " 'Latitude': 339,\n",
       " 'Longitude': 319,\n",
       " 'Timezone': 576}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attr_dict = {}\n",
    "\n",
    "for d in test_dataset:\n",
    "    if d[\"attribute\"] not in attr_dict:\n",
    "        attr_dict[d[\"attribute\"]] = 0\n",
    "    \n",
    "    attr_dict[d[\"attribute\"]] += 1\n",
    "    \n",
    "attr_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hypernet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
