{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from datasets import load_from_disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset_stats(train, test, name):\n",
    "    num_train_examples = len(train)\n",
    "    num_test_examples = len(test)\n",
    "    \n",
    "    num_train_causal_example = len([ex for ex in train if ex['attribute_type'] == \"causal\"])\n",
    "    num_train_iso_example = num_train_examples - num_train_causal_example\n",
    "    \n",
    "    num_test_causal_example = len([ex for ex in test if ex['attribute_type'] == \"causal\"])\n",
    "    num_test_iso_example = num_test_examples - num_test_causal_example\n",
    "    \n",
    "    num_train_entity = len(set([ex['entity'] for ex in train]).union(set([ex['counterfactual_entity'] for ex in train])))\n",
    "    num_test_entity = len(set([ex['entity'] for ex in test]).union(set([ex['counterfactual_entity'] for ex in test])))\n",
    "    print(f\"{name} & {num_train_causal_example}/{num_test_causal_example} & {num_train_iso_example}/{num_test_iso_example} & {num_train_entity}/{num_test_entity} \\\\\\\\\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "city\n",
      "{'Country', 'Language', 'Continent', 'Timezone', 'Latitude', 'Longitude'}\n",
      "nobel_prize_winner\n",
      "{'Country of Birth', 'Award Year', 'Gender', 'Field', 'Birth Year'}\n",
      "occupation\n",
      "{'Work Location', 'Duty', 'Industry'}\n",
      "physical_object\n",
      "{'Color', 'Category', 'Texture', 'Size'}\n",
      "verb\n",
      "{'Past Tense', 'Singular'}\n"
     ]
    }
   ],
   "source": [
    "for domain in [\"city\", \"nobel_prize_winner\", \"occupation\", \"physical_object\", \"verb\"]:\n",
    "    \n",
    "    print(domain)    \n",
    "    trainset = load_from_disk(f\"./data/{domain}_train_large\") if domain != \"city\" else load_from_disk(f\"./data/{domain}_train\")\n",
    "    testset = load_from_disk(f\"./data/{domain}_test\")\n",
    "    domain_attributes = set([d[\"attribute\"] for d in trainset])\n",
    "    for attribute in domain_attributes:\n",
    "        train = [ex for ex in trainset if ex[\"attribute\"] == attribute]\n",
    "        test = [ex for ex in testset if ex[\"attribute\"] == attribute]\n",
    "        get_dataset_stats(train, test, attribute)\n",
    "    # get_dataset_stats(trainset, testset, domain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hypernet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
