{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys_filter.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "random.shuffle(openai_api_keys)\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n",
    "\n",
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "# multiple-list save to logging\n",
    "def list2str(l):\n",
    "    s = ''\n",
    "    for i in l:\n",
    "        s += str(i) + '\\t'\n",
    "    return s\n",
    "def list_equal(a, answers):\n",
    "    # if set(a) == set(b):\n",
    "    #     return True\n",
    "    # else:\n",
    "    #     return False\n",
    "    for b in answers:\n",
    "        # if set(a) contain set(b)\n",
    "        if set(b) == set(a):\n",
    "            return True\n",
    "        # if set(a) == set(b):\n",
    "        #     return True\n",
    "    return False\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym_2[rel] = sym\n",
    "        rel2sym[rel] = '$' + sym + '$'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rh2rules = dict()\n",
    "lid = 1\n",
    "rule_text = ''\n",
    "with open('natural_rules.txt','r') as f:\n",
    "    for line in f:\n",
    "        rh2rules[line.strip().split('\\t')[0]] = line.strip().split('\\t')[-1]\n",
    "        rule_text += 'L' + str(lid) + \": \" + line.strip().split('\\t')[-1] + '\\n'\n",
    "        lid += 1\n",
    "# print(rule_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def get_related_triplets(h, t, G, entpair2rel):\n",
    "    input_text = ''\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "        for edge in path:\n",
    "            # print(edge)\n",
    "            if edge in entpair2rel:\n",
    "                input_text += entpair2rel[edge] + '(' + edge[0] + ', ' + edge[1] + ')\\n'\n",
    "                # input_text += edge[0] + ' is the ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "            else:\n",
    "                input_text += entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ', ' + edge[0] + ')\\n'\n",
    "                # input_text += edge[1] + ' is the ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "    return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "            text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, entpair2rel, text\n",
    "\n",
    "\n",
    "def read_class(path, cid, ent2class, id2ent, class_text):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '. '\n",
    "                class_text += rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += id2ent[cid] + ' is the '+ rel2sym['male'] + '. '\n",
    "                class_text += rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text\n",
    "\n",
    "\n",
    "def read_all_facts(path1, path2, path_class, id2ent, ent2class, cid, text):\n",
    "    f_id = 1\n",
    "    triplets = list()\n",
    "    test_triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    edges = list()\n",
    "    tri2number = dict()\n",
    "\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            edges.append((id2ent[int(h)], id2ent[int(t)]))\n",
    "            # edges.append((id2ent[int(t)], id2ent[int(h)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(t)], id2ent[int(h)])] = rel2sym['inverse_' + id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            # text += 'F' + str(f_id) + \": \" + rel2sym[id2rel[(int(r))]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            \n",
    "            tri2number[(id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)])] = 'F' + str(f_id)\n",
    "            f_id += 1\n",
    "            # text += 'F' + str(f_id) + \": \" + id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            # # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            \n",
    "            # tri2number[(id2ent[int(t)], rel2sym['inverse_' + id2rel[int(r)]], id2ent[int(h)])] = 'F' + str(f_id)\n",
    "            \n",
    "            # f_id += 1\n",
    "            \n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                test_triplets.append((id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)]))\n",
    "                # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                \n",
    "                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # f_id += 1\n",
    "                \n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path_class, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "        \n",
    "                text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is ' + rel2sym[\"female\"] + '.\\n'\n",
    "                # text += 'F' + str(f_id) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['female'])] = 'F' + str(f_id)\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is '+ rel2sym['male'] + '.\\n'\n",
    "                # text += 'F' + str(f_id) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "\n",
    "                # text += id2ent[cid] + ' is the '+ rel2sym['male'] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['male'])] = 'F' + str(f_id)\n",
    "                \n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "            f_id += 1\n",
    "            cid += 1\n",
    "            \n",
    "    return triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, f_id\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "def get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules,  tri2number, rule2number):\n",
    "\n",
    "    # Define the logical rules that the paths should match\n",
    "    def logical_rules(entpair2rel, path, rule):\n",
    "        path_number = list()\n",
    "        for i in range(len(path)-1):\n",
    "            if (path[i], path[i+1]) in entpair2rel:\n",
    "                if entpair2rel[(path[i], path[i+1])] == rule[i]:\n",
    "                    path_number.append(tri2number[(path[i], entpair2rel[(path[i], path[i+1])], path[i+1])])\n",
    "            elif '$r45$' == rule[i]:\n",
    "                path_number.append(tri2number[(path[i+1], entpair2rel[(path[i+1], path[i])], path[i])])\n",
    "\n",
    "            else:\n",
    "                \n",
    "                return None\n",
    "        return path_number\n",
    "\n",
    "\n",
    "\n",
    "    # Define your knowledge graph using the NetworkX library\n",
    "    G = nx.Graph()\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "    fact2explain = dict()\n",
    "    fact2rule = dict()\n",
    "    for tri in test_triplets:\n",
    "        h = tri[0]\n",
    "        r = tri[1]\n",
    "        t = tri[2]\n",
    "        rule = rel2rules[r]\n",
    "        length = len(rule)\n",
    "        \n",
    "        all_paths = list()\n",
    "        # Find all paths that match the logical rules using NetworkX's all_simple_paths() function\n",
    "        for path in nx.all_simple_paths(G, source=h, target=t, cutoff=length):\n",
    "            # print(\"path\", path)\n",
    "            path_number = logical_rules(entpair2rel, path, rule)\n",
    "            if path_number:\n",
    "                if ent2class[h] == rule[-1]:\n",
    "                    path_number.append(tri2number[(h,'gender',ent2class[h])])\n",
    "                    \n",
    "                    path_number.append(rule2number[r])\n",
    "                    all_paths.append(path_number)\n",
    "        fact2explain[(h, r, t)] = all_paths\n",
    "        # fact2explain[(h,r,t)] = [rule2number[r]]\n",
    "        # fact2rule[(h,r,t)] = rule2number[r]\n",
    "\n",
    "    return fact2explain\n",
    "\n",
    "\n",
    "def read_rules(path, rel2sym):\n",
    "    rel2rules = dict()\n",
    "    rule2number = dict()\n",
    "    l_id = 1\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            lst = line.strip().split('\\t')\n",
    "            # replace symbol \n",
    "            new_lst = list()\n",
    "            for l in lst:\n",
    "                new_lst.append(rel2sym[l])\n",
    "            \n",
    "            rel2rules[new_lst[0]] = new_lst[1:]\n",
    "            rule2number[new_lst[0]] = 'L' + str(l_id)\n",
    "            l_id += 1\n",
    "    return rel2rules, rule2number, l_id\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read few_shot_prompts.txt to get the prompts\n",
    "\n",
    "def read_few_shot_prompts(path):\n",
    "    prompts = list()\n",
    "    prompts_plus = list()\n",
    "    with open(path, 'r') as f:\n",
    "        blocks = f.read().split('\\n\\n')\n",
    "        for block in blocks:\n",
    "            # 正则表达式提取statement以后，answer之前的内容\n",
    "            statement = re.findall(r'Statement: (.*?)\\n', block)[0]\n",
    "\n",
    "            \n",
    "            d = {}\n",
    "            d['Statement'] = \"Statement: \" + statement\n",
    "            answer = re.findall(r'Answer: (.*)', block, flags=re.DOTALL)[0]\n",
    "            d['Answer'] = \"Answer: \" + answer\n",
    "            prompts.append(d)\n",
    "\n",
    "            d_plus = {}\n",
    "            d_plus['Statement'] = \"Statement: \" + statement\n",
    "            d_plus['Answer'] = \"Answer: Let's think step by step. \" + answer\n",
    "            prompts_plus.append(d_plus)\n",
    "            \n",
    "            \n",
    "            \n",
    "    print(prompts)\n",
    "    return prompts, prompts_plus\n",
    "prompts, prompts_plus = read_few_shot_prompts('few_shot_prompts_nat.txt')\n",
    "len(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/natural_few_shot_f+r'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "# model = 'gpt-4'\n",
    "model = \"gpt-3.5-turbo\"\n",
    "\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "for i in range(0, 1):\n",
    "        eid = 0\n",
    "        cid = 0\n",
    "        id2ent = dict()\n",
    "        ent2id = dict()\n",
    "        ent2class = dict()\n",
    "        ent2triplets = dict()\n",
    "        class_text = ''\n",
    "        text = ''\n",
    "\n",
    "        path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "        path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "        path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "        path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "        path_rule = 'rule_tab.txt'\n",
    "\n",
    "        eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "        # cid, class_text = read_class(path_class, cid, ent2class, id2ent, class_text)\n",
    "\n",
    "        # triplets, entpair2rel, text = read_all_triplets(path_rel1, path_rel2, id2ent, text)\n",
    "\n",
    "\n",
    "        triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, fid = read_all_facts(path_rel1, path_rel2, path_class, id2ent, ent2class, cid, text)\n",
    "\n",
    "        rel2rules, rule2number, lid = read_rules(path_rule, rel2sym)\n",
    "        fact2explain = get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules, tri2number, rule2number)\n",
    "        true_num = 0\n",
    "        false_num = 0\n",
    "        num = 0\n",
    "\n",
    "\n",
    "        record_flag = False\n",
    "        # prompts, prompts_plus = generate_few_shot_prompts(test_triplets, model, lid, fid, rule_text, text,fact2explain)\n",
    "\n",
    "        for triple in tqdm(test_triplets):\n",
    "                h, r, t = triple\n",
    "                # text_pred = r + '(' + h + ', ' + t + ')'\n",
    "                text_pred = h + ' is ' + r + ' of ' + t + '.'\n",
    "                # print(input_text)\n",
    "                # print(\"We also have some facts. \" + class_text + input_text + text_pred + ' If yes, please answer only with 1 else 0')\n",
    "                \n",
    "                # message = {\n",
    "                #             'system': \"You are a helpful assistant. I will give you some logical rules and facts. Please select one single rule and a few facts to explain the following statement. \",\n",
    "                #             # 'user': \"I will give you some logical rules, facts and a statement.\\nThe logical rules are:\\n\" + rule_text + \"\\nThe facts are:\\n\" + text + \"\\nThe statement is: \" + text_pred + \"\\nPlease first select one logical rule that can infer the statement and then select multiple facts to match the logical rule. The selected logical rule and facts can entail the statement.\\nPlease output the numbers of logical rule and facts.\\nThe selected logical rule and facts are:\",\n",
    "                #             'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to explain the following statement. \" +\n",
    "                #              \"\\nRules:\\n\" + rule_text + \"\\nFacts:\\n\" + text + \"\\nStatement: \" + text_pred + \"\\nThe selected logical rule and facts are: \",\n",
    "                                \n",
    "                #             }\n",
    "                message_1 = {\n",
    "                        # 'system': \"You are a helpful assistant. I will give you a logical rule, some facts and a statement. Based on the given logical rules and facts, please select one logical rule and multiple facts to explain the statement. \",\n",
    "\n",
    "                        # 'system': \"You are a helpful assistant. I will give you some logical rules, facts and a statement. Please select one logical rule and multiple facts to explain the statement. \",\n",
    "                        # 'user': \"I will give you some logical rules, facts and a statetement. Please select one logical rule and multiple facts to explain the statement. \\nThe logical rules are:\\n\" + rule_text + \"\\nThe facts are:\\n\" + text,\n",
    "                        'system': \"You are a helpful assistant with abductive reasoning abilities. Please select one single logical rule and a few facts to explain the following statement. \",\n",
    "                        'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to explain the following statement. \\nRules:\\n\" + rule_text \n",
    "                        + \"\\nFacts:\\n\" + text,\n",
    "                        'Q1': prompts[0]['Statement'] ,\n",
    "                        'A1': prompts[0]['Answer'],\n",
    "                        'Q2': prompts[1]['Statement'],\n",
    "                        'A2': prompts[1]['Answer'],\n",
    "                        'Q3': prompts[2]['Statement'],\n",
    "                        'A3': prompts[2]['Answer'],\n",
    "                        'Q4': prompts[3]['Statement'],\n",
    "                        'A4': prompts[3]['Answer'],\n",
    "                        'Q5': prompts[4]['Statement'],\n",
    "                        'A5': prompts[4]['Answer'],\n",
    "                        # 'Q6': prompts[5]['Statement'],\n",
    "                        # 'A6': prompts[5]['Answer'],\n",
    "                        'Q7': \"Statement: \" + text_pred + '\\nAnswer: ',\n",
    "                \n",
    "                        }\n",
    "                server_error_cnt = 0\n",
    "                \n",
    "                while server_error_cnt<10:\n",
    "                        try:\n",
    "                        \n",
    "                                update_key()\n",
    "                                \n",
    "                                response = openai.ChatCompletion.create(\n",
    "                               \n",
    "                                model = model,\n",
    "\n",
    "\n",
    "                                messages=[\n",
    "                                        {\"role\": \"system\", \"content\": message_1['system']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['user']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q1']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A1']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q2']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A2']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q3']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A3']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q4']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A4']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q5']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A5']},\n",
    "                                        # {\"role\": \"user\", \"content\": message_1['Q6']},\n",
    "                                        # {\"role\": \"assistant\", \"content\": message_1['A6']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q7']},\n",
    "                                        ],\n",
    "                                temperature=0,\n",
    "\n",
    "                                )\n",
    "                                results_1 = response['choices'][0]['message']['content']\n",
    "                                last_text = message_1['user'] + message_1['Q1'] + message_1['A1'] + message_1['Q2'] + message_1['A2'] + message_1['Q3'] + message_1['A3'] + message_1['Q4'] + message_1['A4'] + message_1['Q5'] + message_1['A5'] + message_1['Q7']\n",
    "                                \n",
    "                                update_key()\n",
    "                                message_2 = {\n",
    "                                                # 'systerm': \"You are a helpful assistant. I will give you some logical rules and facts. Please output the numbers (e.g. F1) of selected logical rule and facts. \",\n",
    "\n",
    "                                                'system': \"Please output the numbers (e.g. L1, F1, F2) of the selected logical rule and facts. \",\n",
    "                                                'user' : last_text + '\\n' + results_1 + \"\\nTherefore, the selected logical rule and facts are: \"\n",
    "                                        \n",
    "                                        }\n",
    "                                response = openai.ChatCompletion.create(\n",
    "                                        model= model,\n",
    "                                        messages=[\n",
    "                                                {\"role\": \"system\", \"content\": message_2['system']},\n",
    "                                                {\"role\": \"user\", \"content\": message_2['user']},\n",
    "                                        ],\n",
    "                                        temperature=0,\n",
    "                                )\n",
    "                                results_2 = response['choices'][0]['message']['content']\n",
    "                                num += 1\n",
    "                                break\n",
    "                        except Exception as e:\n",
    "                                server_error_cnt += 1\n",
    "                                print(e)\n",
    "\n",
    "                if record_flag == False:\n",
    "\n",
    "                        logger.info('message_1: \\n' + dict2str(message_1)) \n",
    "                        logger.info('message_2: \\n' + dict2str(message_2)) \n",
    "\n",
    "                        record_flag = True\n",
    "\n",
    "                num += 1\n",
    "                number_list = re.findall(r'[FL]\\d+', results_2)\n",
    "                logger.info(\"statement: \" + text_pred )\n",
    "                logger.info(\"results_1: %s\", results_1)\n",
    "                logger.info(\"results_2: %s\", results_2)\n",
    "                logger.info('LLM: %s', number_list)\n",
    "                logger.info(\"grounding_truth: %s\", list2str(fact2explain[(h,r,t)]))\n",
    "                \n",
    "                if list_equal(number_list, fact2explain[(h,r,t)]):\n",
    "                        logger.info(\"correct\")\n",
    "                        true_num += 1\n",
    "                else:\n",
    "                        false_num += 1\n",
    "                \n",
    "                \n",
    "\n",
    "        logger.info(\"accuracy: \" + str( true_num / num ))\n",
    "        # test_questions = random.sample(triplets, int(len(triplets) * 0.2))\n",
    "        # train_questions = triplets.copy()\n",
    "        # for t in test_questions:\n",
    "        #     train_questions.remove(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
