{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "# prompts = generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts,labels)\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split('----')\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # add inverse relations\n",
    "# rid = 0\n",
    "# id2rel = dict()\n",
    "# rel2id = dict()\n",
    "\n",
    "\n",
    "# with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "#     for line in f:\n",
    "#         _, rel = line.strip().split()\n",
    "#         id2rel[rid] = rel\n",
    "#         rel2id[rel] = rid\n",
    "#         rid += 1\n",
    "\n",
    "# extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "# for rel in extra_relations:\n",
    "#     id2rel[rid] = rel\n",
    "#     rel2id[rel] = rid\n",
    "#     rid += 1\n",
    "\n",
    "# wrid = 1\n",
    "# with open(\"rel2sym_ri.txt\", 'w') as fw:\n",
    "#     for rel in id2rel.values():\n",
    "#         fw.writelines(rel + '\\tr' + str(wrid) + '\\n')\n",
    "#         wrid += 1\n",
    "#     for rel in id2rel.values():\n",
    "#         fw.writelines( \"inverse_\" + rel + '\\tr' + str(wrid) +  '\\n')\n",
    "#         wrid += 1   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rel2sym = dict()\n",
    "# with open(\"rel2sym_ri.txt\",\"r\") as fr:\n",
    "#     for line in fr:\n",
    "#         rel, sym = line.strip().split()\n",
    "#         rel2sym[rel] = sym \n",
    "\n",
    "# def prefix_add(length):\n",
    "#     # \"\\forall A \\forall B\"\n",
    "#     prefix = \"$\\\\forall A\"\n",
    "#     # tail = chr(ord('A') + length-2)\n",
    "    \n",
    "#     for i in range(length-2):\n",
    "#         prefix += \",\" + chr(ord('B') + i)\n",
    "#     prefix += ': '\n",
    "#     # prefix += chr(ord('A') + length - 1)\n",
    "#     return prefix\n",
    "\n",
    "# def replace_symbolic(predicate, rel2sym):\n",
    "#     for i in range(len(list(rel2sym.keys()))-1, -1, -1):\n",
    "#         rel = list(rel2sym.keys())[i]\n",
    "#         if rel in predicate:\n",
    "#             predicate = predicate.replace(rel, rel2sym[rel])\n",
    "#             return predicate\n",
    "        \n",
    "# import re\n",
    "# def raw_to_latex(rfile, wfile):\n",
    "#     with open(wfile, 'w') as fw:\n",
    "#         with open(rfile, 'r') as f:\n",
    "#             for line in f:\n",
    "#                 rule = ''\n",
    "#                 lst = line.strip().split('\\t')\n",
    "#                 length = len(lst)\n",
    "#                 rule = prefix_add(length)\n",
    "#                 # if 'Z' in line:\n",
    "#                 #     rule = '$\\\\forall A \\\\forall B \\\\exists C ('\n",
    "#                 # else:\n",
    "#                 #     rule = '$\\\\forall A \\\\forall Y ('\n",
    "#                 head = ord('A')\n",
    "#                 tail = head + 1\n",
    "#                 for rel in lst[1:-1]:\n",
    "#                     # replace predicate with others\n",
    "                    \n",
    "#                     rel = replace_symbolic(rel, rel2sym)\n",
    "#                     # rule +=  rel + ' \\\\land '\n",
    "#                     rule += rel + '(' + chr(head) + ', ' + chr(tail) + ') \\\\land '\n",
    "#                     head = tail\n",
    "#                     tail = head + 1\n",
    "#                 rel = replace_symbolic(lst[-1], rel2sym)\n",
    "#                 rh = replace_symbolic(lst[0], rel2sym)\n",
    "#                 rule += rel + '(A) \\\\rightarrow ' + rh + '(A,' + chr(tail-1) + ')$'\n",
    "                   \n",
    "#                 fw.writelines(rel2sym[lst[0]] + '\\t' + rule + '\\n')\n",
    "# raw_to_latex('rule_tab.txt', 'latex_rules.txt')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym_2[rel] = sym\n",
    "        rel2sym[rel] = '$' + sym + '$'\n",
    "\n",
    "grounding_truth = dict()\n",
    "\n",
    "with open('rule_symbolic_first_order.txt','r') as f:\n",
    "    for line in f:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "       \n",
    "        grounding_truth['$' + rel + '$'] = rule\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get \"\\forall A,B,C,D: \"\n",
    "def get_prefix(length):\n",
    "    h = ord('A')\n",
    "    text = '\\\\forall '\n",
    "    for _ in range(length):\n",
    "        text += chr(h) + ', '\n",
    "        h += 1\n",
    "    text += chr(h) + ': '\n",
    "    return text\n",
    "\n",
    "# from rule_parent.txt get the length of each rule\n",
    "rule_length = dict()\n",
    "with open('rule_tab.txt', 'r') as f:\n",
    "    for line in f:\n",
    "        lst = line.strip().split('\\t')\n",
    "        rule_length[rel2sym[lst[0]]] = len(lst[1:]) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text, fid):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "\n",
    "            text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            fid += 1\n",
    "            text += 'F' + str(fid) + ': ' + rel2sym['inverse_' + id2rel[int(r)]] + '(' + id2ent[int(t)] + ', ' + id2ent[int(h)] + ')\\n'\n",
    "            fid += 1\n",
    "\n",
    "\n",
    "            # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "            # text += id2ent[int(t)] + ' is ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "                # fid += 1\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, entpair2rel, text, fid\n",
    "\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is ' + rel2sym[\"female\"] + '. '\n",
    "                \n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym[\"female\"] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is '+ rel2sym['male'] + '. '\n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym[\"male\"] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# def read_rules(path):\n",
    "#     rules = list()\n",
    "#     grounding_truth = dict()\n",
    "#     rel2rules = dict()\n",
    "#     with open(path, 'r') as f:\n",
    "#         for line in f:\n",
    "#             lst = line.strip().split('\\t')\n",
    "#             rules.append(lst)\n",
    "#             if lst[0] not in rel2rules:\n",
    "#                 rel2rules[lst[0]] = list() \n",
    "#             grounding_truth[lst[0]].append(lst) \n",
    "#     return rules, rel2rules\n",
    "\n",
    "def get_relation_facts(triplets, rel):\n",
    "    related_triplets_text = ''\n",
    "    gid = 1\n",
    "    for tri in triplets:\n",
    "        if rel2sym[tri[1]] == rel:\n",
    "            # related_triplets_text += tri[0] + ' is ' + rel2sym[tri[1]] + ' of ' + tri[2] + '. '\n",
    "            related_triplets_text += 'G' + str(gid) + ': ' + rel2sym[tri[1]] + '(' + tri[0] + ', ' + tri[2] + ')\\n'\n",
    "            gid += 1\n",
    "\n",
    "    return related_triplets_text, gid\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# based on rule_length to output the template like $\\_(a,b) \\land \\_(b,c) \\land \\_(a)$ \n",
    "\n",
    "\n",
    "template = dict()\n",
    "\n",
    "for key in rule_length:\n",
    "    \n",
    "    template[key] = '$'\n",
    "    template[key] += get_prefix(rule_length[key])\n",
    "    ent_h = ord('A')\n",
    "    for i in range(rule_length[key]):\n",
    "        ent_h = ent_h\n",
    "        ent_t = ent_h + 1\n",
    "        \n",
    "        template[key] += '##(' + chr(ent_h) + ', ' + chr(ent_t) + ') \\land '\n",
    "        ent_h = ent_t \n",
    "    template[key] += '++(A) \\\\rightarrow ' + key[1:-1] + '(A, ' + chr(ent_t) + ')$'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # based on rule_length to output the template like $\\_(a,b) \\land \\_(b,c) \\land \\_(a)$ \n",
    "# template = dict()\n",
    "\n",
    "# for key in rule_length:\n",
    "#     template[key] = 'If '\n",
    "#     ent_h = ord('A')\n",
    "#     for i in range(rule_length[key]):\n",
    "#         ent_h = ent_h\n",
    "#         ent_t = ent_h + 1\n",
    "#         template[key] += chr(ent_h) + ' is ## of ' + chr(ent_t) + ' and '\n",
    "#         # template[key] += '##(' + chr(ent_h) + ',' + chr(ent_t) + ') \\land '\n",
    "#         ent_h = ent_t \n",
    "#     template[key] += 'A is ++, then A is ' + key + ' of ' + chr(ent_t) + '.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # read rule_latext.txt and get the latex format of each rule\n",
    "# rule_latex = dict()\n",
    "# id = 1\n",
    "# with open('../rule_latex.txt', 'r') as f:\n",
    "#     for line in f:\n",
    "        \n",
    "#         rule_latex[rel2sym[infer_rel[id]]] = line.strip()\n",
    "#         id += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_standard'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "# model = \"gpt-3.5-turbo\"\n",
    "model = \"gpt-4\"\n",
    "logger.info('model: ' + model)\n",
    "\n",
    "for i in range(0, 10):\n",
    "    \n",
    "    eid = 0\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "    class_text = ''\n",
    "    text = ''\n",
    "\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "    # path_rule = '../rules_tab.txt'\n",
    "\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class, id2ent, class_text, fid)\n",
    "    \n",
    "    triplets, entpair2rel, text, fid = read_all_triplets(path_rel1, path_rel2, id2ent, text, fid)\n",
    "\n",
    "    \n",
    "    record_flag = False\n",
    "    id = 0\n",
    "    scores = 0\n",
    "    for rel_origin in tqdm(infer_rel[1:]):\n",
    "        # print(rel)\n",
    "        rel = rel2sym[rel_origin]\n",
    "        results = template.copy()\n",
    "        for i in range(rule_length[rel]):\n",
    "                a = random.sample(['r1', 'r45'], 1)[0]\n",
    "                results = results.replace('##', a, 1)\n",
    "        # randomly select 'r43' and 'r44' to fill in the template\n",
    "        b = random.sample(['r43', 'r44'], 1)[0]\n",
    "        results = results.replace('++', b , 1)\n",
    "        logger.info(\"template: \" + template[rel])\n",
    "        logger.info('prediction: '+ results)\n",
    "        logger.info(\"grounding_truth: \" + grounding_truth[rel])\n",
    "        \n",
    "        if grounding_truth[rel] in results:\n",
    "                logger.info(\"correct\")\n",
    "                scores += 1\n",
    "        logger.info(\"============================================================\")\n",
    "        id += 1\n",
    "\n",
    "           \n",
    "        \n",
    "                        \n",
    "                            \n",
    "                \n",
    "\n",
    "    logger.info(\"accuracy: \" + str(scores/id))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
