{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "from random import sample\n",
    "import networkx as nx\n",
    "import re, logging\n",
    "import openai, datetime, os\n",
    "openai.api_key = \"\"\n",
    "import json\n",
    "def load_openai_keys():\n",
    "    keys = []\n",
    "    with open('../openai_keys_filter.txt', \"r\") as f:\n",
    "        for line in f:\n",
    "            key = line.strip().split()\n",
    "            keys.append(key[-1])\n",
    "    return keys\n",
    "openai_api_keys = load_openai_keys()\n",
    "random.shuffle(openai_api_keys)\n",
    "def update_key():\n",
    "    curr_key = openai_api_keys[0]\n",
    "    openai.api_key = curr_key\n",
    "    openai_api_keys.remove(curr_key)\n",
    "    openai_api_keys.append(curr_key)\n",
    "\n",
    "def get_logger(filename, verbosity=1, name=None):\n",
    "    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}\n",
    "    formatter = logging.Formatter(\n",
    "        \"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s\"\n",
    "    )\n",
    "    logger = logging.getLogger(name)\n",
    "    logger.setLevel(level_dict[verbosity])\n",
    "\n",
    "    # Remove any existing handlers\n",
    "    for handler in logger.handlers:\n",
    "        logger.removeHandler(handler)\n",
    "    # Output to file\n",
    "    fh = logging.FileHandler(filename, \"w\")\n",
    "    fh.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "\n",
    "    # # Output to terminal\n",
    "    # sh = logging.StreamHandler()\n",
    "    # sh.setFormatter(formatter)\n",
    "    # logger.addHandler(sh)\n",
    "\n",
    "    return logger\n",
    "\n",
    "# multiple-list save to logging\n",
    "def list2str(l):\n",
    "    s = ''\n",
    "    for i in l:\n",
    "        s += str(i) + '\\t'\n",
    "    return s\n",
    "def list_equal(a, answers):\n",
    "    # if set(a) == set(b):\n",
    "    #     return True\n",
    "    # else:\n",
    "    #     return False\n",
    "    for b in answers:\n",
    "        # if set(a) contain set(b)\n",
    "        if set(b) == set(a):\n",
    "            return True\n",
    "        # if set(a) == set(b):\n",
    "        #     return True\n",
    "    return False\n",
    "def dict2str(d):\n",
    "    s = ''\n",
    "    for k in d:\n",
    "        s += str(k) + '\\t' + str(d[k]) + '\\r'\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read jsonl file\n",
    "def read_jsonl(file):\n",
    "    data = []\n",
    "    with open(file, 'r') as f:\n",
    "        for line in f:\n",
    "            data.append(json.loads(line))\n",
    "    return data\n",
    "raw_data = read_jsonl('../OWA/depth-1/meta-test_symbolic.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filter_data = []\n",
    "for data in raw_data:\n",
    "    fid = 1\n",
    "    lid = 1\n",
    "    facts = ''\n",
    "    rules = ''\n",
    "    \n",
    "    for tri_name in data['triples'].keys():\n",
    "        tri = data['triples'][tri_name]\n",
    "        # facts += 'F' + str(fid) + ': ' + tri['sym_text'][:-2] + '.\\n'\n",
    "        facts += tri['text'] + '\\n'\n",
    "\n",
    "        fid += 1\n",
    "    for rule_name in data['rules'].keys():\n",
    "        rule = data['rules'][rule_name]\n",
    "        # rules += 'L' + str(lid) + ': ' + rule['sym_text'][:-2] + '.\\n'\n",
    "        rules += rule['text'] + '\\n'\n",
    "\n",
    "        lid += 1\n",
    "    for q_name in data['questions'].keys():\n",
    "        question = data['questions'][q_name]\n",
    "        if question['QDep'] >= 1:\n",
    "            if question['answer'] == True or question['answer'] == False:\n",
    "                item = dict()\n",
    "                item['question'] = question['question'] \n",
    "                item['answer'] = question['answer']\n",
    "                item['facts'] = facts\n",
    "                item['rules'] = rules\n",
    "                filter_data.append(item)\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(filter_data[2]['facts'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read few_shot_prompts.txt to get the prompts\n",
    "\n",
    "def read_few_shot_prompts(path):\n",
    "    prompts = list()\n",
    "    prompts_plus = list()\n",
    "    with open(path, 'r') as f:\n",
    "        blocks = f.read().split('\\n\\n')\n",
    "        for block in blocks:\n",
    "            # 正则表达式提取statement以后，answer之前的内容\n",
    "            question, answer = block.split('Reasoning: ')\n",
    "\n",
    "            \n",
    "            d = {}\n",
    "            d['Statement'] = question.strip()\n",
    "            # answer = re.findall(r'Answer: (.*)', block, flags=re.DOTALL)[0]\n",
    "            d['Answer'] = \"Reasoning: \" + answer\n",
    "\n",
    "            d_plus = {}\n",
    "            d_plus['Statement'] = question.strip()\n",
    "            d_plus['Answer'] = \"Reasoning: Let's think step by step. \" + answer\n",
    "            \n",
    "            prompts.append(d)\n",
    "            prompts_plus.append(d_plus)\n",
    "            \n",
    "    print(prompts)\n",
    "    return prompts, prompts_plus\n",
    "prompts, prompts_plus = read_few_shot_prompts('few_shot_prompts.txt')\n",
    "len(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')\n",
    "dir = 'logs/natural_few_shot_cot'\n",
    "if not os.path.exists(dir):\n",
    "        os.makedirs(dir)\n",
    "logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)\n",
    "# model = 'gpt-4'\n",
    "model = \"gpt-3.5-turbo\"\n",
    "\n",
    "logging.info('model: ' + model)\n",
    "\n",
    "\n",
    "\n",
    "num = 0\n",
    "false_num = 0\n",
    "true_num = 0\n",
    "\n",
    "pos_true = 0\n",
    "pos_false = 0\n",
    "neg_true = 0\n",
    "neg_false = 0\n",
    "\n",
    "for item in tqdm(filter_data):\n",
    "    message_1 = {\n",
    "\n",
    "        'system': \"You are a helpful assistant with deductive reasoning abilities. Given a set of rules and facts, you have to reason whether a statement is true or false. \",\n",
    "        'user': \"Given a set of rules and facts, you have to reason whether a statement is true or false. \",\n",
    "        'Q1' : prompts[0]['Statement'],\n",
    "        'A1' : prompts[0]['Answer'],\n",
    "        'Q2' : prompts[1]['Statement'],\n",
    "        'A2' : prompts[1]['Answer'],\n",
    "        'Q3' : prompts[2]['Statement'],\n",
    "        'A3' : prompts[2]['Answer'],\n",
    "        'Q4' : prompts[3]['Statement'],\n",
    "        'A4' : prompts[3]['Answer'],\n",
    "        'Q5' : prompts[4]['Statement'],\n",
    "        'A5' : prompts[4]['Answer'],\n",
    "        'Q6' : \"Here are some facts and rules: \\n\" + item['facts'] + item['rules'] + \"Does it imply that the statement \\\"\" + item['question'] + \"\\\" is True?\\nLet's think step by step. \"\n",
    "\n",
    "    }\n",
    "\n",
    "    server_error_cnt = 0\n",
    "\n",
    "    while server_error_cnt<10:\n",
    "        try:\n",
    "            \n",
    "            update_key()\n",
    "            \n",
    "            response = openai.ChatCompletion.create(\n",
    "                model=model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message_1['system']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['user']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q1']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A1']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q2']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A2']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q3']},                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               \n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A3']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q4']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A4']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q5']},\n",
    "                        {\"role\": \"assistant\", \"content\": message_1['A5']},\n",
    "                        {\"role\": \"user\", \"content\": message_1['Q6']},\n",
    "                    ],\n",
    "                temperature=0,\n",
    "\n",
    "            )\n",
    "            results = response['choices'][0]['message']['content']\n",
    "            last_text = message_1['user'] + message_1['Q1'] + message_1['A1'] + message_1['Q2'] + message_1['A2'] + message_1['Q3'] + message_1['A3'] + message_1['Q4'] + message_1['A4'] + message_1['Q5'] + message_1['A5'] + message_1['Q6']\n",
    "            update_key()\n",
    "            message_2 = {\n",
    "                    'system': \"Please predict True/False of the following statement.\",\n",
    "                    'user': last_text + '\\n' + results + '\\nTherefore, the answer (True or False) is: '\n",
    "                }\n",
    "            response = openai.ChatCompletion.create(\n",
    "                model= model,\n",
    "                messages=[\n",
    "                        {\"role\": \"system\", \"content\": message_2['system']},\n",
    "                        {\"role\": \"user\", \"content\": message_2['user']},\n",
    "                ],\n",
    "                temperature=0,\n",
    "                \n",
    "            )\n",
    "            results = response['choices'][0]['message']['content']\n",
    "            break\n",
    "        except Exception as e:\n",
    "            server_error_cnt += 1\n",
    "            print(e)\n",
    "\n",
    "    \n",
    "\n",
    "    logger.info('message: \\n' + dict2str(message_1)) \n",
    "\n",
    "    num += 1\n",
    "    ans = results.split('.')[0]\n",
    "    if item['answer'] == True:\n",
    "        if 'True' in ans:\n",
    "            true_num += 1\n",
    "            pos_true += 1\n",
    "            logger.info('correctness: ' + 'Correct')\n",
    "        elif 'False' in ans:\n",
    "            false_num += 1\n",
    "            pos_false += 1\n",
    "            logger.info('correctness: ' + 'Incorrect')\n",
    "    else:\n",
    "        if 'True' in ans :\n",
    "            false_num += 1\n",
    "            neg_false += 1\n",
    "            logger.info('correctness: ' + 'Incorrect')\n",
    "        elif 'False' in ans:\n",
    "            true_num += 1\n",
    "            neg_true += 1\n",
    "            logger.info('correctness: ' + 'Correct')\n",
    "\n",
    "    logger.info('grounding truth: ' + str(item['answer']) + '\\tprediction: ' + results)\n",
    "    \n",
    "\n",
    "logger.info(\"accuracy: \" + str( true_num / num ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "family-tree-data-gen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f632ed4e6e58e2f900bc6d8cc82f324645872b4d5da347d3e02478818028ba78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
