{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f8315bf5",
   "metadata": {},
   "source": [
    "# MATH dataset preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d0f9c9d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import json \n",
    "import pandas as pd \n",
    "\n",
    "def last_boxed_only_string(string):\n",
    "    idx = string.rfind(\"\\\\boxed\")\n",
    "    if idx < 0:\n",
    "        idx = string.rfind(\"\\\\fbox\")\n",
    "        if idx < 0:\n",
    "            return None\n",
    "\n",
    "    i = idx\n",
    "    right_brace_idx = None\n",
    "    num_left_braces_open = 0\n",
    "    while i < len(string):\n",
    "        if string[i] == \"{\":\n",
    "            num_left_braces_open += 1\n",
    "        if string[i] == \"}\":\n",
    "            num_left_braces_open -= 1\n",
    "            if num_left_braces_open == 0:\n",
    "                right_brace_idx = i\n",
    "                break\n",
    "        i += 1\n",
    "    \n",
    "    if right_brace_idx == None:\n",
    "        retval = None\n",
    "    else:\n",
    "        retval = string[idx:right_brace_idx + 1]\n",
    "    \n",
    "    return retval\n",
    "\n",
    "\n",
    "def remove_boxed(s):\n",
    "    left = \"\\\\boxed{\"\n",
    "    try:\n",
    "        assert s[:len(left)] == left\n",
    "        assert s[-1] == \"}\"\n",
    "        return s[len(left):-1]\n",
    "    except:\n",
    "        return None\n",
    "\n",
    "\n",
    "math_train_categories = os.listdir(\"data/raw/MATH/train\") \n",
    "math_test_categories = os.listdir(\"data/raw/MATH/test\")\n",
    "\n",
    "\n",
    "def concat_all_data(categories, split):\n",
    "    \n",
    "    questions = []\n",
    "    answers = []\n",
    "    solutions = []\n",
    "    types = []\n",
    "    \n",
    "    for c in categories:\n",
    "        base_path = f\"data/raw/MATH/{split}/{c}\"\n",
    "        temp_listdir = os.listdir(base_path)\n",
    "        for f in temp_listdir:\n",
    "            if f.endswith(\".json\"):\n",
    "                with open(os.path.join(base_path, f), \"r\") as file:\n",
    "                    data = json.load(file)\n",
    "            questions.append(data[\"problem\"])\n",
    "            answers.append(remove_boxed(last_boxed_only_string(data[\"solution\"])))\n",
    "            solutions.append(data[\"solution\"])\n",
    "            types.append(data[\"type\"])\n",
    "            \n",
    "    return pd.DataFrame({\n",
    "        \"question\": questions,\n",
    "        \"true_answer\": answers,\n",
    "        \"solution\": solutions,\n",
    "        \"type\": types\n",
    "    })\n",
    "    \n",
    "    \n",
    "math_train_df = concat_all_data(math_train_categories, 'train')\n",
    "math_test_df = concat_all_data(math_test_categories, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5b10571a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question</th>\n",
       "      <th>true_answer</th>\n",
       "      <th>solution</th>\n",
       "      <th>type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What is the value of $9^3 + 3(9^2) + 3(9) + 1$?</td>\n",
       "      <td>1000</td>\n",
       "      <td>The given expression is the expansion of $(9+1...</td>\n",
       "      <td>Counting &amp; Probability</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Tom has a red marble, a green marble, a blue m...</td>\n",
       "      <td>7</td>\n",
       "      <td>There are two cases here: either Tom chooses t...</td>\n",
       "      <td>Counting &amp; Probability</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>If $m$ and $n$ are odd integers, how many term...</td>\n",
       "      <td>4</td>\n",
       "      <td>By the binomial theorem, $(m+n)^6$ expands as ...</td>\n",
       "      <td>Counting &amp; Probability</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>If three people are selected at random from a ...</td>\n",
       "      <td>\\frac{17}{24}</td>\n",
       "      <td>We can find the probability that no women are ...</td>\n",
       "      <td>Counting &amp; Probability</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            question    true_answer  \\\n",
       "0    What is the value of $9^3 + 3(9^2) + 3(9) + 1$?           1000   \n",
       "1  Tom has a red marble, a green marble, a blue m...              7   \n",
       "2  If $m$ and $n$ are odd integers, how many term...              4   \n",
       "3  If three people are selected at random from a ...  \\frac{17}{24}   \n",
       "\n",
       "                                            solution                    type  \n",
       "0  The given expression is the expansion of $(9+1...  Counting & Probability  \n",
       "1  There are two cases here: either Tom chooses t...  Counting & Probability  \n",
       "2  By the binomial theorem, $(m+n)^6$ expands as ...  Counting & Probability  \n",
       "3  We can find the probability that no women are ...  Counting & Probability  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "math_train_df.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3527bb53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question</th>\n",
       "      <th>true_answer</th>\n",
       "      <th>solution</th>\n",
       "      <th>type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>A positive multiple of 45 less than 1000 is ra...</td>\n",
       "      <td>\\frac{1}{11}</td>\n",
       "      <td>The positive multiples of 45 are  \\[45,90,135,...</td>\n",
       "      <td>Number Theory</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Kirsty needs to hire a plumber to fix her hous...</td>\n",
       "      <td>499</td>\n",
       "      <td>For every hour of labor, $242_5=2\\cdot5^2+4\\cd...</td>\n",
       "      <td>Number Theory</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Find the positive base $b$ in which the equati...</td>\n",
       "      <td>6</td>\n",
       "      <td>When we rewrite the above equation with the ba...</td>\n",
       "      <td>Number Theory</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>A bus comes by Jerry's bus stop every 20 minut...</td>\n",
       "      <td>18</td>\n",
       "      <td>Since 20 minutes evenly divides 60 minutes (wh...</td>\n",
       "      <td>Number Theory</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            question   true_answer  \\\n",
       "0  A positive multiple of 45 less than 1000 is ra...  \\frac{1}{11}   \n",
       "1  Kirsty needs to hire a plumber to fix her hous...           499   \n",
       "2  Find the positive base $b$ in which the equati...             6   \n",
       "3  A bus comes by Jerry's bus stop every 20 minut...            18   \n",
       "\n",
       "                                            solution           type  \n",
       "0  The positive multiples of 45 are  \\[45,90,135,...  Number Theory  \n",
       "1  For every hour of labor, $242_5=2\\cdot5^2+4\\cd...  Number Theory  \n",
       "2  When we rewrite the above equation with the ba...  Number Theory  \n",
       "3  Since 20 minutes evenly divides 60 minutes (wh...  Number Theory  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "math_test_df.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7e22cb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "math_train_df.to_csv(\"./data/processed/math_train.csv\", index=False)\n",
    "math_test_df.to_csv(\"./data/processed/math_test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "612464c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "\n",
    "math_data_check = pd.read_csv(\"./data/processed/math_train.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e06f8b4f",
   "metadata": {},
   "source": [
    "# ARC-challenge dataset preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a86c8e73",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/home/chaeyun-jang/.conda/envs/llm_ft/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"allenai/ai2_arc\", \"ARC-Challenge\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dea03cd6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'question', 'choices', 'answerKey'],\n",
       "        num_rows: 1119\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'question', 'choices', 'answerKey'],\n",
       "        num_rows: 1172\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'question', 'choices', 'answerKey'],\n",
       "        num_rows: 299\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2351fad2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
      "A\n",
      "['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
      "['A', 'B', 'C', 'D']\n"
     ]
    }
   ],
   "source": [
    "print(ds['train']['question'][0])\n",
    "print(ds['train']['answerKey'][0])\n",
    "print(ds['train']['choices'][0]['text'])\n",
    "print(ds['train']['choices'][0]['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "93d3322a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def processing_arc(ds):\n",
    "    questions = []\n",
    "    true_answers = []\n",
    "    true_answers_texts = []\n",
    "    for i in range(len(ds['question'])):\n",
    "        true_answers.append(ds['answerKey'][i])\n",
    "        temp_q = ds['question'][i]\n",
    "        for t, l in zip(ds['choices'][i]['text'], ds['choices'][i]['label']):\n",
    "            temp_q += f\"\\n{l}. {t}\"\n",
    "            if l == ds['answerKey'][i]:\n",
    "                true_answers_texts.append(t)\n",
    "        questions.append(temp_q)\n",
    "    return pd.DataFrame({\n",
    "        \"question\": questions,\n",
    "        \"true_answer_texts\": true_answers_texts,\n",
    "        \"true_answer\": true_answers,\n",
    "    })\n",
    "\n",
    "arc_train_df = processing_arc(ds['train'])\n",
    "arc_val_df = processing_arc(ds['validation'])\n",
    "arc_test_df = processing_arc(ds['test'])    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "09746879",
   "metadata": {},
   "outputs": [],
   "source": [
    "arc_train_df.to_csv(\"./data/processed/arc_challenge_train.csv\", index=False)\n",
    "arc_val_df.to_csv(\"./data/processed/arc_challenge_val.csv\", index=False)\n",
    "arc_test_df.to_csv(\"./data/processed/arc_challenge_test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea446341",
   "metadata": {},
   "source": [
    "# Hellaswag dataset preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4bd25501",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"Rowan/hellaswag\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "03672296",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import datasets\n",
    "\n",
    "def preprocess(text: str) -> str:\n",
    "    text = text.strip()\n",
    "    text = text.replace(\" [title]\", \". \")\n",
    "    text = re.sub(r\"\\[.*?\\]\", \"\", text)\n",
    "    return re.sub(r\" +\", \" \", text)\n",
    "\n",
    "def process_hellaswag(dataset: datasets.Dataset,\n",
    "                      type: str = \"train\") -> datasets.Dataset:\n",
    "    base_prompt = \"\"\"Pick the best ending to the context.\n",
    "\n",
    "Context:\n",
    "{context}\n",
    "\n",
    "Choices:\n",
    "A. {c0}  \n",
    "B. {c1}  \n",
    "C. {c2}  \n",
    "D. {c3}\n",
    "\"\"\"\n",
    "    choice_labels = [\"A\", \"B\", \"C\", \"D\"]\n",
    "\n",
    "    def _process(doc):\n",
    "        context = preprocess(doc[\"activity_label\"] + \": \" + doc[\"ctx_a\"] + \" \" + doc[\"ctx_b\"].capitalize())\n",
    "        choices = [preprocess(ending) for ending in doc[\"endings\"]]\n",
    "        prompt = base_prompt.format(\n",
    "            context=context,\n",
    "            c0=choices[0],\n",
    "            c1=choices[1],\n",
    "            c2=choices[2],\n",
    "            c3=choices[3],\n",
    "        )\n",
    "        if type != \"test\":\n",
    "            gold_idx = int(doc[\"label\"])\n",
    "            return {\n",
    "                \"question\": prompt,\n",
    "                \"answer\": choice_labels[gold_idx],\n",
    "                \"answer_text\": choices[gold_idx],\n",
    "            }\n",
    "            \n",
    "        else:\n",
    "            return {\"question\": prompt}\n",
    "        \n",
    "    return dataset.map(_process, remove_columns=dataset.column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fb09134d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 39905/39905 [00:03<00:00, 11602.87 examples/s]\n",
      "Map: 100%|██████████| 10042/10042 [00:00<00:00, 11849.90 examples/s]\n",
      "Map: 100%|██████████| 10003/10003 [00:00<00:00, 12267.95 examples/s]\n",
      "Creating CSV from Arrow format: 100%|██████████| 40/40 [00:00<00:00, 57.28ba/s]\n",
      "Creating CSV from Arrow format: 100%|██████████| 11/11 [00:00<00:00, 58.17ba/s]\n",
      "Creating CSV from Arrow format: 100%|██████████| 11/11 [00:00<00:00, 89.34ba/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "8253171"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "train_ds = process_hellaswag(ds[\"train\"])\n",
    "val_ds = process_hellaswag(ds[\"validation\"])\n",
    "test_ds = process_hellaswag(ds[\"test\"], type=\"test\")\n",
    "\n",
    "train_ds.to_csv(\"./data/processed/hellaswag_train.csv\", index=False)\n",
    "val_ds.to_csv(\"./data/processed/hellaswag_val.csv\", index=False)\n",
    "test_ds.to_csv(\"./data/processed/hellaswag_test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17a44d10",
   "metadata": {},
   "source": [
    "# MMLU dataset preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "015c1898",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd \n",
    "from datasets import Dataset\n",
    "\n",
    "def process_mmlu(data_dir):\n",
    "    data_paths = os.listdir(data_dir)\n",
    "\n",
    "    data = []\n",
    "    for p in data_paths:\n",
    "        temp_data = pd.read_csv(os.path.join(data_dir, p), header=None)\n",
    "        temp_data.columns = ['question', 'A', 'B', 'C', 'D', 'true_answer']\n",
    "        answer = list(temp_data['true_answer'])\n",
    "        subject = p.split('.')[0].replace('_', ' ').replace(' test', '')\n",
    "        base_p = f\"The following are multiple choice questions (with answers) about {subject}.\\n\"\n",
    "        question = [base_p + q + f\"\\nA. {a}\" + f\"\\nB. {b}\" + f\"\\nC. {c}\" + f\"\\nD. {d}\\n\" for q, a, b, c, d in zip(temp_data['question'], temp_data['A'],\n",
    "                    temp_data['B'], temp_data['C'], temp_data['D'])]\n",
    "        gt_data = [{'question': question[i], 'answer': answer[i]} for i in range(len(answer))]\n",
    "        data.extend(gt_data)\n",
    "        \n",
    "    data = Dataset.from_list(data)\n",
    "    return data\n",
    "\n",
    "mmlu_train = process_mmlu(\"data/raw/MMLU/auxiliary_train\")\n",
    "mmlu_val = process_mmlu(\"data/raw/MMLU/val\")\n",
    "mmlu_test = process_mmlu(\"data/raw/MMLU/test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "666fa9c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99842 1531 14042\n"
     ]
    }
   ],
   "source": [
    "print(len(mmlu_train), len(mmlu_val), len(mmlu_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e4db1f29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following are multiple choice questions (with answers) about arc easy.\n",
      "Which factor will most likely cause a person to develop a fever?\n",
      "A. a leg muscle relaxing after exercise\n",
      "B. a bacterial population in the bloodstream\n",
      "C. several viral particles on the skin\n",
      "D. carbohydrates being digested in the stomach\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mmlu_train['question'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "13bcd87f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Creating CSV from Arrow format: 100%|██████████| 100/100 [00:01<00:00, 50.82ba/s]\n",
      "Creating CSV from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 151.65ba/s]\n",
      "Creating CSV from Arrow format: 100%|██████████| 15/15 [00:00<00:00, 136.07ba/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "7716032"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mmlu_train.to_csv(\"./data/processed/mmlu_train.csv\", index=False)\n",
    "mmlu_val.to_csv(\"./data/processed/mmlu_val.csv\", index=False)\n",
    "mmlu_test.to_csv(\"./data/processed/mmlu_test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3b5126a",
   "metadata": {},
   "source": [
    "## Make training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6f4ed2c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd \n",
    "\n",
    "# llama3.1-8B\n",
    "# parsing need task\n",
    "llama_3_1_gsm_csv = pd.concat([pd.read_csv(f\"./logs/llama3.1/llama8_gsm_seed_{i}/gsm_train.csv\") for i in range(10)]).reset_index(drop=False)\n",
    "llama_3_1_math_csv = pd.concat([pd.read_csv(f\"./logs/llama3.1/llama8_math_seed_{i}/math_train.csv\") for i in range(10)]).reset_index(drop=False)\n",
    "llama_3_1_gsm_base = pd.read_csv(\"./logs/llama3.1/llama8_gsm_zs/gsm_train.csv\")\n",
    "llama_3_1_math_base = pd.read_csv(\"./logs/llama3.1/llama8_mmlu_zs/mmlu_train.csv\")\n",
    "\n",
    "# multiple choice tasks\n",
    "llama_3_1_arc_csv = [pd.read_csv(f\"./logs/llama3.1/llama8_arc_seed_{i}/arc_train.csv\") for i in range(10)]\n",
    "llama_3_1_hellaswag_csv = [pd.read_csv(f\"./logs/llama3.1/llama8_hellaswag_seed_{i}/hellaswag_train.csv\") for i in range(10)]\n",
    "llama_3_1_mmlu_csv = [pd.read_csv(f\"./logs/llama3.1/llama8_mmlu_seed_{i}/mmlu_train.csv\") for i in range(10)]\n",
    "llama_3_1_arc_base = pd.read_csv(\"./logs/llama3.1/llama8_arc_train_base/arc_train.csv\")\n",
    "llama_3_1_hellaswag_base = pd.read_csv(\"./logs/llama3.1/llama8_hellaswag_zs/hellaswag_train.csv\")\n",
    "llama_3_1_mmlu_base = pd.read_csv(\"./logs/llama3.1/llama8_mmlu_zs/mmlu_train.csv\")  \n",
    "\n",
    "# llama3.2-3B\n",
    "# parsing need task \n",
    "llama_3_2_gsm_csv = pd.concat([pd.read_csv(f\"./logs/llama3.2/llama3_gsm_seed_{i}/gsm_train.csv\") for i in range(10)]).reset_index(drop=False)\n",
    "llama_3_2_math_csv = pd.concat([pd.read_csv(f\"./logs/llama3.2/llama3_math_seed_{i}/math_train.csv\") for i in range(10)]).reset_index(drop=False)\n",
    "llama_3_2_gsm_base = pd.read_csv(\"./logs/llama3.2/llama3_gsm_zs/gsm_train.csv\")\n",
    "llama_3_2_math_base = pd.read_csv(\"./logs/llama3.2/llama3_mmlu_zs/mmlu_train.csv\")  \n",
    "\n",
    "# multiple choice tasks\n",
    "llama_3_2_arc_csv = [pd.read_csv(f\"./logs/llama3.2/llama_arc_seed_{i}/arc_train.csv\") for i in range(10)]\n",
    "llama_3_2_hellaswag_csv = [pd.read_csv(f\"./logs/llama3.2/llama3_hellaswag_seed_{i}/hellaswag_train.csv\") for i in range(10)]\n",
    "llama_3_2_mmlu_csv = [pd.read_csv(f\"./logs/llama3.2/llama3_mmlu_seed_{i}/mmlu_train.csv\") for i in range(10)]    \n",
    "llama_3_2_arc_base = pd.read_csv(\"./logs/llama3.2/llama3_arc_train_base/arc_train.csv\")\n",
    "llama_3_2_hellaswag_base = pd.read_csv(\"./logs/llama3.2/llama3_hellaswag_zs/hellaswag_train.csv\")\n",
    "llama_3_2_mmlu_base = pd.read_csv(\"./logs/llama3.2/llama3_mmlu_zs/mmlu_train.csv\")\n",
    "\n",
    "# qwen3-8B\n",
    "# parsing need task\n",
    "qwen_3_8_math_csv = pd.concat([pd.read_csv(f\"./logs/qwen3-8/qwen38_math_seed_{i}/math_train.csv\") for i in range(10)]).reset_index(drop=False)\n",
    "qwen_3_8_math_base = pd.read_csv(\"./logs/qwen3-8/qwen38_math_zs/math_train.csv\")\n",
    "\n",
    "# multiple choice tasks\n",
    "qwen_3_8_arc_csv = [pd.read_csv(f\"./logs/qwen3-8/qwen38_arc_seed_{i}/arc_train.csv\") for i in range(10)]\n",
    "qwen_3_8_hellaswag_csv = [pd.read_csv(f\"./logs/qwen3-8/qwen38_hellaswag_seed_{i}/hellaswag_train.csv\") for i in range(10)]\n",
    "qwen_3_8_mmlu_csv = [pd.read_csv(f\"./logs/qwen3-8/qwen38_mmlu_seed_{i}/mmlu_train.csv\") for i in range(10)]\n",
    "qwen_3_8_arc_base = pd.read_csv(\"./logs/qwen3-8/qwen38_arc_train_base/arc_train.csv\")\n",
    "qwen_3_8_hellaswag_base = pd.read_csv(\"./logs/qwen3-8/qwen38_hellaswag_zs/hellaswag_train.csv\")\n",
    "qwen_3_8_mmlu_base = pd.read_csv(\"./logs/qwen3-8/qwen38_mmlu_zs/mmlu_train.csv\")\n",
    "\n",
    "# qwen3-4B\n",
    "# parsing need task\n",
    "qwen_3_4_math_csv = pd.concat([pd.read_csv(f\"./logs/qwen3-4/qwen34_math_seed_{i}/math_train.csv\") for i in range(10)]).reset_index(drop=False)\n",
    "qwen_3_4_math_base = pd.read_csv(\"./logs/qwen3-4/qwen34_math_zs/math_train.csv\")  \n",
    "\n",
    "# multiple choice tasks\n",
    "qwen_3_4_arc_csv = [pd.read_csv(f\"./logs/qwen3-4/qwen34_arc_seed_{i}/arc_train.csv\") for i in range(10)]\n",
    "qwen_3_4_hellaswag_csv = [pd.read_csv(f\"./logs/qwen3-4/qwen34_hellaswag_seed_{i}/hellaswag_train.csv\") for i in range(10)]\n",
    "qwen_3_4_mmlu_csv = [pd.read_csv(f\"./logs/qwen3-4/qwen34_mmlu_seed_{i}/mmlu_train.csv\") for i in range(10)]\n",
    "qwen_3_4_arc_base = pd.read_csv(\"./logs/qwen3-4/qwen34_arc_train_base/arc_train.csv\")\n",
    "qwen_3_4_hellaswag_base = pd.read_csv(\"./logs/qwen3-4/qwen34_hellaswag_zs/hellaswag_train.csv\")\n",
    "qwen_3_4_mmlu_base = pd.read_csv(\"./logs/qwen3-4/qwen34_mmlu_zs/mmlu_train.csv\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f28ccacd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#llama_3_1_gsm_csv.to_csv(\"./logs/llama3.1/train_gsm.csv\", index=False)\n",
    "#llama_3_1_math_csv.to_csv(\"./logs/llama3.1/train_math.csv\", index=False)\n",
    "\n",
    "#llama_3_2_gsm_csv.to_csv(\"./logs/llama3.2/train_gsm.csv\", index=False)\n",
    "#llama_3_2_math_csv.to_csv(\"./logs/llama3.2/train_math.csv\", index=False)    \n",
    "\n",
    "#qwen_3_8_math_csv.to_csv(\"./logs/qwen3-8/train_math.csv\", index=False)\n",
    "#qwen_3_4_math_csv.to_csv(\"./logs/qwen3-4/train_math.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b988a32b",
   "metadata": {},
   "source": [
    "## No parsing datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "744f72f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re \n",
    "from src.eval_utils import (extract_hash_answer, \n",
    "                            extract_xml_answer, \n",
    "                            extract_number_xml_confidence)\n",
    "\n",
    "def extract_first_uppercase_char(s):\n",
    "    match = re.search(r'[A-Z]', s)\n",
    "    return match.group(0) if match else ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "25d2b72c",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_3_1_arc_labels = []\n",
    "for i in range(len(llama_3_1_arc_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(llama_3_1_arc_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == llama_3_1_arc_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    llama_3_1_arc_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "\n",
    "llama_3_1_hellaswag_labels = []\n",
    "for i in range(len(llama_3_1_hellaswag_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(llama_3_1_hellaswag_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == llama_3_1_hellaswag_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    llama_3_1_hellaswag_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "\n",
    "llama_3_1_mmlu_labels = []\n",
    "for i in range(len(llama_3_1_mmlu_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(llama_3_1_mmlu_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == llama_3_1_mmlu_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    llama_3_1_mmlu_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "    \n",
    "llama_3_2_arc_labels = []\n",
    "for i in range(len(llama_3_2_arc_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(llama_3_2_arc_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == llama_3_2_arc_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    llama_3_2_arc_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "\n",
    "llama_3_2_hellaswag_labels = []\n",
    "for i in range(len(llama_3_2_hellaswag_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(llama_3_2_hellaswag_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == llama_3_2_hellaswag_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    llama_3_2_hellaswag_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "\n",
    "llama_3_2_mmlu_labels = []\n",
    "for i in range(len(llama_3_2_mmlu_csv[0])): \n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(llama_3_2_mmlu_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == llama_3_2_mmlu_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    llama_3_2_mmlu_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "    \n",
    "qwen_3_8_arc_labels = []\n",
    "for i in range(len(qwen_3_8_arc_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(qwen_3_8_arc_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == qwen_3_8_arc_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    qwen_3_8_arc_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "\n",
    "qwen_3_8_hellaswag_labels = []\n",
    "for i in range(len(qwen_3_8_hellaswag_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(qwen_3_8_hellaswag_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == qwen_3_8_hellaswag_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    qwen_3_8_hellaswag_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "\n",
    "qwen_3_8_mmlu_labels = []\n",
    "for i in range(len(qwen_3_8_mmlu_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(qwen_3_8_mmlu_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == qwen_3_8_mmlu_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    qwen_3_8_mmlu_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "    \n",
    "qwen_3_4_arc_labels = []\n",
    "for i in range(len(qwen_3_8_arc_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(qwen_3_8_arc_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == qwen_3_8_arc_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    qwen_3_4_arc_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "    \n",
    "qwen_3_4_hellaswag_labels = []\n",
    "for i in range(len(qwen_3_8_hellaswag_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(qwen_3_8_hellaswag_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == qwen_3_8_hellaswag_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    qwen_3_4_hellaswag_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")\n",
    "    \n",
    "qwen_3_4_mmlu_labels = []\n",
    "for i in range(len(qwen_3_8_mmlu_csv[0])):\n",
    "    temp_answers = [extract_first_uppercase_char(extract_xml_answer(qwen_3_8_mmlu_csv[j]['pred_answer'][i])) for j in range(10)]\n",
    "    temp_tf = [1 if a == qwen_3_8_mmlu_csv[0]['true_answer'][i] else 0 for a in temp_answers]\n",
    "    qwen_3_4_mmlu_labels.append(\"<confidence>\" + str(int(sum(temp_tf)*10)) + \"</confidence>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72b8d84d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from transformers import set_seed\n",
    "\n",
    "# set_seed(42)\n",
    "\n",
    "# llama_3_1_arc_base['conf_label'] = llama_3_1_arc_labels\n",
    "# llama_3_1_hellaswag_base['conf_label'] = llama_3_1_hellaswag_labels\n",
    "# llama_3_1_mmlu_base['conf_label'] = llama_3_1_mmlu_labels\n",
    "\n",
    "# llama_3_2_arc_base['conf_label'] = llama_3_2_arc_labels\n",
    "# llama_3_2_hellaswag_base['conf_label'] = llama_3_2_hellaswag_labels\n",
    "# llama_3_2_mmlu_base['conf_label'] = llama_3_2_mmlu_labels\n",
    "\n",
    "# qwen_3_8_arc_base['conf_label'] = qwen_3_8_arc_labels\n",
    "# qwen_3_8_hellaswag_base['conf_label'] = qwen_3_8_hellaswag_labels\n",
    "# qwen_3_8_mmlu_base['conf_label'] = qwen_3_8_mmlu_labels\n",
    "\n",
    "# qwen_3_4_arc_base['conf_label'] = qwen_3_4_arc_labels\n",
    "# qwen_3_4_hellaswag_base['conf_label'] = qwen_3_4_hellaswag_labels\n",
    "# qwen_3_4_mmlu_base['conf_label'] = qwen_3_4_mmlu_labels \n",
    "\n",
    "# llama_3_1_arc_base = llama_3_1_arc_base.sample(frac=1).reset_index(drop=True)\n",
    "# llama_3_1_hellaswag_base = llama_3_1_hellaswag_base.sample(frac=1).reset_index(drop=True)\n",
    "# llama_3_1_mmlu_base = llama_3_1_mmlu_base.sample(frac=1).reset_index(drop=True)\n",
    "\n",
    "# llama_3_2_arc_base = llama_3_2_arc_base.sample(frac=1).reset_index(drop=True)\n",
    "# llama_3_2_hellaswag_base = llama_3_2_hellaswag_base.sample(frac=1).reset_index(drop=True)\n",
    "# llama_3_2_mmlu_base = llama_3_2_mmlu_base.sample(frac=1).reset_index(drop=True)\n",
    "\n",
    "# qwen_3_8_arc_base = qwen_3_8_arc_base.sample(frac=1).reset_index(drop=True)\n",
    "# qwen_3_8_hellaswag_base = qwen_3_8_hellaswag_base.sample(frac=1).reset_index(drop=True)\n",
    "# qwen_3_8_mmlu_base = qwen_3_8_mmlu_base.sample(frac=1).reset_index(drop=True)\n",
    "\n",
    "# qwen_3_4_arc_base = qwen_3_4_arc_base.sample(frac=1).reset_index(drop=True)\n",
    "# qwen_3_4_hellaswag_base = qwen_3_4_hellaswag_base.sample(frac=1).reset_index(drop=True)\n",
    "# qwen_3_4_mmlu_base = qwen_3_4_mmlu_base.sample(frac=1).reset_index(drop=True)   \n",
    "\n",
    "# llama_3_1_arc_train = llama_3_1_arc_base[:int(len(llama_3_1_arc_base)*0.8)]\n",
    "# llama_3_1_arc_valid = llama_3_1_arc_base[int(len(llama_3_1_arc_base)*0.8):]\n",
    "# llama_3_1_hellaswag_train = llama_3_1_hellaswag_base[:int(len(llama_3_1_hellaswag_base)*0.8)]\n",
    "# llama_3_1_hellaswag_valid = llama_3_1_hellaswag_base[int(len(llama_3_1_hellaswag_base)*0.8):]\n",
    "# llama_3_1_mmlu_train = llama_3_1_mmlu_base[:int(len(llama_3_1_mmlu_base)*0.8)]\n",
    "# llama_3_1_mmlu_valid = llama_3_1_mmlu_base[int(len(llama_3_1_mmlu_base)*0.8):]\n",
    "\n",
    "# llama_3_2_arc_train = llama_3_2_arc_base[:int(len(llama_3_2_arc_base)*0.8)]\n",
    "# llama_3_2_arc_valid = llama_3_2_arc_base[int(len(llama_3_2_arc_base)*0.8):]\n",
    "# llama_3_2_hellaswag_train = llama_3_2_hellaswag_base[:int(len(llama_3_2_hellaswag_base)*0.8)]\n",
    "# llama_3_2_hellaswag_valid = llama_3_2_hellaswag_base[int(len(llama_3_2_hellaswag_base)*0.8):]\n",
    "# llama_3_2_mmlu_train = llama_3_2_mmlu_base[:int(len(llama_3_2_mmlu_base)*0.8)]  \n",
    "# llama_3_2_mmlu_valid = llama_3_2_mmlu_base[int(len(llama_3_2_mmlu_base)*0.8):]\n",
    "\n",
    "# qwen_3_8_arc_train = qwen_3_8_arc_base[:int(len(qwen_3_8_arc_base)*0.8)]\n",
    "# qwen_3_8_arc_valid = qwen_3_8_arc_base[int(len(qwen_3_8_arc_base)*0.8):]\n",
    "# qwen_3_8_hellaswag_train = qwen_3_8_hellaswag_base[:int(len(qwen_3_8_hellaswag_base)*0.8)]\n",
    "# qwen_3_8_hellaswag_valid = qwen_3_8_hellaswag_base[int(len(qwen_3_8_hellaswag_base)*0.8):]\n",
    "# qwen_3_8_mmlu_train = qwen_3_8_mmlu_base[:int(len(qwen_3_8_mmlu_base)*0.8)]\n",
    "# qwen_3_8_mmlu_valid = qwen_3_8_mmlu_base[int(len(qwen_3_8_mmlu_base)*0.8):]\n",
    "\n",
    "# qwen_3_4_arc_train = qwen_3_4_arc_base[:int(len(qwen_3_4_arc_base)*0.8)]\n",
    "# qwen_3_4_arc_valid = qwen_3_4_arc_base[int(len(qwen_3_4_arc_base)*0.8):]\n",
    "# qwen_3_4_hellaswag_train = qwen_3_4_hellaswag_base[:int(len(qwen_3_4_hellaswag_base)*0.8)]\n",
    "# qwen_3_4_hellaswag_valid = qwen_3_4_hellaswag_base[int(len(qwen_3_4_hellaswag_base)*0.8):]\n",
    "# qwen_3_4_mmlu_train = qwen_3_4_mmlu_base[:int(len(qwen_3_4_mmlu_base)*0.8)]\n",
    "# qwen_3_4_mmlu_valid = qwen_3_4_mmlu_base[int(len(qwen_3_4_mmlu_base)*0.8):]\n",
    "\n",
    "# llama_3_1_arc_train.to_csv(\"./data/train_data/Llama-3.1-8B-Instruct/arc/train.csv\", index=False)\n",
    "# llama_3_1_hellaswag_train.to_csv(\"./data/train_data/Llama-3.1-8B-Instruct/hellaswag/train.csv\", index=False)\n",
    "# llama_3_1_mmlu_train.to_csv(\"./data/train_data/Llama-3.1-8B-Instruct/mmlu/train.csv\", index=False)\n",
    "\n",
    "# llama_3_1_arc_valid.to_csv(\"./data/train_data/Llama-3.1-8B-Instruct/arc/valid.csv\", index=False)\n",
    "# llama_3_1_hellaswag_valid.to_csv(\"./data/train_data/Llama-3.1-8B-Instruct/hellaswag/valid.csv\", index=False)    \n",
    "# llama_3_1_mmlu_valid.to_csv(\"./data/train_data/Llama-3.1-8B-Instruct/mmlu/valid.csv\", index=False)  \n",
    "\n",
    "# llama_3_2_arc_train.to_csv(\"./data/train_data/Llama-3.2-3B-Instruct/arc/train.csv\", index=False)\n",
    "# llama_3_2_hellaswag_train.to_csv(\"./data/train_data/Llama-3.2-3B-Instruct/hellaswag/train.csv\", index=False)\n",
    "# llama_3_2_mmlu_train.to_csv(\"./data/train_data/Llama-3.2-3B-Instruct/mmlu/train.csv\", index=False)\n",
    "\n",
    "# llama_3_2_arc_valid.to_csv(\"./data/train_data/Llama-3.2-3B-Instruct/arc/valid.csv\", index=False)\n",
    "# llama_3_2_hellaswag_valid.to_csv(\"./data/train_data/Llama-3.2-3B-Instruct/hellaswag/valid.csv\", index=False)\n",
    "# llama_3_2_mmlu_valid.to_csv(\"./data/train_data/Llama-3.2-3B-Instruct/mmlu/valid.csv\", index=False)\n",
    "\n",
    "# qwen_3_8_arc_train.to_csv(\"./data/train_data/Qwen-3-8B/arc/train.csv\", index=False)\n",
    "# qwen_3_8_hellaswag_train.to_csv(\"./data/train_data/Qwen-3-8B/hellaswag/train.csv\", index=False)\n",
    "# qwen_3_8_mmlu_train.to_csv(\"./data/train_data/Qwen-3-8B/mmlu/train.csv\", index=False)\n",
    "\n",
    "# qwen_3_8_arc_valid.to_csv(\"./data/train_data/Qwen-3-8B/arc/valid.csv\", index=False)\n",
    "# qwen_3_8_hellaswag_valid.to_csv(\"./data/train_data/Qwen-3-8B/hellaswag/valid.csv\", index=False) \n",
    "# qwen_3_8_mmlu_valid.to_csv(\"./data/train_data/Qwen-3-8B/mmlu/valid.csv\", index=False)\n",
    "\n",
    "# qwen_3_4_arc_train.to_csv(\"./data/train_data/Qwen-3-4B/arc/train.csv\", index=False)\n",
    "# qwen_3_4_hellaswag_train.to_csv(\"./data/train_data/Qwen-3-4B/hellaswag/train.csv\", index=False)\n",
    "# qwen_3_4_mmlu_train.to_csv(\"./data/train_data/Qwen-3-4B/mmlu/train.csv\", index=False)\n",
    "\n",
    "# qwen_3_4_arc_valid.to_csv(\"./data/train_data/Qwen-3-4B/arc/valid.csv\", index=False)\n",
    "# qwen_3_4_hellaswag_valid.to_csv(\"./data/train_data/Qwen-3-4B/hellaswag/valid.csv\", index=False)\n",
    "# qwen_3_4_mmlu_valid.to_csv(\"./data/train_data/Qwen-3-4B/mmlu/valid.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6d5a904c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "\n",
    "train_data = pd.read_csv(\"/mnt/home/chaeyun-jang/CSFT/data/train_data/Qwen-3-8B/arc/train.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9340abf7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>user\n",
      "This is a conversation between **User** and **Assistant**.\n",
      "The User asks a **multiple-choice question**, and the Assistant provides a solution.\n",
      "Before answering, the Assistant reasons through the problem step-by-step.\n",
      "The reasoning is enclosed within `<think> ... </think>`, and the final answer within `<answer> ... </answer>`.\n",
      "\n",
      "Example:\n",
      "<think>{step-by-step reasoning}</think>\n",
      "<answer>{final answer}</answer>\n",
      "**Only provide the letter (A, B, C, or D) corresponding to the correct choice in the <answer> tag. Do NOT write the full text of the answer.**\n",
      "\n",
      "Now, respond to the following using the **exact same format**:\n",
      "A mouse is homozygous for black fur (BB). The other parent is heterozygous for black fur with a recessive trait for brown fur ( Bb ). If black is the dominant trait, what percentage of the offspring will be brown?\n",
      "A. 100%\n",
      "B. 50%\n",
      "C. 25%\n",
      "D. 0%\n",
      "<|im_end|>\n",
      "<|im_start|>assistant\n",
      "<think>\n",
      "Okay, let's see. The question is about a mouse's fur color inheritance. The first parent is homozygous for black fur, which is BB. The other parent is heterozygous, so their genotype is Bb. Black is dominant over brown.\n",
      "\n",
      "So, to find out the percentage of offspring that will be brown, I need to do a Punnett square. Let me recall how that works. The parents' alleles are combined in all possible ways.\n",
      "\n",
      "The first parent can only contribute a B allele because they are BB. The second parent can contribute either a B or a b allele since they are Bb. So the possible combinations for the offspring are B from the first parent and B from the second parent, making BB, or B from the first and b from the second, making Bb.\n",
      "\n",
      "So the possible genotypes are BB and Bb. Since both of these have at least one dominant B allele, they will both express the black fur phenotype. There's no scenario where the offspring gets two recessive alleles (bb) because the first parent can't contribute a b. Therefore, all offspring will be black, meaning 0% brown. The answer should be D.\n",
      "</think>\n",
      "\n",
      "</think>\n",
      "\n",
      "</think>\n",
      "\n",
      "<answer>D</answer><|im_start|>user\n",
      "\n",
      "Please respond with a score from 0 to 100 in `<confidence> </confidence>` tags. How confident are you in your previous answer? <|im_end|>\n",
      "<|im_start|>assistant\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(train_data['conf_input'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b85f8ea",
   "metadata": {},
   "source": [
    "## Parsed data check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7afbcb08",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2c729e7b",
   "metadata": {},
   "source": [
    "## Special case: Qwen tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2ccb8aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.model_utils import create_tokenizer\n",
    "\n",
    "tokenizer = create_tokenizer(\"Qwen/Qwen3-8B\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1794e2ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6023, 847, 829, 374, 13517, 1195, 359, 27, 81929, 29, 16, 15, 15, 522, 81929, 29]\n",
      "[6023, 847, 829, 374, 13517, 1195, 359]\n",
      "[27, 81929, 29, 16, 15, 15, 522, 81929, 29]\n"
     ]
    }
   ],
   "source": [
    "string = \"hi my name is chaeyun<confidence>100</confidence>\"\n",
    "new_string = \"hi my name is chaeyun\"\n",
    "sub_string = \"<confidence>100</confidence>\"\n",
    "\n",
    "print(tokenizer.encode(string, add_special_tokens=False))\n",
    "print(tokenizer.encode(new_string, add_special_tokens=False))\n",
    "print(tokenizer.encode(sub_string, add_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "312929e6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm_ft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
