{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}\n"
     ]
    }
   ],
   "source": [
    "from syncode import SyncodeLogitsProcessor, Grammar\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import lark\n",
    "\n",
    "grammar_str = r\"\"\"\n",
    "start: item (\",\" item)* \n",
    "\n",
    "item: \"'\" name \"'\"\n",
    "    | \"\\\"\" name \"\\\"\"\n",
    "\n",
    "name: \"Alice\" \n",
    "    | \"Bob\" \n",
    "    | \"Carol\" \n",
    "    | \"Dave\"\n",
    "    | \"Eve\"\n",
    "\"\"\"\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name = \"deepseek-ai/deepseek-coder-1.3b-base\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\").eval()\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "syncode_grammar = Grammar(grammar_str)\n",
    "parser = lark.Lark(grammar_str)\n",
    "\n",
    "prompt = \"A list of male first names:\\n\"\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
    "\n",
    "constrain = True\n",
    "\n",
    "args = {\n",
    "    \"max_new_tokens\" : 128,\n",
    "    \"do_sample\" : True,\n",
    "    \"num_beams\" : 2,\n",
    "    \"num_return_sequences\" : 2,\n",
    "    \"pad_token_id\" : tokenizer.eos_token_id,\n",
    "}\n",
    "\n",
    "syncode_logits_processor = SyncodeLogitsProcessor(\n",
    "    grammar=syncode_grammar, \n",
    "    tokenizer=tokenizer, \n",
    "    parse_output_only=True, \n",
    "    num_samples=args[\"num_beams\"],\n",
    "    mode=\"grammar_strict\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\"'Alice'\", '\"Alice\"']\n",
      ">>> COMPLETION 0\n",
      "\n",
      "CAN PARSE\n",
      "\n",
      "'Alice'\n",
      "\n",
      "[\"'\", 'A', 'lic', 'e', \"'\", '<｜end▁of▁sentence｜>']\n",
      "\n",
      ">>> COMPLETION 1\n",
      "\n",
      "CAN PARSE\n",
      "\n",
      "\"Alice\"\n",
      "\n",
      "['\"', 'A', 'lic', 'e', '\"', '<｜end▁of▁sentence｜>']\n",
      "\n"
     ]
    }
   ],
   "source": [
    "syncode_logits_processor.reset()\n",
    "\n",
    "outputs = model.generate(\n",
    "    inputs,\n",
    "    logits_processor=[syncode_logits_processor] if constrain else None,\n",
    "    **args,\n",
    ")\n",
    "\n",
    "outputs = [o[len(inputs[0]):] for o in outputs]\n",
    "completions = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "completions_tokens = [[tokenizer.decode(tok) for tok in output] for output in outputs]\n",
    "print(completions)\n",
    "\n",
    "\n",
    "for i, (c, toks) in enumerate(zip(completions, completions_tokens)):\n",
    "    print(f\">>> COMPLETION {i}\\n\")\n",
    "    try:\n",
    "        tree = parser.parse(c)\n",
    "        print(\"CAN PARSE\\n\")\n",
    "    except:\n",
    "        print(\"CANNOT PARSE\\n\") \n",
    "    print(f\"{c}\\n\")\n",
    "    print(f\"{toks}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
