{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/data1/anonymous/codes/med-sipf cache/\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data1/anonymous/misc/miniconda3/envs/m1/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.chdir(\"../..\")\n",
    "import dotenv\n",
    "\n",
    "\n",
    "dotenv.load_dotenv(override=True)\n",
    "\n",
    "print(os.getcwd(), os.getenv(\"HF_HOME\"))\n",
    "\n",
    "import json\n",
    "import time\n",
    "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
    "from functools import partial\n",
    "from hashlib import sha256\n",
    "from pathlib import Path\n",
    "\n",
    "import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/data1/anonymous/codes/med-sipf'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_boxed(s, left=\"\\\\boxed\"):\n",
    "    original_s = s\n",
    "    left = left + \"{\"\n",
    "    try:\n",
    "        assert s[: len(left)] == left\n",
    "        assert s[-1] == \"}\"\n",
    "        answer = s[len(left) : -1]\n",
    "        if \"=\" in answer:\n",
    "            answer = answer.split(\"=\")[-1].lstrip(\" \")\n",
    "        return answer\n",
    "    except Exception:\n",
    "        return original_s\n",
    "\n",
    "\n",
    "def last_boxed_only_string(string, left=\"\\\\boxed\"):\n",
    "    idx = string.rfind(\"\\\\boxed\")\n",
    "    if idx < 0:\n",
    "        idx = string.rfind(\"\\\\fbox\")\n",
    "        if idx < 0:\n",
    "            return string\n",
    "    i = idx\n",
    "    right_brace_idx = None\n",
    "    num_left_braces_open = 0\n",
    "    while i < len(string):\n",
    "        if string[i] == \"{\":\n",
    "            num_left_braces_open += 1\n",
    "        if string[i] == \"}\":\n",
    "            num_left_braces_open -= 1\n",
    "            if num_left_braces_open == 0:\n",
    "                right_brace_idx = i\n",
    "                break\n",
    "        i += 1\n",
    "\n",
    "    if right_brace_idx is None:\n",
    "        retval = string\n",
    "    else:\n",
    "        retval = string[idx : right_brace_idx + 1]\n",
    "\n",
    "    return retval\n",
    "\n",
    "\n",
    "def extract_answer(answer):\n",
    "    # Try to extract content inside \\boxed{}\n",
    "    answer = remove_boxed(last_boxed_only_string(answer))\n",
    "    return remove_boxed(last_boxed_only_string(answer, \"\\\\text\"), \"\\\\text\")\n",
    "\n",
    "\n",
    "def eval_sample(x, prefix=\"\"):\n",
    "    model_answer_string = x[\"distilled_answer_string\"]\n",
    "    answer_letter = x[\"answer_letter\"]\n",
    "    answer_idx = x[\"answer_idx\"]\n",
    "    answer_string = x[\"answer_string\"]\n",
    "\n",
    "    extracted_answer_string = extract_answer(model_answer_string)\n",
    "\n",
    "    try:\n",
    "        if extracted_answer_string.strip().lower() == answer_string.strip().lower():\n",
    "            correct = True\n",
    "        elif (\n",
    "            extracted_answer_string[0].strip().lower() == answer_letter.strip().lower()\n",
    "        ):\n",
    "            correct = True\n",
    "        else:\n",
    "            correct = False\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\\n{model_answer_string}\\n{extracted_answer_string}\")\n",
    "        correct = False\n",
    "    return {\n",
    "        f\"{prefix}extracted_answer_string\": extracted_answer_string,\n",
    "        f\"{prefix}model_answer_string\": model_answer_string,\n",
    "        f\"{prefix}correct\": correct,\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 1249/1249 [00:00<00:00, 8980.87 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "path: anonymous/m196k-random_1k-decon_eval-r1\n",
      "correct: 1078\n",
      "total: 1249\n",
      "accuracy: 0.8630904723779024\n",
      "Dataset({\n",
      "    features: ['answer_idx', 'source', 'metadata', 'prompt', 'answer_letter', 'answer_string', 'reasoning', 'distilled_answer_string', 'r1-extracted_answer_string', 'r1-model_answer_string', 'r1-correct'],\n",
      "    num_rows: 1000\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  8.78ba/s]\n",
      "Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/anonymous/m196k-random_1k-decon_eval-r1-filter_wrong/commit/b44c4d5bd81dfbe072c604a163cc54de8ca0a43e', commit_message='Upload dataset', commit_description='', oid='b44c4d5bd81dfbe072c604a163cc54de8ca0a43e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/anonymous/m196k-random_1k-decon_eval-r1-filter_wrong', endpoint='https://huggingface.co', repo_type='dataset', repo_id='anonymous/m196k-random_1k-decon_eval-r1-filter_wrong'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval_correctness(path, split, prefix):\n",
    "    dataset = datasets.load_dataset(path,split=split)\n",
    "\n",
    "    mapped_dataset = dataset.map(partial(eval_sample, prefix=prefix), keep_in_memory=True, remove_columns=dataset.column_names)\n",
    "    return mapped_dataset\n",
    "\n",
    "path_list = [\"anonymous/m196k-random_1k-decon_eval-r1\"]\n",
    "prefix_list = [\"r1-\"]\n",
    "split=\"train\"\n",
    "\n",
    "eval_correctness_list = []\n",
    "\n",
    "for path, prefix in zip(path_list, prefix_list):\n",
    "\n",
    "    mapped_dataset = eval_correctness(path, split, prefix)\n",
    "    eval_correctness_list.append(mapped_dataset)\n",
    "\n",
    "    correct_list=mapped_dataset[f\"{prefix}correct\"]\n",
    "    print(f\"path: {path}\")\n",
    "    print(f\"correct: {sum(correct_list)}\")\n",
    "    print(f\"total: {len(correct_list)}\")\n",
    "    print(f\"accuracy: {sum(correct_list)/len(correct_list)}\")\n",
    "\n",
    "path=\"anonymous/m196k-random_1k-decon_eval-r1\"\n",
    "split=\"train\"\n",
    "\n",
    "raw_dataset = datasets.load_dataset(path,split=split)\n",
    "merged_dataset = datasets.concatenate_datasets([raw_dataset] + eval_correctness_list, axis=1)\n",
    "def filter_false(x, prefix_list=[\"r1-\"]):\n",
    "    for prefix in prefix_list:\n",
    "        if x[f\"{prefix}correct\"]:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "filtered_dataset = merged_dataset.filter(filter_false)\n",
    "num_samples = 1000\n",
    "filtered_dataset = filtered_dataset.select(range(num_samples))\n",
    "print(filtered_dataset)\n",
    "filtered_dataset.push_to_hub(\"anonymous/m196k-random_1k-decon_eval-r1-filter_wrong\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  32%|███▏      | 9266/29183 [00:00<00:01, 10473.03 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  50%|█████     | 14646/29183 [00:01<00:01, 11592.53 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  85%|████████▌ | 24844/29183 [00:02<00:00, 11884.86 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 29183/29183 [00:02<00:00, 11387.34 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "path: anonymous/m196k-random_23k-decon_eval-r1\n",
      "correct: 25397\n",
      "total: 29183\n",
      "accuracy: 0.870266936229997\n",
      "Dataset({\n",
      "    features: ['answer_idx', 'source', 'metadata', 'prompt', 'answer_letter', 'answer_string', 'reasoning', 'distilled_answer_string', 'r1-extracted_answer_string', 'r1-model_answer_string', 'r1-correct'],\n",
      "    num_rows: 23493\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating parquet from Arrow format: 100%|██████████| 24/24 [00:01<00:00, 12.53ba/s]\n",
      "Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.56s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/anonymous/m196k-random_23k-decon_eval-r1-filter_wrong/commit/16777586cac0d25b6d4990d48e19349a5e69e095', commit_message='Upload dataset', commit_description='', oid='16777586cac0d25b6d4990d48e19349a5e69e095', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/anonymous/m196k-random_23k-decon_eval-r1-filter_wrong', endpoint='https://huggingface.co', repo_type='dataset', repo_id='anonymous/m196k-random_23k-decon_eval-r1-filter_wrong'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval_correctness(path, split, prefix):\n",
    "    dataset = datasets.load_dataset(path,split=split)\n",
    "\n",
    "    mapped_dataset = dataset.map(partial(eval_sample, prefix=prefix), keep_in_memory=True, remove_columns=dataset.column_names)\n",
    "    return mapped_dataset\n",
    "\n",
    "path_list = [\"anonymous/m196k-random_23k-decon_eval-r1\"]\n",
    "prefix_list = [\"r1-\"]\n",
    "split=\"train\"\n",
    "\n",
    "eval_correctness_list = []\n",
    "\n",
    "for path, prefix in zip(path_list, prefix_list):\n",
    "\n",
    "    mapped_dataset = eval_correctness(path, split, prefix)\n",
    "    eval_correctness_list.append(mapped_dataset)\n",
    "\n",
    "    correct_list=mapped_dataset[f\"{prefix}correct\"]\n",
    "    print(f\"path: {path}\")\n",
    "    print(f\"correct: {sum(correct_list)}\")\n",
    "    print(f\"total: {len(correct_list)}\")\n",
    "    print(f\"accuracy: {sum(correct_list)/len(correct_list)}\")\n",
    "\n",
    "path=\"anonymous/m196k-random_23k-decon_eval-r1\"\n",
    "split=\"train\"\n",
    "\n",
    "raw_dataset = datasets.load_dataset(path,split=split)\n",
    "merged_dataset = datasets.concatenate_datasets([raw_dataset] + eval_correctness_list, axis=1)\n",
    "def filter_false(x, prefix_list=[\"r1-\"]):\n",
    "    for prefix in prefix_list:\n",
    "        if x[f\"{prefix}correct\"]:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "filtered_dataset = merged_dataset.filter(filter_false)\n",
    "num_samples = 23493\n",
    "filtered_dataset = filtered_dataset.select(range(num_samples))\n",
    "print(filtered_dataset)\n",
    "filtered_dataset.push_to_hub(\"anonymous/m196k-random_23k-decon_eval-r1-filter_wrong\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "m1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
