{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -e .. datasets sympy numpy matplotlib seaborn -q  # Install dev version of smolagents + some packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Benchmark date\n",
    "# - set a concrete date:\n",
    "DATE = \"2024-12-26\"\n",
    "# - or use default: today\n",
    "# DATE = None\n",
    "\n",
    "# Evaluation dataset\n",
    "# - the dataset is gated, so you must first visit its page to request access: https://huggingface.co/datasets/smolagents-benchmark/benchmark-v1\n",
    "EVAL_DATASET = \"smolagents/benchmark-v1\"\n",
    "\n",
    "# Answers dataset: it must be a gated dataset; required to score the answers\n",
    "ANSWERS_DATASET = \"smolagents/answers\"\n",
    "# Whether to push the answers dataset to the Hub\n",
    "PUSH_ANSWERS_DATASET_TO_HUB = True\n",
    "\n",
    "# Results dataset\n",
    "RESULTS_DATASET = \"smolagents/results\"\n",
    "# Whether to push the results dataset to the Hub\n",
    "PUSH_RESULTS_DATASET_TO_HUB = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constants and utilities/tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import string\n",
    "import warnings\n",
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "from datetime import datetime\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def normalize_number_str(number_str: str) -> float:\n",
    "    # we replace these common units and commas to allow\n",
    "    # conversion to float\n",
    "    for char in [\"$\", \"%\", \",\"]:\n",
    "        number_str = number_str.replace(char, \"\")\n",
    "    try:\n",
    "        return float(number_str)\n",
    "    except ValueError:\n",
    "        return float(\"inf\")\n",
    "\n",
    "\n",
    "def split_string(\n",
    "    s: str,\n",
    "    char_list: list[str] = [\",\", \";\"],\n",
    ") -> list[str]:\n",
    "    pattern = f\"[{''.join(char_list)}]\"\n",
    "    return re.split(pattern, s)\n",
    "\n",
    "\n",
    "def is_float(element: any) -> bool:\n",
    "    try:\n",
    "        float(element)\n",
    "        return True\n",
    "    except ValueError:\n",
    "        return False\n",
    "\n",
    "\n",
    "def normalize_str(input_str, remove_punct=True) -> str:\n",
    "    \"\"\"\n",
    "    Normalize a string by:\n",
    "    - Removing all white spaces\n",
    "    - Optionally removing punctuation (if remove_punct is True)\n",
    "    - Converting to lowercase\n",
    "    Parameters:\n",
    "    - input_str: str, the string to normalize\n",
    "    - remove_punct: bool, whether to remove punctuation (default: True)\n",
    "    Returns:\n",
    "    - str, the normalized string\n",
    "    \"\"\"\n",
    "    # Remove all white spaces. Required e.g for seagull vs. sea gull\n",
    "    no_spaces = re.sub(r\"\\s\", \"\", input_str)\n",
    "\n",
    "    # Remove punctuation, if specified.\n",
    "    if remove_punct:\n",
    "        translator = str.maketrans(\"\", \"\", string.punctuation)\n",
    "        return no_spaces.lower().translate(translator)\n",
    "    else:\n",
    "        return no_spaces.lower()\n",
    "\n",
    "\n",
    "def extract_numbers(text: str) -> list[str]:\n",
    "    \"\"\"This pattern matches:\n",
    "    - Optional negative sign\n",
    "    - Numbers with optional comma thousand separators\n",
    "    - Optional decimal points with decimal numbers\n",
    "    \"\"\"\n",
    "    pattern = r\"-?(?:\\d{1,3}(?:,\\d{3})+|\\d+)(?:\\.\\d+)?\"\n",
    "\n",
    "    return [el.replace(\",\", \"\") for el in re.findall(pattern, text)]\n",
    "\n",
    "\n",
    "def get_question_score_gaia(\n",
    "    model_answer: str,\n",
    "    ground_truth: str,\n",
    ") -> bool:\n",
    "    \"\"\"Scoring function used to score functions from the GAIA benchmark\"\"\"\n",
    "    if is_float(ground_truth):\n",
    "        normalized_answer = normalize_number_str(str(model_answer))\n",
    "        return normalized_answer == float(ground_truth)\n",
    "\n",
    "    elif any(char in ground_truth for char in [\",\", \";\"]):  # if gt is a list\n",
    "        # question with the fish: normalization removes punct\n",
    "        gt_elems = split_string(ground_truth)\n",
    "        ma_elems = split_string(model_answer)\n",
    "\n",
    "        if len(gt_elems) != len(ma_elems):  # check length is the same\n",
    "            warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
    "            return False\n",
    "\n",
    "        comparisons = []\n",
    "        for ma_elem, gt_elem in zip(ma_elems, gt_elems):  # compare each element as float or str\n",
    "            if is_float(gt_elem):\n",
    "                normalized_ma_elem = normalize_number_str(ma_elem)\n",
    "                comparisons.append(normalized_ma_elem == float(gt_elem))\n",
    "            else:\n",
    "                # we do not remove punct since comparisons can include punct\n",
    "                comparisons.append(\n",
    "                    normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n",
    "                )\n",
    "        return all(comparisons)\n",
    "\n",
    "    else:  # if gt is a str\n",
    "        return normalize_str(model_answer) == normalize_str(ground_truth)\n",
    "\n",
    "\n",
    "def get_correct(row):\n",
    "    if row[\"source\"] == \"MATH\":  # Checks the last number in answer\n",
    "        numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
    "        if len(numbers_answer) == 0:\n",
    "            return False\n",
    "        return np.isclose(float(numbers_answer[-1]), float(row[\"true_answer\"]), rtol=1e-5, atol=1e-7)\n",
    "    else:\n",
    "        return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
    "\n",
    "\n",
    "def score_answers_subset(answers_dataset, answers_subset):\n",
    "    try:\n",
    "        print(answers_dataset, answers_subset)\n",
    "        *model_id, action_type, task = answers_subset.split(\"__\")\n",
    "        model_id = \"/\".join(model_id)\n",
    "        ds = datasets.load_dataset(answers_dataset, answers_subset, split=\"test\")\n",
    "        df = ds.to_pandas()\n",
    "        df[\"correct\"] = df.apply(get_correct, axis=1)\n",
    "        assert df[\"correct\"].notnull().sum() > 30, \"Missing answers\"\n",
    "        acc = df[\"correct\"].mean().item()\n",
    "        result = df.loc[0, [\"model_id\", \"agent_action_type\", \"source\"]].to_dict()\n",
    "        result[\"acc\"] = acc\n",
    "        return result\n",
    "    except Exception as e:\n",
    "        print(f\"Error with {answers_subset}: {e}\")\n",
    "        return None\n",
    "\n",
    "\n",
    "def score_answers(\n",
    "    answers_subsets,\n",
    "    answers_dataset=ANSWERS_DATASET,\n",
    "    date=DATE,\n",
    "    push_to_hub_dataset=RESULTS_DATASET if PUSH_RESULTS_DATASET_TO_HUB else None,\n",
    "    set_default=True,\n",
    "):\n",
    "    \"\"\"\n",
    "    Score answers from the given dataset subsets.\n",
    "\n",
    "    Parameters:\n",
    "        answers_subsets: List of dataset subsets to score\n",
    "        answers_dataset: Dataset containing the answers\n",
    "        date: Date to use for the config name\n",
    "        push_to_hub_dataset: Dataset ID to push results to, or None to skip pushing\n",
    "        set_default: If True, sets this config as the default config in the Hugging Face Hub dataset.\n",
    "                     This means when users load the dataset without specifying a config,\n",
    "                     this version will be loaded by default.\n",
    "    \"\"\"\n",
    "    if not answers_dataset:\n",
    "        raise ValueError(\"Pass 'answers_dataset' to load the answers from it\")\n",
    "    date = date or datetime.date.today().isoformat()\n",
    "    results = []\n",
    "    with ThreadPoolExecutor(max_workers=16) as exe:\n",
    "        futures = [\n",
    "            exe.submit(score_answers_subset, answers_dataset, answers_subset) for answers_subset in answers_subsets\n",
    "        ]\n",
    "        for f in tqdm(as_completed(futures), total=len(answers_subsets), desc=\"Processing tasks\"):\n",
    "            result = f.result()\n",
    "            if result:\n",
    "                results.append(result)\n",
    "    df = pd.DataFrame(results)\n",
    "\n",
    "    if push_to_hub_dataset:\n",
    "        ds = datasets.Dataset.from_pandas(df)\n",
    "        config = date\n",
    "        ds.push_to_hub(push_to_hub_dataset, config_name=config, commit_message=f\"Upload {config} results\")\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Score answers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "# Choose the answers subsets to score:\n",
    "# answers_subsets = [\"meta-llama__Llama-3.1-8B-Instruct__code__gaia\"]\n",
    "# or get all the answers subsets present in the ANSWERS_DATASET\n",
    "answers_subsets = datasets.get_dataset_config_names(ANSWERS_DATASET)\n",
    "print(\"Number of answers_subsets\", len(answers_subsets))\n",
    "print(\"Example of answers_subset\", answers_subsets[0])\n",
    "\n",
    "result_df = score_answers(answers_subsets)\n",
    "result_df[\"acc\"] = (result_df[\"acc\"] * 100).round(2)\n",
    "result_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "pivot_df = result_df.pivot_table(\n",
    "    index=[\"model_id\", \"source\"],\n",
    "    columns=[\"agent_action_type\"],\n",
    "    values=\"acc\",\n",
    "    fill_value=float(\"nan\"),\n",
    ").reset_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(pivot_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.legend_handler import HandlerTuple  # Added import\n",
    "\n",
    "\n",
    "# Assuming pivot_df is your original dataframe\n",
    "models = pivot_df[\"model_id\"].unique()\n",
    "sources = pivot_df[\"source\"].unique()\n",
    "\n",
    "# Create figure and axis\n",
    "plt.style.use(\"seaborn-v0_8-white\")\n",
    "fig, ax = plt.subplots(figsize=(15, 6))\n",
    "\n",
    "# Set the width of each bar group and positions of the bars\n",
    "width = 0.15  # width of each bar\n",
    "spacing = 0.02  # space between bars within a group\n",
    "group_spacing = 0.2  # space between model groups\n",
    "\n",
    "# Calculate positions for the bars\n",
    "num_sources = len(sources)\n",
    "total_width_per_group = (width + spacing) * num_sources * 2  # *2 for agent and vanilla\n",
    "x = np.arange(len(models)) * (total_width_per_group + group_spacing)\n",
    "\n",
    "# Plot bars for each source\n",
    "for i, source in enumerate(sources):\n",
    "    source_data = pivot_df[pivot_df[\"source\"] == source]\n",
    "    agent_scores = [\n",
    "        source_data[source_data[\"model_id\"] == model][\"code\"].values[0]\n",
    "        if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
    "        else np.nan\n",
    "        for model in models\n",
    "    ]\n",
    "    vanilla_scores = [\n",
    "        source_data[source_data[\"model_id\"] == model][\"vanilla\"].values[0]\n",
    "        if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
    "        else np.nan\n",
    "        for model in models\n",
    "    ]\n",
    "\n",
    "    # Position calculation for each pair of bars\n",
    "    pos = x + i * (width * 2 + spacing)\n",
    "\n",
    "    agent_bars = ax.bar(pos, agent_scores, width, label=f\"{source} (Agent)\", alpha=0.8)\n",
    "    vanilla_bars = ax.bar(\n",
    "        pos + width * 0.6,\n",
    "        vanilla_scores,\n",
    "        width,\n",
    "        hatch=\"////\",\n",
    "        alpha=0.5,\n",
    "        hatch_linewidth=2,\n",
    "        label=f\"{source} (Vanilla)\",\n",
    "        color=\"white\",\n",
    "        edgecolor=agent_bars[0].get_facecolor(),\n",
    "    )\n",
    "\n",
    "# Customize the plot\n",
    "ax.set_ylabel(\"Score\")\n",
    "ax.set_title(\"Model Performance Comparison\")\n",
    "\n",
    "# Set x-axis ticks in the middle of each group\n",
    "group_centers = x + (total_width_per_group - spacing) / 2\n",
    "ax.set_xticks(group_centers)\n",
    "\n",
    "# Wrap long model names to prevent overlap\n",
    "wrapped_labels = [\"\\n\".join(model.split(\"/\")) for model in models]\n",
    "ax.set_xticklabels(wrapped_labels, rotation=0, ha=\"center\")\n",
    "\n",
    "# Modify legend to combine agent and vanilla entries\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "unique_sources = sources\n",
    "legend_elements = [\n",
    "    (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n",
    "]\n",
    "custom_legend = ax.legend(\n",
    "    [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n",
    "    [label for _, _, label in legend_elements],\n",
    "    handler_map={tuple: HandlerTuple(ndivide=None)},\n",
    "    bbox_to_anchor=(1.05, 1),\n",
    "    loc=\"upper left\",\n",
    ")\n",
    "\n",
    "ax.yaxis.grid(True, linestyle=\"--\", alpha=0.3)\n",
    "ax.set_ylim(bottom=0)\n",
    "plt.tight_layout()\n",
    "ax.spines[\"top\"].set_visible(False)\n",
    "ax.spines[\"right\"].set_visible(False)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "agents",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
