{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "extra_relations = [\"greatAuntUncleOf\",\"grandparentOf\",\"greatGrandparentOf\",\"auntUncleOf\",\"siblingOf\",\"secondAuntUncleOf\",\"childOf\",\"grandchildOf\",\"greatGrandchildOf\",\"nieceNephewOf\",\"cousinOf\",\"secondCousinOf\",\"firstCousinOnceRemovedOf\", \"male\", \"female\"]\n",
    "\n",
    "for rel in extra_relations:\n",
    "    id2rel[rid] = rel\n",
    "    rel2id[rel] = rid\n",
    "    rid += 1\n",
    "with open(\"../rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym[rel] = '$' + sym + '$'\n",
    "        # rel2sym[rel] = sym "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += 'F' + str(fid) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += 'F' + str(fid) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                # fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "# def get_rule_text(rule_heads, rule2text):\n",
    "#     rule_text = ''\n",
    "#     for rule_head in rule_heads:\n",
    "#         rule_text += rule2text[rule_head]\n",
    "#     return rule_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rules = ''\n",
    "rel2rule = dict()\n",
    "rid = 1\n",
    "with open(\"latex_rules.txt\", 'r') as f1:\n",
    "    for line in f1:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "        rel2rule[rel] = rule\n",
    "        rules += 'L' + str(rid) + ': ' + rule + '\\n'\n",
    "        rid += 1\n",
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s\n",
    "def get_negative_samples(triplets, id2ent, id2rel, labels):\n",
    "    neg_samples = []\n",
    "    # random sample head or tail\n",
    "    \n",
    "    \n",
    "    for i in range(len(triplets)):\n",
    "        \n",
    "        while 1:\n",
    "            if random.random() < 0.5:\n",
    "                # sample head\n",
    "                h = random.randint(0, len(id2ent) - 1)\n",
    "                t = triplets[i][2]\n",
    "                r = triplets[i][1]\n",
    "                if (id2ent[int(h)], r, t) not in triplets:\n",
    "                    neg_samples.append((id2ent[int(h)], r, t))\n",
    "                    labels[(id2ent[int(h)], r, t)] = 0\n",
    "                    break\n",
    "            else:\n",
    "                h = triplets[i][0]\n",
    "                t = random.randint(0, len(id2ent) - 1)\n",
    "                r = triplets[i][1]\n",
    "                if (h, r, id2ent[int(t)]) not in triplets:\n",
    "                    neg_samples.append((h, r, id2ent[int(t)]))\n",
    "                    labels[(h, r, id2ent[int(t)])]=0\n",
    "                    break\n",
    "\n",
    "    return neg_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute F beta[0,0.1:1] score\n",
    "def compute_f_beta_score(precision, recall, logger):\n",
    "    for beta in range(0, 11):\n",
    "        beta = beta / 10\n",
    "        if precision == 0 and recall == 0:\n",
    "            logger.info('beta: '+ str(beta) + '\\rF score: 0')\n",
    "            \n",
    "        else:\n",
    "            score = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)\n",
    "            logger.info('beta: '+ str(beta) + '\\tF score: ' + str(score))\n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_related_triplets(h, t, G, entpair2rel, ent2class, fid):\n",
    "    selected_facts = ''\n",
    "    selected_triplets = []\n",
    "    selected_facts += 'F' + str(fid) + \": \" + ent2class[h] + '(' + h + ')\\n'\n",
    "    fid += 1\n",
    "    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=6)):\n",
    "        for edge in path:\n",
    "            # print(edge)\n",
    "            if edge in entpair2rel:\n",
    "                if (edge[0], entpair2rel[edge], edge[1]) not in selected_triplets:\n",
    "                    selected_triplets.append((edge[0], entpair2rel[edge], edge[1]))\n",
    "                    selected_facts += 'F' + str(fid) + \": \" + entpair2rel[edge] + '(' + edge[0] + ',' + edge[1] + ')\\n'\n",
    "                    fid += 1\n",
    "                # input_text += edge[0] + ' is the ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "            else:\n",
    "                if (edge[1], entpair2rel[(edge[1],edge[0])], edge[0]) not in selected_triplets:\n",
    "                    selected_triplets.append((edge[1], entpair2rel[(edge[1],edge[0])], edge[0]))\n",
    "                    selected_facts += 'F' + str(fid) + \": \" + entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ',' + edge[0] + ')\\n'\n",
    "                    fid += 1\n",
    "                # input_text += edge[1] + ' is the ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "    return selected_facts, fid\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_openai_keys():\n",
    "    with open('../openai_keys.txt', \"r\") as f:\n",
    "        context = f.read()\n",
    "    # print(context.split('\\n'))\n",
    "    return context.split('\\n')\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_subgraph'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "model = \"gpt-3.5-turbo\"\n",
    "# model = \"gpt-4\"\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "record_flag = False\n",
    "F1_score = 0\n",
    "for i in range(0, 1):\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "\n",
    "    eid = 0\n",
    "\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    class_text = ''\n",
    "    triplets = []\n",
    "    test_triplets = []\n",
    "    labels = dict()\n",
    "    entpair2rel = dict() \n",
    "    basic_facts = ''\n",
    "    statement = ''\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class,id2ent, class_text, fid)\n",
    "    # print(i)\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    basic_facts += class_text\n",
    "    with open(path,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            # basic_facts += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            # fid += 1\n",
    "        \n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "\n",
    "    with open(path,'r') as f:\n",
    "        for line in tqdm(f):\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                labels[(id2ent[int(h)], id2rel[int(r)], id2ent[int(t)])] = 1\n",
    "    negative_samples = get_negative_samples(test_triplets, id2ent, id2rel, labels)\n",
    "            # if flag == '-':\n",
    "            #     test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            #     labels.append(0)\n",
    "\n",
    "    num = 0\n",
    "    true_num = 0\n",
    "    false_num = 0\n",
    "\n",
    "    pos_true = 0\n",
    "    pos_false = 0\n",
    "\n",
    "    neg_true = 0\n",
    "    neg_false = 0\n",
    "\n",
    "    predicted_facts = test_triplets + negative_samples\n",
    "    # random order in predicted_facts\n",
    "    random.shuffle(predicted_facts)\n",
    "    G = nx.Graph()\n",
    "    for tri in triplets:\n",
    "        G.add_edge(tri[0],tri[2])\n",
    "        \n",
    "    for triple in tqdm(predicted_facts):\n",
    "        h, r, t = triple\n",
    "        statement = rel2sym[r] + '(' + h + ', ' + t + ')'\n",
    "        fid = 1\n",
    "        selected_facts, fid = get_related_triplets(h, t, G, entpair2rel, ent2class, fid)\n",
    "        # message = {\n",
    "        #             'systerm': \"You are a helpful assistant.\",\n",
    "        #             'user': \"I will provide you with logical rules and facts. Please identify all paths connecting \" + h + \"with \" + t + \". Then, predict the correctness of the following statement using deductive reasoning. \\nLogical rules: \" + rel2rule[rel2sym[r]] + \"\\nFacts: \" + basic_facts + \"\\nStatement: \" + statement + \"\\nOutput True or False?\"\n",
    "        #         }\n",
    "        # message = {\n",
    "        #             'systerm': \"You are a helpful assistant with deductive reasoning abilities. You can first identify the logical rule relevant to the relation \" + rel2sym[r] + \" and then find all paths connecting \" + h + \" with \" + t + \". Based on this information, predict the correctness of the following statement using deductive reasoning.\",\n",
    "        #             'user': \"I will provide a set of logical rules and facts. Please identify the logical rules relevant to the relation \" + rel2sym[r] + \" and find all paths connecting \" + h + \" with \" + t + \". Based on this information, predict the correctness of the following statement using deductive reasoning.\\nLogical rules: \" + rules + \"\\nFacts: \" + basic_facts + \"\\nStatement: \" + statement + \"\\nPlease answer with only True, False or Unknown. The answer is: \"\n",
    "        #         }\n",
    "        \n",
    "        # message = {\n",
    "        #             'systerm': \"You are a helpful assistant with deductive reasoning abilities. You can first identify the logical rule relevant to the relation \" + rel2sym[r] + \" and then find all paths connecting \" + h + \" with \" + t + \". Based on this information, predict the correctness of the following statement using deductive reasoning.\",\n",
    "        #             'user': \"I will provide a set of logical rules and facts. Please identify the logical rules relevant to the relation \" + rel2sym[r] + \" and find all paths connecting \" + h + \" with \" + t + \". Based on this information, determine whether the following statement can be inferred.\\nLogical rules: \" + rules + \"\\nFacts: \" + basic_facts + \"\\nStatement: \" + statement + \"\\nPlease answer with only Yes, No or Unknown. The answer is: \"\n",
    "        #         }\n",
    "        # message = {\n",
    "        #             'system': \"You are a helpful assistant.\",\n",
    "        #             'user': \"I will provide a set of logical rules L1 to L28 and facts F1 to F63. Please predict True/False of the following statement using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "                \n",
    "        #         }\n",
    "        # message = {\n",
    "        #             'system': \"You are a helpful assistant with deductive reasoning abilities. \",\n",
    "        #             'user': \"I will provide a set of logical rules and facts. Please select one single logical rule from L1 to L28 and a few facts from F1 to F63 to predict True/False of the unknown fact using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nUnknown fact: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "        #         }\n",
    "        # message = {\n",
    "        #             'system': \"You are a helpful assistant with deductive reasoning abilities. \",\n",
    "        #             'user': \"Given a set of rules and facts, you have to reason whether a statement is True or False.\\nHere are some rules:\\n\" + rules + \"\\nHere are some facts:\\n\" + basic_facts + \"\\nDoes it imply that the statement \\\"\"+ statement + \"\\\" is True?\\nThe answer (YES or NO) is: \"\n",
    "        #         }\n",
    "        # message = {\n",
    "        #             'system': \"You are a helpful assistant with deductive reasoning abilities.\",\n",
    "        #             'user': \"I will provide a set of logical rules L1 to L28 and facts F1 to F63. Please predict True/False of the unknown fact using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nUnknown fact: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "                \n",
    "        #         }\n",
    "        # message = {\n",
    "        #             'system': \"You are a helpful assistant with deductive reasoning abilities. \",\n",
    "        #             'user': \"I will provide a set of logical rules L1 to L\" + str(rid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(rid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the unknown fact using deductive reasoning.\\nLogical rules:\\n\" + rel2rule[rel2sym[r]] + \"\\nFacts:\\n\" + selected_facts + \"\\nUnknown fact: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "        #         }\n",
    "        message = {\n",
    "                    'system': \"You are a helpful assistant with deductive reasoning abilities. \",\n",
    "                    'user': \"I will provide a single rule and facts F1 to F\" + str(fid - 1) + \". Please use the rule and select a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the unknown fact.\\nLogical rule:\\n\" + rel2rule[rel2sym[r][1:-1]] + \"\\nFacts:\\n\" + selected_facts + \"\\nUnknown fact: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "                }\n",
    "        # message = {\n",
    "        #             'system': \"You are a helpful assistant.\",\n",
    "        #             'user': \"I will provide a set of logical rules and facts. Please predict True/False of the following statement using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "        #         }\n",
    "        server_error_cnt = 0\n",
    "        while server_error_cnt<10:\n",
    "            try:\n",
    "                update_key()\n",
    "                response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message['system']},\n",
    "                        {\"role\": \"user\", \"content\": message['user']},\n",
    "                ],\n",
    "                temperature=0,\n",
    "                )\n",
    "                logger.info('message: \\n' + dict2str(message))\n",
    "                \n",
    "\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                num += 1\n",
    "                \n",
    "                ans = results.split('.')[0]\n",
    "                if labels[(h, r, t)] == 1:\n",
    "                    if 'True' in ans:\n",
    "                        true_num += 1\n",
    "                        pos_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                    elif 'False' in ans:\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                    elif 'Unknown' in ans:\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                        print(results)\n",
    "                else:\n",
    "                    if 'True' in ans :\n",
    "                        false_num += 1\n",
    "                        neg_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                    elif 'False' in ans:\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                    elif 'Unknown' in ans:\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                        print(results)\n",
    "                \n",
    "                logger.info('triplet: ' + statement + '\\tgrounding truth: ' + str(labels[(h, r, t)]) + '\\tprediction: ' + results )\n",
    "\n",
    "                break\n",
    "\n",
    "            except Exception as e:\n",
    "                server_error_cnt += 1\n",
    "                logger.info(e)\n",
    "\n",
    "    logger.info(str(i) + ': ' + str(true_num / num))\n",
    "    logger.info('pos_acc: ' + str(pos_true / (pos_true + pos_false)))\n",
    "    logger.info('neg_acc: ' + str(neg_true / (neg_true + neg_false)))\n",
    "    TP = pos_true\n",
    "    FN = pos_false\n",
    "    FP = neg_false\n",
    "    TN = neg_true\n",
    "    logger.info('precision: ' + str(TP / (TP + FP)))\n",
    "    logger.info('recall: ' + str(TP / (TP + FN)))\n",
    "    # compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)\n",
    "    F1_score += compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)\n",
    "logger.info('F1_scores: ' + str(F1_score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TP = pos_true\n",
    "FN = pos_false\n",
    "FP = neg_false\n",
    "TN = neg_true\n",
    "logger.info('precision: ' + str(TP / (TP + FP)))\n",
    "logger.info('recall: ' + str(TP / (TP + FN)))\n",
    "compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
