{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b6b81cc1-c6d8-4ba1-b2c1-f491e3ca43f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(\n",
    "    os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cca1b6d8-7ae2-4bdc-aeec-006957a260c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env OPENAI_API_KEY=<Enter you key here>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5851b3f0-f18f-4f8c-9fe6-a5df17449337",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "import openai\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from IPython.core.display import HTML\n",
    "from functools import partial\n",
    "\n",
    "from utils import ProgramGenerator, ProgramInterpreter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "867d2c78-753c-455e-a758-376022e11d63",
   "metadata": {},
   "source": [
    "## Prediction Accuracy ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d89e654c-6e4d-46d0-9f8f-48c714acb1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_predictions(gen_answer):\n",
    "    \n",
    "    # Call the OpenAI ChatCompletion API\n",
    "    response = openai.ChatCompletion.create(\n",
    "        model=\"gpt-4\",  # Specify the model version\n",
    "        messages=[\n",
    "            {\"role\": \"user\", \"content\": (\n",
    "                f\"Model generated answer: {gen_answer}\\n\"\n",
    "                f\"From the provided Model generated answer, identify and list any explicit or implied numerical predictions in the following format. If no such predictions are mentioned, return an empty list.\\n\"\n",
    "                f\"Model Generated Predictions:[value1, value2, ...]\"\n",
    "            )}\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=1024,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0\n",
    "    )\n",
    "    return response['choices'][0]['message']['content'].strip()\n",
    "\n",
    "\n",
    "def extract_groundtruths(gt_answer):\n",
    "    ground_truths = re.findall(r\"Ground Truths: \\[([^\\]]+)\\]\", gt_answer)\n",
    "    ground_truth_array = [float(num) for num in ground_truths[0].split(', ')]\n",
    "    return ground_truth_array\n",
    "\n",
    "def extract_predictions(answer_df):\n",
    "    model_predictions = []\n",
    "    for i in range(len(answer_df)):\n",
    "        try:\n",
    "            if i>=50 and i<100:\n",
    "                query = answer_df[\"Query\"][i]\n",
    "                ground_truth_answer = answer_df[\"Answer\"][i]\n",
    "                gen_answer = answer_df[\"Gen_Answer\"][i]\n",
    "                    \n",
    "                result = check_predictions(gen_answer)\n",
    "                numbers = re.findall(r\"\\d+\\.\\d+\", result)\n",
    "                predictions = [float(num) for num in numbers]\n",
    "                model_predictions.append(predictions)\n",
    "    \n",
    "        except Exception as e:\n",
    "            print(f\"An error occurred at {i}: {e}\")\n",
    "            print(f\"Query: {query}\")\n",
    "            print(f\"Generated Answer: {gen_answer}\")\n",
    "            print(f\"Ground Truth Answer: {ground_truth_answer}\")\n",
    "    \n",
    "            continue\n",
    "\n",
    "    return model_predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b3caecc-4eb9-4f25-bd43-9ed54df7a48c",
   "metadata": {},
   "source": [
    "### Extract Ground Truths ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "79202bef-6f2a-4f0a-9568-b05ae5740ed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_df = pd.read_csv('data/STReasonAnswers.csv')\n",
    "\n",
    "ground_truths = []\n",
    "for i in range(len(STReason_df)):\n",
    "    if i>=50 and i<100:\n",
    "        ground_truth_answer = STReason_df[\"Answer\"][i]\n",
    "        gt = extract_groundtruths(ground_truth_answer)\n",
    "        ground_truths.append(gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ee02965c-0c5f-4c08-acd7-5df9be9b45d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ground_truths)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b88c3b4-8cdf-4f2e-a14c-7a2b385e2fd0",
   "metadata": {},
   "source": [
    "### STReason ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "aac10543-1350-4661-aa9b-882eb6dfdcc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "STReason_predictions = extract_predictions(STReason_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "74c42035-03fa-4715-a984-2099c87c0fe1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(STReason_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a95a199-3e73-4af2-b1b7-4b0893ae6d0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ForecastResults_df = pd.DataFrame({\n",
    "    'Ground Truth': ground_truths,\n",
    "    'STReason': STReason_predictions})\n",
    "\n",
    "ForecastResults_df.to_csv('Results/ForecastResults.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c1f57c6-7d30-43bd-84e2-e138c486d3e4",
   "metadata": {},
   "source": [
    "### Calculate Accuracies ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "37a81493-4b46-495e-b79c-e6c22fead323",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import ast\n",
    "\n",
    "forecast_results = pd.read_csv('Results/ForecastResults.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca3ecd85-0014-49da-9a70-73ae11373eaf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: STReason\n",
      "Mean Absolute Error (MAE): 7.632892782977917\n",
      "Root Mean Squared Error (RMSE) 8.41129525420537\n"
     ]
    }
   ],
   "source": [
    "models = ['STReason']\n",
    "results = []\n",
    "for model in models:\n",
    "    mean_absolute_error = []\n",
    "    root_mean_squared_error = []\n",
    "    \n",
    "    for i in range(len(forecast_results)):\n",
    "        predicted = np.array(ast.literal_eval(forecast_results[model][i]))\n",
    "        actual = np.array(ast.literal_eval(forecast_results['Ground Truth'][i]))\n",
    "        \n",
    "        # Check if predicted array is empty\n",
    "        if predicted.size == 0:\n",
    "            # Fill with zeros of the same length as actual\n",
    "            predicted = np.zeros_like(actual)\n",
    "        elif len(predicted) != len(actual):\n",
    "            if len(predicted)<len(actual):\n",
    "                # If lengths differ and predicted is not empty, extend by repeating the last value\n",
    "                last_value = predicted[-1]\n",
    "                needed_length = len(actual) - len(predicted)\n",
    "                predicted = np.concatenate([predicted, np.full(needed_length, last_value)])\n",
    "            elif len(predicted)>len(actual):\n",
    "                predicted = predicted[:len(actual)]\n",
    "\n",
    "        # Calculate MAE\n",
    "        mae = np.mean(np.abs(predicted - actual))\n",
    "        mean_absolute_error.append(mae)\n",
    "        \n",
    "        # Calculate RMSE\n",
    "        rmse = np.sqrt(np.mean((predicted - actual) ** 2))\n",
    "        root_mean_squared_error.append(rmse)\n",
    "    \n",
    "    print(\"Model:\", model)\n",
    "    print(\"Mean Absolute Error (MAE):\", np.mean(mean_absolute_error))\n",
    "    print(\"Root Mean Squared Error (RMSE)\", np.mean(root_mean_squared_error))\n",
    "\n",
    "    # Store results in a list of dictionaries\n",
    "    results.append({\n",
    "        'Model': model,\n",
    "        'MAE': np.mean(mean_absolute_error),\n",
    "        'RMSE': np.mean(root_mean_squared_error)\n",
    "    })"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
