{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys, re\n",
    "import datasets\n",
    "import numpy as np \n",
    "ROOT = \"/u/audreyh/workspace/test-code\"\n",
    "sys.path.append(os.path.join(ROOT, 'code'))\n",
    "\n",
    "os.environ['XDG_CACHE_HOME'] =  \"/work/hdd/bdkj/audreyh/.cache\"\n",
    "\n",
    "from importlib import import_module\n",
    "from omegaconf import OmegaConf, DictConfig, ListConfig\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "completion = \"les out (D). The only remaining option is (B), as arthropods have an open circulatory system with a dorsal tubular heart. The answer is B.\"\n",
    "ANSWER_IDS = [\"The answer is \"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<re.Match object; span=(122, 137), match='The answer is B'>\n"
     ]
    }
   ],
   "source": [
    "for answer_id in ANSWER_IDS:\n",
    "    ans_re = re.compile(fr\"{answer_id}([A-D])\")\n",
    "    match = ans_re.search(completion)\n",
    "    print(match)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "B\n"
     ]
    }
   ],
   "source": [
    "match_str = match.group(1)\n",
    "print(match_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The answer is ? les out (D). The only remaining option is (B), as arthropods have an open circulatory system with a dorsal tubular heart. The answer is B.\n"
     ]
    }
   ],
   "source": [
    "answer_id = ANSWER_IDS[0]\n",
    "answer = completion.split(answer_id)[-1]\n",
    "print(answer_id,answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['D', 'B', 'B']\n"
     ]
    }
   ],
   "source": [
    "letters = re.compile(\n",
    "      r'\\b[A-D][.\\)]?\\b',\n",
    "      re.MULTILINE | re.DOTALL,\n",
    "  ).findall(completion)\n",
    "print(letters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'name': 'cais/mmlu', 'subset': ['college_chemistry', 'college_mathematics'], 'split': 'test', 'question_field': 'question', 'answer_field': 'answer', 'choices_field': 'choices'}\n"
     ]
    }
   ],
   "source": [
    "config = {\n",
    "    \"name\" : \"cais/mmlu\",\n",
    "    \"subset\" : [\"college_chemistry\", \"college_mathematics\"],\n",
    "    \"split\" : \"test\", \n",
    "    \"question_field\": \"question\", \n",
    "    \"answer_field\": \"answer\", \n",
    "    \"choices_field\": \"choices\", \n",
    "}\n",
    "\n",
    "cfg = OmegaConf.create(config)\n",
    "print(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_mod = import_module(\"tasks.mmlu\")\n",
    "load_data = dl_mod.load_data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "questions, answers, choices = load_data(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The rate, r, of a zero-order chemical reaction A → B can be expressed as which of the following?\n",
      "Let T: R^2 -> R^2 be the linear transformation that maps the point (1, 2) to (2, 3) and the point (-1, 2) to (2, -3). Then T maps the point (2, 1) to\n"
     ]
    }
   ],
   "source": [
    "print(questions[0])\n",
    "print(questions[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rlhf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
