{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fae8857a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from syncode import Syncode"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "775fcbf7",
   "metadata": {},
   "source": [
    "# 1. Define the grammar\n",
    "Syncode allows user to define a grammar using a simple EBNF syntax adapted from Lark. The grammars for some common programming languages are defined in the `syncode/grammars` directory. One can also simply feed the grammar rules directly as a string of rules as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "14ee1dc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.14it/s]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Lark base parser from cache: cache/parsers/custom_lalr_parser.pkl\n"
     ]
    }
   ],
   "source": [
    "grammar = \"\"\"\n",
    "        start: expr\n",
    "\n",
    "        ?expr: term\n",
    "            | expr \"+\" term\n",
    "            | expr \"-\" term\n",
    "            | expr \"*\" term\n",
    "            | expr \"/\" term\n",
    "            | expr \"=\" term\n",
    "\n",
    "        ?term: DEC_NUMBER | \"(\" expr \")\"\n",
    "\n",
    "        DEC_NUMBER: /0|[1-9]\\d*/i\n",
    "        \n",
    "        %ignore \" \"\n",
    "    \"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b4d4a0f",
   "metadata": {},
   "source": [
    "# 2. Load the Huggingface models\n",
    "Syncode uses the Huggingface transformers library to load the models. Simply provide the HuggingFace ID of the model to be used. Users can select the mode of generation `original` or `grammar_mask` to generate code with or without grammar constraining."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9ba4d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"microsoft/phi-2\"\n",
    "\n",
    "# Load the unconstrained original model\n",
    "llm = Syncode(model = model_name, mode='original', max_new_tokens=20)\n",
    "\n",
    "# Load the Syncode augmented model\n",
    "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar=grammar)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97c09f77",
   "metadata": {},
   "source": [
    "# 3. Comparing the outputs of Syncode with standard generation\n",
    "We compare the outputs of Syncode with the standard generation. We see that by constraining the LLM generation with the arithmetic expression grammar, the LLM refrains from generating output that contains text and generates only arithmetic expressions. This is a powerful feature that can be used to generate syntacitcal correct code and outputs in specific contexts where the LLM output is automatically provided to another tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "8ba78740",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM output:\n",
      "'The answer is 4.\\n<|im_end|>\\n\\n<|im_start'\n",
      "\n",
      "Syncode augmented LLM output:\n",
      "2+2=4 \n"
     ]
    }
   ],
   "source": [
    "output = llm.infer(\"What is 2+2?\")\n",
    "print(f\"LLM output:\\n{repr(output)}\\n\")\n",
    "\n",
    "output = syn_llm.infer(\"What is 2+2?\")\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "d1b61fda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM output:\n",
      "'The answer is 56.\\n<|im_end|>\\n\\n<|im_start'\n",
      "\n",
      "Syncode augmented LLM output:\n",
      "7 * 8 = 56 \n"
     ]
    }
   ],
   "source": [
    "output = llm.infer('What is 7 multiplied by 8?')\n",
    "print(f\"LLM output:\\n{repr(output)}\\n\")\n",
    "\n",
    "output = syn_llm.infer('What is 7 multiplied by 8?')\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "24a12ace",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM output:\n",
      "'The square root of 64 is 8.\\n<|im_end|>\\n'\n",
      "\n",
      "Syncode augmented LLM output:\n",
      "8 \n"
     ]
    }
   ],
   "source": [
    "output = llm.infer('What is square root of 64?')\n",
    "print(f\"LLM output:\\n{repr(output)}\\n\")\n",
    "\n",
    "output = syn_llm.infer('What is square root of 64?')\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codex-env",
   "language": "python",
   "name": "codex"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
