{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "23530ae1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Lark base parser from cache: cache/parsers/python_lalr_parser.pkl\n"
     ]
    }
   ],
   "source": [
    "from syncode import Syncode\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "model_name = \"WizardLM/WizardCoder-1B-V1.0\"\n",
    "\n",
    "# Load the unconstrained original model\n",
    "llm = Syncode(model = model_name, mode='original', max_new_tokens=200)\n",
    "\n",
    "# Load the Syncode augmented model\n",
    "syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python', parse_output_only=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "987e0831",
   "metadata": {},
   "source": [
    "# Syncode vs Standard LLM generation of Python program\n",
    "We show that in following program the standard LLM generates a python program that has indentation error, while the Syncode LLM generates a correct python program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "490cddb3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def is_prime(n):\n",
      "    '''Return if prime'''\n",
      "   if n < 2:\n",
      "       return False\n",
      "   for i in range(2, int(n**0.5)+1):\n",
      "       if n % i == 0:\n",
      "           return False\n",
      "   return True\n"
     ]
    },
    {
     "ename": "IndentationError",
     "evalue": "unindent does not match any outer indentation level (<string>, line 3)",
     "output_type": "error",
     "traceback": [
      "Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
      "\u001b[0m  File \u001b[1;32m~/anaconda3/envs/codex/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n    exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
      "\u001b[0;36m  Cell \u001b[0;32mIn[16], line 4\u001b[0;36m\n\u001b[0;31m    exec(output)\u001b[0;36m\n",
      "\u001b[0;36m  File \u001b[0;32m<string>:3\u001b[0;36m\u001b[0m\n\u001b[0;31m    if n < 2:\u001b[0m\n\u001b[0m             ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unindent does not match any outer indentation level\n"
     ]
    }
   ],
   "source": [
    "partial_code = \"def is_prime(n):\\n    '''Return if prime'''\\n  \"\n",
    "output = partial_code+llm.infer(partial_code)[0]\n",
    "print(output)\n",
    "exec(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "76cd93f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def is_prime(n):\n",
      "    '''Return if prime'''\n",
      "    if n < 2:\n",
      "        return False\n",
      "    for i in range(2, int(n**0.5) + 1):\n",
      "        if n % i == 0:\n",
      "            return False\n",
      "    return True\n"
     ]
    }
   ],
   "source": [
    "output = partial_code+syn_llm.infer(partial_code)[0]\n",
    "print(output)\n",
    "exec(output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
