{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71e73838",
   "metadata": {},
   "source": [
    "# utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "668e64b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm.auto import tqdm\n",
    "import itertools\n",
    "import os\n",
    "from copy import deepcopy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def build_dicts(entities):\n",
    "    entity2ind = dict()\n",
    "    ind2entity = []\n",
    "    for i in range(len(entities)):\n",
    "        entity = entities[i]\n",
    "        if not (entity in ind2entity):\n",
    "            ind2entity.append(entity)\n",
    "            entity2ind[entity] = len(ind2entity) - 1\n",
    "    return ind2entity, entity2ind\n",
    "\n",
    "def choose(arr, ratio_or_count):\n",
    "    if type(ratio_or_count) == float:\n",
    "        num = round(ratio_or_count*len(arr))\n",
    "    elif type(ratio_or_count) == int:\n",
    "        num = ratio_or_count\n",
    "    else:\n",
    "         assert False\n",
    "    if num >= len(arr):\n",
    "        return arr\n",
    "    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()\n",
    "    return [arr[i] for i in rand_inds]\n",
    "    \n",
    "def split(arr, ratio_or_count):\n",
    "    if type(ratio_or_count) == float:\n",
    "        num = round(ratio_or_count*len(arr))\n",
    "    elif type(ratio_or_count) == int:\n",
    "        num = ratio_or_count\n",
    "    else:\n",
    "         assert False\n",
    "    train, test = [], []\n",
    "    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()\n",
    "    for i in tqdm(range(len(arr))):\n",
    "        if i in rand_inds:\n",
    "            train.append(arr[i])\n",
    "        else:\n",
    "            test.append(arr[i])\n",
    "    return [train, test]\n",
    "\n",
    "def form_items(c, t):\n",
    "    \"\"\"\n",
    "    return format:\n",
    "        {\n",
    "            \"input_text\": \"<e_736><r_120>\", \n",
    "            \"target_text\": \"<e_736><r_120><e_1544></a>\",\n",
    "        }\n",
    "    \"\"\"\n",
    "    input_text = \"\".join(c)\n",
    "    target_text = input_text + \"\".join([t, \"</a>\"])\n",
    "    item = {\n",
    "        \"input_text\": input_text,\n",
    "        \"target_text\": target_text\n",
    "    }\n",
    "    return item"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fcedc07",
   "metadata": {},
   "source": [
    "# Base Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8d4fd9c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_base_dataset(num_entities, num_relations, out_degree=20):\n",
    " \n",
    "    entities = [\"<e_{}>\".format(i) for i in range(num_entities)]\n",
    "    ind2entity, entity2ind = build_dicts(entities)\n",
    "\n",
    "    relations = [\"<r_{}>\".format(i) for i in range(num_relations)]\n",
    "    ind2relation, relation2ind = build_dicts(relations)\n",
    "\n",
    "    atomic_dict = dict()   # {h1: [(r1, t1), (r2, t2), ...], ...}\n",
    "    atomic_facts = []   # [{\"input_text\": \"...\", \"target_text\": \"...\"}, ...]\n",
    "    atomics = []   # [(h1,r1,t1), (h2,r2,t2), ...]\n",
    "\n",
    "    for i in tqdm(range(num_entities)):\n",
    "        # for each subject entity, randomly select some outgoing relations to some random object entity\n",
    "        num_rows = out_degree\n",
    "        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()\n",
    "        for row_idx in selected_rows:\n",
    "            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)\n",
    "            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]  # h and t might be same here\n",
    "            atomic_facts.append(form_items([h, r], t))\n",
    "            atomics.append((h,r,t))\n",
    "            if h not in atomic_dict:\n",
    "                atomic_dict[h] = []\n",
    "            atomic_dict[h].append((r, t))\n",
    "    \n",
    "    # split ID/OOD\n",
    "    OOD_ratio = 0.05\n",
    "    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))  # randomly\n",
    "    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)\n",
    "\n",
    "    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]\n",
    "    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]\n",
    "\n",
    "    train_2hop_ii, test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo = [], [], [], [], []\n",
    "    \n",
    "    for ent in tqdm(entities, desc=\"2-hop: \"):\n",
    "        for (r1, b) in atomic_dict[ent]:\n",
    "            for (r2, t) in atomic_dict[b]:\n",
    "                if (ent, r1, b) in ID_facts and (b, r2, t) in ID_facts:\n",
    "                    if np.random.uniform() > 0.05:\n",
    "                        train_2hop_ii.append(form_items([ent, r1, r2], t))\n",
    "                    else:\n",
    "                        test_2hop_ii.append(form_items([ent, r1, r2], t))\n",
    "                \n",
    "                elif (ent, r1, b) in ID_facts and (b, r2, t) in OOD_facts:\n",
    "                    test_2hop_io.append(form_items([ent, r1, r2], t))\n",
    "                \n",
    "                elif (ent, r1, b) in OOD_facts and (b, r2, t) in ID_facts:\n",
    "                    test_2hop_oi.append(form_items([ent, r1, r2], t))\n",
    "                \n",
    "                elif (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:\n",
    "                    test_2hop_oo.append(form_items([ent, r1, r2], t))\n",
    "\n",
    "    return (\n",
    "        entities, relations,  # vocab\n",
    "        id_atomic_facts, ood_atomic_facts,\n",
    "        \n",
    "        # 2-hop\n",
    "        train_2hop_ii,  # train\n",
    "        test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo,  # test\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5af1d81",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ENTITY_IN = 2000\n",
    "NUM_RELATION = 200\n",
    "\n",
    "(\n",
    "    entities, relations,  # vocab\n",
    "    id_atomic_facts, ood_atomic_facts,\n",
    "    \n",
    "    # 2-hop\n",
    "    train_2hop_ii,  # train\n",
    "    test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo,  # test\n",
    ") = build_base_dataset(NUM_ENTITY_IN, NUM_RELATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6fb49399",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = []\n",
    "vocab = vocab + entities + relations\n",
    "# special tokens\n",
    "# actually only \"</a>\" used as <eos>\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "\n",
    "# For predict_during_training: cut test group\n",
    "test_size = 3000\n",
    "# atomic\n",
    "test_id_atomic = choose(id_atomic_facts, test_size)\n",
    "test_ood_atomic = choose(ood_atomic_facts, test_size)\n",
    "# 2-hop\n",
    "test_2hop_ii = choose(test_2hop_ii, test_size)\n",
    "test_2hop_io = choose(test_2hop_io, test_size)\n",
    "test_2hop_oi = choose(test_2hop_oi, test_size)\n",
    "test_2hop_oo = choose(test_2hop_oo, test_size)\n",
    "\n",
    "train_atomics = id_atomic_facts + ood_atomic_facts\n",
    "\n",
    "phi = 7.2  # Train-II / ID Triples\n",
    "train_2hop_ii = choose(train_2hop_ii, phi * len(id_atomic_facts))\n",
    "\n",
    "dataset_name = \"base_configuration.{}.{}.{}\".format(NUM_ENTITY_IN, NUM_RELATION, phi)\n",
    "os.makedirs(\"../data/{}\".format(dataset_name), exist_ok=True)\n",
    "\n",
    "probes = []\n",
    "\n",
    "for item in test_id_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"ID Triples\"\n",
    "\n",
    "for item in test_ood_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"OOD Triples\"\n",
    "\n",
    "for item in choose(train_2hop_ii, test_size):\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Train-II'\n",
    "\n",
    "for item in test_2hop_ii:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-II\"\n",
    "\n",
    "for item in test_2hop_io:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-IO'\n",
    "\n",
    "for item in test_2hop_oi:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-OI'\n",
    "\n",
    "for item in test_2hop_oo:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-OO\"\n",
    "\n",
    "with open(\"../data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(train_atomics + train_2hop_ii, f)\n",
    "\n",
    "# evaluate_during_training\n",
    "with open(\"../data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(test_2hop_oo, f)\n",
    "\n",
    "# predict_during_training\n",
    "with open(\"../data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(probes, f)\n",
    "\n",
    "# add vocab\n",
    "with open(\"../data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(vocab, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec8beb4f",
   "metadata": {},
   "source": [
    "# 3-hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "77585c5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_3_hop(num_entities, num_relations, out_degree=10):\n",
    "\n",
    "    entities = [\"<e_{}>\".format(i) for i in range(num_entities)]\n",
    "    ind2entity, entity2ind = build_dicts(entities)\n",
    "\n",
    "    relations = [\"<r_{}>\".format(i) for i in range(num_relations)]\n",
    "    ind2relation, relation2ind = build_dicts(relations)\n",
    "\n",
    "    atomic_dict = dict()   # {h1: [(r1, t1), (r2, t2), ...], ...}\n",
    "    atomic_facts = []   # [{\"input_text\": \"...\", \"target_text\": \"...\"}, ...]\n",
    "    atomics = []   # [(h1,r1,t1), (h2,r2,t2), ...]\n",
    "\n",
    "    for i in tqdm(range(num_entities)):\n",
    "        # for each subject entity, randomly select some outgoing relations to some random object entity\n",
    "        num_rows = out_degree\n",
    "        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()\n",
    "        for row_idx in selected_rows:\n",
    "            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)\n",
    "            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]  # h and t might be same here\n",
    "            atomic_facts.append(form_items([h, r], t))\n",
    "            atomics.append((h,r,t))\n",
    "            if h not in atomic_dict:\n",
    "                atomic_dict[h] = []\n",
    "            atomic_dict[h].append((r, t))\n",
    "    \n",
    "    # split ID/OOD\n",
    "    OOD_ratio = 0.2  # This ratio can't be too low in 3-hop dataset\n",
    "    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))  # randomly\n",
    "    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)\n",
    "\n",
    "    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]\n",
    "    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]\n",
    "    \n",
    "    (\n",
    "        train_3hop_iii, \n",
    "        test_3hop_iii, test_3hop_iio, test_3hop_ioi, test_3hop_ioo,  # startwiths i\n",
    "        test_3hop_oii, test_3hop_oio, test_3hop_ooi, test_3hop_ooo,  # startwiths o\n",
    "    ) = [], [], [], [], [], [], [], [], []\n",
    "\n",
    "    for ent in tqdm(entities, desc=\"3-hop: \"):\n",
    "        for (r1, b1) in atomic_dict[ent]:\n",
    "            for (r2, b2) in atomic_dict[b1]:\n",
    "                for (r3, t) in atomic_dict[b2]:\n",
    "                    if (ent, r1, b1) in ID_facts and (b1, r2, b2) in ID_facts and (b2, r3, t) in ID_facts:\n",
    "                        if np.random.uniform() > 0.05:\n",
    "                            # 1000000 * 0.8 * 0.8 * 0.8 * 0.95 = 486400\n",
    "                            train_3hop_iii.append(form_items([ent, r1, r2, r3], t))\n",
    "                        else:\n",
    "                            # 1000000 * 0.8 * 0.8 * 0.8 * 0.05 = 25600\n",
    "                            test_3hop_iii.append(form_items([ent, r1, r2, r3], t))\n",
    "                    \n",
    "                    # 1000000 * 0.8 * 0.8 * 0.2 = 128000\n",
    "                    elif (ent, r1, b1) in ID_facts and (b1, r2, b2) in ID_facts and (b2, r3, t) in OOD_facts:\n",
    "                        test_3hop_iio.append(form_items([ent, r1, r2, r3], t))\n",
    "\n",
    "                    # 1000000 * 0.8 * 0.2 * 0.8 = 128000\n",
    "                    elif (ent, r1, b1) in ID_facts and (b1, r2, b2) in OOD_facts and (b2, r3, t) in ID_facts:\n",
    "                        test_3hop_ioi.append(form_items([ent, r1, r2, r3], t))\n",
    "\n",
    "                    # 1000000 * 0.8 * 0.2 * 0.2 = 32000\n",
    "                    elif (ent, r1, b1) in ID_facts and (b1, r2, b2) in OOD_facts and (b2, r3, t) in OOD_facts:\n",
    "                        test_3hop_ioo.append(form_items([ent, r1, r2, r3], t))\n",
    "                    \n",
    "                    # 1000000 * 0.2 * 0.8 * 0.8 = 128000\n",
    "                    elif (ent, r1, b1) in OOD_facts and (b1, r2, b2) in ID_facts and (b2, r3, t) in ID_facts:\n",
    "                        test_3hop_oii.append(form_items([ent, r1, r2, r3], t))\n",
    "                    \n",
    "                    # 1000000 * 0.2 * 0.8 * 0.2 = 32000\n",
    "                    elif (ent, r1, b1) in OOD_facts and (b1, r2, b2) in ID_facts and (b2, r3, t) in OOD_facts:\n",
    "                        test_3hop_oio.append(form_items([ent, r1, r2, r3], t))\n",
    "                    \n",
    "                    # 1000000 * 0.2 * 0.2 * 0.8 = 32000\n",
    "                    elif (ent, r1, b1) in OOD_facts and (b1, r2, b2) in OOD_facts and (b2, r3, t) in ID_facts:\n",
    "                        test_3hop_ooi.append(form_items([ent, r1, r2, r3], t))\n",
    "                    \n",
    "                    # 1000000 * 0.2 * 0.2 * 0.2 = 8000\n",
    "                    elif (ent, r1, b1) in OOD_facts and (b1, r2, b2) in OOD_facts and (b2, r3, t) in OOD_facts:\n",
    "                        test_3hop_ooo.append(form_items([ent, r1, r2, r3], t))\n",
    "\n",
    "\n",
    "    return (\n",
    "        entities, relations,  # vocab\n",
    "        id_atomic_facts, ood_atomic_facts,\n",
    "\n",
    "        # 3-hop\n",
    "        train_3hop_iii, # train\n",
    "        test_3hop_iii, test_3hop_iio, test_3hop_ioi, test_3hop_ioo,\n",
    "        test_3hop_oii, test_3hop_oio, test_3hop_ooi, test_3hop_ooo,  # test\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7259c86",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ENTITY_IN = 1000\n",
    "NUM_RELATION = 100\n",
    "\n",
    "(\n",
    "    entities, relations,  # vocab\n",
    "    id_atomic_facts, ood_atomic_facts,\n",
    "\n",
    "    # 3-hop\n",
    "    train_3hop_iii, # train\n",
    "    test_3hop_iii, test_3hop_iio, test_3hop_ioi, test_3hop_ioo,\n",
    "    test_3hop_oii, test_3hop_oio, test_3hop_ooi, test_3hop_ooo,  # test\n",
    ") = build_3_hop(NUM_ENTITY_IN, NUM_RELATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "bada93b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab size: 1106\n"
     ]
    }
   ],
   "source": [
    "vocab = []\n",
    "vocab = vocab + entities + relations\n",
    "# special tokens\n",
    "# actually only \"</a>\" used as <eos>\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "\n",
    "print(\"vocab size:\", len(vocab))\n",
    "\n",
    "# For predict_during_training: cut test group\n",
    "test_size = 1000\n",
    "# atomic\n",
    "test_id_atomic = choose(id_atomic_facts, test_size)\n",
    "test_ood_atomic = choose(ood_atomic_facts, test_size)\n",
    "\n",
    "# 3-hop\n",
    "test_3hop_iii = choose(test_3hop_iii, test_size)\n",
    "test_3hop_iio = choose(test_3hop_iio, test_size)\n",
    "test_3hop_ioi = choose(test_3hop_ioi, test_size)\n",
    "test_3hop_ioo = choose(test_3hop_ioo, test_size)\n",
    "test_3hop_oii = choose(test_3hop_oii, test_size)\n",
    "test_3hop_oio = choose(test_3hop_oio, test_size)\n",
    "test_3hop_ooi = choose(test_3hop_ooi, test_size)\n",
    "test_3hop_ooo = choose(test_3hop_ooo, test_size)\n",
    "\n",
    "train_atomics = id_atomic_facts + ood_atomic_facts\n",
    "\n",
    "phi = 12  # Train-III / ID Triples\n",
    "dataset_name = \"3hop.{}.{}.{}\".format(NUM_ENTITY_IN, NUM_RELATION, phi)\n",
    "os.makedirs(\"../data/{}\".format(dataset_name), exist_ok=True)\n",
    "\n",
    "train_3hop_iii_cut = choose(train_3hop_iii, round(phi * len(train_atomics)))\n",
    "\n",
    "probes = []\n",
    "\n",
    "for item in test_id_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"ID Triples\"\n",
    "\n",
    "for item in test_ood_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"OOD Triples\"\n",
    "\n",
    "for item in choose(train_3hop_iii_cut, test_size):\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Train-III'\n",
    "\n",
    "for item in test_3hop_iii:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-III'\n",
    "\n",
    "for item in test_3hop_iio:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-IIO\"\n",
    "\n",
    "for item in test_3hop_ioi:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-IOI'\n",
    "\n",
    "for item in test_3hop_ioo:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-IOO\"\n",
    "\n",
    "for item in test_3hop_oii:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-OII'\n",
    "\n",
    "for item in test_3hop_oio:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-OIO\"\n",
    "\n",
    "for item in test_3hop_ooi:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-OOI'\n",
    "\n",
    "for item in test_3hop_ooo:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-OOO\"\n",
    "\n",
    "with open(\"../data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(train_atomics + train_3hop_iii_cut, f)\n",
    "\n",
    "# evaluate_during_training\n",
    "with open(\"../data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(test_3hop_ooo, f)\n",
    "\n",
    "# predict_during_training\n",
    "with open(\"../data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(probes, f)\n",
    "\n",
    "# add vocab\n",
    "with open(\"../data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(vocab, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f070b54",
   "metadata": {},
   "source": [
    "# Base Configuration with Second-Hop Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8a8e9130",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_second_hop_ablation(num_entities, num_relations, out_degree=20):\n",
    " \n",
    "    entities = [\"<e_{}>\".format(i) for i in range(num_entities)]\n",
    "    ind2entity, entity2ind = build_dicts(entities)\n",
    "\n",
    "    relations = [\"<r_{}>\".format(i) for i in range(num_relations)]\n",
    "    ind2relation, relation2ind = build_dicts(relations)\n",
    "\n",
    "    atomic_dict = dict()   # {h1: [(r1, t1), (r2, t2), ...], ...}\n",
    "    atomic_facts = []   # [{\"input_text\": \"...\", \"target_text\": \"...\"}, ...]\n",
    "    atomics = []   # [(h1,r1,t1), (h2,r2,t2), ...]\n",
    "\n",
    "    for i in tqdm(range(num_entities)):\n",
    "        # for each subject entity, randomly select some outgoing relations to some random object entity\n",
    "        num_rows = out_degree\n",
    "        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()\n",
    "        for row_idx in selected_rows:\n",
    "            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)\n",
    "            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]  # h and t might be same here\n",
    "            atomic_facts.append(form_items([h, r], t))\n",
    "            atomics.append((h,r,t))\n",
    "            if h not in atomic_dict:\n",
    "                atomic_dict[h] = []\n",
    "            atomic_dict[h].append((r, t))\n",
    "    \n",
    "    # split ID/OOD\n",
    "    OOD_ratio = 0.05\n",
    "    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))  # randomly\n",
    "    OOD_facts = set(OOD_facts)\n",
    "\n",
    "    second_hop_restricted_ratio = 0.05\n",
    "    second_hop_restricted_id, normal_id = split(ID_facts, round(len(ID_facts) * second_hop_restricted_ratio))\n",
    "    second_hop_restricted_id, normal_id = set(second_hop_restricted_id), set(normal_id)\n",
    "    ID_facts = set(ID_facts)\n",
    "\n",
    "    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]\n",
    "    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]\n",
    "\n",
    "    # We don't involve OOD here\n",
    "    train_2hop_ii, test_2hop_ii, test_2hop_second_hop_restricted = [], [], []\n",
    "    \n",
    "    for ent in tqdm(entities, desc=\"2-hop: \"):\n",
    "        for (r1, b) in atomic_dict[ent]:\n",
    "            for (r2, t) in atomic_dict[b]:\n",
    "                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:\n",
    "                    continue\n",
    "\n",
    "                if (b, r2, t) in second_hop_restricted_id:\n",
    "                    # (ent, r1, b) from second_hop_restricted_id or normal_id; (b, r2, t) from second_hop_restricted_id\n",
    "                    test_2hop_second_hop_restricted.append(form_items([ent, r1, r2], t))\n",
    "                else:\n",
    "                    # (ent, r1, b) from second_hop_restricted_id or normal_id; (b, r2, t) from normal_id\n",
    "                    if np.random.uniform() > 0.005:\n",
    "                        train_2hop_ii.append(form_items([ent, r1, r2], t))\n",
    "                    else:\n",
    "                        test_2hop_ii.append(form_items([ent, r1, r2], t))\n",
    "\n",
    "    return (\n",
    "        entities, relations,  # vocab\n",
    "        id_atomic_facts, ood_atomic_facts,\n",
    "        \n",
    "        # 2-hop\n",
    "        train_2hop_ii,  # train\n",
    "        test_2hop_ii, test_2hop_second_hop_restricted,  # test\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "825f8c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ENTITY_IN = 2000\n",
    "NUM_RELATION = 200\n",
    "\n",
    "(\n",
    "    entities, relations,  # vocab\n",
    "    id_atomic_facts, ood_atomic_facts,\n",
    "    \n",
    "    # 2-hop\n",
    "    train_2hop_ii,  # train\n",
    "    test_2hop_ii, test_2hop_second_hop_restricted,  # test\n",
    ") = build_second_hop_ablation(NUM_ENTITY_IN, NUM_RELATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d6e0f3c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = []\n",
    "vocab = vocab + entities + relations\n",
    "# special tokens\n",
    "# actually only \"</a>\" used as <eos>\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "\n",
    "# For predict_during_training: cut test group\n",
    "test_size = 3000\n",
    "# atomic\n",
    "test_id_atomic = choose(id_atomic_facts, test_size)\n",
    "test_ood_atomic = choose(ood_atomic_facts, test_size)\n",
    "# 2-hop\n",
    "test_2hop_ii = choose(test_2hop_ii, test_size)\n",
    "test_2hop_second_hop_restricted = choose(test_2hop_second_hop_restricted, test_size)\n",
    "\n",
    "train_atomics = id_atomic_facts + ood_atomic_facts\n",
    "\n",
    "phi = 7.2  # Train-II / ID Triples\n",
    "train_2hop_ii = choose(train_2hop_ii, phi * len(id_atomic_facts))\n",
    "\n",
    "dataset_name = \"second_hop_ablation_configuration.{}.{}.{}\".format(NUM_ENTITY_IN, NUM_RELATION, phi)\n",
    "os.makedirs(\"../data/{}\".format(dataset_name), exist_ok=True)\n",
    "\n",
    "probes = []\n",
    "\n",
    "for item in test_id_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"ID Triples\"\n",
    "\n",
    "for item in test_ood_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"OOD Triples\"\n",
    "\n",
    "for item in choose(train_2hop_ii, test_size):\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Train-II'\n",
    "\n",
    "for item in test_2hop_ii:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-II\"\n",
    "\n",
    "for item in test_2hop_second_hop_restricted:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-II-SR'\n",
    "\n",
    "\n",
    "with open(\"../data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(train_atomics + train_2hop_ii, f)\n",
    "\n",
    "# evaluate_during_training\n",
    "with open(\"../data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(test_2hop_oo, f)\n",
    "\n",
    "# predict_during_training\n",
    "with open(\"../data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(probes, f)\n",
    "\n",
    "# add vocab\n",
    "with open(\"../data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(vocab, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "897e2597",
   "metadata": {},
   "source": [
    "# Unanchored OOD Configuration & Decoding Preference Configuration\n",
    "&emsp; without ID Triples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eba2359",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ENTITY_IN = 2000\n",
    "NUM_RELATION = 200\n",
    "\n",
    "# Reuse the build_base_dataset functions\n",
    "(\n",
    "    entities, relations,  # vocab\n",
    "    id_atomic_facts, ood_atomic_facts,\n",
    "    \n",
    "    # 2-hop\n",
    "    train_2hop_ii,  # train\n",
    "    test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo,  # test\n",
    ") = build_base_dataset(NUM_ENTITY_IN, NUM_RELATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "67376197",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = []\n",
    "vocab = vocab + entities + relations\n",
    "# special tokens\n",
    "# actually only \"</a>\" used as <eos>\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "\n",
    "# For predict_during_training: cut test group\n",
    "test_size = 3000\n",
    "# atomic\n",
    "test_id_atomic = choose(id_atomic_facts, test_size)\n",
    "test_ood_atomic = choose(ood_atomic_facts, test_size)\n",
    "# 2-hop\n",
    "test_2hop_ii = choose(test_2hop_ii, test_size)\n",
    "test_2hop_io = choose(test_2hop_io, test_size)\n",
    "test_2hop_oi = choose(test_2hop_oi, test_size)\n",
    "test_2hop_oo = choose(test_2hop_oo, test_size)\n",
    "\n",
    "# ensure that ID triples do not appear in the training set\n",
    "train_atomics = ood_atomic_facts\n",
    "\n",
    "phi = 7.2  # Train-II / ID Triples\n",
    "train_2hop_ii = choose(train_2hop_ii, phi * len(id_atomic_facts))\n",
    "\n",
    "dataset_name = \"without_id.{}.{}.{}\".format(NUM_ENTITY_IN, NUM_RELATION, phi)\n",
    "os.makedirs(\"../data/{}\".format(dataset_name), exist_ok=True)\n",
    "\n",
    "probes = []\n",
    "\n",
    "for item in test_id_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"ID Triples\"\n",
    "\n",
    "for item in test_ood_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"OOD Triples\"\n",
    "\n",
    "for item in choose(train_2hop_ii, test_size):\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Train-II'\n",
    "\n",
    "for item in test_2hop_ii:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-II\"\n",
    "\n",
    "for item in test_2hop_io:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-IO'\n",
    "\n",
    "for item in test_2hop_oi:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-OI'\n",
    "\n",
    "for item in test_2hop_oo:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-OO\"\n",
    "\n",
    "with open(\"../data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(train_atomics + train_2hop_ii, f)\n",
    "\n",
    "# evaluate_during_training\n",
    "with open(\"../data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(test_2hop_oo, f)\n",
    "\n",
    "# predict_during_training\n",
    "with open(\"../data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(probes, f)\n",
    "\n",
    "# add vocab\n",
    "with open(\"../data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(vocab, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a755d4d5",
   "metadata": {},
   "source": [
    "# Held-out ID Triple Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0cc2bbae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_held_out_id(num_entities, num_relations, out_degree=20):\n",
    " \n",
    "    entities = [\"<e_{}>\".format(i) for i in range(num_entities)]\n",
    "    ind2entity, entity2ind = build_dicts(entities)\n",
    "\n",
    "    relations = [\"<r_{}>\".format(i) for i in range(num_relations)]\n",
    "    ind2relation, relation2ind = build_dicts(relations)\n",
    "\n",
    "    atomic_dict = dict()   # {h1: [(r1, t1), (r2, t2), ...], ...}\n",
    "    atomic_facts = []   # [{\"input_text\": \"...\", \"target_text\": \"...\"}, ...]\n",
    "    atomics = []   # [(h1,r1,t1), (h2,r2,t2), ...]\n",
    "\n",
    "    for i in tqdm(range(num_entities)):\n",
    "        # for each subject entity, randomly select some outgoing relations to some random object entity\n",
    "        num_rows = out_degree\n",
    "        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()\n",
    "        for row_idx in selected_rows:\n",
    "            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)\n",
    "            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]  # h and t might be same here\n",
    "            atomic_facts.append(form_items([h, r], t))\n",
    "            atomics.append((h,r,t))\n",
    "            if h not in atomic_dict:\n",
    "                atomic_dict[h] = []\n",
    "            atomic_dict[h].append((r, t))\n",
    "    \n",
    "    # split ID/OOD\n",
    "    OOD_ratio = 0.05\n",
    "    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))  # randomly\n",
    "    OOD_facts = set(OOD_facts)\n",
    "\n",
    "    held_out_ratio = 0.05\n",
    "    held_out_id, retained_id = split(ID_facts, round(len(ID_facts) * held_out_ratio))\n",
    "    held_out_id, retained_id = set(held_out_id), set(retained_id)\n",
    "\n",
    "    ID_facts = set(ID_facts)\n",
    "\n",
    "    held_out_id_atomic_facts = [form_items([h, r], t) for (h,r,t) in held_out_id]\n",
    "    retained_id_atomic_facts = [form_items([h, r], t) for (h,r,t) in retained_id]\n",
    "    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]\n",
    "\n",
    "    train_2hop_ii, test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo = [], [], [], [], []\n",
    "    \n",
    "    for ent in tqdm(entities, desc=\"2-hop: \"):\n",
    "        for (r1, b) in atomic_dict[ent]:\n",
    "            for (r2, t) in atomic_dict[b]:\n",
    "                if (ent, r1, b) in ID_facts and (b, r2, t) in ID_facts:\n",
    "                    if np.random.uniform() > 0.05:\n",
    "                        train_2hop_ii.append(form_items([ent, r1, r2], t))\n",
    "                    else:\n",
    "                        test_2hop_ii.append(form_items([ent, r1, r2], t))\n",
    "                \n",
    "                elif (ent, r1, b) in ID_facts and (b, r2, t) in OOD_facts:\n",
    "                    test_2hop_io.append(form_items([ent, r1, r2], t))\n",
    "                \n",
    "                elif (ent, r1, b) in OOD_facts and (b, r2, t) in ID_facts:\n",
    "                    test_2hop_oi.append(form_items([ent, r1, r2], t))\n",
    "                \n",
    "                elif (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:\n",
    "                    test_2hop_oo.append(form_items([ent, r1, r2], t))\n",
    "\n",
    "    return (\n",
    "        entities, relations,  # vocab\n",
    "        held_out_id_atomic_facts, retained_id_atomic_facts, ood_atomic_facts,\n",
    "        \n",
    "        # 2-hop\n",
    "        train_2hop_ii,  # train\n",
    "        test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo,  # test\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a214fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ENTITY_IN = 2000\n",
    "NUM_RELATION = 200\n",
    "\n",
    "(\n",
    "    entities, relations,  # vocab\n",
    "    held_out_id_atomic_facts, retained_id_atomic_facts, ood_atomic_facts,\n",
    "    \n",
    "    # 2-hop\n",
    "    train_2hop_ii,  # train\n",
    "    test_2hop_ii, test_2hop_io, test_2hop_oi, test_2hop_oo,  # test\n",
    ") = build_held_out_id(NUM_ENTITY_IN, NUM_RELATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ab950a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = []\n",
    "vocab = vocab + entities + relations\n",
    "# special tokens\n",
    "# actually only \"</a>\" used as <eos>\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "\n",
    "# For predict_during_training: cut test group\n",
    "test_size = 3000\n",
    "# atomic\n",
    "test_held_out_id_atomic = choose(held_out_id_atomic_facts, test_size)\n",
    "test_retained_id_atomic = choose(retained_id_atomic_facts, test_size)\n",
    "test_ood_atomic = choose(ood_atomic_facts, test_size)\n",
    "# 2-hop\n",
    "test_2hop_ii = choose(test_2hop_ii, test_size)\n",
    "test_2hop_io = choose(test_2hop_io, test_size)\n",
    "test_2hop_oi = choose(test_2hop_oi, test_size)\n",
    "test_2hop_oo = choose(test_2hop_oo, test_size)\n",
    "\n",
    "# ensure that held_out_id_atomic_facts do not appear in the training set\n",
    "train_atomics = retained_id_atomic_facts + ood_atomic_facts\n",
    "\n",
    "phi = 7.2  # Train-II / ID Triples\n",
    "train_2hop_ii = choose(train_2hop_ii, phi * len(retained_id_atomic_facts + held_out_id_atomic_facts))\n",
    "\n",
    "dataset_name = \"held_out_id_configuration.{}.{}.{}\".format(NUM_ENTITY_IN, NUM_RELATION, phi)\n",
    "os.makedirs(\"../data/{}\".format(dataset_name), exist_ok=True)\n",
    "\n",
    "probes = []\n",
    "\n",
    "for item in test_held_out_id_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Held-out ID Triples\"\n",
    "\n",
    "for item in test_retained_id_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Retained ID Triples\"\n",
    "\n",
    "for item in test_ood_atomic:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"OOD Triples\"\n",
    "\n",
    "for item in choose(train_2hop_ii, test_size):\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Train-II'\n",
    "\n",
    "for item in test_2hop_ii:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-II\"\n",
    "\n",
    "for item in test_2hop_io:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-IO'\n",
    "\n",
    "for item in test_2hop_oi:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1]['type'] = 'Test-OI'\n",
    "\n",
    "for item in test_2hop_oo:\n",
    "    probes.append(deepcopy(item))\n",
    "    probes[-1][\"type\"] = \"Test-OO\"\n",
    "\n",
    "with open(\"../data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(train_atomics + train_2hop_ii, f)\n",
    "\n",
    "# evaluate_during_training\n",
    "with open(\"../data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(test_2hop_oo, f)\n",
    "\n",
    "# predict_during_training\n",
    "with open(\"../data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(probes, f)\n",
    "\n",
    "# add vocab\n",
    "with open(\"../data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "    json.dump(vocab, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
