{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"copy_prompt_builder.py\n",
    "\n",
    "Builds prompts for the *Copy* probing task (UF/UB or NF/NB).\n",
    "\n",
    "Each example line has the form::\n",
    "\n",
    "    <bos> a b c d <sep> a b c d <eos>\n",
    "\n",
    "Where\n",
    "* ``<bos>`` begins the sequence\n",
    "* characters of the *input* are separated by spaces so that every character is\n",
    "  kept as an independent token by the tokenizer\n",
    "* ``<sep>`` ends the copy source; the colon separates *source* and *target*\n",
    "* the *target* is either the *forward* copy (UF/NF) or the *backward* copy\n",
    "  (UB/NB)\n",
    "\n",
    "For the *query* example (the one we want the model to answer) we omit the\n",
    "*target* so the prompt line ends right after ``<sep>:``.  A full prompt is a\n",
    "block of *k* few‑shot lines followed by the query line.\n",
    "\"\"\"\n",
    "from __future__ import annotations\n",
    "\n",
    "import random\n",
    "from pathlib import Path\n",
    "from typing import Dict, List, Sequence\n",
    "\n",
    "import jsonlines\n",
    "from vllm import LLM, SamplingParams\n",
    "from transformers import AutoTokenizer\n",
    "# from copying.prompt_constants import *\n",
    "# ---------------- Dataset loading ---------------- #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "!export PYTHONPATH=~/Desktop/Lacoco/len-gen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "@dataclass(frozen=True)\n",
    "class TaskRule:\n",
    "    key: str\n",
    "    rule_simple: str\n",
    "    rule_hint: str\n",
    "\n",
    "# Forward‑copy (F) and backward‑copy (B) rules for the two token‑uniqueness\n",
    "# regimes.  The wording is identical for unique / non‑unique because the\n",
    "# operation (copy or reverse) does not depend on repetition once the input is\n",
    "# fixed.\n",
    "TASKS: Dict[str, TaskRule] = {\n",
    "    # ―― Unique tokens ――――――――――――――――――――――――――――――――――――――――――――――――\n",
    "    \"UF\": TaskRule(\n",
    "        key=\"UF\",\n",
    "        rule_simple=\"The output is exactly the same sequence as the input.\",\n",
    "        rule_hint=\"in every example the output repeats the input unchanged.\"\n",
    "    ),\n",
    "    \"UB\": TaskRule(\n",
    "        key=\"UB\",\n",
    "        rule_simple=\"The output is the input sequence written in reverse order.\",\n",
    "        rule_hint='in every example the output is the input read backwards.'\n",
    "    ),\n",
    "    # ―― Non‑unique tokens ――――――――――――――――――――――――――――――――――――――――――――\n",
    "    \"NF\": TaskRule(\n",
    "        key=\"NF\",\n",
    "        rule_simple=\"The output is exactly the same sequence as the input.\",\n",
    "        rule_hint='in every example the output repeats the input unchanged.'\n",
    "    ),\n",
    "    \"NB\": TaskRule(\n",
    "        key=\"NB\",\n",
    "        rule_simple=\"The output is the input sequence written in reverse order.\",\n",
    "        rule_hint='in every example the output is the input read backwards.'\n",
    "    ),\n",
    "}\n",
    "\n",
    "# ---------------------------------------------------------------------------\n",
    "# Prompt templates (absolutely no imperatives!)\n",
    "# ---------------------------------------------------------------------------\n",
    "\n",
    "TEMPLATES: Dict[str, str] = {\n",
    "    # 1) Bare – pattern exposure only --------------------------------------\n",
    "    \"bare\": \"{examples}\",\n",
    "    # 2) One‑line natural rule --------------------------------------------\n",
    "    \"obey\": \"Here are some samples of the format <bos> input : output <eos> where the rule being followed is - {rule_simple}\\n\\n{examples}\",\n",
    "    # 3) Reverse‑specific hint --------------------------------------------\n",
    "    \"hint\": \"Here are some examples where {rule_hint}\\n\\n{examples}\"\n",
    "}\n",
    "\n",
    "MODELS: Dict[str, str] = {\n",
    "    'llama3_8B': \"/local/common_models/Llama-3.1-8B\", \n",
    "    'llama3_70B': \"/local/common_models/Llama-3.1-70B\",\n",
    "    'qwen2.5_7B': \"/local/common_models/Qwen2.5-7B\",\n",
    "    'qwen2.5_32B': \"/local/common_models/Qwen2.5-32B\",\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "SETTINGS = [\"UF\", \"UB\", \"NF\", \"NB\"]\n",
    "\n",
    "\n",
    "def load_dataset(data_path: Path, task_key: str) -> List[Dict]:\n",
    "    \"\"\"Load *copy* task JSON‑Lines and return a list of dicts.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    data_path\n",
    "        Path to a JSONL file produced by *CopyDatasetGenerator*.\n",
    "    task_key\n",
    "        One of ``UF``, ``UB``, ``NF``, ``NB`` selecting the *target* column.\n",
    "    \"\"\"\n",
    "    if task_key not in SETTINGS:\n",
    "        raise ValueError(f\"task_key must be one of {SETTINGS!r}\")\n",
    "\n",
    "    records: List[Dict] = []\n",
    "    with jsonlines.open(data_path, \"r\") as reader:\n",
    "        for rec in reader:\n",
    "            records.append({\n",
    "                \"input\": rec[\"input\"],\n",
    "                \"target\": rec[task_key],\n",
    "            })\n",
    "    return records\n",
    "\n",
    "\n",
    "# ---------------- Prompt helpers ---------------- #\n",
    "\n",
    "def _space_chars(text: str) -> str:\n",
    "    \"\"\"Return characters of *text* separated by single spaces.\"\"\"\n",
    "    return \" \".join(text)\n",
    "\n",
    "\n",
    "class PromptBuilderCopy:\n",
    "    \"\"\"Create few‑shot prompts for the Copy task.\n",
    "\n",
    "    Each *example line* is built like::\n",
    "\n",
    "        <bos> a b c d <sep>: a b c d\n",
    "\n",
    "    For the query (target unknown) the part after the colon is omitted.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        shots: int,\n",
    "        variant: str,\n",
    "        task_key: str,\n",
    "        test_data_path: Path,\n",
    "        fewshot_data_path: Path | None,\n",
    "        tokenizer: AutoTokenizer,\n",
    "        rng: random.Random | None = None,\n",
    "    ) -> None:\n",
    "        self.shots = shots\n",
    "        self.task_obj = TASKS[task_key]\n",
    "        self.tokenizer = tokenizer\n",
    "        self.rng = rng or random.Random()\n",
    "        self.prompt_type = variant\n",
    "        self.dataset = load_dataset(test_data_path, task_key)\n",
    "        if fewshot_data_path is None:\n",
    "            # allow re‑using the test set as few‑shot pool\n",
    "            self.fewshot_dataset = self.dataset.copy()\n",
    "        else:\n",
    "            self.fewshot_dataset = load_dataset(fewshot_data_path, task_key)\n",
    "\n",
    "    # -------------- Single‑line builder -------------- #\n",
    "\n",
    "    def _make_example_line(self, input_str: str, target: str | None) -> str:\n",
    "        \"\"\"Return a prompt line with (or without) *target*.\"\"\"\n",
    "        src = f\"<bos> {_space_chars(input_str)} :\"\n",
    "        return src if target is None else f\"{src} {_space_chars(target)} <eos>\"\n",
    "\n",
    "    # -------------- Few‑shot block -------------- #\n",
    "\n",
    "    def _few_shot_block(self, curr_record: Dict) -> str:\n",
    "        \"\"\"Sample *k* few‑shot examples (excluding *curr_record*).\"\"\"\n",
    "        pool = [rec for rec in self.fewshot_dataset if rec != curr_record]\n",
    "        examples = self.rng.sample(pool, k=self.shots)\n",
    "        lines = [self._make_example_line(ex[\"input\"], ex[\"target\"]) for ex in examples]\n",
    "        # add the *query* without target\n",
    "        lines.append(self._make_example_line(curr_record[\"input\"], None))\n",
    "        return \"\\n\".join(lines)\n",
    "\n",
    "    # -------------- Public API -------------- #\n",
    "\n",
    "    def build_prompt(self, curr_record: Dict) -> tuple[str, List[int]]:\n",
    "        \"\"\"Return the textual prompt **and** its token‑ids.\"\"\"\n",
    "        # Create the examples along with the input\n",
    "        examples = self._few_shot_block(curr_record)\n",
    "        template_base = TEMPLATES[self.prompt_type].format(\n",
    "            rule_simple=self.task_obj.rule_simple,\n",
    "            rule_hint=self.task_obj.rule_hint,\n",
    "            examples=examples,\n",
    "        )\n",
    "        ids = self.tokenizer.encode(template_base, add_special_tokens=False)\n",
    "        return template_base, ids\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "# Evaluation / prompt generation class\n",
    "############################################\n",
    "\n",
    "class ModelEvaluator:\n",
    "    \"\"\"Handle vLLM based inference\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_id: str,\n",
    "        batch_size: int,\n",
    "        tp: int,\n",
    "        gpu_mem: float,\n",
    "        temperature: float,\n",
    "        seed: int\n",
    "    ) -> None:\n",
    "        self.model_id = model_id\n",
    "        self.batch_size = batch_size\n",
    "        self.temperature = temperature\n",
    "        self.seed = seed\n",
    "\n",
    "        print(f\"→ Loading tokenizer for {model_id} …\", flush=True)\n",
    "        # self.tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)\n",
    "        self.tokenizer.pad_token = self.tokenizer.eos_token\n",
    "\n",
    "        print(\"→ Initialising vLLM engine …\", flush=True)\n",
    "        self.llm = LLM(\n",
    "            model=model_id,\n",
    "            tensor_parallel_size=tp,\n",
    "            gpu_memory_utilization=gpu_mem,\n",
    "            seed=seed,\n",
    "            skip_tokenizer_init=True,\n",
    "        )\n",
    "        self.sampling_params = SamplingParams(\n",
    "            max_tokens=50,\n",
    "            temperature=temperature,\n",
    "            stop=[\"\\n\\n\", \"<eos>\"]\n",
    "        )\n",
    "\n",
    "    # ---------------- Evaluation ---------------- #\n",
    "\n",
    "    def run(\n",
    "        self,\n",
    "        prompt_packs: Sequence[Dict],\n",
    "    ) -> List[Dict]:\n",
    "\n",
    "        results: List[Dict] = []\n",
    "        for i in range(0, len(prompt_packs), self.batch_size):\n",
    "            batch = prompt_packs[i : i + self.batch_size]\n",
    "            outs = self.llm.generate(\n",
    "                prompt_token_ids=[b[\"ids\"] for b in batch],\n",
    "                sampling_params=self.sampling_params,\n",
    "            )\n",
    "            for generated_output, current_entry in zip(outs, batch):\n",
    "                generated_text = self.tokenizer.decode(generated_output.outputs[0].token_ids).strip()\n",
    "                try:\n",
    "                    prediction = generated_text.strip()[0]\n",
    "                except Exception:\n",
    "                    prediction = \"failed\"\n",
    "                results.append({\n",
    "                    \"prediction\": prediction,\n",
    "                    \"full_output\": generated_text,\n",
    "                    \"input\": current_entry[\"rec\"][\"input\"],\n",
    "                    \"target\": current_entry[\"rec\"][\"target\"],\n",
    "                    \"prompt\": current_entry[\"prompt\"]\n",
    "                })\n",
    "        return results\n",
    "\n",
    "    def cleanup(self):\n",
    "        self.llm.llm_engine.__del__()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pb = PromptBuilderCopy(3, \"bare\", \"UB\", Path(\"../datasets/copying/unique/unique_N1500_L10_seed121.jsonl\"), None, AutoTokenizer.from_pretrained('Qwen/Qwen2.5-1.5B'))\n",
    "pb = PromptBuilderCopy(3, \"obey\", \"NB\", Path(\"../datasets/copying/nonunique/nonunique_N1500_L10_seed121.jsonl\"), None, AutoTokenizer.from_pretrained('Qwen/Qwen2.5-1.5B'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "record = pb.dataset[0]  # type: ignore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Here are some samples of the format <bos> input : output <eos> where the rule being followed is - The output is the input sequence written in reverse order.\n",
      "\n",
      "<bos> H k m W f w h K O m : m O K h w f W m k H <eos>\n",
      "<bos> e Y F r E H B U k e : e k U B H E r F Y e <eos>\n",
      "<bos> r w f L Q O b P e r : r e P b O Q L f w r <eos>\n",
      "<bos> F N h O W O b p k A :\n"
     ]
    }
   ],
   "source": [
    "prompt, ids = pb.build_prompt(record)\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
