{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from syncode import SyncodeLogitsProcessor\n",
    "from syncode import Grammar\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import os\n",
    "\n",
    "HF_CACHE = os.environ['HF_CACHE'] if 'HF_CACHE' in os.environ else 'cache/'\n",
    "HF_ACCESS_TOKEN = os.environ['HF_ACCESS_TOKEN'] if 'HF_ACCESS_TOKEN' in os.environ else None\n",
    "\n",
    "device = 'cuda'\n",
    "model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
      "\n",
      "Cutting Knowledge Date: December 2023\n",
      "Today Date: 26 Jul 2024\n",
      "\n",
      "<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
      "\n",
      "Write a java function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
      "\n",
      " \n",
      "\n",
      "[OUTPUT] public class HelloWorld {\n",
      "    public static void main(String[] args) {\n",
      "        System.out.println(\"Hello World\");\n",
      "    } \n",
      "\n",
      "    public static void printReverse(String str) {\n",
      "        char[] arr = str.toCharArray();\n",
      "        int start = 0;\n",
      "        int end = arr.length - 1;\n",
      "\n",
      "        while (start < end) {\n",
      "            System.out.print(arr[start]);\n",
      "            System.out.print(arr[end]);\n",
      "            start++;\n",
      "            end--;\n",
      "        } \n",
      "        System.out.println();\n",
      "    } \n",
      "}\n"
     ]
    }
   ],
   "source": [
    "grammar_str = \"java\"\n",
    "\n",
    "grammar = Grammar(grammar_str)\n",
    "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
    "\n",
    "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
    "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
    "prompt = tokenizer.apply_chat_template(\n",
    "                  messages, tokenize=False, add_generation_prompt=True\n",
    "            )\n",
    "print(\"[PROMPT]\", prompt, \"\\n\")\n",
    "\n",
    "syncode_logits_processor.reset()\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
    "\n",
    "attention_mask = torch.ones_like(inputs)\n",
    "output = model.generate(\n",
    "      inputs,\n",
    "      attention_mask=attention_mask,\n",
    "      max_length=512, \n",
    "      num_return_sequences=1, \n",
    "      pad_token_id=tokenizer.eos_token_id, \n",
    "      logits_processor=[syncode_logits_processor]\n",
    "      )\n",
    "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
    "print(\"[OUTPUT]\", output_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
      "\n",
      "Cutting Knowledge Date: December 2023\n",
      "Today Date: 26 Jul 2024\n",
      "\n",
      "<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
      "\n",
      "Write a python function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
      "\n",
      " \n",
      "\n",
      "--------------------------------------------------\n",
      "Parsing failed! Falling back to unconstrained decoding.\n",
      "Exception: Unexpected token Token('NAME', 'simple') at line 3, column 11.\n",
      "Expected one of: \n",
      "\t* __ANON_9\n",
      "\t* __ANON_21\n",
      "\t* AMPERSAND\n",
      "\t* RPAR\n",
      "\t* __ANON_4\n",
      "\t* LESSTHAN\n",
      "\t* IF\n",
      "\t* STAR\n",
      "\t* RSQB\n",
      "\t* __ANON_5\n",
      "\t* __ANON_17\n",
      "\t* LSQB\n",
      "\t* SLASH\n",
      "\t* MINUS\n",
      "\t* VBAR\n",
      "\t* _NL\n",
      "\t* FROM\n",
      "\t* __ANON_20\n",
      "\t* EQUAL\n",
      "\t* __ANON_22\n",
      "\t* __ANON_13\n",
      "\t* OR\n",
      "\t* SEMICOLON\n",
      "\t* PLUS\n",
      "\t* LPAR\n",
      "\t* CIRCUMFLEX\n",
      "\t* FOR\n",
      "\t* __ANON_2\n",
      "\t* NOT\n",
      "\t* AT\n",
      "\t* __ANON_10\n",
      "\t* COMMA\n",
      "\t* __ANON_18\n",
      "\t* COLON\n",
      "\t* MORETHAN\n",
      "\t* AS\n",
      "\t* __ANON_6\n",
      "\t* ELSE\n",
      "\t* __ANON_16\n",
      "\t* __ANON_11\n",
      "\t* DOT\n",
      "\t* IN\n",
      "\t* __ANON_7\n",
      "\t* ASYNC\n",
      "\t* IS\n",
      "\t* RBRACE\n",
      "\t* __ANON_8\n",
      "\t* __ANON_3\n",
      "\t* AND\n",
      "\t* __ANON_15\n",
      "\t* __ANON_19\n",
      "\t* __ANON_14\n",
      "\t* PERCENT\n",
      "\t* __ANON_12\n",
      "\t* __ANON_1\n",
      "\n",
      "Partial code: ### Printing 'Hello World' in Reverse\n",
      "\n",
      "Here is a simple Python\n",
      "Parsed lexical tokens: [Token('_NL', \"### Printing 'Hello World' in Reverse\\n\\n\"), Token('NAME', 'Here'), Token('IS', 'is'), Token('NAME', 'a'), Token('NAME', 'simple')]\n",
      "--------------------------------------------------\n",
      "[OUTPUT] ### Printing 'Hello World' in Reverse\n",
      "\n",
      "Here is a simple Python function that prints 'Hello World' in reverse:\n",
      "\n",
      "```python\n",
      "def print_hello_world_reverse():\n",
      "    \"\"\"\n",
      "    Prints 'Hello World' in reverse.\n",
      "    \"\"\"\n",
      "    print(\"Hello World\")\n",
      "\n",
      "# Example usage:\n",
      "print_hello_world_reverse()\n",
      "```\n",
      "\n",
      "When you run this code, it will output:\n",
      "```\n",
      "olleH dlroW\n",
      "```\n",
      "\n",
      "Alternatively, you can also use slicing to reverse the string:\n",
      "\n",
      "```python\n",
      "def print_hello_world_reverse():\n",
      "    \"\"\"\n",
      "    Prints 'Hello World' in reverse.\n",
      "    \"\"\"\n",
      "    print(\" \".join([\"H\", \"e\", \"l\", \"l\", \"o\", \" \", \"W\", \"o\", \"r\", \"l\", \"d\"])\n",
      "\n",
      "# Example usage:\n",
      "print_hello_world_reverse()\n",
      "```\n",
      "\n",
      "This will output:\n",
      "```\n",
      "olleH dlroW\n",
      "```\n",
      "\n",
      "Note: The `join()` method is used to concatenate the elements of a list into a single string, which is then printed.\n"
     ]
    }
   ],
   "source": [
    "grammar_str = \"python\"\n",
    "\n",
    "grammar = Grammar(grammar_str)\n",
    "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
    "\n",
    "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
    "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
    "prompt = tokenizer.apply_chat_template(\n",
    "                  messages, tokenize=False, add_generation_prompt=True\n",
    "            )\n",
    "print(\"[PROMPT]\", prompt, \"\\n\")\n",
    "\n",
    "syncode_logits_processor.reset()\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
    "\n",
    "attention_mask = torch.ones_like(inputs)\n",
    "output = model.generate(\n",
    "      inputs,\n",
    "      attention_mask=attention_mask,\n",
    "      max_length=512, \n",
    "      num_return_sequences=1, \n",
    "      pad_token_id=tokenizer.eos_token_id, \n",
    "      logits_processor=[syncode_logits_processor]\n",
    "      )\n",
    "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
    "print(\"[OUTPUT]\", output_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PROMPT] <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
      "\n",
      "Cutting Knowledge Date: December 2023\n",
      "Today Date: 26 Jul 2024\n",
      "\n",
      "<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
      "\n",
      "Write a go function that prints 'hello world' in reverse.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
      "\n",
      " \n",
      "\n",
      "[OUTPUT] // \n",
      "\n",
      "package main\n",
      "\n",
      "import (\n",
      "    \"fmt\" // Import the fmt package\n",
      ")\n",
      "\n",
      "// Function to print 'hello world' in reverse\n",
      "func printHelloWorld() {\n",
      "    // Declare a variable to hold the string 'hello world'\n",
      "    var s string = \"hello world\" // Define the string\n",
      "    // Use string reverse() to reverse the string\n",
      "    var reversed string = strings.Reverses(s) // Reverse the string\n",
      "    // Print the reversed string\n",
      "    fmt.Println(reversed) // Print the reversed string\n",
      "}\n",
      "\n",
      "func main() {\n",
      "    // Call the function to print 'hello world' in reverse\n",
      "    printHelloWorld() // Call the function\n",
      "} \n",
      "\n",
      "// Note: The string reverse() function in Go returns a string slice, not a string.\n",
      "// If you want to convert the string slice to a string, you can use the string slice's string() method.\n",
      "// Here, we are using the Reverses() function which returns a string slice.\n"
     ]
    }
   ],
   "source": [
    "grammar_str = \"go\"\n",
    "\n",
    "grammar = Grammar(grammar_str)\n",
    "syncode_logits_processor = SyncodeLogitsProcessor(grammar=grammar, tokenizer=tokenizer, parse_output_only=True)\n",
    "\n",
    "prompt = f\"Write a {grammar_str} function that prints 'hello world' in reverse.\"\n",
    "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
    "prompt = tokenizer.apply_chat_template(\n",
    "                  messages, tokenize=False, add_generation_prompt=True\n",
    "            )\n",
    "print(\"[PROMPT]\", prompt, \"\\n\")\n",
    "\n",
    "syncode_logits_processor.reset()\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors='pt').input_ids.to(device)\n",
    "\n",
    "attention_mask = torch.ones_like(inputs)\n",
    "output = model.generate(\n",
    "      inputs,\n",
    "      attention_mask=attention_mask,\n",
    "      max_length=512, \n",
    "      num_return_sequences=1, \n",
    "      pad_token_id=tokenizer.eos_token_id, \n",
    "      logits_processor=[syncode_logits_processor]\n",
    "      )\n",
    "output_str = tokenizer.decode(output[0][len(inputs[0]):], skip_special_tokens=True)\n",
    "print(\"[OUTPUT]\", output_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
