{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "openai.api_key = \"\"\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "with open(\"../rel2sym_ri.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        rel2sym[rel] = '$' + sym + '$'\n",
    "        # rel2sym[rel] = sym "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                class_text += 'F' + str(fid) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "# def get_rule_text(rule_heads, rule2text):\n",
    "#     rule_text = ''\n",
    "#     for rule_head in rule_heads:\n",
    "#         rule_text += rule2text[rule_head]\n",
    "#     return rule_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rules = ''\n",
    "rel2rule = dict()\n",
    "lid = 1\n",
    "with open(\"latex_rules.txt\", 'r') as f1:\n",
    "    for line in f1:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "        rel2rule[rel] = rule\n",
    "        rules += 'L' + str(lid) + ': ' + rule + '\\n'\n",
    "        lid += 1\n",
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s\n",
    "def get_negative_samples(triplets, id2ent, id2rel, labels):\n",
    "    neg_samples = []\n",
    "    # random sample head or tail\n",
    "    \n",
    "    \n",
    "    for i in range(len(triplets)):\n",
    "        \n",
    "        while 1:\n",
    "            if random.random() < 0.5:\n",
    "                # sample head\n",
    "                h = random.randint(0, len(id2ent) - 1)\n",
    "                t = triplets[i][2]\n",
    "                r = triplets[i][1]\n",
    "                if (id2ent[int(h)], r, t) not in triplets:\n",
    "                    neg_samples.append((id2ent[int(h)], r, t))\n",
    "                    labels[(id2ent[int(h)], r, t)] = 0\n",
    "                    break\n",
    "            else:\n",
    "                h = triplets[i][0]\n",
    "                t = random.randint(0, len(id2ent) - 1)\n",
    "                r = triplets[i][1]\n",
    "                if (h, r, id2ent[int(t)]) not in triplets:\n",
    "                    neg_samples.append((h, r, id2ent[int(t)]))\n",
    "                    labels[(h, r, id2ent[int(t)])]=0\n",
    "                    break\n",
    "\n",
    "    return neg_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute F beta[0,0.1:1] score\n",
    "def compute_f_beta_score(precision, recall, logger):\n",
    "    for beta in range(0, 11):\n",
    "        beta = beta / 10\n",
    "        if precision == 0 and recall == 0:\n",
    "            logger.info('beta: '+ str(beta) + '\\rF score: 0')\n",
    "            \n",
    "        else:\n",
    "            score = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)\n",
    "            logger.info('beta: '+ str(beta) + '\\tF score: ' + str(score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/first_order_standard'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "model = \"gpt-3.5-turbo\"\n",
    "# model = \"gpt-4\"\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "record_flag = False\n",
    "\n",
    "for i in range(0, 1):\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "\n",
    "    eid = 0\n",
    "\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    class_text = ''\n",
    "    triplets = []\n",
    "    test_triplets = []\n",
    "    labels = dict()\n",
    "    entpair2rel = dict() \n",
    "    basic_facts = ''\n",
    "    statement = ''\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class,id2ent, class_text, fid)\n",
    "    # print(i)\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    basic_facts += class_text\n",
    "    with open(path,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            basic_facts += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            fid += 1\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "            # text += (id2rel[int(r)] + '(' + id2ent[int(h)], id2ent[int(t)] +')')\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "\n",
    "    with open(path,'r') as f:\n",
    "        for line in tqdm(f):\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                labels[(id2ent[int(h)], id2rel[int(r)], id2ent[int(t)])] = 1\n",
    "    negative_samples = get_negative_samples(test_triplets, id2ent, id2rel, labels)\n",
    "            # if flag == '-':\n",
    "            #     test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            #     labels.append(0)\n",
    "\n",
    "    num = 0\n",
    "    true_num = 0\n",
    "    false_num = 0\n",
    "\n",
    "    pos_true = 0\n",
    "    pos_false = 0\n",
    "\n",
    "    neg_true = 0\n",
    "    neg_false = 0\n",
    "\n",
    "    predicted_facts = test_triplets + negative_samples\n",
    "    # random order in predicted_facts\n",
    "    random.shuffle(predicted_facts)\n",
    "    for triple in tqdm(predicted_facts):\n",
    "        h, r, t = triple\n",
    "        statement = rel2sym[r] + '(' + h + ', ' + t + ')'\n",
    "        \n",
    "        message = {\n",
    "                    'system': \"You are a helpful assistant with deductive reasoning abilities. \",\n",
    "                    'user': \"I will provide a set of logical rules and facts. Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the unknown fact using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nUnknown fact: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "                }\n",
    "        # message = {\n",
    "        #             'system': \"Please select one single logical rule and a few facts to predict True/False of the unknown fact. \",\n",
    "        #             'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the unknown fact.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nUnknown fact: \" + statement + \"\\nThe answer (True or False) is: \"\n",
    "        #         }\n",
    "\n",
    "        server_error_cnt = 0\n",
    "        while server_error_cnt<10:\n",
    "            try:\n",
    "                update_key()\n",
    "                response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message['system']},\n",
    "                        {\"role\": \"user\", \"content\": message['user']},\n",
    "                ],\n",
    "                temperature=0,\n",
    "                )\n",
    "\n",
    "                if record_flag == False:\n",
    "                    logger.info('message: \\n' + dict2str(message))\n",
    "                    record_flag = True\n",
    "\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                num += 1\n",
    "                \n",
    "                ans = results.split('.')[0]\n",
    "                if labels[(h, r, t)] == 1:\n",
    "                    if 'True' in ans:\n",
    "                        true_num += 1\n",
    "                        pos_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                    elif 'False' in ans:\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                    elif 'Unknown' in ans:\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                        print(results)\n",
    "                else:\n",
    "                    if 'True' in ans :\n",
    "                        false_num += 1\n",
    "                        neg_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                    elif 'False' in ans:\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                    elif 'Unknown' in ans:\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                        print(results)\n",
    "                \n",
    "                \n",
    "                logger.info('triplet: ' + statement + '\\tgrounding truth: ' + str(labels[(h, r, t)]) + '\\tprediction: ' + results )\n",
    "\n",
    "                break\n",
    "\n",
    "            except Exception as e:\n",
    "                server_error_cnt += 1\n",
    "                logger.info(e)\n",
    "    logger.info(str(i) + ': ' + str(true_num / num))\n",
    "    logger.info('pos_acc: ' + str(pos_true / (pos_true + pos_false)))\n",
    "    logger.info('neg_acc: ' + str(neg_true / (neg_true + neg_false)))\n",
    "    TP = pos_true\n",
    "    FN = pos_false\n",
    "    FP = neg_false\n",
    "    TN = neg_true\n",
    "    logger.info('precision: ' + str(TP / (TP + FP)))\n",
    "    logger.info('recall: ' + str(TP / (TP + FN)))\n",
    "    compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(str(i) + ': ' + str(true_num / num))\n",
    "logger.info('pos_acc: ' + str(pos_true / (pos_true + pos_false)))\n",
    "logger.info('neg_acc: ' + str(neg_true / (neg_true + neg_false)))\n",
    "TP = pos_true\n",
    "FN = pos_false\n",
    "FP = neg_false\n",
    "TN = neg_true\n",
    "logger.info('precision: ' + str(TP / (TP + FP)))\n",
    "logger.info('recall: ' + str(TP / (TP + FN)))\n",
    "compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_num / num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
