{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ad711007",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import analysis\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9531359b",
   "metadata": {},
   "outputs": [],
   "source": [
    "LOG_PATH = Path('..', 'log').with_suffix('.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2822f826",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = analysis.get_all_files_matching(\n",
    "    model_name='/public/hf/models/meta-llama/Meta-Llama-3.1-8B-Instruct/',\n",
    "    split='answerable',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc2f5c43",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = Path('..', 'runs', f'{Path(files[\"filename\"][0]).name}')\n",
    "with p.open('r') as f:\n",
    "    df = pd.read_json(f, lines=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed354373",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = analysis.remove_no_answer_rows(df=df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b3f730cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['model_answer'] = df['messages'].apply(analysis.get_model_final_answer)\n",
    "df['correct'] = df.apply(analysis.check_final_answer, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0c1999e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>answer</th>\n",
       "      <th>model_answer</th>\n",
       "      <th>correct</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.222426e+01</td>\n",
       "      <td>32.2242574604399</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.728116e+00</td>\n",
       "      <td>4.728116333076132</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.796131e+10</td>\n",
       "      <td>37961311304.52</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.414000e+01</td>\n",
       "      <td>14.14</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2.340627e+00</td>\n",
       "      <td>2.3406274391271746</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>5.392222e+00</td>\n",
       "      <td>5.39</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>5.954529e+01</td>\n",
       "      <td>55.071119960256986</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>2.800364e+00</td>\n",
       "      <td>2.8003636363636364</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>1.000000e+02</td>\n",
       "      <td>100</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>6.732560e+01</td>\n",
       "      <td>67.3256</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>78 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          answer        model_answer  correct\n",
       "0   3.222426e+01    32.2242574604399     True\n",
       "1   4.728116e+00   4.728116333076132    False\n",
       "2   3.796131e+10      37961311304.52     True\n",
       "4   1.414000e+01               14.14     True\n",
       "6   2.340627e+00  2.3406274391271746     True\n",
       "..           ...                 ...      ...\n",
       "95  5.392222e+00                5.39    False\n",
       "96  5.954529e+01  55.071119960256986    False\n",
       "97  2.800364e+00  2.8003636363636364    False\n",
       "98  1.000000e+02                 100    False\n",
       "99  6.732560e+01             67.3256     True\n",
       "\n",
       "[78 rows x 3 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[['answer', 'model_answer', 'correct']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9584bf3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['gold_tools'] = df['actions'].apply(lambda actions: analysis.get_gold_tool_calls(actions))\n",
    "df['model_tools'] = df['messages'].apply(lambda messages: analysis.get_model_tool_calls(messages))\n",
    "df['precision'] = df.apply(lambda row: analysis.precision(row['gold_tools'], row['model_tools']), axis=1)\n",
    "df['recall'] = df.apply(lambda row: analysis.recall(row['gold_tools'], row['model_tools']), axis=1)\n",
    "df['accuracy'] = df.apply(lambda row: analysis.accuracy(row['gold_tools'], row['model_tools']), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "44b56a5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.437500</td>\n",
       "      <td>0.777778</td>\n",
       "      <td>0.388889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.750000</td>\n",
       "      <td>0.857143</td>\n",
       "      <td>0.666667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.388889</td>\n",
       "      <td>0.777778</td>\n",
       "      <td>0.350000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.458333</td>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.423077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.791667</td>\n",
       "      <td>0.863636</td>\n",
       "      <td>0.703704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>0.733333</td>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.647059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>0.360000</td>\n",
       "      <td>0.692308</td>\n",
       "      <td>0.310345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>0.812500</td>\n",
       "      <td>0.866667</td>\n",
       "      <td>0.722222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>0.733333</td>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.647059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>0.636364</td>\n",
       "      <td>0.777778</td>\n",
       "      <td>0.538462</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>78 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    precision    recall  accuracy\n",
       "0    0.437500  0.777778  0.388889\n",
       "1    0.750000  0.857143  0.666667\n",
       "2    0.388889  0.777778  0.350000\n",
       "4    0.458333  0.846154  0.423077\n",
       "6    0.791667  0.863636  0.703704\n",
       "..        ...       ...       ...\n",
       "95   0.733333  0.846154  0.647059\n",
       "96   0.360000  0.692308  0.310345\n",
       "97   0.812500  0.866667  0.722222\n",
       "98   0.733333  0.846154  0.647059\n",
       "99   0.636364  0.777778  0.538462\n",
       "\n",
       "[78 rows x 3 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[['precision', 'recall', 'accuracy']]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
