{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys_filter.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "random.shuffle(openai_api_keys)\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n",
    "\n",
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "# multiple-list save to logging\n",
    "def list2str(l):\n",
    "    s = ''\n",
    "    for i in l:\n",
    "        s += str(i) + '\\t'\n",
    "    return s\n",
    "def list_equal(a, answers):\n",
    "    # if set(a) == set(b):\n",
    "    #     return True\n",
    "    # else:\n",
    "    #     return False\n",
    "    for b in answers:\n",
    "        # if set(a) contain set(b)\n",
    "        if set(b) == set(a):\n",
    "            return True\n",
    "        # if set(a) == set(b):\n",
    "        #     return True\n",
    "    return False\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "rel2sym_2 = dict()\n",
    "relation_txt = ''\n",
    "\n",
    "infer_rel = list()\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "\n",
    "        infer_rel.append(rel)\n",
    "\n",
    "        relation_txt += rel + ', '\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "    \n",
    "with open(\"rel2sym.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        if 'Of' in sym:\n",
    "\n",
    "            rel2sym[rel] = sym[:-2]\n",
    "        else:\n",
    "            rel2sym[rel] = sym\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rh2rules = dict()\n",
    "lid = 1\n",
    "rule_text = ''\n",
    "with open('natural_rules.txt','r') as f:\n",
    "    for line in f:\n",
    "        rh2rules[line.strip().split('\\t')[0]] = line.strip().split('\\t')[-1]\n",
    "        rule_text += 'L' + str(lid) + \": \" + line.strip().split('\\t')[-1] + '\\n'\n",
    "        lid += 1\n",
    "# print(rule_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def get_related_triplets(h, t, G, entpair2rel):\n",
    "    input_text = ''\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "        for edge in path:\n",
    "            # print(edge)\n",
    "            if edge in entpair2rel:\n",
    "                # input_text += entpair2rel[edge] + '(' + edge[0] + ',' + edge[1] + ')\\n'\n",
    "                input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '.\\n'\n",
    "            else:\n",
    "                # input_text += entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ',' + edge[0] + ')\\n'\n",
    "                input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '.\\n'\n",
    "    return input_text\n",
    "\n",
    "def read_all_triplets(path1, path2, id2ent, text):\n",
    "    triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            # text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "                # text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\\n'\n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "\n",
    "    return triplets, entpair2rel, text\n",
    "\n",
    "\n",
    "def read_class(path, cid, ent2class, id2ent, class_text):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                # ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                class_text += id2ent[cid] + ' is ' + rel2sym[\"female\"] + '.\\n'\n",
    "                # class_text += rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is ' + rel2sym['female'] + '.\\n'\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                # ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                class_text += id2ent[cid] + ' is '+ rel2sym['male'] + '.\\n'\n",
    "                # class_text += rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text\n",
    "\n",
    "\n",
    "def read_all_facts(path1, path2, path_class, id2ent, ent2class, cid, text):\n",
    "    f_id = 1\n",
    "    triplets = list()\n",
    "    test_triplets = list()\n",
    "    entpair2rel = dict()\n",
    "    edges = list()\n",
    "    tri2number = dict()\n",
    "\n",
    "    with open(path1,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            edges.append((id2ent[int(h)], id2ent[int(t)]))\n",
    "            # edges.append((id2ent[int(t)], id2ent[int(h)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(t)], id2ent[int(h)])] = rel2sym['inverse_' + id2rel[int(r)]]\n",
    "            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "            text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            # text += 'F' + str(f_id) + \": \" + rel2sym[id2rel[(int(r))]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            \n",
    "            tri2number[(id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)])] = 'F' + str(f_id)\n",
    "            f_id += 1\n",
    "            # text += 'F' + str(f_id) + \": \" + id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            # # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\\n'\n",
    "            \n",
    "            # tri2number[(id2ent[int(t)], rel2sym['inverse_' + id2rel[int(r)]], id2ent[int(h)])] = 'F' + str(f_id)\n",
    "            \n",
    "            # f_id += 1\n",
    "            \n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path2,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                test_triplets.append((id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)]))\n",
    "                # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]\n",
    "                \n",
    "                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                # text += 'F' + str(f_id) + \": \" + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '\n",
    "                # f_id += 1\n",
    "                \n",
    "                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '\n",
    "\n",
    "                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "    with open(path_class, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "        \n",
    "                text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is ' + rel2sym[\"female\"] + '.\\n'\n",
    "                # text += 'F' + str(f_id) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # text += id2ent[cid] + ' is the ' + rel2sym[\"female\"] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['female'])] = 'F' + str(f_id)\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                text += 'F' + str(f_id) + \": \" + id2ent[cid] + ' is '+ rel2sym['male'] + '.\\n'\n",
    "                # text += 'F' + str(f_id) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "\n",
    "                # text += id2ent[cid] + ' is the '+ rel2sym['male'] + '.\\n'\n",
    "\n",
    "                tri2number[(id2ent[cid], 'gender', rel2sym['male'])] = 'F' + str(f_id)\n",
    "                \n",
    "                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "            f_id += 1\n",
    "            cid += 1\n",
    "            \n",
    "    return triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, f_id\n",
    "    \n",
    "def get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules,  tri2number, rule2number):\n",
    "\n",
    "    # Define the logical rules that the paths should match\n",
    "    def logical_rules(entpair2rel, path, rule):\n",
    "        path_number = list()\n",
    "        for i in range(len(path)-1):\n",
    "            if (path[i], path[i+1]) in entpair2rel:\n",
    "                if entpair2rel[(path[i], path[i+1])] == rule[i]:\n",
    "                    path_number.append(tri2number[(path[i], entpair2rel[(path[i], path[i+1])], path[i+1])])\n",
    "            elif 'inverse_parent' == rule[i]:\n",
    "                path_number.append(tri2number[(path[i+1], entpair2rel[(path[i+1], path[i])], path[i])])\n",
    "\n",
    "            else:\n",
    "                \n",
    "                return None\n",
    "        return path_number\n",
    "\n",
    "\n",
    "\n",
    "    # Define your knowledge graph using the NetworkX library\n",
    "    G = nx.Graph()\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "    fact2explain = dict()\n",
    "    fact2rule = dict()\n",
    "    for tri in test_triplets:\n",
    "        h = tri[0]\n",
    "        r = tri[1]\n",
    "        t = tri[2]\n",
    "        rule = rel2rules[r]\n",
    "        length = len(rule)\n",
    "        \n",
    "        all_paths = list()\n",
    "        # Find all paths that match the logical rules using NetworkX's all_simple_paths() function\n",
    "        for path in nx.all_simple_paths(G, source=h, target=t, cutoff=length):\n",
    "            # print(\"path\", path)\n",
    "            path_number = logical_rules(entpair2rel, path, rule)\n",
    "            if path_number:\n",
    "                if ent2class[h] == rule[-1]:\n",
    "                    path_number.append(tri2number[(h,'gender',ent2class[h])])\n",
    "                    \n",
    "                    # path_number.append(rule2number[r])\n",
    "                    all_paths.append(path_number)\n",
    "        fact2explain[(h, r, t)] = all_paths\n",
    "        # fact2explain[(h,r,t)] = [rule2number[r]]\n",
    "        # fact2rule[(h,r,t)] = rule2number[r]\n",
    "\n",
    "    return fact2explain\n",
    "\n",
    "\n",
    "def read_rules(path, rel2sym):\n",
    "    rel2rules = dict()\n",
    "    rule2number = dict()\n",
    "    l_id = 1\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            lst = line.strip().split('\\t')\n",
    "            # replace symbol \n",
    "            new_lst = list()\n",
    "            for l in lst:\n",
    "                new_lst.append(rel2sym[l])\n",
    "            \n",
    "            rel2rules[new_lst[0]] = new_lst[1:]\n",
    "            rule2number[new_lst[0]] = 'L' + str(l_id)\n",
    "            l_id += 1\n",
    "    return rel2rules, rule2number, l_id\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read few_shot_prompts.txt to get the prompts\n",
    "\n",
    "def read_few_shot_prompts(path):\n",
    "    prompts = list()\n",
    "    prompts_plus = list()\n",
    "    with open(path, 'r') as f:\n",
    "        blocks = f.read().split('\\n\\n')\n",
    "        for block in blocks:\n",
    "            # 正则表达式提取statement以后，answer之前的内容\n",
    "            statement = re.findall(r'Statement: (.*?)\\n', block)[0]\n",
    "\n",
    "            \n",
    "            d = {}\n",
    "            d['Statement'] = \"Statement: \" + statement\n",
    "            answer = re.findall(r'Answer: (.*)', block, flags=re.DOTALL)[0]\n",
    "            d['Answer'] = \"Answer: \" + answer\n",
    "            prompts.append(d)\n",
    "\n",
    "            d_plus = {}\n",
    "            d_plus['Statement'] = \"Statement: \" + statement\n",
    "            d_plus['Answer'] = \"Answer: Let's think step by step. \" + answer\n",
    "            prompts_plus.append(d_plus)\n",
    "            \n",
    "            \n",
    "            \n",
    "    print(prompts)\n",
    "    return prompts, prompts_plus\n",
    "prompts, prompts_plus = read_few_shot_prompts('few_shot_prompts_f.txt')\n",
    "len(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/natural_few_shot_f'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "model = 'gpt-4'\n",
    "# model = \"gpt-3.5-turbo\"\n",
    "\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "for i in range(0, 1):\n",
    "\n",
    "        eid = 0\n",
    "        cid = 0\n",
    "        id2ent = dict()\n",
    "        ent2id = dict()\n",
    "        ent2class = dict()\n",
    "        ent2triplets = dict()\n",
    "        class_text = ''\n",
    "        text = ''\n",
    "\n",
    "        path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "        path_rel1 = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "        path_rel2 = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "        path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "\n",
    "        path_rule = 'rule_tab.txt'\n",
    "\n",
    "        eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "        \n",
    "\n",
    "        triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, fid = read_all_facts(path_rel1, path_rel2, path_class, id2ent, ent2class, cid, text)\n",
    "\n",
    "        rel2rules, rule2number, lid = read_rules(path_rule, rel2sym)\n",
    "        fact2explain = get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules, tri2number, rule2number)\n",
    "        true_num = 0\n",
    "        false_num = 0\n",
    "        num = 0\n",
    "\n",
    "\n",
    "        record_flag = False\n",
    "        # prompts, prompts_plus = generate_few_shot_prompts(test_triplets, model, lid, fid, rule_text, text,fact2explain)\n",
    "\n",
    "        for triple in tqdm(test_triplets):\n",
    "                h, r, t = triple\n",
    "                # text_pred = r + '(' + h + ', ' + t + ')'\n",
    "                text_pred = h + ' is ' + r + ' of ' + t + '.'\n",
    "               \n",
    "                message_1 = {\n",
    "                        # 'system': \"You are a helpful assistant. I will give you a logical rule, some facts and a statement. Based on the given logical rules and facts, please select one logical rule and multiple facts to explain the statement. \",\n",
    "\n",
    "                        # 'system': \"You are a helpful assistant. I will give you some logical rules, facts and a statement. Please select one logical rule and multiple facts to explain the statement. \",\n",
    "                        # 'user': \"I will give you some logical rules, facts and a statetement. Please select one logical rule and multiple facts to explain the statement. \\nThe logical rules are:\\n\" + rule_text + \"\\nThe facts are:\\n\" + text,\n",
    "                        'system': \"You are a helpful assistant with abductive reasoning abilities. Please select a few facts to explain the following statement. \",\n",
    "                        'user': \"I will provide a set of facts F1 to F\" + str(fid - 1) + \". Please select a few facts from F1 to F\" + str(fid - 1) + \" to explain the following statement. \" \n",
    "                        + \"\\nFacts:\\n\" + text,\n",
    "                        'Q1': prompts[0]['Statement'] ,\n",
    "                        'A1': prompts[0]['Answer'],\n",
    "                        'Q2': prompts[1]['Statement'],\n",
    "                        'A2': prompts[1]['Answer'],\n",
    "                        'Q3': prompts[2]['Statement'],\n",
    "                        'A3': prompts[2]['Answer'],\n",
    "                        'Q4': prompts[3]['Statement'],\n",
    "                        'A4': prompts[3]['Answer'],\n",
    "                        'Q5': prompts[4]['Statement'],\n",
    "                        'A5': prompts[4]['Answer'],\n",
    "                        # 'Q6': prompts[5]['Statement'],\n",
    "                        # 'A6': prompts[5]['Answer'],\n",
    "                        'Q7': \"Statement: \" + text_pred + '\\nAnswer: ',\n",
    "                \n",
    "                        }\n",
    "                server_error_cnt = 0\n",
    "                \n",
    "                while server_error_cnt<10:\n",
    "                        try:\n",
    "                        \n",
    "                                # update_key()\n",
    "                                \n",
    "                                response = openai.ChatCompletion.create(\n",
    "                                \n",
    "                                model = model,\n",
    "\n",
    "\n",
    "                                messages=[\n",
    "                                        {\"role\": \"system\", \"content\": message_1['system']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['user']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q1']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A1']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q2']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A2']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q3']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A3']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q4']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A4']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q5']},\n",
    "                                        {\"role\": \"assistant\", \"content\": message_1['A5']},\n",
    "                                        # {\"role\": \"user\", \"content\": message_1['Q6']},\n",
    "                                        # {\"role\": \"assistant\", \"content\": message_1['A6']},\n",
    "                                        {\"role\": \"user\", \"content\": message_1['Q7']},\n",
    "                                        ],\n",
    "                                temperature=0,\n",
    "\n",
    "                                )\n",
    "                                results_1 = response['choices'][0]['message']['content']\n",
    "                                last_text = message_1['user'] + message_1['Q1'] + message_1['A1'] + message_1['Q2'] + message_1['A2'] + message_1['Q3'] + message_1['A3'] + message_1['Q4'] + message_1['A4'] + message_1['Q5'] + message_1['A5'] + message_1['Q7']\n",
    "                                message_2 = {\n",
    "                                                # 'systerm': \"You are a helpful assistant. I will give you some logical rules and facts. Please output the numbers (e.g. F1) of selected logical rule and facts. \",\n",
    "\n",
    "                                                'system': \"Please output the numbers (e.g. L1, F1, F2) of the selected logical rule and facts. \",\n",
    "                                                'user' : last_text + '\\n' + results_1 + \"\\nTherefore, the selected facts are: \"\n",
    "                                        \n",
    "                                        }\n",
    "                                # update_key()\n",
    "                                response = openai.ChatCompletion.create(\n",
    "                                model= model,\n",
    "                                messages=[\n",
    "                                        {\"role\": \"system\", \"content\": message_2['system']},\n",
    "                                        {\"role\": \"user\", \"content\": message_2['user']},\n",
    "                                ],\n",
    "                                temperature=0,\n",
    "                                )\n",
    "                                results = response['choices'][0]['message']['content']\n",
    "                                num += 1\n",
    "                                break\n",
    "                        except Exception as e:\n",
    "                                server_error_cnt += 1\n",
    "                                print(e)\n",
    "\n",
    "                if record_flag == False:\n",
    "\n",
    "                        logger.info('message_1: \\n' + dict2str(message_1)) \n",
    "                        logger.info('message_2: \\n' + dict2str(message_2)) \n",
    "\n",
    "                        record_flag = True\n",
    "\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                num += 1\n",
    "                number_list = re.findall(r'[FL]\\d+', results)\n",
    "                # logger.info(\"statement: \" + text_pred + '\\t' + \"LLM: \" + ' '.join(number_list) + \"\\tgrounding_truth: \" + list2str(fact2explain[(h,r,t)]) + '\\t' + \"results: \" + results)\n",
    "                logger.info(\"statement: \" + text_pred )\n",
    "                logger.info('LLM: %s', number_list)\n",
    "                logger.info(\"grounding_truth: %s\", list2str(fact2explain[(h,r,t)]))\n",
    "                if list_equal(number_list, fact2explain[(h,r,t)]):\n",
    "                        logger.info(\"correct\")\n",
    "                        true_num += 1\n",
    "                else:\n",
    "                        false_num += 1\n",
    "                \n",
    "                \n",
    "\n",
    "        logger.info(\"accuracy: \" + str( true_num / num ))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
