{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bce0c83e",
   "metadata": {},
   "source": [
    "## Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d8218cd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import datasets\n",
    "ds = datasets.load_dataset(\"commonsense_qa\")\n",
    "#ds = datasets.load_dataset(\"openbookqa\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7fcab59",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds, ds.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08bea599",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(ds['train']), len(ds['train']), ds['train'][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6a97592",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds['train'][1]['answerKey']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "26290084",
   "metadata": {},
   "source": [
    "## Preprocessing Q+choices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b70a5aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_multichoices(label, text):\n",
    "    choice = ''\n",
    "    for i in range(len(label)):\n",
    "        choice += label[i]+'.'+text[i]+' '\n",
    "    return choice.strip(' ')\n",
    "\n",
    "\n",
    "ds_new = {}\n",
    "for key in ds:\n",
    "    ds_new[key] = []\n",
    "    for ex in ds[key]:\n",
    "        ex_new = {}\n",
    "        ex_new['id'] = ex['id']\n",
    "        #ex_new['question'] = ex['question_stem']\n",
    "        ex_new['question'] = ex['question']\n",
    "        ex_new['choices'] = get_multichoices(ex['choices']['label'], ex['choices']['text'])\n",
    "        ex_new['answerKey'] = ex['answerKey']\n",
    "        ds_new[key].append(ex_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be4a32c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a1bbebc6",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c67225d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import json\n",
    "import os\n",
    "\n",
    "\n",
    "# 目前需要设置代理才可以访问 api\n",
    "os.environ[\"HTTP_PROXY\"] = \"\"\n",
    "os.environ[\"HTTPS_PROXY\"] = \"\"\n",
    "openai.api_key =  \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76eec927",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc5714bf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for key in ds_new:\n",
    "    if key =='validation':\n",
    "        with open('chatgpt_kg_csqa_morefewshots16_'+key+'.txt', 'a+') as g:\n",
    "            for ex in ds_new[key]:\n",
    "                cnt += 1\n",
    "                if cnt > 0:\n",
    "                    print(cnt)\n",
    "                    #if ex['id'] in test_ids:\n",
    "                    rsp = openai.ChatCompletion.create(\n",
    "                      model=\"gpt-4\",\n",
    "                      messages=[\n",
    "                            {\"role\": \"system\", \"content\": \"Please choose answer from [A, B, C, D, E] and output one capital letter only.\"},\n",
    "#                                 {\"role\": \"user\", \"content\": \"Question: When is hotter?\" +'\\n'+\"A.desert B.night C.morning D.noon\"},\n",
    "#                                 #{\"role\": \"user\", \"content\": \"A.desert B.night C.morning D.noon\"},\n",
    "#                                 {\"role\": \"assistant\", \"content\": \" Answer: A. The desert is hotter than the other options listed because it receives direct sunlight and has little to no vegetation to provide shade or cool the air. During the day, the sun beats down on the desert, causing temperatures to rise significantly. At night, the lack of vegetation and moisture in the air means that heat is quickly lost, causing temperatures to drop rapidly. In the morning and at noon, the sun is still high in the sky, but temperatures may be slightly cooler due to the presence of some shade and cooler air from the night before.\"},\n",
    "\n",
    "#                                 {\"role\": \"user\", \"content\": \"Question: Some farmers make a very good living because the crops they grow are sold to lots of people for\" +'\\n'+\"A.feeding to Mars B.sustenance C.gravity D.metal\"},\n",
    "#                                 #{\"role\": \"user\", \"content\": \"A.feeding to Mars B.sustenance C.gravity D.metal\"},\n",
    "#                                 {\"role\": \"assistant\", \"content\": \" Answer: B. Sustenance refers to food or other necessities of life. Farmers grow crops that provide sustenance to people, which means they are able to sell their crops to lots of people who need food to survive. This allows them to make a good living from their farming activities. Feeding to Mars, gravity, and metal are not related to the crops grown by farmers and are therefore not the correct answer.\"},\n",
    "\n",
    "#                                 {\"role\": \"user\", \"content\": \"Question: Someone in a room without any lights on, but that has mirrors in it, can still have light in that room if\"+'\\n'+\"A.lights are turned on in another house B.a joined room is dark C.the room is completely dark D.a joined room is lit\"},\n",
    "#                                 #{\"role\": \"user\", \"content\": \"A.lights are turned on in another house B.a joined room is dark C.the room is completely dark D.a joined room is lit\"},\n",
    "#                                 {\"role\": \"assistant\", \"content\": \" Answer: D. The mirrors in the room can reflect light from a joined room that is lit, allowing some light to enter the dark room. This is known as indirect lighting. So, if a joined room is lit, the mirrors in the dark room can reflect that light and provide some illumination.\"},\n",
    "\n",
    "                            {\"role\": \"user\", \"content\": 'Question: ' + ex['question'] +'\\n'+ex['choices']}, \n",
    "\n",
    "                            #{\"role\": \"user\", \"content\": Answer:']}, \n",
    "                            #{\"role\": \"user\", \"content\": 'Explain and then answer.'}, \n",
    "                            #{\"role\": \"user\", \"content\": 'Answer and then explain.'}, \n",
    "                        ],\n",
    "                    temperature = 0.0,\n",
    "                    top_p = 1)\n",
    "                    llm_output = rsp['choices'][0]['message']['content']\n",
    "\n",
    "                    ex['llm_ans'] = llm_output\n",
    "                    g.write(json.dumps(ex)+'\\n')\n",
    "                    if len(llm_output) > 1 or llm_output not in ['A','B','C','D','E']:\n",
    "                        print(llm_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d67a51",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d4a18119",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6ccf970",
   "metadata": {},
   "outputs": [],
   "source": [
    "inhouse = []\n",
    "with open('./Data/commen/output/candidate_generate/inhouse_split_qids.txt','r') as f:\n",
    "    for i in f.readlines():\n",
    "        inhouse.append(i.strip('\\n'))\n",
    "        \n",
    "test_ids = []\n",
    "with open('./Data/commen/output/candidate_generate/train_rand_split.jsonl','r') as f:\n",
    "    for line in f.readlines():\n",
    "        l = json.loads(line)\n",
    "        if l[\"id\"] not in inhouse:\n",
    "            test_ids.append(l[\"id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98408ef9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a9984c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prompt后，专用\n",
    "cnt, correct,t = 0, 0, 0\n",
    "with open('', 'r') as f:\n",
    "    for line in f.readlines():\n",
    "        cnt += 1\n",
    "        l = json.loads(line)\n",
    "        if l['llm_ans'][0] == l[\"answerKey\"]:\n",
    "            correct += 1\n",
    "        if l['llm_ans'][0] not in ['A','B','C','D','E']:\n",
    "            print(l[\"question\"])\n",
    "            print('ans========', (l[\"answerKey\"]+l[\"choices\"].split(l[\"answerKey\"])[1].strip(' ')).split(' ')[0])\n",
    "            print('llm_ans====', l['llm_ans'])\n",
    "            print('\\n')\n",
    "            t += 1\n",
    "    print(correct, cnt, correct/cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc3748fc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python dglke",
   "language": "python",
   "name": "dglke"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
