{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'60.60'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# Change working directory to the root of the project\n",
    "os.chdir(\"..\")\n",
    "\n",
    "table = \"\"\"\n",
    "| Year | Competition | Venue | Position | Event | Notes |\n",
    "| 2000 | World Junior Championships | Santiago, Chile | 1st | Discus throw | 59.51 m |\n",
    "| 2003 | All-Africa Games | Abuja, Nigeria | 5th | Shot put | 17.76 m |\n",
    "| 2003 | All-Africa Games | Abuja, Nigeria | 2nd | Discus throw | 62.86 m |\n",
    "| 2004 | African Championships | Brazzaville, Republic of the Congo | 2nd | Discus throw | 63.50 m |\n",
    "| 2004 | Olympic Games | Athens, Greece | 8th | Discus throw | 62.58 m |\n",
    "| 2006 | Commonwealth Games | Melbourne, Australia | 7th | Shot put | 18.44 m |\n",
    "| 2006 | Commonwealth Games | Melbourne, Australia | 4th | Discus throw | 60.99 m |\n",
    "| 2007 | All-Africa Games | Algiers, Algeria | 3rd | Discus throw | 57.79 m |\n",
    "| 2008 | African Championships | Addis Ababa, Ethiopia | 2nd | Discus throw | 56.98 m |\n",
    "\"\"\"\n",
    "\n",
    "title = \"Hannes Hopley\"\n",
    "question = \"What is the average distance of the discus throw for Hannes Hopley? (in meters, rounded to 2 decimal places)\"\n",
    "\n",
    "answer = (59.51 + 62.86 + 63.50 + 62.58 + 60.99 + 57.79 + 56.98) / 7\n",
    "answer = f\"{answer:.2f}\"\n",
    "\n",
    "answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Year</th>\n",
       "      <th>Competition</th>\n",
       "      <th>Venue</th>\n",
       "      <th>Position</th>\n",
       "      <th>Event</th>\n",
       "      <th>Notes</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2000</td>\n",
       "      <td>World Junior Championships</td>\n",
       "      <td>Santiago, Chile</td>\n",
       "      <td>1st</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>59.51 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2003</td>\n",
       "      <td>All-Africa Games</td>\n",
       "      <td>Abuja, Nigeria</td>\n",
       "      <td>5th</td>\n",
       "      <td>Shot put</td>\n",
       "      <td>17.76 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2003</td>\n",
       "      <td>All-Africa Games</td>\n",
       "      <td>Abuja, Nigeria</td>\n",
       "      <td>2nd</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>62.86 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2004</td>\n",
       "      <td>African Championships</td>\n",
       "      <td>Brazzaville, Republic of the Congo</td>\n",
       "      <td>2nd</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>63.50 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2004</td>\n",
       "      <td>Olympic Games</td>\n",
       "      <td>Athens, Greece</td>\n",
       "      <td>8th</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>62.58 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>2006</td>\n",
       "      <td>Commonwealth Games</td>\n",
       "      <td>Melbourne, Australia</td>\n",
       "      <td>7th</td>\n",
       "      <td>Shot put</td>\n",
       "      <td>18.44 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2006</td>\n",
       "      <td>Commonwealth Games</td>\n",
       "      <td>Melbourne, Australia</td>\n",
       "      <td>4th</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>60.99 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>2007</td>\n",
       "      <td>All-Africa Games</td>\n",
       "      <td>Algiers, Algeria</td>\n",
       "      <td>3rd</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>57.79 m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2008</td>\n",
       "      <td>African Championships</td>\n",
       "      <td>Addis Ababa, Ethiopia</td>\n",
       "      <td>2nd</td>\n",
       "      <td>Discus throw</td>\n",
       "      <td>56.98 m</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Year                 Competition                               Venue  \\\n",
       "0  2000  World Junior Championships                     Santiago, Chile   \n",
       "1  2003            All-Africa Games                      Abuja, Nigeria   \n",
       "2  2003            All-Africa Games                      Abuja, Nigeria   \n",
       "3  2004       African Championships  Brazzaville, Republic of the Congo   \n",
       "4  2004               Olympic Games                      Athens, Greece   \n",
       "5  2006          Commonwealth Games                Melbourne, Australia   \n",
       "6  2006          Commonwealth Games                Melbourne, Australia   \n",
       "7  2007            All-Africa Games                    Algiers, Algeria   \n",
       "8  2008       African Championships               Addis Ababa, Ethiopia   \n",
       "\n",
       "  Position         Event    Notes  \n",
       "0      1st  Discus throw  59.51 m  \n",
       "1      5th      Shot put  17.76 m  \n",
       "2      2nd  Discus throw  62.86 m  \n",
       "3      2nd  Discus throw  63.50 m  \n",
       "4      8th  Discus throw  62.58 m  \n",
       "5      7th      Shot put  18.44 m  \n",
       "6      4th  Discus throw  60.99 m  \n",
       "7      3rd  Discus throw  57.79 m  \n",
       "8      2nd  Discus throw  56.98 m  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from utils.execute import markdown_to_df\n",
    "\n",
    "df = markdown_to_df(table)\n",
    "\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/tianyang/anaconda3/envs/table/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2023-12-29 01:28:22,287\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "2023-12-29 01:28:22,369\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    }
   ],
   "source": [
    "from agent import Model\n",
    "\n",
    "# set openai api key\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
    "\n",
    "# openai model\n",
    "model = Model(model_name=\"gpt-3.5-turbo-0613\", provider=\"openai\")\n",
    "long_model = Model(model_name=\"gpt-3.5-turbo-16k-0613\", provider=\"openai\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agent import TableAgent\n",
    "\n",
    "agent = TableAgent(\n",
    "    table=df,\n",
    "    prompt_type=\"wtq\", # since we are trying to answer a question, we use wtq\n",
    "    model=model,\n",
    "    long_model=long_model,\n",
    "    temperature=0.0, # no randomness\n",
    "    use_full_table=True, # do not omit any rows\n",
    "    print_process=True, # print the process\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "You are working with a pandas dataframe in Python. The name of the dataframe is `df`. Your task is to use `python_repl_ast` to answer the question posed to you.\n",
      "\n",
      "Tool description:\n",
      "- `python_repl_ast`: A Python shell. Use this to execute python commands. Input should be a valid python command. When using this tool, sometimes the output is abbreviated - ensure it does not appear abbreviated before using it in your answer.\n",
      "\n",
      "Guidelines:\n",
      "- **Aggregated Rows**: Be cautious of rows that aggregate data such as 'total', 'sum', or 'average'. Ensure these rows do not influence your results inappropriately.\n",
      "- **Data Verification**: Before concluding the final answer, always verify that your observations align with the original table and question.\n",
      "\n",
      "Strictly follow the given format to respond:\n",
      "\n",
      "Question: the input question you must answer\n",
      "Thought: you should always think about what to do to interact with `python_repl_ast`\n",
      "Action: can **ONLY** be `python_repl_ast`\n",
      "Action Input: the input code to the action\n",
      "Observation: the result of the action\n",
      "... (this Thought/Action/Action Input/Observation can repeat N times)\n",
      "Thought: after verifying the table, observations, and the question, I am confident in the final answer\n",
      "Final Answer: the final answer to the original input question (AnswerName1, AnswerName2...)\n",
      "\n",
      "Notes for final answer:\n",
      "- Ensure the final answer format is only \"Final Answer: AnswerName1, AnswerName2...\" form, no other form. \n",
      "- Ensure the final answer is a number or entity names, as short as possible, without any explanation.\n",
      "- Ensure to have a concluding thought that verifies the table, observations and the question before giving the final answer.\n",
      "\n",
      "You are provided with a table regarding \"Hannes Hopley\". This is the result of `print(df.to_markdown())`:\n",
      "\n",
      "|    |   Year | Competition                | Venue                              | Position   | Event        | Notes   |\n",
      "|---:|-------:|:---------------------------|:-----------------------------------|:-----------|:-------------|:--------|\n",
      "|  0 |   2000 | World Junior Championships | Santiago, Chile                    | 1st        | Discus throw | 59.51 m |\n",
      "|  1 |   2003 | All-Africa Games           | Abuja, Nigeria                     | 5th        | Shot put     | 17.76 m |\n",
      "|  2 |   2003 | All-Africa Games           | Abuja, Nigeria                     | 2nd        | Discus throw | 62.86 m |\n",
      "|  3 |   2004 | African Championships      | Brazzaville, Republic of the Congo | 2nd        | Discus throw | 63.50 m |\n",
      "|  4 |   2004 | Olympic Games              | Athens, Greece                     | 8th        | Discus throw | 62.58 m |\n",
      "|  5 |   2006 | Commonwealth Games         | Melbourne, Australia               | 7th        | Shot put     | 18.44 m |\n",
      "|  6 |   2006 | Commonwealth Games         | Melbourne, Australia               | 4th        | Discus throw | 60.99 m |\n",
      "|  7 |   2007 | All-Africa Games           | Algiers, Algeria                   | 3rd        | Discus throw | 57.79 m |\n",
      "|  8 |   2008 | African Championships      | Addis Ababa, Ethiopia              | 2nd        | Discus throw | 56.98 m |\n",
      "\n",
      "**Note**: All cells in the table should be considered as `object` data type, regardless of their appearance.\n",
      "\n",
      "Begin!\n",
      "Question: What is the average distance of the discus throw for Hannes Hopley? (in meters, rounded to 2 decimal places)\n",
      "Thought: To find the average distance of the discus throw for Hannes Hopley, I need to filter the dataframe for rows where the Event is \"Discus throw\" and then calculate the average of the \"Notes\" column.\n",
      "\n",
      "Action: Filter the dataframe and calculate the average distance.\n",
      "\n",
      "Action Input: `df[df['Event'] == 'Discus throw']['Notes'].astype(float).mean()`\n",
      "\n",
      "Observation: ValueError: could not convert string to float: '59.51 m'\n",
      "\n",
      "Thought: It seems that the \"Notes\" column contains values with the unit \"m\" attached to them. I need to remove the \"m\" from the values before converting them to float.\n",
      "\n",
      "Action: Remove the \"m\" from the values in the \"Notes\" column and calculate the average distance.\n",
      "\n",
      "Action Input: `df[df['Event'] == 'Discus throw']['Notes'].str.replace(' m', '').astype(float).mean()`\n",
      "\n",
      "Observation: 60.60142857142858\n",
      "\n",
      "Final Answer: 60.60"
     ]
    }
   ],
   "source": [
    "# now we can run the agent\n",
    "text, _ = agent.run(question=question, title=title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'60.60'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from utils.eval import extract_answer, eval_ex_match\n",
    "\n",
    "# extract the answer from the text\n",
    "pred = extract_answer(text)\n",
    "\n",
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_ex_match(pred, answer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "table",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
