{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6a5254d4-3b40-46cf-80ca-4ac7cdb5679b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(\n",
    "    os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3872927a-c96e-4e2c-9538-9f3f146ba4fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env OPENAI_API_KEY=<Enter you key here>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "478e5362-1f1a-4ca8-9411-3d0739256990",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "import openai\n",
    "\n",
    "import pandas as pd\n",
    "from IPython.core.display import HTML\n",
    "from functools import partial\n",
    "\n",
    "from utils import ProgramGenerator, ProgramInterpreter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "179bf29f-0d7e-496e-be61-04dd38d8677b",
   "metadata": {},
   "source": [
    "## Constraint Score ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "07b29182-53ff-404d-85c6-d3fdeef9723d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_constraints(query, gen_answer, constraints):\n",
    "    try:\n",
    "\n",
    "        # Call the OpenAI ChatCompletion API\n",
    "        response = openai.ChatCompletion.create(\n",
    "            model=\"gpt-4\",  \n",
    "            messages=[\n",
    "                {\"role\": \"user\", \"content\": (\n",
    "                    f\"Consider the following Question, Constraint and Answer:\\n\"\n",
    "                    f\"Question: {query}\\n\"\n",
    "                    f\"Constraint: {constraints}\\n\"\n",
    "                    f\"Answer: {gen_answer}\\n\"\n",
    "                    f\"Conduct the following analysis:\\n\"\n",
    "                    f\"Step 1: Assess if the answer align with the constraint:\\n\"\n",
    "                    f\" - [Constraint] Is met/not met because [reason based on the answer part].\\n\"\n",
    "                    f\"Step 2: Summarize the findings:\\n\"\n",
    "                    f\" - Constraint is [met/not met].\\n\"\n",
    "                    f\"Final Answer: Respond 'True' if constraint is met, respond 'False' otherwise.\"\n",
    "                )}\n",
    "            ],\n",
    "            temperature=0.7,\n",
    "            max_tokens=512,\n",
    "            top_p=1,\n",
    "            frequency_penalty=0,\n",
    "            presence_penalty=0\n",
    "        )\n",
    "        return response['choices'][0]['message']['content'].strip()\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"An error occurred: {e}\")\n",
    "        return False\n",
    "\n",
    "def extract_final_answer(response_text):\n",
    "    try:\n",
    "        lines = response_text.split('\\n')\n",
    "        for line in lines:\n",
    "            if line.startswith(\"Final Answer:\"):\n",
    "                final_answer = line.split(\":\")[-1].strip()\n",
    "                return final_answer\n",
    "    except Exception as e:\n",
    "        print(f\"An error occurred: {e}\")\n",
    "        return None\n",
    "\n",
    "def calculate_constraint_score(answer_df):\n",
    "    constraint_scores = []\n",
    "    for i in range(len(answer_df)):\n",
    "        try:\n",
    "            query = answer_df[\"Query\"][i]\n",
    "            gen_answer = answer_df[\"Gen_Answer\"][i]\n",
    "            program_str = answer_df[\"Program\"][i].replace(\"'\", '\"')\n",
    "            data = json.loads(program_str)\n",
    "            \n",
    "            functions = []\n",
    "            for step in data:\n",
    "                function = step['function']\n",
    "                functions.append(function)\n",
    "    \n",
    "                if function == 'LOAD_SPATIOTEMPORAL_DATA':\n",
    "                    parameters = step['parameters']\n",
    "                    for key, value in parameters.items():\n",
    "                        if key == \"constraints\" and value != \"None\":\n",
    "                            constraint = \"The analysis corresponds to \" + str(value)\n",
    "                            constraint_check = check_constraints(query,gen_answer,constraint)\n",
    "                            final_answer = extract_final_answer(constraint_check)\n",
    "\n",
    "                            if final_answer == \"True\":\n",
    "                                constraint_score = 1\n",
    "                            else:\n",
    "                                constraint_score = 0\n",
    "                        else:\n",
    "                            constraint_score = 1\n",
    "    \n",
    "    \n",
    "                elif function == 'FORECAST':\n",
    "                    parameters = step['parameters']\n",
    "                    for key, value in parameters.items():\n",
    "                        if key == \"constraint_val\":\n",
    "                            constraint = \"The answer explicitly inlcudes all predictions and all should be less than threshold value of \" + str(value)\n",
    "                            constraint_check = check_constraints(query,gen_answer,constraint)\n",
    "                            final_answer = extract_final_answer(constraint_check)\n",
    "\n",
    "                            if final_answer == \"True\":\n",
    "                                constraint_score = 1\n",
    "                            else:\n",
    "                                constraint_score = 0\n",
    "    \n",
    "            if constraint_score == 0:\n",
    "                print(\"Index: \",i)\n",
    "                print(\"Query: \",answer_df[\"Query\"][i])\n",
    "                print(\"Gen_Answer: \",answer_df[\"Gen_Answer\"][i])\n",
    "                print(\"Constraint Check: \",constraint_check)\n",
    "    \n",
    "            constraint_scores.append(constraint_score)\n",
    "    \n",
    "        except Exception as e:\n",
    "            print(f\"Error processing query at {i}: {str(e)}\")\n",
    "            continue\n",
    "\n",
    "    return constraint_scores\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ddd394f-6a63-40e1-8cea-da0fedad5d73",
   "metadata": {},
   "source": [
    "### STReason ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c3f7751b-d1f4-47ee-9fbf-7cf167ff6ab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_df = pd.read_csv('gen_answers/STReasonAnswerscsv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3a609c06-24e1-4137-bd0a-5bd1faf9f93f",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_scores = calculate_constraint_score(STReason_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b9e4e077-c3ce-400d-b11a-044785709ea5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(constraint_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "16461a52-078f-4436-b609-0ae4f37ec212",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_constraint_score = sum(constraint_scores)/len(STReason_df)\n",
    "final_constraint_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "913ef3c2-54a5-43bb-bed7-a8c925caf2ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_constraint_df = pd.DataFrame({\n",
    "    'Constraint Score': constraint_scores})\n",
    "STReason_constraint_df.to_csv('Results/ConstraintScores/STReason_ConstraintResults.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
