{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n",
      "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.75it/s]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating DFA mask store for LlamaTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/LlamaTokenizerFast/grammar_strict_5884830165_32000.pkl.\n",
      "Ignore whitespace tokens is False\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:02<00:00,  8.67it/s]\n"
     ]
    }
   ],
   "source": [
    "from syncode.infer import Syncode\n",
    "\n",
    "model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
    "\n",
    "grammar_without_ws = \"\"\" \n",
    "                start: \"assert\" \"(\" expression \")\" \";\"\n",
    "                                \n",
    "                expression: \"name\"          \n",
    "        \"\"\"\n",
    "\n",
    "syn_llm = Syncode(model=model_name, grammar=grammar_without_ws, parse_output_only=True, new_mask_store=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are not running the flash-attention implementation, expect numerical differences.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Syncode augmented LLM output:\n",
      "assert(name);\n",
      "\n"
     ]
    }
   ],
   "source": [
    "p = \"You are an expert in writing print statements in Python. Follow the user query and generate a Python print statement without any comments or explanations. User query: In a python print statement print a string name.\"\n",
    "\n",
    "output = syn_llm.infer(p)[0]\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n",
      "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.72it/s]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating DFA mask store for LlamaTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/LlamaTokenizerFast/grammar_strict_5906785652_32000.pkl.\n",
      "Ignore whitespace tokens is True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 21/21 [00:03<00:00,  6.77it/s]\n"
     ]
    }
   ],
   "source": [
    "model_name = \"microsoft/Phi-3-mini-4k-instruct\"\n",
    "\n",
    "grammar_with_ws = \"\"\" \n",
    "                start: \"assert\" \"(\" expression \")\" \";\"\n",
    "                                \n",
    "                expression: \"name\"  \n",
    "\n",
    "                %import common.WS \n",
    "                %ignore WS \n",
    "        \"\"\"\n",
    "syn_llm = Syncode(model=model_name, grammar=grammar_with_ws, parse_output_only=True, new_mask_store=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Syncode augmented LLM output:\n",
      "\n",
      "\n",
      "assert(name)\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "p = \"You are an expert in writing print statements in Python. Follow the user query and generate a Python print statement without any comments or explanations. User query: In a python print statement print a string name.\"\n",
    "\n",
    "output = syn_llm.infer(p)[0]\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\\n\")\n",
    "\n",
    "# Although the output ends without \";\", it is syntactically valid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shubham/anaconda3/envs/codex/lib/python3.11/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating DFA mask store for CodeGenTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/CodeGenTokenizerFast/grammar_strict_5884830165_50257.pkl.\n",
      "Ignore whitespace tokens is False\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:03<00:00,  5.40it/s]\n"
     ]
    }
   ],
   "source": [
    "from syncode.infer import Syncode\n",
    "\n",
    "model_name = \"Salesforce/codegen-350M-multi\"\n",
    "\n",
    "grammar_without_ws = \"\"\" \n",
    "                start: \"assert\" \"(\" expression \")\" \";\"\n",
    "                                \n",
    "                expression: \"name\"          \n",
    "        \"\"\"\n",
    "syn_llm = Syncode(model=model_name, grammar=grammar_without_ws, parse_output_only=True, new_mask_store=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Syncode augmented LLM output:\n",
      "assert(name);\n",
      "\n"
     ]
    }
   ],
   "source": [
    "p = \"You are an expert in writing print statements in Python. Follow the user query and generate a Python print statement without any comments or explanations. User query: In a python print statement print a string name.\"\n",
    "\n",
    "output = syn_llm.infer(p)[0]\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating DFA mask store for CodeGenTokenizerFast and custom, may take more than 10 minutes. Caching at /home/shubham/syncode/cache/mask_stores/CodeGenTokenizerFast/grammar_strict_5906785652_50257.pkl.\n",
      "Ignore whitespace tokens is True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 21/21 [00:04<00:00,  4.25it/s]\n"
     ]
    }
   ],
   "source": [
    "from syncode.infer import Syncode\n",
    "\n",
    "model_name = \"Salesforce/codegen-350M-multi\"\n",
    "\n",
    "grammar_with_ws = \"\"\" \n",
    "                start: \"assert\" \"(\" expression \")\" \";\"\n",
    "                                \n",
    "                expression: \"name\"  \n",
    "\n",
    "                %import common.WS \n",
    "                %ignore WS \n",
    "        \"\"\"\n",
    "syn_llm = Syncode(model=model_name, grammar=grammar_with_ws, parse_output_only=True, new_mask_store=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Syncode augmented LLM output:\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "        \n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "p = \"You are an expert in writing print statements in Python. Follow the user query and generate a Python print statement without any comments or explanations. User query: In a python print statement print a string name.\"\n",
    "\n",
    "output = syn_llm.infer(p)[0]\n",
    "print(f\"Syncode augmented LLM output:\\n{output}\\n\")\n",
    "\n",
    "# Although the output is empty it is syntactically valid"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
