{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a09a26c1-3c05-446a-97ca-5bcaf2437768",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(\n",
    "    os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5328dab1-2dd9-4654-b152-6cea25864098",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env OPENAI_API_KEY=<Enter you key here>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0c0c3bb4-9383-401f-8199-452954aa5848",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "import openai\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from IPython.core.display import HTML\n",
    "from functools import partial\n",
    "\n",
    "from utils import ProgramGenerator, ProgramInterpreter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ceb8704-99c0-42a8-9f12-1f73cbd5b58f",
   "metadata": {},
   "source": [
    "## Coherence Score ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a3535af5-df7b-4bad-ba28-906bcc79287d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_reasoning_quality(query, gen_answer):\n",
    "    \n",
    "    # Call the OpenAI ChatCompletion API\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-4\",  # Specify the model version\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": (\n",
    "                f\"Question: {query}\\n\"\n",
    "                f\"Answer: {gen_answer}\\n\"\n",
    "                f\"Evaluate the answer of the given question:\\n\"\n",
    "                f\"Step 1: Coherence Evaluation: Are the transitions between points in the explanation smooth and logical?\\n\"\n",
    "                f\"- Cohesion score = 1: Assign if transitions are abrupt or disjointed.\\n\"\n",
    "                f\"- Cohesion score = 2: Assign if transitions are generally clear with minor issues.\\n\"\n",
    "                f\"- Cohesion score = 3: Assign if the explanation flows seamlessly and logically.\\n\"\n",
    "                f\"Step 2: Format the evaluation as follows:\\n\"\n",
    "                f\"- Cohesion score = <Score>\\n\"\n",
    "                f\"- Explanation: <Explanation for assignment of specific score>\\n\"\n",
    "            )}\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=1024,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0\n",
    "    )\n",
    "    return response['choices'][0]['message']['content'].strip()\n",
    "\n",
    "def extract_scores(result):\n",
    "    lines = result.split('\\n')\n",
    "    for line in lines:\n",
    "        if \"Cohesion score =\" in line:\n",
    "            cohesion_score = int(line.split('=')[1].split(':')[0].strip())\n",
    "    return cohesion_score\n",
    "\n",
    "\n",
    "def calculate_reasoning_scores(answer_df):\n",
    "    cohesion_scores = [] \n",
    "    reasoning_results = []\n",
    "    queries = []\n",
    "    answers = []\n",
    "        \n",
    "    for i in range(len(answer_df)):\n",
    "        try:\n",
    "            query = answer_df[\"Query\"][i]\n",
    "            gen_answer = answer_df[\"Gen_Answer\"][i]\n",
    "            queries.append(query)\n",
    "            answers.append(gen_answer)\n",
    "            \n",
    "            result = check_reasoning_quality(query, gen_answer)\n",
    "            reasoning_results.append(result)\n",
    "\n",
    "            cohesion_score = extract_scores(result) \n",
    "            cohesion_scores.append(cohesion_score)\n",
    "\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"An error occurred at {i}: {e}\")\n",
    "            print(f\"Query: {query}\")\n",
    "            print(f\"Generated Answer: {gen_answer}\")\n",
    "            print(f\"Reasoning Check Output: {result}\")\n",
    "    \n",
    "            continue\n",
    "            \n",
    "    return queries, answers, cohesion_scores, reasoning_results #insight_scores\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f8f1232-9947-4b3b-9b07-ccca53a446b3",
   "metadata": {},
   "source": [
    "### STReason ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0692b3-a089-4d2a-babd-5b2dce49f2af",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_df = pd.read_csv('gen_answers/STReasonAnswers.csv')\n",
    "queries, answers, cohesion_scores, reasoning_results = calculate_reasoning_scores(STReason_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4c687caf-c2d4-43fc-9c64-b15c39c143a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "150\n",
      "150\n",
      "150\n",
      "150\n"
     ]
    }
   ],
   "source": [
    "print(len(queries))\n",
    "print(len(answers))\n",
    "print(len(reasoning_results))\n",
    "print(len(cohesion_scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "138b8091-73e8-4e21-8daa-4fcca0562d3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_reasoning_df = pd.DataFrame({\n",
    "    'Query': queries,\n",
    "    'Gen Answer': answers,\n",
    "    'Reasoning Result': reasoning_results,\n",
    "    'Cohesion Score': cohesion_scores,})\n",
    "\n",
    "STReason_reasoning_df.to_csv('Results/CoherenceScores/STReason_ReasoningResults.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
