{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02e7fcfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1,2,3\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ['NCCL_P2P_DISABLE'] = \"1\"\n",
    "os.environ['NCCL_IB_DISABLE'] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d44a069",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import ast\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "from transformers import pipeline\n",
    "\n",
    "import transformers\n",
    "\n",
    "transformers.logging.set_verbosity_error()\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "746a5fc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "login(token=\"your_huggingface_token\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3159f743",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b62eeef5",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1\n",
    "dataset_list = ['Irony', 'Metaphor']\n",
    "\n",
    "df = pd.read_csv(f'./data/pragmatic/{dataset_list[idx]}.csv')\n",
    "df['options'] = [ast.literal_eval(data) for data in df['options']]\n",
    "df['answer_keys'] = [ast.literal_eval(data) for data in df['answer_keys']]  # Rename to avoid conflict\n",
    "df['answer'] = [data.index('correct') + 1 for data in df['answer_keys']]\n",
    "\n",
    "instruction = Path(f\"./data/pragmatic/{dataset_list[idx]}Instructions.txt\").read_text(encoding=\"utf-8\").strip()\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e927eac4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "idx = 2\n",
    "dataset_list = ['simile', 'metaphor', 'idiom']\n",
    "\n",
    "df = pd.read_csv(f'./data/pragmatic/{dataset_list[idx]}_test_for_new_method.csv')\n",
    "df['options'] = [ast.literal_eval(data) for data in df['options']]\n",
    "df['answer'] = [data + 1 for data in df['answer']]\n",
    "\n",
    "instruction = \"\"\n",
    "\n",
    "print(f\"Dataset loaded: {len(df)} rows\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8b3d3bd",
   "metadata": {},
   "source": [
    "- For testing each LLM, remove the annotations and run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1d15d17",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\",\n",
    "    model=\"ibm-granite/granite-3.1-1b-a400m-instruct\",\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    pad_token_id=128001)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f2bccd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\", \n",
    "    model=\"meta-llama/Llama-3.2-1B-Instruct\",\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    pad_token_id=128001\n",
    ")\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba309509",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    "    pad_token_id=128001,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c1d897",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\", \n",
    "    model=\"mistralai/Mistral-7B-Instruct-v0.3\",\n",
    "    device=\"cuda:1\",\n",
    "    torch_dtype=torch.float16)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40ea05db",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=\"google/gemma-2b-it\",\n",
    "    device=\"cuda:1\",\n",
    "    torch_dtype=torch.float16\n",
    ")\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd85c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\", \n",
    "    model=\"Qwen/Qwen3-4B-Instruct-2507\",\n",
    "    device=\"cuda:3\",\n",
    "    torch_dtype=torch.float16\n",
    ")\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38cf9e86",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\", \n",
    "    model=\"Qwen/Qwen2.5-1.5B-Instruct\",\n",
    "    device=\"cuda:3\",\n",
    "    torch_dtype=torch.float16\n",
    ")\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bfeb766",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\", \n",
    "    model=\"microsoft/Phi-3-mini-4k-instruct\",\n",
    "    device=\"cuda:3\",\n",
    "    torch_dtype=torch.float16\n",
    ")'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4000a084",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "pipe = pipeline(\"text-generation\", \n",
    "    model=\"HuggingFaceTB/SmolLM2-360M-Instruct\",\n",
    "    device=\"cuda:3\",\n",
    "    torch_dtype=torch.float16\n",
    ")'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c61f664",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = '''You are a careful multiple-choice grader for short story comprehension with figurative language.\n",
    "\n",
    "You will receive:\n",
    "• One instruction line.\n",
    "• A short story and a question.\n",
    "• Answer options labeled 1..N (N varies by item).\n",
    "\n",
    "Rules:\n",
    "• Use only the given text and commonsense.\n",
    "• Prefer figurative/pragmatic meaning over literal.\n",
    "• Do NOT invent options or rely on external facts.\n",
    "\n",
    "Output:\n",
    "• Output ONLY the chosen option number (1..N). One line, no spaces, no words, no punctuation.\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1163d2ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_answer(text, n=5):\n",
    "    m = re.search(rf\"\\b([1-{n}])\\b\", text)\n",
    "    if m: return m.group(1)\n",
    "    words = {\"one\":\"1\",\"two\":\"2\",\"three\":\"3\",\"four\":\"4\",\"five\":\"5\"}\n",
    "    m = re.search(r\"\\b(\" + \"|\".join(words) + r\")\\b\", text.lower())\n",
    "    if m: return words[m.group(1)]\n",
    "    return \"-1\"\n",
    "\n",
    "def gather_eos_ids(tok):\n",
    "    ids = set()\n",
    "    if tok.eos_token_id is not None:\n",
    "        ids.add(tok.eos_token_id)\n",
    "    for t in [\"<|im_end|>\", \"<|eot_id|>\", \"<|end_of_text|>\", \"</s>\"]:\n",
    "        i = tok.convert_tokens_to_ids(t)\n",
    "        if i is not None and i != -1:\n",
    "            ids.add(i)\n",
    "    return sorted(ids)\n",
    "\n",
    "def number_token_ids(tok, N):\n",
    "    s = set()\n",
    "    for i in range(1, N+1):\n",
    "        for v in (str(i), \" \" + str(i)):\n",
    "            ids = tok(v, add_special_tokens=False).input_ids\n",
    "            if ids:\n",
    "                s.add(ids[-1])\n",
    "    return sorted(s)\n",
    "\n",
    "def newline_last_id(tok):\n",
    "    ids = tok(\"\\n\", add_special_tokens=False).input_ids\n",
    "    return ids[-1] if ids else None\n",
    "\n",
    "def build_prompt_text(tok, system_prompt: str, user_prompt: str) -> str:\n",
    "    msgs = [\n",
    "        {\"role\": \"system\", \"content\": system_prompt},\n",
    "        {\"role\": \"user\", \"content\": user_prompt},\n",
    "    ]\n",
    "    try:\n",
    "        return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
    "    except Exception as e:\n",
    "        if \"System role not supported\" in str(e):\n",
    "            fallback_msgs = [\n",
    "                {\"role\": \"user\", \"content\": system_prompt.strip() + \"\\n\\n\" + user_prompt.strip()}\n",
    "            ]\n",
    "            return tok.apply_chat_template(fallback_msgs, tokenize=False, add_generation_prompt=True)\n",
    "        raise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c36c6fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "tok = pipe.tokenizer  # reuse tokenizer/model from your pipeline\n",
    "model = pipe.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ba1d076",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "acc_list, debug_list = [], []\n",
    "\n",
    "for data in tqdm(df.itertuples(), total=len(df)):\n",
    "    scenarios = data.scenarios\n",
    "    options = data.options  # list[str]\n",
    "\n",
    "    opts_block = \"\\n\".join(f\"{i}. {opt}\" for i, opt in enumerate(options, start=1))\n",
    "    user_prompt = f\"\"\"{instruction}\n",
    "\n",
    "{scenarios}\n",
    "\n",
    "{opts_block}\"\"\".strip()\n",
    "\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": system_prompt},\n",
    "        {\"role\": \"user\", \"content\": user_prompt},\n",
    "    ]\n",
    "    \n",
    "    prompt_text = build_prompt_text(tok, system_prompt, user_prompt)\n",
    "\n",
    "    start_len = tok(prompt_text, return_tensors=\"pt\").input_ids.shape[-1]\n",
    "\n",
    "    N = len(options)\n",
    "    digit_ids = number_token_ids(tok, N)         \n",
    "    nl_id    = newline_last_id(tok)\n",
    "    eos_ids  = gather_eos_ids(tok)\n",
    "    eos_arg  = eos_ids if eos_ids else None\n",
    "    pad_id   = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id\n",
    "\n",
    "    def allow(batch_id, ids):\n",
    "        cur_len = ids.shape[-1] if hasattr(ids, \"shape\") else len(ids)\n",
    "        if cur_len == start_len:      \n",
    "            return digit_ids\n",
    "        allowed = list(eos_ids)\n",
    "        if nl_id is not None:\n",
    "            allowed.append(nl_id)        \n",
    "        return allowed\n",
    "\n",
    "    bw = tok([\"answer\",\"Answer\",\"option\",\"Option\",\"correct\",\"Correct\"],\n",
    "             add_special_tokens=False)[\"input_ids\"]\n",
    "    bad_words_ids = [w for w in bw if len(w) > 0]\n",
    "\n",
    "\n",
    "    res = pipe(\n",
    "        prompt_text,\n",
    "        do_sample=False,\n",
    "        temperature=0.0,\n",
    "        top_p=1.0,\n",
    "        max_new_tokens=2,               \n",
    "        eos_token_id=eos_arg,           \n",
    "        pad_token_id=pad_id,\n",
    "        prefix_allowed_tokens_fn=allow, \n",
    "        bad_words_ids=bad_words_ids,\n",
    "        return_full_text=False,\n",
    "    )\n",
    "\n",
    "    output = res[0][\"generated_text\"].strip()\n",
    "\n",
    "    import re\n",
    "    m = re.search(rf\"\\b([1-{N}])\\b\", output)\n",
    "    answer = m.group(1) if m else \"-1\"\n",
    "\n",
    "    acc_list.append(1 if answer == str(data.answer) else 0)\n",
    "    debug_list.append(int(output))\n",
    "end = time.time()\n",
    "acc = sum(acc_list) / len(acc_list)\n",
    "print(f\"acc : {acc:.4f}\")\n",
    "print(f'Latency : {end-start:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e695899",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "topo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
