{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "# environment setup (working dir is project root)\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import os\n",
    "if globals().get(\"PROJECT_ROOT\") is None:\n",
    "    PROJECT_ROOT = os.path.join(os.getcwd(), \"..\")\n",
    "os.chdir(PROJECT_ROOT)\n",
    "sys.path.append(f\"{PROJECT_ROOT}/lmql\")\n",
    "import lmql\n",
    "import asyncio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "task = \"boolean_expressions\"\n",
    "path = os.path.join(\"generated_tasks\", task + \".json\")\n",
    "\n",
    "with open(path, \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    instances = data[\"instances\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def strip_indent(text):\n",
    "    common_indent = None\n",
    "    for line in text.splitlines():\n",
    "        if len(line) == 0:\n",
    "            continue\n",
    "        line_stripped = line.lstrip()\n",
    "        if len(line_stripped) < len(line):\n",
    "            indent = line[:len(line) - len(line_stripped)]\n",
    "            if common_indent is None or len(indent) < len(common_indent):\n",
    "                common_indent = indent\n",
    "    return \"\\n\".join([l[len(common_indent) - 1:] if len(line) > 0 else \"\" for l in text.splitlines()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BEAM(dclib_decoder=\"argmax\", max_len=120)\n",
      "    \"To evaluate the logical expression '(True and False)' you split it into multiple subexpressions. The first subexpression is 'True and False' and evaluates to 'False'. To conclude, the full expression '(True and False)' evaluates to 'False'.\\n\"\n",
      "    \"To eveluate the logical expression (True) you split it into multiple subexpressions. The first one is [subexpression] and evaluates to [subeval].[control_var]\" \n",
      "    while control_var.startswith(\" The\"):\n",
      "        \" one is [subexpression] and evaluates to [subeval].[control_var]\"\n",
      "    \" the full expression (True) evaluates to [answer]\"\n",
      "FROM \n",
      "    \"openai/text-ada-001\"\n",
      "WHERE\n",
      "    control_var in [\" The next\", \"Therefore,\"] and subeval in [\"True\", \"False\"] and answer in [\"True\", \"False\"]\n",
      "Decoding using DCLib decoders.\n",
      "get return value <lmql.runtime.prompt_interpreter.HypothesisHead object at 0x7f19f3e74100>\n"
     ]
    }
   ],
   "source": [
    "from lmql.runtime.output_writer import PrintingDebuggerOutputWriter\n",
    "\n",
    "async def run(code):\n",
    "    temp_lmql_file = os.path.abspath(\"__temp.lmql\")\n",
    "    with open(temp_lmql_file, \"w\") as f:\n",
    "        f.write(code)\n",
    "\n",
    "    writer = PrintingDebuggerOutputWriter()\n",
    "    writer.clear = False\n",
    "    writer.print_output = False\n",
    "    \n",
    "    return await lmql.run_file(temp_lmql_file, output_writer=writer)\n",
    "\n",
    "few_shot_samples = [\n",
    "    \"To evaluate the logical expression '(True and False)' you split it into multiple subexpressions. The first subexpression is 'True and False' and evaluates to 'False'. To conclude, the full expression '(True and False)' evaluates to 'False'.\\\\n\",\n",
    "    \"To evaluate the logical expression '(False and not False)' you split it into multiple subexpressions. The first subexpression is 'not False' and evaluates to 'True'. The next one is 'False and True' and evaluates to 'False'. To conclude, the full expression '(False and not False)' evaluates to 'False'.\\\\n\"\n",
    "]\n",
    "\n",
    "def lstrip_spaces(l):\n",
    "    r = \"\"\n",
    "    for i, c in enumerate(l):\n",
    "        if c == \"\\t\": r += c\n",
    "        elif c == \" \": continue\n",
    "        else: \n",
    "            r += l[i:]\n",
    "            break\n",
    "    return r\n",
    "\n",
    "def indent(l):\n",
    "    return l.replace(\"\\n\", \"\\n    \")\n",
    "\n",
    "async def eval(model, decoder, shots = 0, **kwargs):\n",
    "    for i,instance in instances.items():\n",
    "        template = \"\\n\".join([lstrip_spaces(l) for l in instance[\"template\"].split(\"\\n\")]).replace(\"\\t\", \"    \")\n",
    "        condition = instance[\"condition\"]\n",
    "        target = instance[\"target\"]\n",
    "\n",
    "        if shots > 0:\n",
    "            template = \"\\n\".join(f'\"{s}\"' for s in few_shot_samples[:shots]) + \"\\n\" + template\n",
    "\n",
    "        additional_args = sorted([f\"{k}={v}\" for k,v in kwargs.items()])\n",
    "        additional_args = \", \".join(additional_args)\n",
    "        if len(additional_args) > 0:\n",
    "            additional_args = \", \" + additional_args\n",
    "\n",
    "        query = f\"\"\"\n",
    "BEAM(dclib_decoder=\"{decoder}\"{additional_args})\n",
    "    {indent(template)}\n",
    "FROM \n",
    "    \"{model}\"\n",
    "WHERE\n",
    "    {condition}\n",
    "        \"\"\".strip()\n",
    "\n",
    "        print(query)\n",
    "        return await run(query)\n",
    "        break\n",
    "\n",
    "import asyncio\n",
    "loop = asyncio.get_event_loop()\n",
    "res = await eval(\"openai/text-ada-001\", \"argmax\", shots=1, max_len=120)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### With Quotes and Few Shot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BEAM (dclib_decoder=\"argmax\")\n",
    "   \"To evaluate the logical expression '(True and False)' you split it into multiple subexpressions. The first subexpression is 'True and False' and evaluates to 'False'. To conclude, the full expression '(True and False)' evaluates to 'False'.\\n\"\n",
    "   \"To evaluate the logical expression '(False and not False)' you split it into multiple subexpressions. The first subexpression is 'not False' and evaluates to 'True'. The next one is 'False and True' and evaluates to 'False'. To conclude, the full expression '(False and not False)' evaluates to 'False'.\\n\"\n",
    "   \"To evaluate the logical expression '(not False or False and True)' you split it into multiple subexpressions. The first subexpression is '[subexpression] and evalutes to '[subeval]'.[control_var]\"\n",
    "   while control_var.startswith(\" The next\"):\n",
    "      \" one is '[subexpression] and evaluates to '[subeval].[control_var]\"\n",
    "   \"the full expression '(not False or False and True)' evaluates to[answer]\"\n",
    "FROM\n",
    "   \"openai/text-ada-001\"\n",
    "WHERE\n",
    "   control_var in [\" The next\", \" To conclude, \"] and subeval in [\" True\", \"False\"] and answer in [\"True\", \"False\"] and STOPS_AT (subexpression, \"'\") and not \"\\n\" in subexpression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BEAM(dclib_decoder=\"var\", openai_chunksize=16, num_beams=3, sample=False)\n",
    "   \"To evaluate the logical expression '{False and not False}' you split it into multiple subexpressions. The first subexpression is 'not False' and evaluates to 'True'. The next one is 'False and True' and evaluates to 'False'. To conclude, the full expression '(False and not False)' evaluates to 'False'.\\n\"\n",
    "   \"To evaluate the logical expression 'not False or False and True' you split it into multiple subexpressions. The first subexpression is '[subexpression] and evalutes to '[subeval]'.[control_var]\"\n",
    "   n_eval_steps = 0\n",
    "   while control_var.startswith(\" The next\") and n_eval_steps < 2:\n",
    "      \" one is '[subexpression] and evaluates to '[subeval]'.[control_var]\"\n",
    "      n_eval_steps += 1\n",
    "   \"the full expression 'not False or False and True' evaluates to '[answer]'.\"\n",
    "FROM\n",
    "   \"openai/text-ada-001\"\n",
    "WHERE\n",
    "   control_var in [\" The next\", \" To conclude, \"] and subeval in [\" True\", \"False\"] and answer in [\"True\", \"False\"] and STOPS_AT (subexpression, \"'\") and not \"\\n\" in subexpression"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lmql",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1a3f742538928d7fe17d54779274ecff8afc8007fdeca0464d7b2ce3865992ba"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
