{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n",
    "\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys_filter.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "# extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "# for rel in extra_relations:\n",
    "#     id2rel[rid] = rel\n",
    "#     rel2id[rel] = rid\n",
    "#     rid += 1\n",
    "\n",
    "with open(\"rel2sym.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "      \n",
    "        rel2sym[rel] = sym "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "# def get_rule_text(rule_heads, rule2text):\n",
    "#     rule_text = ''\n",
    "#     for rule_head in rule_heads:\n",
    "#         rule_text += rule2text[rule_head]\n",
    "#     return rule_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rules = ''\n",
    "rel2rule = dict()\n",
    "lid = 1\n",
    "with open(\"latex_rules.txt\", 'r') as f1:\n",
    "    for line in f1:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "        rel2rule[rel] = rule\n",
    "        rules += 'L' + str(lid) + ': ' + rule + '\\n'\n",
    "        lid += 1\n",
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s\n",
    "def get_negative_samples(triplets, id2ent, id2rel, labels):\n",
    "    neg_samples = []\n",
    "    # random sample head or tail\n",
    "    \n",
    "    \n",
    "    for i in range(len(triplets)):\n",
    "        \n",
    "        while 1:\n",
    "            if random.random() < 0.5:\n",
    "                # sample head\n",
    "                h = random.randint(0, len(id2ent) - 1)\n",
    "                t = triplets[i][2]\n",
    "                r = triplets[i][1]\n",
    "                if (id2ent[int(h)], r, t) not in triplets:\n",
    "                    neg_samples.append((id2ent[int(h)], r, t))\n",
    "                    labels[(id2ent[int(h)], r, t)] = 0\n",
    "                    break\n",
    "            else:\n",
    "                h = triplets[i][0]\n",
    "                t = random.randint(0, len(id2ent) - 1)\n",
    "                r = triplets[i][1]\n",
    "                if (h, r, id2ent[int(t)]) not in triplets:\n",
    "                    neg_samples.append((h, r, id2ent[int(t)]))\n",
    "                    labels[(h, r, id2ent[int(t)])]=0\n",
    "                    break\n",
    "\n",
    "    return neg_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute F beta[0,0.1:1] score\n",
    "def compute_f_beta_score(precision, recall, logger):\n",
    "    for beta in range(0, 11):\n",
    "        beta = beta / 10\n",
    "        if precision == 0 and recall == 0:\n",
    "            logger.info('beta: '+ str(beta) + '\\rF score: 0')\n",
    "            \n",
    "        else:\n",
    "            score = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)\n",
    "            logger.info('beta: '+ str(beta) + '\\tF score: ' + str(score))\n",
    "            print('beta: '+ str(beta) + '\\tF score: ' + str(score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_zero_shot_cot'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "model = \"gpt-3.5-turbo\"\n",
    "# model = \"gpt-4\"\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "record_flag = False\n",
    "\n",
    "for i in range(0, 10):\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "\n",
    "    eid = 0\n",
    "\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    class_text = ''\n",
    "    triplets = []\n",
    "    test_triplets = []\n",
    "    labels = dict()\n",
    "    entpair2rel = dict() \n",
    "    basic_facts = ''\n",
    "    statement = ''\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class,id2ent, class_text, fid)\n",
    "    # print(i)\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    basic_facts += class_text\n",
    "    with open(path,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            basic_facts += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            fid += 1\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "            # text += (id2rel[int(r)] + '(' + id2ent[int(h)], id2ent[int(t)] +')')\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "\n",
    "    with open(path,'r') as f:\n",
    "        for line in tqdm(f):\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                labels[(id2ent[int(h)], id2rel[int(r)], id2ent[int(t)])] = 1\n",
    "    negative_samples = get_negative_samples(test_triplets, id2ent, id2rel, labels)\n",
    "            # if flag == '-':\n",
    "            #     test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            #     labels.append(0)\n",
    "\n",
    "    num = 0\n",
    "    true_num = 0\n",
    "    false_num = 0\n",
    "\n",
    "    pos_true = 0\n",
    "    pos_false = 0\n",
    "\n",
    "    neg_true = 0\n",
    "    neg_false = 0\n",
    "\n",
    "    predicted_facts = test_triplets + negative_samples\n",
    "    # random order in predicted_facts\n",
    "    random.shuffle(predicted_facts)\n",
    "    for triple in tqdm(predicted_facts):\n",
    "        h, r, t = triple\n",
    "        statement = rel2sym[r] + '(' + h + ', ' + t + ')'\n",
    "        \n",
    "        server_error_cnt = 0\n",
    "        while server_error_cnt<10:\n",
    "            try:\n",
    "                update_key()\n",
    "                # message_1 = {\n",
    "                #     'system': \"You are a helpful assistant with deductive reasoning abilities.\",\n",
    "                #     'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the following statement using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nAnswer with only True or False? Let's think step by step.\"\n",
    "                \n",
    "                # }\n",
    "                update_key()\n",
    "                message_1 = {\n",
    "                    'system': \"Please select one single logical rule and a few facts to predict True/False of the following statement.\",\n",
    "                    'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the following statement.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nLet's think step by step.\"\n",
    "                \n",
    "                }\n",
    "                response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message_1['system']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['user']},\n",
    "                ],\n",
    "                temperature=0,\n",
    "                # max_tokens = 2096,\n",
    "                )\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                message_2 = {\n",
    "                    'system': \"Please predict True/False of the following statement.\",\n",
    "                    'user': message_1['user'] + '\\n' + results + '\\nTherefore, the answer (True or False) is: '\n",
    "                }\n",
    "                response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message_2['system']},\n",
    "                        {\"role\": \"user\", \"content\": message_2['user']},\n",
    "                ],\n",
    "                temperature=0,\n",
    "                # max_tokens = 2096,\n",
    "                )\n",
    "                if record_flag == False:\n",
    "                    logger.info('message_1: \\n' + dict2str(message_1))\n",
    "                    logger.info('message_2: \\n' + dict2str(message_2))\n",
    "\n",
    "                    record_flag = True\n",
    "\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                num += 1\n",
    "\n",
    "                if len(results.split('.')) >= 1:\n",
    "                    last_sentence = results.split('.')[0]\n",
    "                elif len(results.split('\\n')) >= 1:\n",
    "                    last_sentence = results.split('\\n')[0]\n",
    "                else:\n",
    "                    last_sentence = results\n",
    "                    logger.info('output: ' + results)\n",
    "                false_words = ['indeterminate', 'FALSE', 'Unknown', 'unknown', 'not', 'False', ' no ', \"inconclusive\", \"undefined\", \"invalid\", 'false']\n",
    "\n",
    "                # if last sentence contain one of false words\n",
    "                \n",
    "                if labels[(h, r, t)] == 1:\n",
    "                    if any(word in last_sentence for word in false_words):\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                    elif 'True' in last_sentence or 'true' in last_sentence or \"TRUE\" in last_sentence:\n",
    "                        true_num += 1\n",
    "                        pos_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "\n",
    "                    else:\n",
    "                        # false_num += 1\n",
    "                        # pos_false += 1\n",
    "                        # logger.info('correctness: ' + 'Incorrect')\n",
    "\n",
    "                        print(last_sentence)\n",
    "                else:\n",
    "                    if any(word in last_sentence for word in false_words):\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                        \n",
    "                    elif 'True' in last_sentence or 'true' in last_sentence or \"TRUE\" in last_sentence:\n",
    "                        false_num += 1\n",
    "                        neg_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "\n",
    "                    else:\n",
    "                        # true_num += 1\n",
    "                        # neg_true += 1\n",
    "                        # logger.info('correctness: ' + 'Correct')\n",
    "\n",
    "                        print(last_sentence)\n",
    "                \n",
    "                \n",
    "                logger.info('statement: ' + statement + '\\tgrounding truth: ' + str(labels[(h, r, t)]) + '\\tprediction: ' + results )\n",
    "                \n",
    "                \n",
    "                break\n",
    "\n",
    "            except Exception as e:\n",
    "                server_error_cnt += 1\n",
    "                logger.info(e)\n",
    "    logger.info(str(i) + ': ' + str(true_num / num))\n",
    "    logger.info('pos_acc: ' + str(pos_true / (pos_true + pos_false)))\n",
    "    logger.info('neg_acc: ' + str(neg_true / (neg_true + neg_false)))\n",
    "    TP = pos_true\n",
    "    FN = pos_false\n",
    "    FP = neg_false\n",
    "    TN = neg_true\n",
    "    logger.info('precision: ' + str(TP / (TP + FP)))\n",
    "    logger.info('recall: ' + str(TP / (TP + FN)))\n",
    "    compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 正则表达式提取倒数第二行\n",
    "import re\n",
    "def get_last_line(text):\n",
    "    # last_line = text.strip().split('\\n')[-1]\n",
    "    second_last_line = re.findall('.+', text)[0]\n",
    "    # last_sentence = re.findall(r\"(.*\\n){2}$\", text)\n",
    "    return last_sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = ''' \n",
    "Now, we can apply L3 with A=Alina and B=Elena to get motherOf(Alina, Elena). Therefore, the statement \"motherOf(Alina, Elena)\" is true. \n",
    "\n",
    "The logical rule used is L3 and the relevant facts are F36, F46, F49.\n",
    "\n",
    "'''\n",
    "get_last_line(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
