{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SynCode Logits Porcessor\n",
    "In this notebook, we will use only the SyncodeLogitsProcessor with HuggingFace model to enable grammar-guided decoding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from syncode import SyncodeLogitsProcessor\n",
    "from syncode import Grammar\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "device = 'cuda'\n",
    "model_name = \"microsoft/phi-2\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).eval().to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize SynCode logits processor for the given grammar\n",
    "\n",
    "grammar_str = \"\"\" start: month \" \" day \n",
    "              \n",
    "              day: /[1-9]/ | /[1-2][0-9]/ | /3[0-1]/\n",
    "              \n",
    "              month: \"January\" | \"February\" | \"March\" | \"April\" | \"May\" | \"June\" | \"July\" | \"August\" | \"September\" | \"October\" | \"November\" | \"December\"\n",
    "\"\"\"\n",
    "date_grammar = Grammar(grammar_str)\n",
    "syncode_logits_processor = SyncodeLogitsProcessor(grammar=date_grammar, tokenizer=tokenizer, parse_output_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "July 4\n"
     ]
    }
   ],
   "source": [
    "prompt = \"When is the US independence day?\"\n",
    "syncode_logits_processor.reset()\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
    "\n",
    "output = model.generate(\n",
    "      inputs,\n",
    "      max_length=100, \n",
    "      num_return_sequences=1, \n",
    "      pad_token_id=tokenizer.eos_token_id, \n",
    "      logits_processor=[syncode_logits_processor]\n",
    "      )\n",
    "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
    "print(output_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
