{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"RUC-AIBOX/STILL-3-Preview-RL-Data\")\n",
    "# Convert ds into json list of dicotinaries\n",
    "still_ds = ds['train'].map(lambda x: {\"problem\": x[\"question\"], \"answer\": x[\"answer\"],})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15965"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load all jsons in ../test and merge them into one large json\n",
    "import os\n",
    "import json\n",
    "\n",
    "prior_data = []\n",
    "for file in ['aime.json', 'amc.json', 'math.json', 'omni_math.json']:\n",
    "    # get absolute path by merging file wtih directory\n",
    "    file_path = os.path.join(\"../train\", file)\n",
    "    file_path = os.path.abspath(file_path)\n",
    "    with open(file_path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    prior_data.extend(data)\n",
    "    \n",
    "len(prior_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rllm.utils import RAG\n",
    "\n",
    "rag_searcher = RAG(docs=[d[\"problem\"] for d in prior_data])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:   3%|▎         | 1021/29925 [00:07<03:33, 135.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:   7%|▋         | 2015/29925 [00:14<03:27, 134.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  10%|█         | 3023/29925 [00:22<03:18, 135.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  13%|█▎        | 4017/29925 [00:29<03:11, 135.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  17%|█▋        | 5025/29925 [00:37<03:03, 135.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  20%|██        | 6019/29925 [00:44<02:55, 135.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  23%|██▎       | 7027/29925 [00:51<02:48, 135.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  27%|██▋       | 8021/29925 [00:59<02:41, 135.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  30%|███       | 9015/29925 [01:06<02:33, 136.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  33%|███▎      | 10023/29925 [01:14<02:25, 136.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  37%|███▋      | 11017/29925 [01:21<02:18, 136.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  40%|████      | 12025/29925 [01:28<02:10, 137.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  44%|████▎     | 13019/29925 [01:35<02:03, 136.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  47%|████▋     | 14027/29925 [01:43<01:55, 137.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  50%|█████     | 15021/29925 [01:50<01:48, 137.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  54%|█████▎    | 16015/29925 [01:57<01:41, 136.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  57%|█████▋    | 17023/29925 [02:05<01:33, 137.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  60%|██████    | 18017/29925 [02:12<01:26, 138.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  64%|██████▎   | 19028/29925 [02:19<01:18, 139.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  67%|██████▋   | 20016/29925 [02:26<01:11, 139.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  70%|███████   | 21027/29925 [02:34<01:03, 139.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  74%|███████▎  | 22024/29925 [02:41<00:56, 138.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  77%|███████▋  | 23016/29925 [02:48<00:49, 139.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  80%|████████  | 24015/29925 [02:55<00:41, 140.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  84%|████████▎ | 25020/29925 [03:02<00:34, 140.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  87%|████████▋ | 26024/29925 [03:09<00:27, 141.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  90%|█████████ | 27014/29925 [03:16<00:20, 141.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  94%|█████████▎| 28016/29925 [03:24<00:13, 140.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data:  97%|█████████▋| 29021/29925 [03:31<00:06, 140.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering stil data: 100%|██████████| 29925/29925 [03:37<00:00, 137.53it/s]\n"
     ]
    }
   ],
   "source": [
    "# Filter for olympiad problems that are not in the omni dataset\n",
    "from tqdm import tqdm\n",
    "\n",
    "filter_problems = []\n",
    "num_problems = 0\n",
    "\n",
    "counter = 0\n",
    "# Wrap olympiad_data with tqdm, optionally adding a description and total\n",
    "for d in tqdm(still_ds, desc=\"Filtering stil data\", total=len(still_ds)):\n",
    "    search_result = rag_searcher.top_k(d[\"problem\"], k=1)[0]\n",
    "    score = search_result[\"score\"]\n",
    "    if score > 0.90:\n",
    "        num_problems += 1\n",
    "    else:\n",
    "        filter_problems.append(d)\n",
    "    counter += 1\n",
    "    if counter %1000 == 0:\n",
    "        print(counter)\n",
    "# Save final list as json\n",
    "with open(\"still.json\", \"w\") as f:\n",
    "    json.dump(filter_problems, f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "26026"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for d in filter_problems:\n",
    "    del d[\"question\"]\n",
    "    del d['messages']\n",
    "\n",
    "# Save final list as json\n",
    "with open(\"still.json\", \"w\") as f:\n",
    "    json.dump(filter_problems, f, indent=2)\n",
    "len(filter_problems)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
