{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "openai.api_key = \"\"\n",
    "\n",
    "\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split('----')\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "random.shuffle(openai_api_keys)\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n",
    "\n",
    "letters = ['X','Y','Z','W','V','U','T']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"rel2sym.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym[rel] = sym\n",
    "        # rel2sym[rel] = '$' + sym + '$'\n",
    "\n",
    "grounding_truth = dict()\n",
    "\n",
    "with open('latex_rules.txt','r') as f:\n",
    "    for line in f:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "\n",
    "        grounding_truth[rel] = rule\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get \"\\forall A,B,C,D: \"\n",
    "def get_prefix(length):\n",
    "    h = 0\n",
    "    text = '\\\\forall '\n",
    "    for _ in range(length):\n",
    "        text += letters[h] + ','\n",
    "        h += 1\n",
    "    text += letters[h] + ': '\n",
    "    return text\n",
    "\n",
    "# from rule_parent.txt get the length of each rule\n",
    "rule_length = dict()\n",
    "with open('rule_tab.txt', 'r') as f:\n",
    "    for line in f:\n",
    "        lst = line.strip().split('\\t')\n",
    "        rule_length[rel2sym[lst[0]]] = len(lst[1:]) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text, fid):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "\n",
    "            text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            fid += 1\n",
    "            text += 'F' + str(fid) + ': ' + rel2sym['inverse_' + id2rel[int(r)]] + '(' + id2ent[int(t)] + ', ' + id2ent[int(h)] + ')\\n'\n",
    "            fid += 1\n",
    "\n",
    "\n",
    "            # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "            # text += id2ent[int(t)] + ' is ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "                # fid += 1\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, entpair2rel, text, fid\n",
    "\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is ' + rel2sym[\"female\"] + '. '\n",
    "                \n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym[\"female\"] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is '+ rel2sym['male'] + '. '\n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym[\"male\"] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# def read_rules(path):\n",
    "#     rules = list()\n",
    "#     grounding_truth = dict()\n",
    "#     rel2rules = dict()\n",
    "#     with open(path, 'r') as f:\n",
    "#         for line in f:\n",
    "#             lst = line.strip().split('\\t')\n",
    "#             rules.append(lst)\n",
    "#             if lst[0] not in rel2rules:\n",
    "#                 rel2rules[lst[0]] = list() \n",
    "#             grounding_truth[lst[0]].append(lst) \n",
    "#     return rules, rel2rules\n",
    "\n",
    "def get_relation_facts(triplets, rel):\n",
    "    related_triplets_text = ''\n",
    "    gid = 1\n",
    "    for tri in triplets:\n",
    "        if rel2sym[tri[1]] == rel:\n",
    "            # related_triplets_text += tri[0] + ' is ' + rel2sym[tri[1]] + ' of ' + tri[2] + '. '\n",
    "            related_triplets_text += 'G' + str(gid) + ': ' + rel2sym[tri[1]] + '(' + tri[0] + ', ' + tri[2] + ')\\n'\n",
    "            gid += 1\n",
    "\n",
    "    return related_triplets_text, gid\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# based on rule_length to output the template like $\\_(a,b) \\land \\_(b,c) \\land \\_(a)$ \n",
    "\n",
    "\n",
    "template = dict()\n",
    "\n",
    "for key in rule_length:\n",
    "    \n",
    "    template[key] = '$'\n",
    "    template[key] += get_prefix(rule_length[key])\n",
    "    ent_h = 0\n",
    "    for i in range(rule_length[key]):\n",
    "        ent_h = ent_h\n",
    "        ent_t = ent_h + 1\n",
    "        \n",
    "        template[key] += '##(' + letters[ent_h] + ', ' + letters[ent_t] + ') \\land '\n",
    "        ent_h = ent_t \n",
    "    template[key] += '++(X) \\\\rightarrow ' + key + '(X, ' + letters[ent_t] + ')$'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # based on rule_length to output the template like $\\_(a,b) \\land \\_(b,c) \\land \\_(a)$ \n",
    "# template = dict()\n",
    "\n",
    "# for key in rule_length:\n",
    "#     template[key] = 'If '\n",
    "#     ent_h = ord('A')\n",
    "#     for i in range(rule_length[key]):\n",
    "#         ent_h = ent_h\n",
    "#         ent_t = ent_h + 1\n",
    "#         template[key] += chr(ent_h) + ' is ## of ' + chr(ent_t) + ' and '\n",
    "#         # template[key] += '##(' + chr(ent_h) + ',' + chr(ent_t) + ') \\land '\n",
    "#         ent_h = ent_t \n",
    "#     template[key] += 'A is ++, then A is ' + key + ' of ' + chr(ent_t) + '.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # read rule_latext.txt and get the latex format of each rule\n",
    "# rule_latex = dict()\n",
    "# id = 1\n",
    "# with open('../rule_latex.txt', 'r') as f:\n",
    "#     for line in f:\n",
    "        \n",
    "#         rule_latex[rel2sym[infer_rel[id]]] = line.strip()\n",
    "#         id += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_standard'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "model = \"gpt-3.5-turbo\"\n",
    "# model = \"gpt-4\"\n",
    "logger.info('model: ' + model)\n",
    "\n",
    "for i in range(0, 10):\n",
    "    \n",
    "    eid = 0\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "    class_text = ''\n",
    "    text = ''\n",
    "\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "    # path_rule = '../rules_tab.txt'\n",
    "\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class, id2ent, class_text, fid)\n",
    "    \n",
    "    triplets, entpair2rel, text, fid = read_all_triplets(path_rel1, path_rel2, id2ent, text, fid)\n",
    "\n",
    "    \n",
    "    record_flag = False\n",
    "    id = 0\n",
    "    scores = 0\n",
    "    for rel_origin in tqdm(infer_rel[1:]):\n",
    "            # print(rel)\n",
    "            rel = rel2sym[rel_origin]\n",
    "            relation_specific_text, gid = get_relation_facts(triplets, rel)\n",
    "            \n",
    "            server_error_cnt = 0\n",
    "\n",
    "            while server_error_cnt < 10:\n",
    "                try:    \n",
    "                        update_key()\n",
    "                        # message = {\n",
    "                        #                 'system': \"You are a helpful assistant with inductive reasoning abilities. I will give you a set of facts F1 to F\" + str(fid - 1) + \", facts G1 to G\"+ str(gid-1) +\" and a template for a logical rule. Please generate one single rule to match the template and logically entail the facts G1 to G\" + str(gid-1) + \" based on facts F1 to F99.\\n\",\n",
    "                        #                 'user': \"I will give you a set of facts F1 to F\" + str(fid - 1) + \", facts G1 to G\"+ str(gid-1) +\" and a template for a logical rule. Please generate one single rule to match the template and logically entail the facts G1 to G\" + str(gid-1) + \" based on facts F1 to F99.\\nFacts: \" + class_text + text + relation_specific_text + '\\nTemplate: ' + template[rel]\n",
    "                        #                 + \"\\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\\nAfter filling in the template, the predicted rule is: \",\n",
    "                        #                 }\n",
    "                        # message = {\n",
    "                        #        'system': \"I will give you the template of logical rule and facts. You should find the paths connnecting those facts. Based on those paths and given facts, you can induce only one logical rule that entails all given facts and matches the template. The template is: \" + template[rel],\n",
    "                        #        'user': \"I will give you the template of logical rule and facts. You should find the paths connnecting those facts. Based on those paths and given facts, you can induce a logical rule that entails all given facts and matches the template.\\nThe facts A are: \" + class_text + text + \"\\nThe facts B are: \" + relation_specific_text + \"Please find the paths (including relation) connecting all facts B to induce the logical rule.\\nTemplate: \" + template[rel] + \"\\nNote that 1) '##' should be replaced with 'r1' or 'r45' while '++' should be replaced with 'r43' or 'r44'.\\n2) 'r1', 'r45', 'r43' and 'r44' are different relations.\\nAfter replacing the special '##' and '++', the logical rule is: \",\n",
    "                        # }\n",
    "                        \n",
    "                        # message = {\n",
    "                        #        'system': \"YI will give you the template of logical rule and facts. You only need to output the logical rule that entails all given facts and matches the template.\",\n",
    "                        #        'user': \"I will provide a set of facts. Please induce a rule to entail the given facts.\\nHere are some facts: \" + class_text + text + relation_specific_text + \"\\nThe induced rule is: \",\n",
    "                        # }\n",
    "                        message = {\n",
    "                        #        'system': \"You are a helpful assistant with inductive reasoning abilities.\",\n",
    "                                'system': \"You are a helpful assistant with inductive reasoning abilities. Please generate one single rule to match the template and logically entail the facts. Note that the symbol '##' in the template should be filled with either 'parent' or 'child', while the symbol '++' should be filled with either 'male' or 'female'.\",\n",
    "\n",
    "                               'user': \"I will give you the template of logical rule and facts. You only need to output the logical rule that entails all given facts and matches the template.\\nThe facts are: \" + class_text + text + \"\\nThe template is: \" + template[rel] + \"\\nNote that '##' should be replaced with 'parentOf' or 'childOf' while '++' should be replaced with 'male' or 'female'.\\nAfter replacing the special '##' and '++', the logical rule is: \",\n",
    "                        }\n",
    "                        \n",
    "                        # message = {\n",
    "                        #        'system': \"You are a helpful assistant with inductive reasoning abilities.\",\n",
    "                        #        'user': \"I will provide a set of facts. Please fill in the rule template such that the generated rule can entail these facts.\\nHere are some facts: \" + class_text + text + relation_specific_text + '\\nHere is the template: ' + template[rel] + \"\\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\\nAfter filling in the template, the predicted rule is: \",\n",
    "                        # }\n",
    "                        response = openai.ChatCompletion.create(\n",
    "                                model= model,\n",
    "                                messages=[\n",
    "                                        {\"role\": \"system\", \"content\": message['system']},\n",
    "                                        {\"role\": \"user\", \"content\": message['user']},\n",
    "                                ],\n",
    "                                temperature=0,\n",
    "                                )\n",
    "                        results = response['choices'][0]['message']['content']\n",
    "                        \n",
    "                        logger.info('message: \\n' + dict2str(message)) \n",
    "                        # if record_flag == False:\n",
    "                                \n",
    "                        #         record_flag = True\n",
    "                        \n",
    "                        logger.info(\"template: \" + template[rel])\n",
    "                        logger.info('prediction: '+ results)\n",
    "                        logger.info(\"grounding_truth: \" + grounding_truth[rel])\n",
    "                        \n",
    "                        if grounding_truth[rel] in results:\n",
    "                                logger.info(\"correct\")\n",
    "                                scores += 1\n",
    "                        logger.info(\"============================================================\")\n",
    "                        id += 1\n",
    "                        break\n",
    "\n",
    "                except Exception as e:\n",
    "                        server_error_cnt += 1\n",
    "                        print(e)\n",
    "\n",
    "    logger.info(\"accuracy: \" + str(scores/id))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
