{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d6a4f97f-c032-4b57-849c-c942280b3578",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install openai==0.28.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f689dead-9ddd-4fab-af37-818bfdb97c63",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(\n",
    "    os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63cc098f-f16f-457d-842f-efcbb5e674ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env OPENAI_API_KEY=<Enter you key here>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "48f8af2c-fe99-4b66-9cd9-b939e47d689d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "import openai\n",
    "\n",
    "import pandas as pd\n",
    "from IPython.core.display import HTML\n",
    "from functools import partial\n",
    "\n",
    "from utils import ProgramGenerator, ProgramInterpreter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9460c840-fd66-46d6-a252-8b3a5428f885",
   "metadata": {},
   "source": [
    "## Factuality Score ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "64ebd953-213f-45c6-a331-d681ca44c09c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_component_presence_general(ground_truth_answer_components, ground_truth_answer, gen_answer):\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-4\",  \n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": (\n",
    "                f\"Consider the following details to verify if specific 'Ground Truth Answer Components' with defined values from 'Ground Truth Answer' are accurately reflected in the 'Model generated answer'.\\n\"\n",
    "                f\"Ground Truth Answer Components: {ground_truth_answer_components}\\n\"\n",
    "                f\"Ground Truth Answer: {ground_truth_answer}\\n\"\n",
    "                f\"Model generated answer: {gen_answer}\\n\"\n",
    "                f\"Conduct the following analysis:\\n\"\n",
    "                f\"Step 1: Component Presence Check: For each component in 'Ground Truth Answer Components', check if it is explicitly mentioned in the 'Model generated answer' with the values or details as given in 'Ground Truth Answer'.\"\n",
    "                f\"Record 'present with correct details' for components found with matching details/values and 'absent or incorrect details' for those that are not present or have incorrect details.\\n\"\n",
    "                f\"Step 2: Correct number of components: Count the number of components identified as 'present with correct details'. \\n\"\n",
    "                f\"Final Answer: Respond with the Correct number of components. This response should be a numeric value only.\"\n",
    "            )}\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=1024,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0\n",
    "    )\n",
    "    return response['choices'][0]['message']['content'].strip()\n",
    "\n",
    "\n",
    "def extract_final_answer(response_text):\n",
    "    lines = response_text.split('\\n')\n",
    "    for line in lines:\n",
    "        if line.startswith(\"Final Answer:\"):\n",
    "            final_answer = line.split(\":\")[-1].strip()\n",
    "            final_answer = int(final_answer)\n",
    "        elif line.strip().isdigit():\n",
    "            final_answer = int(line.strip())\n",
    "    \n",
    "    return final_answer\n",
    "\n",
    "def calculate_factuality_score(answer_df):\n",
    "    factuality_scores = []   \n",
    "        \n",
    "    for i in range(len(answer_df)):\n",
    "        try:\n",
    "            query = answer_df[\"Query\"][i]\n",
    "            ground_truth_answer = answer_df[\"Answer\"][i]\n",
    "            gen_answer = answer_df[\"Gen_Answer\"][i]\n",
    "            program_str = answer_df[\"Program\"][i].replace(\"'\", '\"')\n",
    "            gt_ans_components = answer_df[\"Ans_Components\"][i]\n",
    "\n",
    "            result = check_component_presence_general(gt_ans_components, ground_truth_answer, gen_answer)  \n",
    "            num_correct_components = extract_final_answer(result)\n",
    "            factuality_score = num_correct_components/len(gt_ans_components)\n",
    "            factuality_scores.append(factuality_score)\n",
    "    \n",
    "        except Exception as e:\n",
    "            print(f\"An error occurred at {i}: {e}\")\n",
    "            print(f\"Query: {query}\")\n",
    "            print(f\"Generated Answer: {gen_answer}\")\n",
    "            print(f\"Factuality Check Output: {result}\")\n",
    "    \n",
    "            continue\n",
    "    return factuality_scores\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "987389f0-c7ea-43bd-aa0e-02a59ed6d0f6",
   "metadata": {},
   "source": [
    "### STReason ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88ac76bc-b0e9-4f0b-afb8-07f5e92a2610",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_df = pd.read_csv('gen_answers/STReasonAnswers.csv')\n",
    "streason_factuality_scores = calculate_factuality_score(STReason_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8d96e6b5-8d8a-46f2-a454-71aaa340679b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(streason_factuality_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5f5753e8-4ec0-43a8-a00c-97266c8d4b3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "84.4444444444444"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "streason_percent_factuality = [x * 100 for x in streason_factuality_scores]\n",
    "final_factuality_score = sum(streason_percent_factuality)/len(streason_factuality_scores)\n",
    "final_factuality_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9815c84e-7a6a-45bd-b6f2-0cfcde31c866",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_factuality_df = pd.DataFrame({\n",
    "    'Factuality Score': streason_percent_factuality})\n",
    "STReason_factuality_df.to_csv('Results/FactualityScores/STReason_FactualityResults.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
