{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "openai.api_key = \"\"\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys_filter.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rid = 0\n",
    "id2rel = dict()\n",
    "rel2id = dict()\n",
    "rel2sym = dict()\n",
    "\n",
    "with open(\"../symbolic_tree/1.relations\", 'r') as f:\n",
    "    for line in f:\n",
    "        _, rel = line.strip().split()\n",
    "        id2rel[rid] = rel\n",
    "        rel2id[rel] = rid\n",
    "        rid += 1\n",
    "\n",
    "with open(\"rel2sym.txt\",\"r\") as fr:\n",
    "    for line in fr:\n",
    "        rel, sym = line.strip().split()\n",
    "        if 'Of' in sym:\n",
    "\n",
    "            rel2sym[rel] = sym[:-2]\n",
    "        else:\n",
    "            rel2sym[rel] = sym\n",
    "        # rel2sym[rel] = sym "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "\n",
    "def read_entity(path, eid, id2ent, ent2id ):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            _, ent = line.strip().split()\n",
    "            if ent not in ent2id:\n",
    "                id2ent[eid] = ent\n",
    "                ent2id[ent] = eid\n",
    "                eid += 1\n",
    "                \n",
    "        return eid\n",
    "    \n",
    "def read_class(path, cid, ent2class, id2ent, class_text, fid):\n",
    "    with open(path, 'r') as f:\n",
    "        for line in f:\n",
    "            female, male  = line.strip().split()\n",
    "            if female == '1':\n",
    "                ent2class[id2ent[cid]] = rel2sym['female']\n",
    "                # ent2class[id2ent[cid]] = 'female'\n",
    "\n",
    "                # class_text += 'F' + str(fid) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\\n'\n",
    "                class_text += 'F' + str(fid) + ': ' + id2ent[cid] + ' is ' + rel2sym['female'] + '.\\n'\n",
    "                \n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a ' + \"female\" + '. '\n",
    "\n",
    "                # class_text += ('female'+'(' + id2ent[cid] + ')')\n",
    "            else:\n",
    "                ent2class[id2ent[cid]] = rel2sym['male']\n",
    "                # ent2class[id2ent[cid]] = 'male'\n",
    "\n",
    "                # class_text += 'F' + str(fid) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\\n'\n",
    "                class_text += 'F' + str(fid) + ': ' + id2ent[cid] + ' is ' + rel2sym['male'] + '.\\n'\n",
    "                fid += 1\n",
    "                # class_text += id2ent[cid] + ' is a '+ 'male' + '. '\n",
    "\n",
    "                # class_text += ('male'+'(' + id2ent[cid] + ')')\n",
    "\n",
    "            cid += 1\n",
    "        return cid, class_text, fid\n",
    "    \n",
    "# def get_related_triplets(h, t, G, entpair2rel):\n",
    "#     input_text = ''\n",
    "#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):\n",
    "#         for edge in path:\n",
    "#             # print(edge)\n",
    "#             if edge in entpair2rel:\n",
    "#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '\n",
    "#             else:\n",
    "#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '\n",
    "#     return input_text\n",
    "# def get_rule_text(rule_heads, rule2text):\n",
    "#     rule_text = ''\n",
    "#     for rule_head in rule_heads:\n",
    "#         rule_text += rule2text[rule_head]\n",
    "#     return rule_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rules = ''\n",
    "rel2rule = dict()\n",
    "lid = 1\n",
    "with open(\"natural_rules.txt\", 'r') as f1:\n",
    "    for line in f1:\n",
    "        rel, rule = line.strip().split('\\t')\n",
    "        rel2rule[rel] = rule\n",
    "        rules += 'L' + str(lid) + ': ' + rule + '\\n'\n",
    "        lid += 1\n",
    "\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s\n",
    "def get_negative_samples(triplets, id2ent, id2rel, labels):\n",
    "    neg_samples = []\n",
    "    # random sample head or tail\n",
    "    \n",
    "    \n",
    "    for i in range(len(triplets)):\n",
    "        \n",
    "        while 1:\n",
    "            if random.random() < 0.5:\n",
    "                # sample head\n",
    "                h = random.randint(0, len(id2ent) - 1)\n",
    "                t = triplets[i][2]\n",
    "                r = triplets[i][1]\n",
    "                if (id2ent[int(h)], r, t) not in triplets:\n",
    "                    neg_samples.append((id2ent[int(h)], r, t))\n",
    "                    labels[(id2ent[int(h)], r, t)] = 0\n",
    "                    break\n",
    "            else:\n",
    "                h = triplets[i][0]\n",
    "                t = random.randint(0, len(id2ent) - 1)\n",
    "                r = triplets[i][1]\n",
    "                if (h, r, id2ent[int(t)]) not in triplets:\n",
    "                    neg_samples.append((h, r, id2ent[int(t)]))\n",
    "                    labels[(h, r, id2ent[int(t)])]=0\n",
    "                    break\n",
    "\n",
    "    return neg_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute F beta[0,0.1:1] score\n",
    "def compute_f_beta_score(precision, recall, logger):\n",
    "    for beta in range(0, 11):\n",
    "        beta = beta / 10\n",
    "        if precision == 0 and recall == 0:\n",
    "            logger.info('beta: '+ str(beta) + '\\rF score: 0')\n",
    "            \n",
    "        else:\n",
    "            score = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)\n",
    "            logger.info('beta: '+ str(beta) + '\\tF score: ' + str(score))\n",
    "def output_relation_acc(relation_true_num, relation_num, logger):\n",
    "    for rel in relation_true_num:\n",
    "        logger.info(rel + '\\t' + str(relation_true_num[rel] / relation_num[rel]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import tiktoken\n",
    "# enc = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n",
    "enc = tiktoken.encoding_for_model(\"gpt-4\")\n",
    "\n",
    "\n",
    "def generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts, labels):\n",
    "    # randomly select 5 triplet from predicated facts\n",
    "    number = 0\n",
    "    false_number = 0\n",
    "    true_number = 0\n",
    "    prompts = list()\n",
    "    prompts_plus = list()\n",
    "    relation_list = list()\n",
    "    extra_tokens = 1200\n",
    "    while number < 6 :\n",
    "        fact = random.sample(predicted_facts, 1)[0]\n",
    "        \n",
    "        h, r, t = fact\n",
    "        \n",
    "        if r not in relation_list:\n",
    "\n",
    "            \n",
    "\n",
    "            # statement = h + ' has a relationship ' + rel2sym[r] + ' with ' + t + '.\\n'\n",
    "            # statement = rel2sym[r] + '(' + h + ',' + t + ')' \n",
    "            statement = h + ' is ' + rel2sym[r] + ' of ' + t + '.'\n",
    "            # message_1 = {\n",
    "            #             # 'system': \"Please select a few facts to predict True/False of the unknown fact.\",\n",
    "            #             # 'system': \"You are a helpful assistant with deductive reasoning abilities. \",\n",
    "            #             'system': \"You are a helpful assistant with deductive reasoning abilities. Please select one single logical rule and a few facts to predict True/False of the following statement.\",\n",
    "            #             'user': \"I will provide a set of logical rules L1 to L\" + str(lid - 1) + \" and facts F1 to F\" + str(fid - 1) + \". Please select one single logical rule from L1 to L\" + str(lid - 1) + \" and a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the following statement using deductive reasoning.\\nLogical rules:\\n\" + rules + \"\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nAnswer with True or False? Let's think step by step.\"\n",
    "                    \n",
    "            #         }\n",
    "            message_1 = {\n",
    "                    'system': \"Please select a few facts to predict True/False of the following statement.\",\n",
    "                    'user': \"I will provide a set of facts F1 to F\" + str(fid - 1) + \". Please select a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the following statement.\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nLet's think step by step.\"\n",
    "                \n",
    "                }\n",
    "            server_error_cnt = 0\n",
    "            \n",
    "            while server_error_cnt < 10:\n",
    "                try:\n",
    "                    # update_key()\n",
    "                    response = openai.ChatCompletion.create(\n",
    "                    model= model,\n",
    "                    messages=[\n",
    "                            {\"role\": \"system\", \"content\": message_1['system']},\n",
    "                            {\"role\": \"user\", \"content\": message_1['user']},\n",
    "                    ],\n",
    "                    temperature = 0,\n",
    "                    # max_tokens = 2096,\n",
    "                    )\n",
    "                    \n",
    "                    break\n",
    "                except Exception as e:\n",
    "                    server_error_cnt += 1\n",
    "                    print(e)\n",
    "            \n",
    "            results = response['choices'][0]['message']['content']\n",
    "            results = re.sub(r'\\n+', '\\n', results)\n",
    "            if len(enc.encode(results)) <= extra_tokens / (6 - number):\n",
    "            # if response['usage']['completion_tokens'] < extra_tokens / (5 - number):\n",
    "                \n",
    "                # 多换行符变成一个换行符\n",
    "                \n",
    "                message_2 = {\n",
    "                    'system': \"Please predict True/False of the following statement.\",\n",
    "                    'user': message_1['user'] + '\\n' + results + '\\nTherefore, the answer (True or False) is: '\n",
    "                }\n",
    "                server_error_cnt = 0\n",
    "                while server_error_cnt < 10:\n",
    "                    try:\n",
    "                        # update_key()\n",
    "                        response = openai.ChatCompletion.create(\n",
    "                        model= model,\n",
    "                        messages=[\n",
    "                                {\"role\": \"system\", \"content\": message_2['system']},\n",
    "                                {\"role\": \"user\", \"content\": message_2['user']},\n",
    "                        ],\n",
    "                        temperature=0,\n",
    "                        # max_tokens = 2096,\n",
    "                        )\n",
    "                        break\n",
    "                    except Exception as e:\n",
    "                        server_error_cnt += 1\n",
    "                        print(e)\n",
    "                        \n",
    "                results_2 = response['choices'][0]['message']['content']\n",
    "                # last_line = results.strip().split('\\n')[-1]\n",
    "                # last_sentence = re.findall(r\"\\b\\S[^.!?]*[.!?]\", last_line)[-1]\n",
    "                if len(results_2.split('.')) >= 1:\n",
    "                    last_sentence = results_2.split('.')[0]\n",
    "                elif len(results_2.split('\\n')) >= 1:\n",
    "                    last_sentence = results_2.split('\\n')[0]\n",
    "                else:\n",
    "                    last_sentence = results_2\n",
    "                    print('output: ' + results_2)\n",
    "                false_words = [ 'False', 'false', 'Unknown', 'unknown']\n",
    "\n",
    "                if labels[(h, r, t)] == 1:\n",
    "                    if any(word in last_sentence for word in false_words):\n",
    "                        # false_num += 1\n",
    "                        # pos_false += 1\n",
    "                        print('correctness: ' + 'Incorrect')\n",
    "                        \n",
    "                    elif 'True' in last_sentence or 'true' in last_sentence:\n",
    "                        if true_number == 3:\n",
    "                            if false_number == 3:\n",
    "                                break\n",
    "                            else:\n",
    "                                continue\n",
    "                        print('correctness: ' + 'Correct')\n",
    "\n",
    "                        d = {}\n",
    "                        d['Statement'] = \"Statement: \" + statement\n",
    "                        d['Answer'] = \"Answer: \" + results\n",
    "                        d_plus = {}\n",
    "                        d_plus['Statement'] = \"Statement: \" + statement\n",
    "                        d_plus['Answer'] = \"Answer: Let's think step by step. \" + results\n",
    "                        prompts.append(d)\n",
    "                        prompts_plus.append(d_plus)\n",
    "                        number += 1\n",
    "                        true_number += 1\n",
    "                        extra_tokens -= len(enc.encode(results))\n",
    "                        print('true label')\n",
    "                        relation_list.append(r)\n",
    "                        \n",
    "                    else:\n",
    "                        # false_num += 1\n",
    "                        # pos_false += 1\n",
    "                        print('correctness: ' + 'Incorrect')\n",
    "\n",
    "                        print(last_sentence)\n",
    "            \n",
    "            \n",
    "                else:\n",
    "                    if any(word in last_sentence for word in false_words):\n",
    "                        # true_num += 1\n",
    "                        # neg_true += 1\n",
    "                        if false_number == 3:\n",
    "                            if true_number == 3:\n",
    "                                break\n",
    "                            continue\n",
    "                        d = {}\n",
    "                        d['Statement'] = \"Statement: \" + statement\n",
    "                        d['Answer'] = \"Answer: \" + results\n",
    "                        d_plus = {}\n",
    "                        d_plus['Statement'] = \"Statement: \" + statement\n",
    "                        d_plus['Answer'] = \"Answer: Let's think step by step. \" + results\n",
    "                        prompts.append(d)\n",
    "                        prompts_plus.append(d_plus)\n",
    "                        number += 1\n",
    "                        false_number += 1\n",
    "                        extra_tokens -= len(enc.encode(results))\n",
    "                        relation_list.append(r)\n",
    "\n",
    "                        print('correctness: ' + 'Correct')\n",
    "                        print('false label')\n",
    "                    elif 'True' in last_sentence or 'true' in last_sentence:\n",
    "                        # false_num += 1\n",
    "                        # neg_false += 1\n",
    "                        print('correctness: ' + 'Incorrect')\n",
    "                        \n",
    "                    else:\n",
    "                        # if false_number == 3:\n",
    "                        #     if true_number == 3:\n",
    "                        #         break\n",
    "                        #     else:\n",
    "                        #         continue\n",
    "                        # d = {}\n",
    "                        # d['Statement'] = \"Statement: \" + statement\n",
    "                        # d['Answer'] = \"Answer: \" + results\n",
    "                        # d_plus = {}\n",
    "                        # d_plus['Statement'] = \"Statement: \" + statement\n",
    "                        # d_plus['Answer'] = \"Answer: Let's think step by step. \" + results\n",
    "                        # prompts.append(d)\n",
    "                        # prompts_plus.append(d_plus)\n",
    "                        # number += 1\n",
    "                        # false_number += 1\n",
    "                        # extra_tokens -= len(enc.encode(results))\n",
    "                        # print('false label')\n",
    "                        # relation_list.append(r)\n",
    "                    \n",
    "                        print('correctness: ' + 'Unknown')\n",
    "        \n",
    "    return prompts, prompts_plus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/natural_few_shot_cot_auto_facts'\n",
    "if not os.path.exists(dir):\n",
    "    os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "\n",
    "# model = \"gpt-3.5-turbo\"\n",
    "model = \"gpt-4\"\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "record_flag = False\n",
    "\n",
    "for i in range(0, 1):\n",
    "    id2ent = dict()\n",
    "    ent2id = dict()\n",
    "    ent2class = dict()\n",
    "    ent2triplets = dict()\n",
    "\n",
    "    eid = 0\n",
    "\n",
    "    cid = 0\n",
    "    fid = 1\n",
    "    class_text = ''\n",
    "    triplets = []\n",
    "    test_triplets = []\n",
    "    labels = dict()\n",
    "    entpair2rel = dict() \n",
    "    basic_facts = ''\n",
    "    statement = ''\n",
    "    path_ent = \"../symbolic_tree/\" + str(i) + \".individuals\"\n",
    "    path_class = \"../symbolic_tree/\" + str(i) + \".classes.data\"\n",
    "    eid = read_entity(path_ent,eid, id2ent,ent2id)\n",
    "    cid, class_text, fid = read_class(path_class, cid, ent2class,id2ent, class_text, fid)\n",
    "    # print(i)\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data\"\n",
    "    basic_facts += class_text\n",
    "    with open(path,'r') as f:\n",
    "        for line in f:\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]\n",
    "            # basic_facts += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\\n'\n",
    "            basic_facts += 'F' + str(fid) + ': ' + id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\\n'\n",
    "            \n",
    "            fid += 1\n",
    "            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '\n",
    "\n",
    "            # text +=(id2rel[int(r)] + '(' + id2ent[int(h)], id2ent[int(t)] +')')\n",
    "    path = \"../symbolic_tree/\"+str(i)+\".relations.data.inf\"\n",
    "\n",
    "    with open(path,'r') as f:\n",
    "        for line in tqdm(f):\n",
    "            flag, h, r, t = line.strip().split()\n",
    "            if flag == '+':\n",
    "                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "                labels[(id2ent[int(h)], id2rel[int(r)], id2ent[int(t)])] = 1\n",
    "    negative_samples = get_negative_samples(test_triplets, id2ent, id2rel, labels)\n",
    "            # if flag == '-':\n",
    "            #     test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))\n",
    "            #     labels.append(0)\n",
    "\n",
    "    num = 0\n",
    "    true_num = 0\n",
    "    false_num = 0\n",
    "\n",
    "    pos_true = 0\n",
    "    pos_false = 0\n",
    "\n",
    "    neg_true = 0\n",
    "    neg_false = 0\n",
    "\n",
    "    predicted_facts = test_triplets + negative_samples\n",
    "    # random order in predicted_facts\n",
    "    random.shuffle(predicted_facts)\n",
    "    prompts, prompts_plus = generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts,labels)\n",
    "    \n",
    "    for triple in tqdm(predicted_facts):\n",
    "        h, r, t = triple\n",
    "        # statement = rel2sym[r] + '(' + h + ', ' + t + ')'\n",
    "        statement = h + ' is ' + rel2sym[r] + ' of ' + t + '. '\n",
    "\n",
    "        \n",
    "        \n",
    "        server_error_cnt = 0\n",
    "        while server_error_cnt<10:\n",
    "            try:\n",
    "                                \n",
    "                # update_key()\n",
    "                message_1 = {\n",
    "                    'system': \"Please select a few facts to predict True/False of the following statement.\",\n",
    "                    'user': \"I will provide a set of facts F1 to F\" + str(fid - 1) + \". Please select a few facts from F1 to F\" + str(fid - 1) + \" to predict True/False of the following statement.\\nFacts:\\n\" + basic_facts + \"\\nStatement: \" + statement + \"\\nLet's think step by step.\",\n",
    "                    'Q1': prompts[0]['Statement'] ,\n",
    "                    'A1': prompts[0]['Answer'],\n",
    "                    'Q2': prompts[1]['Statement'],\n",
    "                    'A2': prompts[1]['Answer'],\n",
    "                    'Q3': prompts[2]['Statement'],\n",
    "                    'A3': prompts[2]['Answer'],\n",
    "                    'Q4': prompts[3]['Statement'],\n",
    "                    'A4': prompts[3]['Answer'],\n",
    "                    'Q5': prompts[4]['Statement'],\n",
    "                    'A5': prompts[4]['Answer'],\n",
    "                    'Q6': prompts[5]['Statement'],\n",
    "                    'A6': prompts[5]['Answer'],\n",
    "                    'Q7': \"Statement: \" + statement + '\\nAnswer: ',\n",
    "                }\n",
    "            \n",
    "                response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message_1['system']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['user']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q1']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A1']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q2']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A2']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q3']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A3']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q4']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A4']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q5']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A5']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q6']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A6']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q7']},\n",
    "\n",
    "                ],\n",
    "                temperature=0,\n",
    "                # max_tokens = 900,\n",
    "                )\n",
    "                # update_key()\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                last_text = message_1['user'] + message_1['Q1'] + message_1['A1'] + message_1['Q2'] + message_1['A2'] + message_1['Q3'] + message_1['A3'] + message_1['Q4'] + message_1['A4'] + message_1['Q5'] + message_1['A5'] + message_1['Q6'] + message_1['A6'] + message_1['Q7']\n",
    "\n",
    "                message_2 = {\n",
    "                    'system': \"Please predict True/False of the following statement.\",\n",
    "                    'user': last_text + '\\n' + results + '\\nTherefore, the answer (True or False) is: '\n",
    "                }\n",
    "                response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message_2['system']},\n",
    "                        {\"role\": \"user\", \"content\": message_2['user']},\n",
    "                ],\n",
    "                temperature=0,\n",
    "                )\n",
    "                results = response['choices'][0]['message']['content']\n",
    "                num += 1\n",
    "\n",
    "                if record_flag == False:\n",
    "                    logger.info('message: \\n' + dict2str(message_1))\n",
    "\n",
    "                    record_flag = True\n",
    "\n",
    "                # get the last sentence\n",
    "                # last_line = results.strip().split('\\n')[-1]\n",
    "                # last_sentence = re.findall(r\"\\b\\S[^.!?]*[.!?]\", last_line)[-1]\n",
    "                if len(results.split('.')) >= 1:\n",
    "                    last_sentence = results.split('.')[0]\n",
    "                elif len(results.split('\\n')) >= 1:\n",
    "                    last_sentence = results.split('\\n')[0]\n",
    "                else:\n",
    "                    last_sentence = results\n",
    "                    logger.info('output: ' + results)\n",
    "\n",
    "                false_words = ['indeterminate', 'Indeterminate', 'FALSE', 'Unknown', 'unknown', 'not', 'False', ' no ', \"inconclusive\", \"undefined\", \"invalid\", 'false']\n",
    "\n",
    "                # if last sentence contain one of false words\n",
    "                \n",
    "                if labels[(h, r, t)] == 1:\n",
    "                    if any(word in last_sentence for word in false_words):\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "                    elif 'True' in last_sentence or 'true' in last_sentence or 'TRUE' in last_sentence:\n",
    "                        true_num += 1\n",
    "                        pos_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "\n",
    "                    else:\n",
    "                        false_num += 1\n",
    "                        pos_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "\n",
    "                        print(last_sentence)\n",
    "                else:\n",
    "                    if any(word in last_sentence for word in false_words):\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "                        \n",
    "                    elif 'True' in last_sentence or 'true' in last_sentence or 'TRUE' in last_sentence:\n",
    "                        false_num += 1\n",
    "                        neg_false += 1\n",
    "                        logger.info('correctness: ' + 'Incorrect')\n",
    "\n",
    "                    else:\n",
    "                        true_num += 1\n",
    "                        neg_true += 1\n",
    "                        logger.info('correctness: ' + 'Correct')\n",
    "\n",
    "                        print(last_sentence)\n",
    "                \n",
    "                \n",
    "                logger.info('statement: ' + statement + '\\tgrounding truth: ' + str(labels[(h, r, t)]) + '\\tprediction: ' + results )\n",
    "\n",
    "                break\n",
    "\n",
    "            except Exception as e:\n",
    "                server_error_cnt += 1\n",
    "                logger.info(e)\n",
    "    \n",
    "    \n",
    "    logger.info(str(i) + ': ' + str(true_num / num))\n",
    "    logger.info('pos_acc: ' + str(pos_true /(pos_true + pos_false)))\n",
    "    logger.info('neg_acc: ' + str(neg_true /(neg_true + neg_false)))\n",
    "    logger.info('pos_true:' + str(pos_true))\n",
    "    logger.info('neg_true:' + str(neg_true))\n",
    "    TP = pos_true\n",
    "    FN = pos_false\n",
    "    FP = neg_false\n",
    "    TN = neg_true\n",
    "    logger.info('precision: ' + str(TP / (TP + FP)))\n",
    "    logger.info('recall: ' + str(TP / (TP + FN)))\n",
    "    compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
