{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys_filter.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "random.shuffle(openai_api_keys)\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n",
    "\n",
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "# multiple-list save to logging\n",
    "def list2str(l):\n",
    "    s = ''\n",
    "    for i in l:\n",
    "        s += str(i) + '\\t'\n",
    "    return s\n",
    "def list_equal(a, answers):\n",
    "    # if set(a) == set(b):\n",
    "    #     return True\n",
    "    # else:\n",
    "    #     return False\n",
    "    for b in answers:\n",
    "        # if set(a) contain set(b)\n",
    "        if set(b) == set(a):\n",
    "            return True\n",
    "        # if set(a) == set(b):\n",
    "        #     return True\n",
    "    return False\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym_2[rel] = sym\n",
    "        rel2sym[rel] = '$' + sym + '$'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rh2rules = dict()\n",
    "lid = 1\n",
    "rule_text = ''\n",
    "with open('latex_rules.txt','r') as f:\n",
    "    for line in f:\n",
    "        rh2rules[line.strip().split('\\t')[0]] = line.strip().split('\\t')[-1]\n",
    "        rule_text += 'L' + str(lid) + \": \" + line.strip().split('\\t')[-1] + '\\n'\n",
    "        lid += 1\n",
    "# print(rule_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def get_related_triplets(h, t, G, entpair2rel):\n",
    "    input_text = ''\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "        for edge in path:\n",
    "            # print(edge)\n",
    "            if edge in entpair2rel:\n",
    "                input_text += entpair2rel[edge] + '(' + edge[0] + ',' + edge[1] + ')\\n'\n",
    "                # input_text += edge[0] + ' is the ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "            else:\n",
    "                input_text += entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ',' + edge[0] + ')\\n'\n",
    "                # input_text += edge[1] + ' is the ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "    return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "            text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, entpair2rel, text\n",
    "\n",
    "\n",
    "def read_class(path, cid, ent2class, id2ent, class_text):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '. '\n",
    "                class_text += rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the '+ rel2sym['male'] + '. '\n",
    "                class_text += rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text\n",
    "\n",
    "\n",
    "def read_all_facts(path1, path2, path_class, id2ent, ent2class, cid, text):\n",
    "    f_id = 1\n",
    "    triplets = list()\n",
    "    test_triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    edges = list()\n",
    "    tri2number = dict()\n",
    "\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            edges.append((id2ent[int(h)], id2ent[int(t)]))\n",
    "            # edges.append((id2ent[int(t)], id2ent[int(h)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(t)], id2ent[int(h)])] = rel2sym['inverse_' + id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            # text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            text += 'F' + str(f_id) + \": \" + rel2sym[id2rel[(int(r))]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            \n",
    "            tri2number[(id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)])] = 'F' + str(f_id)\n",
    "            f_id += 1\n",
    "            # text += 'F' + str(f_id) + \": \" + id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            # # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            \n",
    "            # tri2number[(id2ent[int(t)], rel2sym['inverse_' + id2rel[int(r)]], id2ent[int(h)])] = 'F' + str(f_id)\n",
    "            \n",
    "            # f_id += 1\n",
    "            \n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                test_triplets.append((id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)]))\n",
    "                # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                \n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # f_id += 1\n",
    "                \n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path_class, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "        \n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '.\\n'\n",
    "                text += 'F' + str(f_id) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['female'])] = 'F' + str(f_id)\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is the '+ rel2sym['male'] + '.\\n'\n",
    "                text += 'F' + str(f_id) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "\n",
    "                # text += id2ent[cid] + ' is the '+ rel2sym['male'] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['male'])] = 'F' + str(f_id)\n",
    "                \n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "            f_id += 1\n",
    "            cid += 1\n",
    "            \n",
    "    return triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, f_id\n",
    "    \n",
    "def get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules,  tri2number, rule2number):\n",
    "\n",
    "    # Define the logical rules that the paths should match\n",
    "    def logical_rules(entpair2rel, path, rule):\n",
    "        path_number = list()\n",
    "        for i in range(len(path)-1):\n",
    "            if (path[i], path[i+1]) in entpair2rel:\n",
    "                if entpair2rel[(path[i], path[i+1])] == rule[i]:\n",
    "                    path_number.append(tri2number[(path[i], entpair2rel[(path[i], path[i+1])], path[i+1])])\n",
    "            elif '$r45$' == rule[i]:\n",
    "                path_number.append(tri2number[(path[i+1], entpair2rel[(path[i+1], path[i])], path[i])])\n",
    "\n",
    "            else:\n",
    "                \n",
    "                return None\n",
    "        return path_number\n",
    "\n",
    "\n",
    "\n",
    "    # Define your knowledge graph using the NetworkX library\n",
    "    G = nx.Graph()\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "    fact2explain = dict()\n",
    "    fact2rule = dict()\n",
    "    for tri in test_triplets:\n",
    "        h = tri[0]\n",
    "        r = tri[1]\n",
    "        t = tri[2]\n",
    "        rule = rel2rules[r]\n",
    "        length = len(rule)\n",
    "        \n",
    "        all_paths = list()\n",
    "        # Find all paths that match the logical rules using NetworkX's all_simple_paths() function\n",
    "        for path in nx.all_simple_paths(G, source=h, target=t, cutoff=length):\n",
    "            # print(\"path\", path)\n",
    "            path_number = logical_rules(entpair2rel, path, rule)\n",
    "            if path_number:\n",
    "                if ent2class[h] == rule[-1]:\n",
    "                    path_number.append(tri2number[(h,'gender',ent2class[h])])\n",
    "                    \n",
    "                    path_number.append(rule2number[r])\n",
    "                    all_paths.append(path_number)\n",
    "        fact2explain[(h, r, t)] = all_paths\n",
    "        # fact2explain[(h,r,t)] = [rule2number[r]]\n",
    "        # fact2rule[(h,r,t)] = rule2number[r]\n",
    "\n",
    "    return fact2explain\n",
    "\n",
    "\n",
    "def read_rules(path, rel2sym):\n",
    "    rel2rules = dict()\n",
    "    rule2number = dict()\n",
    "    l_id = 1\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            lst = line.strip().split('\\t')\n",
    "            # replace symbol \n",
    "            new_lst = list()\n",
    "            for l in lst:\n",
    "                new_lst.append(rel2sym[l])\n",
    "            \n",
    "            rel2rules[new_lst[0]] = new_lst[1:]\n",
    "            rule2number[new_lst[0]] = 'L' + str(l_id)\n",
    "            l_id += 1\n",
    "    return rel2rules, rule2number, l_id\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_related_triplets_t(h, t, G, entpair2rel, ent2class, fid):\n",
    "    selected_facts = ''\n",
    "    selected_triplets = []\n",
    "    selected_facts += 'F' + str(fid) + \": \" + ent2class[h] + '(' + h + ')\\n'\n",
    "    fid += 1\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=6)):\n",
    "        \n",
    "        for edge in path:\n",
    "            # print(edge)\n",
    "            if edge in entpair2rel:\n",
    "                if (edge[0], entpair2rel[edge], edge[1]) not in selected_triplets:\n",
    "                    selected_triplets.append((edge[0], entpair2rel[edge], edge[1]))\n",
    "                    selected_facts += 'F' + str(fid) + \": \" + entpair2rel[edge] + '(' + edge[0] + ',' + edge[1] + ')\\n'\n",
    "                    fid += 1\n",
    "                # input_text += edge[0] + ' is the ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "            else:\n",
    "                if (edge[1], entpair2rel[(edge[1],edge[0])], edge[0]) not in selected_triplets:\n",
    "                    selected_triplets.append((edge[1], entpair2rel[(edge[1],edge[0])], edge[0]))\n",
    "                    selected_facts += 'F' + str(fid) + \": \" + entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ',' + edge[0] + ')\\n'\n",
    "                    fid += 1\n",
    "                # input_text += edge[1] + ' is the ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "    return selected_facts, fid\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/logic_standard_f+r_subgraph'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "# model = 'gpt-4'\n",
    "model = \"gpt-3.5-turbo\"\n",
    "\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "for i in range(0, 1):\n",
    "\n",
    "        eid = 0\n",
    "        cid = 0\n",
    "        id2ent = dict()\n",
    "        ent2id = dict()\n",
    "        ent2class = dict()\n",
    "        ent2triplets = dict()\n",
    "        class_text = ''\n",
    "        text = ''\n",
    "\n",
    "        path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "        path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "        path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "        path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "        path_rule = 'rule_tab.txt'\n",
    "\n",
    "        eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "        # cid, class_text = read_class(path_class, cid, ent2class, id2ent, class_text)\n",
    "        \n",
    "        # triplets, entpair2rel, text = read_all_triplets(path_rel1, path_rel2, id2ent, text)\n",
    "\n",
    "        \n",
    "        triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, fid = read_all_facts(path_rel1, path_rel2, path_class, id2ent, ent2class, cid, text)\n",
    "        \n",
    "        rel2rules, rule2number, lid = read_rules(path_rule, rel2sym)\n",
    "        fact2explain = get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules, tri2number, rule2number)\n",
    "        # test_questions = random.sample(triplets, int(len(triplets) * 0.2))\n",
    "        # train_questions = triplets.copy()\n",
    "        # for t in test_questions:\n",
    "        #     train_questions.remove(t)\n",
    "        true_num = 0\n",
    "        false_num = 0\n",
    "        num = 0\n",
    "\n",
    "\n",
    "        record_flag = False\n",
    "        G = nx.Graph()\n",
    "        for tri in triplets:\n",
    "            G.add_edge(tri[0],tri[2])\n",
    "        for triple in tqdm(test_triplets):\n",
    "                h, r, t = triple\n",
    "                text_pred = r + '(' + h + ', ' + t + ')'\n",
    "                selected_facts, fid = get_related_triplets_t(h, t, G, entpair2rel, ent2class, fid)\n",
    "                # text_pred = h + ' is the ' + r + ' of ' + t + '.'\n",
    "                # text_pred = 'Is ' + h + ' the ' + r + ' of ' + t + '? \\n '\n",
    "                # input_rule_text = get_rule_text(rule_heads[r], rule2text)\n",
    "                # input_text = get_related_triplets(h, t, G, entpair2rel)\n",
    "                # input_class_text = h + \" is the \" + ent2class[h] + '. ' \n",
    "\n",
    "                # print(input_text)\n",
    "                # print(\"We also have some facts. \" + class_text + input_text + text_pred + ' If yes, please answer only with 1 else 0')\n",
    "\n",
    "                message = {\n",
    "\n",
    "                        'system': \"You are a helpful assistant with abductive reasoning abilities. Please select one single logical rule and a few facts to explain the following statement. \",\n",
    "                        'user': \"I will provide one single logical rules and facts F1 to F\" + str(fid - 1) + \". Please select a few facts from F1 to F\" + str(fid - 1) + \" to explain the following statement. \\nRules:\\n\" + rh2rules[r[1:-1]] \n",
    "                        + \"\\nFacts:\\n\" + selected_facts + \"\\nStatement: \" + text_pred + \"\\nAnswer with the numbers of the selected rule and facts. The selected rule and facts are: \",\n",
    "                        \n",
    "                                # 'system': \"You are a helpful assistant. I will give you some logical rules and facts. Please select one single rule and a few facts to explain the following statement. \",\n",
    "                                # # 'user': \"I will give you some logical rules, facts and a statement.\\nThe logical rules are:\\n\" + rule_text + \"\\nThe facts are:\\n\" + text + \"\\nThe statement is: \" + text_pred + \"\\nPlease first select one logical rule that can infer the statement and then select multiple facts to match the logical rule. The selected logical rule and facts can entail the statement.\\nPlease output the numbers of logical rule and facts.\\nThe selected logical rule and facts are:\",\n",
    "                                # 'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to explain the following statement. \" +\n",
    "                                # \"\\nRules:\\n\" + rule_text + \"\\nFacts:\\n\" + text + \"\\nStatement: \" + text_pred + \"\\nThe selected logical rule and facts are: \",\n",
    "                                \n",
    "                        }\n",
    "                # message = {\n",
    "                #         'system': \"Please select one single logical rule and a few facts to explain the following statement. \",\n",
    "                #         'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to explain the following statement. \\nRules:\\n\" + rule_text \n",
    "                #         + \"\\nFacts:\\n\" + text + \"\\nStatement: \" + text_pred + \"\\nAnswer with the numbers of the selected rule and facts. The selected rule and facts are: \",\n",
    "                # }\n",
    "\n",
    "                server_error_cnt = 0\n",
    "\n",
    "                while server_error_cnt<10:\n",
    "                        try:\n",
    "                        \n",
    "                                update_key()\n",
    "                                response = openai.ChatCompletion.create(\n",
    "                                        # model=\"gpt-4\",\n",
    "                                        model=\"gpt-3.5-turbo\",\n",
    "                                        messages=[\n",
    "                                                {\"role\": \"system\", \"content\": message['system']},\n",
    "                                                {\"role\": \"user\", \"content\": message['user']},\n",
    "                                                ],\n",
    "                                        temperature=0,\n",
    "\n",
    "                                )\n",
    "                                break\n",
    "                        except Exception as e:\n",
    "                                server_error_cnt += 1\n",
    "                                print(e)\n",
    "\n",
    "                if record_flag == False:\n",
    "\n",
    "                        logger.info('message: \\n' + dict2str(message)) \n",
    "                        record_flag = True\n",
    "\n",
    "                results = response['choices'][0]['message']['content']\n",
    "\n",
    "                # 提取 前两句话\n",
    "                # ans = re.findall(r'^(.*?\\n){2}', results)\n",
    "                # 提取explain前的文本\n",
    "                # ans = re.findall(r'^.*?(?=Explanation)', results,flags=re.DOTALL | re.MULTILINE)\n",
    "                # if len(ans) == 0:\n",
    "                #     last_line = results.strip().split('\\n')[0]\n",
    "                # else:\n",
    "                #     last_line = ans[0] \n",
    "\n",
    "\n",
    "                # last_sentence = re.findall(r\"\\b\\S[^.!?]*[.!?]\", last_line)[-1]\n",
    "                # print(\"I will give you some logical rules, facts.\\nThe logical rules are:\\n\" + rule_text + \"\\nThe facts are:\\n\" + text +  \"\\nPlease first select one logical rule that can infer the following statement and then select multiple facts to match the logical rule. The selected logical rule and facts can entail the statement.\\nPlease output the numbers of logical rule and facts.\\nFor example:\\nThe statement is: Lorenz is the $jdvxr$ of Lea.\\nThe selected logical rule and facts are: ['F8', 'F29', 'L22']\\n\\nThe statement is: Lorenz is the $jdvxr$ of Gabriel.\\nThe selected logical rule and facts are: ['F12', 'F29', 'L22']\\n\\nThe statement is: \" + text_pred + \"\\nThe selected logical rule and facts are:\")\n",
    "\n",
    "\n",
    "                # print(results)\n",
    "                # print(last_line)\n",
    "                num += 1\n",
    "                number_list = re.findall(r'[FL]\\d+', results)\n",
    "                logger.info(\"statement: \" + text_pred )\n",
    "                logger.info(\"results: %s\", results)\n",
    "                \n",
    "                logger.info('LLM: %s', number_list)\n",
    "                logger.info(\"grounding_truth: %s\", list2str(fact2explain[(h,r,t)]))\n",
    "                \n",
    "                if list_equal(number_list, fact2explain[(h,r,t)]):\n",
    "                        logger.info(\"correct\")\n",
    "                        true_num += 1\n",
    "                else:\n",
    "                        false_num += 1\n",
    "\n",
    "\n",
    "        logger.info(\"accuracy: \" + str( true_num / num ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
