{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/data1/anonymous/codes/med-sipf cache/\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data1/anonymous/misc/miniconda3/envs/anonymous-med_sipf-1/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.chdir(\"../..\")\n",
    "import dotenv\n",
    "\n",
    "\n",
    "dotenv.load_dotenv(override=True)\n",
    "\n",
    "print(os.getcwd(), os.getenv(\"HF_HOME\"))\n",
    "\n",
    "import json\n",
    "import time\n",
    "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
    "from functools import partial\n",
    "from hashlib import sha256\n",
    "from pathlib import Path\n",
    "\n",
    "import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/data1/anonymous/codes/med-sipf'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_boxed(s, left=\"\\\\boxed\"):\n",
    "    original_s = s\n",
    "    left = left + \"{\"\n",
    "    try:\n",
    "        assert s[: len(left)] == left\n",
    "        assert s[-1] == \"}\"\n",
    "        answer = s[len(left) : -1]\n",
    "        if \"=\" in answer:\n",
    "            answer = answer.split(\"=\")[-1].lstrip(\" \")\n",
    "        return answer\n",
    "    except Exception:\n",
    "        return original_s\n",
    "\n",
    "\n",
    "def last_boxed_only_string(string, left=\"\\\\boxed\"):\n",
    "    idx = string.rfind(\"\\\\boxed\")\n",
    "    if idx < 0:\n",
    "        idx = string.rfind(\"\\\\fbox\")\n",
    "        if idx < 0:\n",
    "            return string\n",
    "    i = idx\n",
    "    right_brace_idx = None\n",
    "    num_left_braces_open = 0\n",
    "    while i < len(string):\n",
    "        if string[i] == \"{\":\n",
    "            num_left_braces_open += 1\n",
    "        if string[i] == \"}\":\n",
    "            num_left_braces_open -= 1\n",
    "            if num_left_braces_open == 0:\n",
    "                right_brace_idx = i\n",
    "                break\n",
    "        i += 1\n",
    "\n",
    "    if right_brace_idx is None:\n",
    "        retval = string\n",
    "    else:\n",
    "        retval = string[idx : right_brace_idx + 1]\n",
    "\n",
    "    return retval\n",
    "\n",
    "\n",
    "def extract_answer(answer):\n",
    "    # Try to extract content inside \\boxed{}\n",
    "    answer = remove_boxed(last_boxed_only_string(answer))\n",
    "    return remove_boxed(last_boxed_only_string(answer, \"\\\\text\"), \"\\\\text\")\n",
    "\n",
    "\n",
    "def eval_sample(x, prefix=\"\"):\n",
    "    model_answer_string = x[\"distilled_answer_string\"]\n",
    "    answer_letter = x[\"answer_letter\"]\n",
    "    answer_idx = x[\"answer_idx\"]\n",
    "    answer_string = x[\"answer_string\"]\n",
    "\n",
    "    extracted_answer_string = extract_answer(model_answer_string)\n",
    "\n",
    "    try:\n",
    "        if extracted_answer_string.strip().lower() == answer_string.strip().lower():\n",
    "            correct = True\n",
    "        elif (\n",
    "            extracted_answer_string[0].strip().lower() == answer_letter.strip().lower()\n",
    "        ):\n",
    "            correct = True\n",
    "        else:\n",
    "            correct = False\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\\n{model_answer_string}\\n{extracted_answer_string}\")\n",
    "        correct = False\n",
    "    return {\n",
    "        f\"{prefix}extracted_answer_string\": extracted_answer_string,\n",
    "        f\"{prefix}model_answer_string\": model_answer_string,\n",
    "        f\"{prefix}correct\": correct,\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating train split: 100%|██████████| 37815/37815 [00:01<00:00, 28526.09 examples/s]\n",
      "Map:   4%|▍         | 1699/37815 [00:00<00:04, 8538.98 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  26%|██▌       | 9871/37815 [00:01<00:02, 9576.47 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  29%|██▊       | 10835/37815 [00:01<00:02, 9591.76 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  37%|███▋      | 14119/37815 [00:01<00:02, 8302.31 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  47%|████▋     | 17952/37815 [00:02<00:02, 9232.59 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  56%|█████▌    | 21182/37815 [00:02<00:01, 9280.88 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  69%|██████▊   | 25986/37815 [00:02<00:01, 9578.24 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  75%|███████▌  | 28370/37815 [00:03<00:00, 9557.51 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  81%|████████▏ | 30780/37815 [00:03<00:00, 9601.81 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n",
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  93%|█████████▎| 35034/37815 [00:03<00:00, 9486.81 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: string index out of range\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 37815/37815 [00:04<00:00, 9207.54 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "path: anonymous/m196k-dedup-decon-filter_easy-r1\n",
      "correct: 23504\n",
      "total: 37815\n",
      "accuracy: 0.6215522940632024\n"
     ]
    }
   ],
   "source": [
    "def eval_correctness(path, split, prefix):\n",
    "    dataset = datasets.load_dataset(path,split=split)\n",
    "\n",
    "    mapped_dataset = dataset.map(partial(eval_sample, prefix=prefix), keep_in_memory=True, remove_columns=dataset.column_names)\n",
    "    return mapped_dataset\n",
    "\n",
    "path_list = [\"anonymous/m196k-dedup-decon-filter_easy-r1\"]\n",
    "prefix_list = [\"r1-\"]\n",
    "split=\"train\"\n",
    "\n",
    "eval_correctness_list = []\n",
    "\n",
    "for path, prefix in zip(path_list, prefix_list):\n",
    "\n",
    "    mapped_dataset = eval_correctness(path, split, prefix)\n",
    "    eval_correctness_list.append(mapped_dataset)\n",
    "\n",
    "    correct_list=mapped_dataset[f\"{prefix}correct\"]\n",
    "    print(f\"path: {path}\")\n",
    "    print(f\"correct: {sum(correct_list)}\")\n",
    "    print(f\"total: {len(correct_list)}\")\n",
    "    print(f\"accuracy: {sum(correct_list)/len(correct_list)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "path: anonymous/m196k-dedup-decon-filter_easy-r1-250311\n",
    "correct: 23494\n",
    "total: 37786\n",
    "accuracy: 0.6217646747472609\n",
    "\n",
    "path: anonymous/m196k-dedup-decon-filter_easy-r1\n",
    "correct: 23504\n",
    "total: 37815\n",
    "accuracy: 0.6215522940632024\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path=\"anonymous/m196k-dedup-decon-filter_easy-r1\"\n",
    "split=\"train\"\n",
    "\n",
    "raw_dataset = datasets.load_dataset(path,split=split)\n",
    "merged_dataset = datasets.concatenate_datasets([raw_dataset] + eval_correctness_list, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filter: 100%|██████████| 37815/37815 [00:00<00:00, 38039.77 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def filter_false(x, prefix_list=[\"r1-\"]):\n",
    "    for prefix in prefix_list:\n",
    "        if x[f\"{prefix}correct\"]:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "filtered_dataset = merged_dataset.filter(filter_false)\n",
    "filtered_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating parquet from Arrow format: 100%|██████████| 24/24 [00:02<00:00,  8.19ba/s]\n",
      "Uploading the dataset shards: 100%|██████████| 1/1 [00:08<00:00,  8.29s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/anonymous/m196k-dedup-decon-filter_easy-r1-filter_wrong/commit/5102cab42552c317e15aa7bd549195e4f936a5c2', commit_message='Upload dataset', commit_description='', oid='5102cab42552c317e15aa7bd549195e4f936a5c2', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/anonymous/m196k-dedup-decon-filter_easy-r1-filter_wrong', endpoint='https://huggingface.co', repo_type='dataset', repo_id='anonymous/m196k-dedup-decon-filter_easy-r1-filter_wrong'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filtered_dataset.push_to_hub(\"anonymous/m196k-dedup-decon-filter_easy-r1-filter_wrong\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "anonymous-med_sipf-1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
