{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "# prompts = generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts,labels)\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split('----')\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym_2[rel] = sym\n",
    "        rel2sym[rel] = '$' + sym + '$'\n",
    "# get \"\\forall A,B,C,D: \"\n",
    "def get_prefix(length):\n",
    "    h = ord('A')\n",
    "    text = '\\\\forall '\n",
    "    for _ in range(length):\n",
    "        text += chr(h) + ', '\n",
    "        h += 1\n",
    "    text += chr(h) + ': '\n",
    "    return text\n",
    "\n",
    "# from rule_parent.txt get the length of each rule\n",
    "rule_length = dict()\n",
    "rel2rule_body = dict()\n",
    "with open('rule_tab.txt', 'r') as f:\n",
    "    for line in f:\n",
    "        lst = line.strip().split('\\t')\n",
    "        rel2rule_body[rel2sym[lst[0]]] = lst[1:]\n",
    "        rule_length[rel2sym[lst[0]]] = len(lst[1:]) - 1"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "grounding_truth = dict()\n",
    "\n",
    "with open('rule_symbolic_first_order.txt','r') as f:\n",
    "    for line in f:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "       \n",
    "        grounding_truth['$'+rel+'$'] = rule\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text, fid):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    test_triplets = list()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            triplets.append((id2ent[int(t)], 'inverse_' + id2rel[int(r)], id2ent[int(h)]))\n",
    "\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            entpair2rel[(id2ent[int(t)], id2ent[int(h)])] = 'inverse_' + id2rel[int(r)]\n",
    "\n",
    "            # text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            # fid += 1\n",
    "            # text += 'F' + str(fid) + ': ' + rel2sym['inverse_' + id2rel[int(r)]] + '(' + id2ent[int(t)] + ', ' + id2ent[int(h)] + ')\\n'\n",
    "            # fid += 1\n",
    "\n",
    "\n",
    "            # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "            # text += id2ent[int(t)] + ' is ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "                # fid += 1\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, test_triplets, entpair2rel, text, fid\n",
    "\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is ' + rel2sym[\"female\"] + '. '\n",
    "                \n",
    "                # class_text += 'F' + str(fid) + ': ' + rel2sym[\"female\"] + '(' + id2ent[cid] + ')\\n'\n",
    "                # fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is '+ rel2sym['male'] + '. '\n",
    "                # class_text += 'F' + str(fid) + ': ' + rel2sym[\"male\"] + '(' + id2ent[cid] + ')\\n'\n",
    "                # fid += 1\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# def read_rules(path):\n",
    "#     rules = list()\n",
    "#     grounding_truth = dict()\n",
    "#     rel2rules = dict()\n",
    "#     with open(path, 'r') as f:\n",
    "#         for line in f:\n",
    "#             lst = line.strip().split('\\t')\n",
    "#             rules.append(lst)\n",
    "#             if lst[0] not in rel2rules:\n",
    "#                 rel2rules[lst[0]] = list() \n",
    "#             grounding_truth[lst[0]].append(lst) \n",
    "#     return rules, rel2rules\n",
    "\n",
    "def get_paths_triplets(h, t, G, entpair2rel, rule_body, ent2class, rel2sym, fid):\n",
    "    input_text = ''\n",
    "    input_text += 'F' + str(fid) + ': ' + rel2sym[ent2class[h]] + '(' + h + ')\\n'\n",
    "    # input_text += rel2sym[ent2class[h]] + '(' + h + ')\\n'\n",
    "    \n",
    "    fid += 1\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=len(rule_body)-1)):\n",
    "        rid = 0\n",
    "        \n",
    "        for edge in path:\n",
    "            if entpair2rel[(edge[0], edge[1])] == rule_body[rid]:\n",
    "                input_text += 'F' + str(fid) + ': ' + rel2sym[rule_body[rid]] + '(' + edge[0] + ', ' + edge[1] + ')\\n'\n",
    "                # input_text += rule_body[rid] + '(' + edge[0] + ', ' + edge[1] + ')\\n'\n",
    "                \n",
    "                fid += 1\n",
    "                rid += 1\n",
    "            # # print(edge)\n",
    "            # if edge in entpair2rel:\n",
    "                \n",
    "            #     input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "            # else:\n",
    "            #     input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "    return input_text, fid\n",
    "\n",
    "def get_relation_facts(test_triplets, rel, G, entpair2rel, rule_body, ent2class, rel2sym, fid, gid):\n",
    "    related_triplets_text = ''\n",
    "    \n",
    "    for tri in test_triplets:\n",
    "        if rel2sym[tri[1]] == rel:\n",
    "            # related_triplets_text += tri[0] + ' is ' + rel2sym[tri[1]] + ' of ' + tri[2] + '. '\n",
    "            related_triplets_text += 'G' + str(gid) + ': ' + rel2sym[tri[1]] + '(' + tri[0] + ', ' + tri[2] + ')\\n'\n",
    "            # related_triplets_text += 'F' + str(fid) + ': ' + rel2sym[tri[1]] + '(' + tri[0] + ', ' + tri[2] + ')\\n'\n",
    "            \n",
    "            # related_triplets_text += rel2sym[tri[1]] + '(' + tri[0] + ', ' + tri[2] + ')\\n'\n",
    "            gid += 1\n",
    "            # fid += 1\n",
    "            text, fid = get_paths_triplets(tri[0], tri[2], G, entpair2rel, rule_body, ent2class, rel2sym, fid)\n",
    "            \n",
    "            related_triplets_text += text\n",
    "    return related_triplets_text, fid, gid\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# based on rule_length to output the template like $\\_(a,b) \\land \\_(b,c) \\land \\_(a)$ \n",
    "\n",
    "\n",
    "template = dict()\n",
    "\n",
    "for key in rule_length:\n",
    "    \n",
    "    template[key] = '$'\n",
    "    template[key] += get_prefix(rule_length[key])\n",
    "    ent_h = ord('A')\n",
    "    for i in range(rule_length[key]):\n",
    "        ent_h = ent_h\n",
    "        ent_t = ent_h + 1\n",
    "        \n",
    "        template[key] += '##(' + chr(ent_h) + ', ' + chr(ent_t) + ') \\land '\n",
    "        ent_h = ent_t \n",
    "    template[key] += '++(A) \\\\rightarrow ' + key[1:-1] + '(A, ' + chr(ent_t) + ')$'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # based on rule_length to output the template like $\\_(a,b) \\land \\_(b,c) \\land \\_(a)$ \n",
    "# template = dict()\n",
    "\n",
    "# for key in rule_length:\n",
    "#     template[key] = 'If '\n",
    "#     ent_h = ord('A')\n",
    "#     for i in range(rule_length[key]):\n",
    "#         ent_h = ent_h\n",
    "#         ent_t = ent_h + 1\n",
    "#         template[key] += chr(ent_h) + ' is ## of ' + chr(ent_t) + ' and '\n",
    "#         # template[key] += '##(' + chr(ent_h) + ',' + chr(ent_t) + ') \\land '\n",
    "#         ent_h = ent_t \n",
    "#     template[key] += 'A is ++, then A is ' + key + ' of ' + chr(ent_t) + '.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # read rule_latext.txt and get the latex format of each rule\n",
    "# rule_latex = dict()\n",
    "# id = 1\n",
    "# with open('../rule_latex.txt', 'r') as f:\n",
    "#     for line in f:\n",
    "        \n",
    "#         rule_latex[rel2sym[infer_rel[id]]] = line.strip()\n",
    "#         id += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "for i in range(0, 1):\n",
    "    \n",
    "    eid = 0\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "    class_text = ''\n",
    "    text = ''\n",
    "\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "    # path_rule = '../rules_tab.txt'\n",
    "\n",
    "    eid = read_entity(path_ent,eid, id2ent, ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class, id2ent, class_text, fid)\n",
    "    \n",
    "    triplets, test_triplets, entpair2rel, text, fid = read_all_triplets(path_rel1, path_rel2, id2ent, text, fid)\n",
    "\n",
    "    # rules, rel2rules = read_rules(path_rule)\n",
    "\n",
    "    # test_questions = random.sample(triplets, int(len(triplets) * 0.2))\n",
    "    # train_questions = triplets.copy()\n",
    "    # for t in test_questions:\n",
    "    #     train_questions.remove(t)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## communicate with LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_filter_cot/'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "# model = \"gpt-3.5-turbo\"\n",
    "model = \"gpt-4\"\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "record_flag = False\n",
    "\n",
    "\n",
    "id = 0\n",
    "scores = 0\n",
    "\n",
    "G = nx.DiGraph()\n",
    "for tri in triplets:\n",
    "        G.add_edge(tri[0],tri[2])\n",
    "\n",
    "for rel_origin in tqdm(infer_rel[1:]):\n",
    "        # print(rel)\n",
    "        rel = rel2sym[rel_origin]\n",
    "        fid = 1\n",
    "        gid = 1\n",
    "        relation_specific_text, fid, gid = get_relation_facts(test_triplets, rel, G, entpair2rel, rel2rule_body[rel], ent2class, rel2sym, fid, gid)\n",
    "        server_flag = 0\n",
    "        server_error_cnt = 0\n",
    "        \n",
    "        while server_error_cnt < 10:\n",
    "                try:    \n",
    "                        \n",
    "                        # message = {\n",
    "                        #                 'system': \"You are a helpful assistant with inductive reasoning abilities.\",\n",
    "                        #                 'user': \"I will give you a set of facts F1 to F\" + str(fid-1) + \", facts G1 to G\"+ str(gid-1) +\" and a template for a logical rule. Please generate one single rule to match the template and logically entail the facts G1 to G\" + str(gid-1) + \" based on facts F1 to F\" + str(fid) + \".\\nFacts: \" + relation_specific_text + '\\nTemplate: ' + template[rel]\n",
    "                        #                 + \"\\nNote that the symbol '##' in the template should be filled with either 'yufevh' or 'jqabh', while the symbol '++' should be filled with either 'audz' or 'ntoea'.\\nAfter filling in the template, the predicted rule is: \",\n",
    "                        #                 }\n",
    "                        message = {\n",
    "                                'system': \"You are a helpful assistant with inductive reasoning abilities. I will give you a set of facts and a template for a logical rule. Please generate one single rule to match the template and logically entail the facts.\",\n",
    "                                'user': \"I will give you a set of facts and a template for a logical rule. Please generate one single rule to match the template and logically entail the facts.\\nFacts: \" + relation_specific_text + '\\nTemplate: ' + template[rel]\n",
    "                                + \"\\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\\nThe predicted rule is: Let's think step by step. \",\n",
    "                                }\n",
    "                        response = openai.ChatCompletion.create(\n",
    "                                model= model,\n",
    "                                messages=[\n",
    "                                        {\"role\": \"system\", \"content\": message['system']},\n",
    "                                        {\"role\": \"user\", \"content\": message['user']},\n",
    "                                ],\n",
    "                                temperature=0,\n",
    "                                )\n",
    "                        results = response['choices'][0]['message']['content']\n",
    "\n",
    "                        if record_flag == False:\n",
    "\n",
    "                                logger.info('message: \\n' + dict2str(message)) \n",
    "                                record_flag = True\n",
    "                        \n",
    "                        logger.info(\"template: \" + template[rel])\n",
    "\n",
    "                        logging.info('prediction' + \": \"+ results)\n",
    "                        logger.info(\"grounding_truth: \" + grounding_truth[rel])\n",
    "                        \n",
    "                        if grounding_truth[rel] in results:\n",
    "                                logger.info(\"correct\")\n",
    "                                scores += 1\n",
    "                        logger.info(\"============================================================\")\n",
    "                        id += 1\n",
    "                        break\n",
    "                        \n",
    "                except Exception as e:\n",
    "                        server_error_cnt += 1\n",
    "                        print(e)\n",
    "\n",
    "logger.info(\"accuracy: \" + str(scores/id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
