{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "\n",
    "def list_equal(a, answers):\n",
    "    # if set(a) == set(b):\n",
    "    #     return True\n",
    "    # else:\n",
    "    #     return False\n",
    "    for b in answers:\n",
    "        if set(a) == set(b):\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n",
    "\n",
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "# multiple-list save to logging\n",
    "def list2str(l):\n",
    "    s = ''\n",
    "    for i in l:\n",
    "        s += str(i) + '\\r'\n",
    "    return s\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"../rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym_2[rel] = sym\n",
    "        rel2sym[rel] = '$' + sym + '$'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rh2rules = dict()\n",
    "id = 1\n",
    "rule_text = ''\n",
    "with open('latex_rules.txt','r') as f:\n",
    "    for line in f:\n",
    "        rh2rules[rel2sym[infer_rel[id]]] = line.strip().split('\\t')[-1]\n",
    "        rule_text += 'L' + str(id) + \": \" + line\n",
    "        id += 1\n",
    "# print(rule_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def get_related_triplets(h, t, G, entpair2rel):\n",
    "    input_text = ''\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "        for edge in path:\n",
    "            # print(edge)\n",
    "            if edge in entpair2rel:\n",
    "                input_text += entpair2rel[edge] + '(' + edge[0] + ',' + edge[1] + ')\\n'\n",
    "                # input_text += edge[0] + ' is the ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "            else:\n",
    "                input_text += entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ',' + edge[0] + ')\\n'\n",
    "                # input_text += edge[1] + ' is the ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "    return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "            text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, entpair2rel, text\n",
    "\n",
    "\n",
    "def read_class(path, cid, ent2class, id2ent, class_text):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '. '\n",
    "                class_text += rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the '+ rel2sym['male'] + '. '\n",
    "                class_text += rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text\n",
    "\n",
    "\n",
    "def read_all_facts(path1, path2, path_class, id2ent, ent2class, cid, text ):\n",
    "    f_id = 1\n",
    "    triplets = list()\n",
    "    test_triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    edges = list()\n",
    "    tri2number = dict()\n",
    "\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            edges.append((id2ent[int(h)], id2ent[int(t)]))\n",
    "            # edges.append((id2ent[int(t)], id2ent[int(h)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(t)], id2ent[int(h)])] = rel2sym['inverse_' + id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            # text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            text += 'F' + str(f_id) + \": \" + rel2sym[id2rel[(int(r))]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            \n",
    "            tri2number[(id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)])] = 'F' + str(f_id)\n",
    "            f_id += 1\n",
    "            # text += 'F' + str(f_id) + \": \" + id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            # # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            \n",
    "            # tri2number[(id2ent[int(t)], rel2sym['inverse_' + id2rel[int(r)]], id2ent[int(h)])] = 'F' + str(f_id)\n",
    "            \n",
    "            # f_id += 1\n",
    "            \n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                test_triplets.append((id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)]))\n",
    "                # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                \n",
    "                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # f_id += 1\n",
    "                \n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path_class, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "        \n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '.\\n'\n",
    "                text += 'F' + str(f_id) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['female'])] = 'F' + str(f_id)\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is the '+ rel2sym['male'] + '.\\n'\n",
    "                text += 'F' + str(f_id) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "\n",
    "                # text += id2ent[cid] + ' is the '+ rel2sym['male'] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['male'])] = 'F' + str(f_id)\n",
    "                \n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "            f_id += 1\n",
    "            cid += 1\n",
    "    return triplets, test_triplets, entpair2rel, cid, text, edges, tri2number\n",
    "    \n",
    "def get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules,  tri2number, rule2number):\n",
    "\n",
    "    # Define the logical rules that the paths should match\n",
    "    def logical_rules(entpair2rel, path, rule):\n",
    "        path_number = list()\n",
    "        for i in range(len(path)-1):\n",
    "            if (path[i], path[i+1]) in entpair2rel:\n",
    "                if entpair2rel[(path[i], path[i+1])] == rule[i]:\n",
    "                    path_number.append(tri2number[(path[i], entpair2rel[(path[i], path[i+1])], path[i+1])])\n",
    "            elif '$r45$' == rule[i]:\n",
    "                path_number.append(tri2number[(path[i+1], entpair2rel[(path[i+1], path[i])], path[i])])\n",
    "\n",
    "            else:\n",
    "                \n",
    "                return None\n",
    "        return path_number\n",
    "\n",
    "\n",
    "\n",
    "    # Define your knowledge graph using the NetworkX library\n",
    "    G = nx.Graph()\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "    fact2explain = dict()\n",
    "    fact2rule = dict()\n",
    "    for tri in test_triplets:\n",
    "        h = tri[0]\n",
    "        r = tri[1]\n",
    "        t = tri[2]\n",
    "        rule = rel2rules[r]\n",
    "        length = len(rule)\n",
    "        \n",
    "        all_paths = list()\n",
    "        # Find all paths that match the logical rules using NetworkX's all_simple_paths() function\n",
    "        for path in nx.all_simple_paths(G, source=h, target=t, cutoff=length):\n",
    "            # print(\"path\", path)\n",
    "            path_number = logical_rules(entpair2rel, path, rule)\n",
    "            if path_number:\n",
    "                if ent2class[h] == rule[-1]:\n",
    "                    path_number.append(tri2number[(h,'gender',ent2class[h])])\n",
    "                    \n",
    "                    # path_number.append(rule2number[r])\n",
    "                    all_paths.append(path_number)\n",
    "        fact2explain[(h, r, t)] = all_paths\n",
    "        # fact2explain[(h,r,t)] = [rule2number[r]]\n",
    "        fact2rule[(h,r,t)] = rule2number[r]\n",
    "\n",
    "    return fact2explain, fact2rule\n",
    "\n",
    "\n",
    "def read_rules(path, rel2sym):\n",
    "    rel2rules = dict()\n",
    "    rule2number = dict()\n",
    "    l_id = 1\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            lst = line.strip().split('\\t')\n",
    "            # replace symbol \n",
    "            new_lst = list()\n",
    "            for l in lst:\n",
    "                new_lst.append(rel2sym[l])\n",
    "            \n",
    "            rel2rules[new_lst[0]] = new_lst[1:]\n",
    "            rule2number[new_lst[0]] = 'L' + str(l_id)\n",
    "            l_id += 1\n",
    "    return rel2rules, rule2number\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from rule_parent.txt get the length of each rule\n",
    "rule_length = dict()\n",
    "with open('rule_tab.txt', 'r') as f:\n",
    "    for line in f:\n",
    "        lst = line.strip().split('\\t')\n",
    "        rule_length[rel2sym[lst[0]]] = len(lst[1:]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "for i in range(0, 1):\n",
    "    \n",
    "    eid = 0\n",
    "    cid = 0\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "    class_text = ''\n",
    "    text = ''\n",
    "\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "    path_rule = 'rule_tab.txt'\n",
    "\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    # cid, class_text = read_class(path_class, cid, ent2class, id2ent, class_text)\n",
    "    \n",
    "    # triplets, entpair2rel, text = read_all_triplets(path_rel1, path_rel2, id2ent, text)\n",
    "\n",
    "    \n",
    "    triplets, test_triplets, entpair2rel, cid, text, edges, tri2number = read_all_facts(path_rel1, path_rel2, path_class, id2ent, ent2class, cid, text)\n",
    "    \n",
    "    rel2rules, rule2number = read_rules(path_rule, rel2sym)\n",
    "    fact2explain, fact2rule= get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules, tri2number, rule2number)\n",
    "    # test_questions = random.sample(triplets, int(len(triplets) * 0.2))\n",
    "    # train_questions = triplets.copy()\n",
    "    # for t in test_questions:\n",
    "    #     train_questions.remove(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fact2explain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fact2explain[test_triplets[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_standard/'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "model = \"gpt-3.5-turbo\"\n",
    "# model = \"gpt-4\"\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "record_flag = False\n",
    "\n",
    "true_num = 0\n",
    "false_num = 0\n",
    "num = 0\n",
    "\n",
    "\n",
    "\n",
    "for triple in tqdm(test_triplets):\n",
    "    h, r, t = triple\n",
    "    # text_pred = h + ' is the ' + r + ' of ' + t + '.'\n",
    "    text_pred = r + '(' + h + ', ' + t + ')'\n",
    "    # text_pred = 'Is ' + h + ' the ' + r + ' of ' + t + '? \\n '\n",
    "    # input_rule_text = get_rule_text(rule_heads[r], rule2text)\n",
    "    # input_text = get_related_triplets(h, t, G, entpair2rel)\n",
    "    # input_class_text = h + \" is the \" + ent2class[h] + '. ' \n",
    "\n",
    "    # print(input_text)\n",
    "    # print(\"We also have some facts. \" + class_text + input_text + text_pred + ' If yes, please answer only with 1 else 0')\n",
    "    \n",
    "    # print(text_pred)\n",
    "   \n",
    "    \n",
    "    server_error_cnt = 0\n",
    "    \n",
    "    while server_error_cnt<10:\n",
    "        try:\n",
    "            message = {\n",
    "                        'system': \"You are a helpful assistant. I will give you some logical rules and facts. Please select some facts to explain the following statement. \",\n",
    "                        'user': \"I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select \" + str(rule_length[r]) + \" facts that matches the logical rule. The selected facts and the given rule can entail the following statement.\\nThe statement is: \" + text_pred + \"\\nThe selected \" + str(rule_length[r]) +  \" facts are: \",\n",
    "                        }\n",
    "            update_key()\n",
    "            response = openai.ChatCompletion.create(\n",
    "            model=\"gpt-3.5-turbo\",\n",
    "            # model=\"gpt-4\",\n",
    "\n",
    "            messages=[\n",
    "                 \n",
    "                    {\"role\": \"system\", \"content\": message['system']},\n",
    "                    {\"role\": \"user\", \"content\": message['user']},\n",
    "                                \n",
    "                    # {\"role\": \"system\", \"content\": \"You are a helpful assistant. \"},\n",
    "\n",
    "                    # {\"role\": \"system\", \"content\": \"You are a helpful assistant. I will give you some logical rules and facts. You need select \" + str(rule_length[r]) + \" facts to explain the following statement. \"},\n",
    "                    # {\"role\": \"system\", \"content\": \"You are a helpful assistant. I will give you some logical rules and facts. Please select some facts to explain the following statement. \"},\n",
    "                    \n",
    "                    # {\"role\": \"user\", \"content\": \"We have some logical rules. \" + rules_text + \"We also have some facts. \" + class_text + text + text_pred + ' If yes, please answer only with 1 else 0'},\n",
    "                    # {\"role\": \"user\", \"content\": \"We have some logical rules.\\n\" + input_rule_text + \"We also have some facts.\\n\" + input_text + text_pred + 'If yes, please answer only with 1 else 0'},\n",
    "                    # {\"role\": \"user\", \"content\": \"We have some logical rules.\\n\" + rules_text + \"We also have some facts.\\n\" + input_class_text + input_text + text_pred + 'If yes, please answer only with 1 else 0'},\n",
    "                    # {\"role\": \"user\", \"content\": \" I will give you some logical rules and facts. You only need select one logical rule to explain the following statement. The logical rules are:\\n\" + rule_text + \"\\nThe facts are: \" + text + \"\\nThe statement is: \" + text_pred + \"\\nYou can find the logical rule that entail the statement and then grounding the logical rule to select related facts. Please select the numbers of one logical rule and multiple facts. The number list is: \"},\n",
    "                    # {\"role\": \"user\", \"content\": \" I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select \" + str(rule_length[r]) + \" facts. The selected facts and the given rule can entail the following statement. \\nThe statement is: \" + text_pred + \"\\nThe selected \" + str(rule_length[r]) +  \" facts are: \"},\n",
    "                    # {\"role\": \"user\", \"content\": \" I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select \" + str(rule_length[r]) + \" facts that matches the logical rule. The selected facts and the given rule can entail the following statement.\\nThe statement is: \" + text_pred + \"\\nThe selected \" + str(rule_length[r]) +  \" facts are: \"},\n",
    "                    # {\"role\": \"user\", \"content\": \" I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select multiple facts that matches the logical rule. The selected facts and the given rule can entail the following statement.\\nThe statement is: \" + text_pred + \"\\nThe selected facts are: \"},\n",
    "                    # {\"role\": \"user\", \"content\": \"I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select \" + str(rule_length[r]) + \" facts that matches the logical rule. The selected facts and the given rule can entail the following statement.\\nThe statement is: \" + text_pred + \"\\nThe selected \" + str(rule_length[r]) +  \" facts are: \"},\n",
    "                    # {\"role\": \"user\", \"content\": \" I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select multiple facts that matches the logical rule. The selected facts and the given rule can entail the following statement.\\nThe statement is: \" + text_pred + \"\\nThe selected facts are: \"},\n",
    "                    \n",
    "                ],\n",
    "            temperature=0,\n",
    "\n",
    "            )\n",
    "            \n",
    "            results = response['choices'][0]['message']['content']\n",
    "            # print(\"I will give you some logical rules and facts.\\nThe facts are:\\n\" + text + \"\\nThe logical rule is:\\n\" + rh2rules[r] + \"\\nPlease select \" + str(rule_length[r]) + \" facts that matches the logical rule. The selected facts and the given rule can entail the following statement.\\nThe statement is: \" + text_pred + \"\\nThe selected \" + str(rule_length[r]) +  \" facts are: \")\n",
    "            if record_flag == False:\n",
    "\n",
    "                logger.info('message: \\n' + dict2str(message)) \n",
    "\n",
    "                record_flag = True\n",
    "            \n",
    "            # print(results)\n",
    "            num += 1\n",
    "            number_list = re.findall(r'[FL]\\d+', results)\n",
    "            logger.info('LLM: ' + (', '.join(number_list)) + '\\t' + \"grounding_truth: \" + (', '.join(fact2explain[(h,r,t)][0])) + '\\t' + \"results: \" + results)\n",
    "            # print('LLM: ', number_list)\n",
    "            # print(\"grounding_truth: \", fact2explain[(h,r,t)])\n",
    "            if list_equal(number_list, fact2explain[(h,r,t)]):\n",
    "                print(\"correct\")\n",
    "                true_num += 1\n",
    "            else:\n",
    "                false_num += 1\n",
    "            break\n",
    "        except Exception as e:\n",
    "            server_error_cnt += 1\n",
    "            print(e)\n",
    "\n",
    "print(\"accuracy: \", true_num/num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fact2explain[('Lorenz', '$r2$', 'Lea')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_equal(['F1', 'F1', 'F3'], ['F1', 'F3'])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## for each triplet to find grounded rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for triplet in triplets:\n",
    "    rel = triplet[1]\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "# Define your knowledge graph using the NetworkX library\n",
    "G = nx.Graph()\n",
    "G.add_edges_from([(1,2), (2,3), (3,4), (4,5), (5,6), (6,7), (7,8)])\n",
    "\n",
    "# Define the logical rules that the paths should match\n",
    "def logical_rules(path):\n",
    "    return path[0] == 1 and path[-1] == 8 and len(path) >= 4\n",
    "\n",
    "# Find all paths that match the logical rules using NetworkX's all_simple_paths() function\n",
    "all_paths = []\n",
    "for path in nx.all_simple_paths(G, source=1, target=8):\n",
    "    print(\"path\", path)\n",
    "    if logical_rules(path):\n",
    "        all_paths.append(path)\n",
    "\n",
    "# Print the resulting paths\n",
    "print(all_paths)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# add inverse relation into KG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add inverse relation for each relation\n",
    "rel2sym = dict()\n",
    "sym2rel = dict()\n",
    "import random\n",
    "from random import sample\n",
    "\n",
    "with open('rel2sym.txt', 'r') as f:\n",
    "    for line in f:\n",
    "        line = line.strip()\n",
    "        rel, sym = line.split('\\t')\n",
    "        rel2sym[rel] = sym\n",
    "        sym2rel[sym] = rel\n",
    "\n",
    "\n",
    "for t in list(rel2sym.keys()):\n",
    "    t_prime = 'inverse_' + t\n",
    "    # random sample a symbol\n",
    "    sym = ''.join(sample('abcdefghijklmnopqrstuvwxyz',random.randint(4,8)))\n",
    "    rel2sym[t_prime] = sym\n",
    "    sym2rel[sym] = t_prime\n",
    "\n",
    "# write rel2sym_inverse.txt to file\n",
    "with open('rel2sym_inverse.txt', 'w') as f:\n",
    "    for t in rel2sym:\n",
    "        f.write(t + '\\t' + rel2sym[t] + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rel2sym.keys()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# tranlate the logical rule as latex format and save to file \"rule_latex.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  tranlate the logical rule as latex format and save to file \"rule_latex.txt\"\n",
    "# input: sisterOf\tinverse_parentOf\tparentOf\tfemale\n",
    "# output: $inverse_parentOf(a,b) \\land parentOf(b,c) \\land female(a) \\rightarrow sisterOf(a,c) $\n",
    "with open(\"rule_latex.txt\", 'w') as fw:\n",
    "    with open(\"rule_parents.txt\", 'r') as f:\n",
    "        for line in f:\n",
    "            lst = line.strip().split('\\t')\n",
    "            fw.writelines('$')\n",
    "            h = ord('A')\n",
    "            for rel in lst[1:-1]:\n",
    "                t = h + 1\n",
    "                fw.writelines(rel2sym_2[rel] + '(' + chr(h) + ',' + chr(t) + ') \\land ')\n",
    "                h = t\n",
    "            fw.writelines(rel2sym_2[lst[-1]] + '(A) \\\\rightarrow ' + rel2sym_2[lst[0]] + '(A,' + chr(t) + ')$\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tranlate the logical rule as natural language and save to file \"rule_latex.txt\"\n",
    "with open(\"rule_language.txt\", 'w') as fw:\n",
    "    with open(\"rule_parents.txt\", 'r') as f:\n",
    "        for line in f:\n",
    "            lst = line.strip().split('\\t')\n",
    "            fw.writelines('If ')\n",
    "            h = ord('A')\n",
    "            for rel in lst[1:-1]:\n",
    "                t = h + 1\n",
    "                fw.writelines(chr(h) + ' is the $' + rel2sym_2[rel] + '$ of ' + chr(t) + ' and ')\n",
    "                h = t\n",
    "            fw.writelines('A is the $' + rel2sym_2[lst[-1]] + '$, then A is the $' + rel2sym_2[lst[0]] + '$ of ' + chr(t) + '.\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = dict()\n",
    "\n",
    "for key in rule_length:\n",
    "    template[key] = 'If '\n",
    "    ent_h = ord('A')\n",
    "    for i in range(rule_length[key]):\n",
    "        ent_h = ent_h\n",
    "        ent_t = ent_h + 1\n",
    "        template[key] += chr(ent_h) + ' is the ## of ' + chr(ent_t) + ' and '\n",
    "        # template[key] += '##(' + chr(ent_h) + ',' + chr(ent_t) + ') \\land '\n",
    "        ent_h = ent_t \n",
    "    template[key] += 'A is the ++, then A is the ' + key + ' of ' + chr(ent_t) + '.'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# translate the family relation into symbolic and save a new file.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"rule_symbolic_tab.txt\", 'w') as fw:\n",
    "    with open(\"rule_parents.txt\", 'r') as f:\n",
    "        for line in f:\n",
    "            lsts = line.strip().split('\\t')\n",
    "            for lst in lsts[:-1]:\n",
    "                fw.writelines(rel2sym[lst] + '\\t')\n",
    "            fw.writelines(rel2sym[lsts[-1]] + '\\n')\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 正则表达式 F或L+数字 \n",
    "import re\n",
    "a = 'F5,L1,F45'\n",
    "b = re.findall(r'[FL]\\d+', a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# list转换成集合\n",
    "a = ['F1', 'F1', 'F3']\n",
    "b = set(a)\n",
    "c = set(['F1', 'F3'])\n",
    "# 判断两个集合是否相等\n",
    "\n",
    "b == c\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [1,2,3]\n",
    "b = [3,2,1]\n",
    "list_equal(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
