{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fecb2fb5-b87f-4428-8175-e3a46fe77371",
   "metadata": {},
   "source": [
    "## Tutorial: Defining a new test-time loss and optimizing code.\n",
    "\n",
    "![TextGrad](https://github.com/vinid/data/blob/master/logo_full.png?raw=true)\n",
    "\n",
    "An autograd engine -- for textual gradients!\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zou-group/TextGrad/blob/main/examples/notebooks/Prompt-Optimization.ipynb)\n",
    "[![GitHub license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/)\n",
    "[![Arxiv](https://img.shields.io/badge/arXiv-2406.07496-B31B1B.svg)](https://arxiv.org/abs/2406.07496)\n",
    "[![Documentation Status](https://readthedocs.org/projects/textgrad/badge/?version=latest)](https://textgrad.readthedocs.io/en/latest/?badge=latest)\n",
    "[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/textgrad)](https://pypi.org/project/textgrad/)\n",
    "[![PyPI](https://img.shields.io/pypi/v/textgrad)](https://pypi.org/project/textgrad/)\n",
    "\n",
    "**Objectives:**\n",
    "\n",
    "* In this tutorial, we will do a quick walkthrough around how to define a simple test time loss in TextGrad and optimize a variable of interest.\n",
    "\n",
    "**Requirements:**\n",
    "\n",
    "* You need to have an OpenAI API key to run this tutorial. This should be set as an environment variable as OPENAI_API_KEY.\n",
    "\n",
    "We first define some utilities and a set of test cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7add4547-4278-411b-a827-79be521851f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install textgrad # you might need to restart the notebook after installing textgrad\n",
    "\n",
    "import textgrad as tg\n",
    "import random\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bc1891d-cbc0-4e46-a618-2925fe5f122c",
   "metadata": {},
   "source": [
    "### Utilities to run the code, and test cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54c93006-dc0e-4882-bbb6-a248afed404b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll use below utilities to run a python function.\n",
    "from IPython.core.interactiveshell import InteractiveShell\n",
    "\n",
    "def run_function_in_interpreter(func_code):\n",
    "    raise Exception(\"This function will run the code returned by GPT-4o. Remove this if you'd like to run the code!\")\n",
    "    interpreter = InteractiveShell.instance()\n",
    "    \n",
    "    interpreter.run_cell(func_code, store_history=False, silent=True)\n",
    "    \n",
    "    func_name = func_code.split(\"def \")[1].split(\"(\")[0].strip()\n",
    "    func = interpreter.user_ns[func_name]\n",
    "    \n",
    "    return func\n",
    "\n",
    "\n",
    "def test_longest_increasing_subsequence(fn):\n",
    "    nums = [10, 22, 9, 33, 21, 50, 41, 60]\n",
    "    assert fn(nums) == 5\n",
    "\n",
    "    nums = [7, 2, 1, 3, 8, 4, 9, 6, 5]\n",
    "    assert fn(nums) == 4\n",
    "\n",
    "    nums = [5, 4, 3, 2, 1]\n",
    "    assert fn(nums) == 1\n",
    "\n",
    "    nums = [1, 2, 3, 4, 5]\n",
    "    assert fn(nums) == 5\n",
    "\n",
    "    nums = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]\n",
    "    assert fn(nums) == 4\n",
    "\n",
    "    nums = [10, 9, 2, 5, 3, 7, 101, 18]\n",
    "    assert fn(nums) == 4\n",
    "\n",
    "    nums = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]\n",
    "    assert fn(nums) == 6\n",
    "\n",
    "    nums = [7, 7, 7, 7, 7, 7, 7]\n",
    "    assert fn(nums) == 1\n",
    "\n",
    "    nums = [20, 25, 47, 35, 56, 68, 98, 101, 212, 301, 415, 500]\n",
    "    assert fn(nums) == 11\n",
    "\n",
    "    nums = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]\n",
    "    assert fn(nums) == 1\n",
    "\n",
    "    print(\"All test cases passed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c6c90b1-24b8-4f78-8cd3-5c01c02dbf99",
   "metadata": {},
   "source": [
    "## Problem: Improving a code snippet.\n",
    "We have a simple problem, and an initial solution that does not run quite fast. We first test this solution and look at the wall clock time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88256034-6c88-4a81-ac41-7e6f14abc89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "problem_text = \"\"\"Longest Increasing Subsequence (LIS)\n",
    "\n",
    "Problem Statement:\n",
    "Given a sequence of integers, find the length of the longest subsequence that is strictly increasing. A subsequence is a sequence that can be derived from another sequence by deleting some or no elements without changing the order of the remaining elements.\n",
    "\n",
    "Input:\n",
    "The input consists of a list of integers representing the sequence.\n",
    "\n",
    "Output:\n",
    "The output should be an integer representing the length of the longest increasing subsequence.\"\"\"\n",
    "\n",
    "initial_solution = \"\"\"\n",
    "def longest_increasing_subsequence(nums):\n",
    "    n = len(nums)\n",
    "    dp = [1] * n\n",
    "    \n",
    "    for i in range(1, n):\n",
    "        for j in range(i):\n",
    "            if nums[i] > nums[j]:\n",
    "                dp[i] = max(dp[i], dp[j] + 1)\n",
    "    \n",
    "    max_length = max(dp)\n",
    "    lis = []\n",
    "    \n",
    "    for i in range(n - 1, -1, -1):\n",
    "        if dp[i] == max_length:\n",
    "            lis.append(nums[i])\n",
    "            max_length -= 1\n",
    "    \n",
    "    return len(lis[::-1])\n",
    "\"\"\"\n",
    "\n",
    "# Generate a random test case\n",
    "def generate_random_test_case(size, min_value, max_value):\n",
    "    return [random.randint(min_value, max_value) for _ in range(size)]\n",
    "\n",
    "# Test the function with a random test case\n",
    "size = 10000  # Adjust the size as needed\n",
    "min_value = 1\n",
    "max_value = 1000\n",
    "\n",
    "nums = generate_random_test_case(size, min_value, max_value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "331469ff-6106-4f7b-a761-f9d3cf91f2b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Case Size: 10000\n",
      "Longest Increasing Subsequence Length: 185\n",
      "Runtime: 4.23714 seconds\n",
      "All test cases passed!\n"
     ]
    }
   ],
   "source": [
    "longest_increasing_subsequence = run_function_in_interpreter(initial_solution)\n",
    "\n",
    "start_time = time.time()\n",
    "lis = longest_increasing_subsequence(nums)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f\"Test Case Size: {size}\")\n",
    "print(f\"Longest Increasing Subsequence Length: {lis}\")\n",
    "print(f\"Runtime: {end_time - start_time:.5f} seconds\")\n",
    "\n",
    "# Test for all test cases\n",
    "test_longest_increasing_subsequence(longest_increasing_subsequence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db85e3b2-fed6-4ca2-a7a4-13183a0a9273",
   "metadata": {},
   "source": [
    "## TextGrad to optimize code!\n",
    "Here, we will optimize the code instance. We first define the variables and instantiate the optimizer, then define our loss function, and finally update the code!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4ded5b66-3250-4a85-b045-625084a5c214",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_engine = tg.get_engine(\"gpt-4o\")\n",
    "tg.set_backward_engine(llm_engine)\n",
    "\n",
    "# Code is the variable of interest we want to optimize -- so requires_grad=True\n",
    "code = tg.Variable(value=initial_solution,\n",
    "                   requires_grad=True,\n",
    "                   role_description=\"code instance to optimize\")\n",
    "\n",
    "# We are not interested in optimizing the problem -- so requires_grad=False\n",
    "problem = tg.Variable(problem_text, \n",
    "                      requires_grad=False, \n",
    "                      role_description=\"the coding problem\")\n",
    "\n",
    "# Let TGD know to update code!\n",
    "optimizer = tg.TGD(parameters=[code])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "207fe9d3-3103-4abe-b0b7-482b99311ff1",
   "metadata": {},
   "source": [
    "## Defining a loss function with the FormattedLLMCall operation\n",
    "\n",
    "Here, we define a structured loss function. In particular, we want the following format:\n",
    "\n",
    "```\n",
    "{instruction}\n",
    "Problem: {problem}\n",
    "Current Code: {code}\n",
    "```\n",
    "\n",
    "`FormattedLLMCall` helps us define loss functions like this, while keeping track of the children variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "40b86157-ed34-4a32-9461-dde35755ef72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The system prompt that will guide the behavior of the loss function.\n",
    "loss_system_prompt = \"You are a smart language model that evaluates code snippets. You do not solve problems or propose new code snippets, only evaluate existing solutions critically and give very concise feedback.\"\n",
    "loss_system_prompt = tg.Variable(loss_system_prompt, requires_grad=False, role_description=\"system prompt to the loss function\")\n",
    "\n",
    "# The instruction that will be the prefix\n",
    "instruction = \"\"\"Think about the problem and the code snippet. Does the code solve the problem? What is the runtime complexity?\"\"\"\n",
    "\n",
    "# The format string and setting up the call\n",
    "format_string = \"{instruction}\\nProblem: {{problem}}\\nCurrent Code: {{code}}\"\n",
    "format_string = format_string.format(instruction=instruction)\n",
    "\n",
    "fields = {\"problem\": None, \"code\": None}\n",
    "formatted_llm_call = tg.autograd.FormattedLLMCall(engine=llm_engine,\n",
    "                                                  format_string=format_string,\n",
    "                                                  fields=fields,\n",
    "                                                  system_prompt=loss_system_prompt)\n",
    "\n",
    "# Finally, the loss function\n",
    "def loss_fn(problem: tg.Variable, code: tg.Variable) -> tg.Variable:\n",
    "    inputs = {\"problem\": problem, \"code\": code}\n",
    "    \n",
    "    return formatted_llm_call(inputs=inputs,\n",
    "                              response_role_description=f\"evaluation of the {code.get_role_description()}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "624c9847-633b-4e9c-9722-e433e25593b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Yes, the code correctly solves the problem of finding the length of the longest increasing subsequence (LIS).\n",
      "\n",
      "**Runtime Complexity:**\n",
      "- The outer loop runs `n` times.\n",
      "- The inner loop runs up to `n` times for each iteration of the outer loop.\n",
      "- Therefore, the time complexity is \\(O(n^2)\\).\n",
      "\n",
      "**Space Complexity:**\n",
      "- The space complexity is \\(O(n)\\) due to the `dp` array and the `lis` list.\n",
      "\n",
      "**Additional Notes:**\n",
      "- The code correctly constructs the `dp` array to store the length of the LIS ending at each index.\n",
      "- The reconstruction of the LIS sequence is not necessary for finding the length, but it is correctly implemented.\n",
      "- The final return statement could be simplified to `max(dp)` instead of reconstructing the sequence and then taking its length.\n"
     ]
    }
   ],
   "source": [
    "# Let's do the forward pass for the loss function.\n",
    "loss = loss_fn(problem, code)\n",
    "print(loss.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4afc96ed-d37d-4301-a942-b6c92d9b1aac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"573pt\" height=\"354pt\"\n",
       " viewBox=\"0.00 0.00 572.50 354.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 350)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"lightgrey\" stroke=\"transparent\" points=\"-4,4 -4,-350 568.5,-350 568.5,4 -4,4\"/>\n",
       "<!-- 140403494540304 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>140403494540304</title>\n",
       "<polygon fill=\"lavender\" stroke=\"black\" points=\"374,-206 198,-206 198,0 374,0 374,-206\"/>\n",
       "<text text-anchor=\"start\" x=\"213\" y=\"-193.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Role: </text>\n",
       "<text text-anchor=\"start\" x=\"237\" y=\"-193.6\" font-family=\"Arial\" font-size=\"8.00\"> Evaluation of the code instance to</text>\n",
       "<text text-anchor=\"start\" x=\"271\" y=\"-185.6\" font-family=\"Arial\" font-size=\"8.00\">optimize</text>\n",
       "<text text-anchor=\"start\" x=\"211.5\" y=\"-177.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Value: </text>\n",
       "<text text-anchor=\"start\" x=\"239.5\" y=\"-177.6\" font-family=\"Arial\" font-size=\"8.00\"> Yes, the code correctly solves the</text>\n",
       "<text text-anchor=\"start\" x=\"224\" y=\"-169.6\" font-family=\"Arial\" font-size=\"8.00\">problem of finding the length of the</text>\n",
       "<text text-anchor=\"start\" x=\"218\" y=\"-161.6\" font-family=\"Arial\" font-size=\"8.00\">longest increasing subsequence (LIS).</text>\n",
       "<text text-anchor=\"start\" x=\"213\" y=\"-153.6\" font-family=\"Arial\" font-size=\"8.00\">**Runtime Complexity:** &#45; The outer loop</text>\n",
       "<text text-anchor=\"start\" x=\"215.5\" y=\"-145.6\" font-family=\"Arial\" font-size=\"8.00\">runs `n` times. &#45; The inner loop runs up</text>\n",
       "<text text-anchor=\"start\" x=\"223\" y=\"-137.6\" font-family=\"Arial\" font-size=\"8.00\">to `n` times for each iteration of the</text>\n",
       "<text text-anchor=\"start\" x=\"228.5\" y=\"-129.6\" font-family=\"Arial\" font-size=\"8.00\">outer loop. &#45; Therefore, the time</text>\n",
       "<text text-anchor=\"start\" x=\"228.5\" y=\"-121.6\" font-family=\"Arial\" font-size=\"8.00\">complexity is \\(O(n^2)\\). **Space</text>\n",
       "<text text-anchor=\"start\" x=\"216.5\" y=\"-113.6\" font-family=\"Arial\" font-size=\"8.00\">Complexity:** &#45; The space complexity is</text>\n",
       "<text text-anchor=\"start\" x=\"219.5\" y=\"-105.6\" font-family=\"Arial\" font-size=\"8.00\">\\(O(n)\\) due to the `dp` array and the</text>\n",
       "<text text-anchor=\"start\" x=\"224.5\" y=\"-97.6\" font-family=\"Arial\" font-size=\"8.00\">`lis` list. **Additional Notes:** &#45; The</text>\n",
       "<text text-anchor=\"start\" x=\"216\" y=\"-89.6\" font-family=\"Arial\" font-size=\"8.00\">code correctly constructs the `dp` array</text>\n",
       "<text text-anchor=\"start\" x=\"217.5\" y=\"-81.6\" font-family=\"Arial\" font-size=\"8.00\">to store the length of the LIS ending at</text>\n",
       "<text text-anchor=\"start\" x=\"217\" y=\"-73.6\" font-family=\"Arial\" font-size=\"8.00\">each index. &#45; The reconstruction of the</text>\n",
       "<text text-anchor=\"start\" x=\"225.5\" y=\"-65.6\" font-family=\"Arial\" font-size=\"8.00\">LIS sequence is not necessary for</text>\n",
       "<text text-anchor=\"start\" x=\"224.5\" y=\"-57.6\" font-family=\"Arial\" font-size=\"8.00\">finding the length, but it is correctly</text>\n",
       "<text text-anchor=\"start\" x=\"230.5\" y=\"-49.6\" font-family=\"Arial\" font-size=\"8.00\">implemented. &#45; The final return</text>\n",
       "<text text-anchor=\"start\" x=\"230\" y=\"-41.6\" font-family=\"Arial\" font-size=\"8.00\">statement could be simplified to</text>\n",
       "<text text-anchor=\"start\" x=\"216.5\" y=\"-33.6\" font-family=\"Arial\" font-size=\"8.00\">`max(dp)` instead of reconstructing the</text>\n",
       "<text text-anchor=\"start\" x=\"222\" y=\"-25.6\" font-family=\"Arial\" font-size=\"8.00\">sequence and then taking its length.</text>\n",
       "<text text-anchor=\"start\" x=\"266\" y=\"-17.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Grad Fn: </text>\n",
       "<text text-anchor=\"start\" x=\"303\" y=\"-17.6\" font-family=\"Arial\" font-size=\"8.00\"> </text>\n",
       "<text text-anchor=\"start\" x=\"205\" y=\"-9.6\" font-family=\"Arial\" font-size=\"8.00\">textgrad.autograd.llm_ops.LLMCall.backward</text>\n",
       "</g>\n",
       "<!-- 140403495435408 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>140403495435408</title>\n",
       "<polygon fill=\"lavender\" stroke=\"black\" points=\"184,-346 0,-346 0,-220 184,-220 184,-346\"/>\n",
       "<text text-anchor=\"start\" x=\"43\" y=\"-333.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Role: </text>\n",
       "<text text-anchor=\"start\" x=\"67\" y=\"-333.6\" font-family=\"Arial\" font-size=\"8.00\"> The coding problem</text>\n",
       "<text text-anchor=\"start\" x=\"7\" y=\"-325.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Value: </text>\n",
       "<text text-anchor=\"start\" x=\"35\" y=\"-325.6\" font-family=\"Arial\" font-size=\"8.00\"> Longest Increasing Subsequence (LIS)</text>\n",
       "<text text-anchor=\"start\" x=\"18.5\" y=\"-317.6\" font-family=\"Arial\" font-size=\"8.00\">Problem Statement: Given a sequence of</text>\n",
       "<text text-anchor=\"start\" x=\"25\" y=\"-309.6\" font-family=\"Arial\" font-size=\"8.00\">integers, find the length of the longest</text>\n",
       "<text text-anchor=\"start\" x=\"24\" y=\"-301.6\" font-family=\"Arial\" font-size=\"8.00\">subsequence that is strictly increasing.</text>\n",
       "<text text-anchor=\"start\" x=\"18\" y=\"-293.6\" font-family=\"Arial\" font-size=\"8.00\">A subsequence is a sequence that can be</text>\n",
       "<text text-anchor=\"start\" x=\"30.5\" y=\"-285.6\" font-family=\"Arial\" font-size=\"8.00\">derived from another sequence by</text>\n",
       "<text text-anchor=\"start\" x=\"25.5\" y=\"-277.6\" font-family=\"Arial\" font-size=\"8.00\">deleting some or no elements without</text>\n",
       "<text text-anchor=\"start\" x=\"28\" y=\"-269.6\" font-family=\"Arial\" font-size=\"8.00\">changing the order of the remaining</text>\n",
       "<text text-anchor=\"start\" x=\"22.5\" y=\"-261.6\" font-family=\"Arial\" font-size=\"8.00\">elements. Input: The input consists of a</text>\n",
       "<text text-anchor=\"start\" x=\"37\" y=\"-253.6\" font-family=\"Arial\" font-size=\"8.00\">list of integers representing the</text>\n",
       "<text text-anchor=\"start\" x=\"20\" y=\"-245.6\" font-family=\"Arial\" font-size=\"8.00\">sequence. Output: The output should be</text>\n",
       "<text text-anchor=\"start\" x=\"27\" y=\"-237.6\" font-family=\"Arial\" font-size=\"8.00\">an integer representing the length of</text>\n",
       "<text text-anchor=\"start\" x=\"27\" y=\"-229.6\" font-family=\"Arial\" font-size=\"8.00\">the longest increasing subsequence.</text>\n",
       "</g>\n",
       "<!-- 140403495435408&#45;&gt;140403494540304 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>140403495435408&#45;&gt;140403494540304</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M159.58,-219.99C169.52,-210.87 179.92,-201.33 190.3,-191.81\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"192.85,-194.22 197.85,-184.88 188.11,-189.06 192.85,-194.22\"/>\n",
       "</g>\n",
       "<!-- 140403495364752 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>140403495364752</title>\n",
       "<polygon fill=\"lavender\" stroke=\"black\" points=\"369.5,-330 202.5,-330 202.5,-236 369.5,-236 369.5,-330\"/>\n",
       "<text text-anchor=\"start\" x=\"227\" y=\"-317.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Role: </text>\n",
       "<text text-anchor=\"start\" x=\"251\" y=\"-317.6\" font-family=\"Arial\" font-size=\"8.00\"> Code instance to optimize</text>\n",
       "<text text-anchor=\"start\" x=\"265\" y=\"-309.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Value: </text>\n",
       "<text text-anchor=\"start\" x=\"293\" y=\"-309.6\" font-family=\"Arial\" font-size=\"8.00\"> def</text>\n",
       "<text text-anchor=\"start\" x=\"209.5\" y=\"-301.6\" font-family=\"Arial\" font-size=\"8.00\">longest_increasing_subsequence(nums): n</text>\n",
       "<text text-anchor=\"start\" x=\"231.5\" y=\"-293.6\" font-family=\"Arial\" font-size=\"8.00\">= len(nums) dp = [1] * n for i in</text>\n",
       "<text text-anchor=\"start\" x=\"233\" y=\"-285.6\" font-family=\"Arial\" font-size=\"8.00\">range(1, n): for j in range(i): if</text>\n",
       "<text text-anchor=\"start\" x=\"224\" y=\"-277.6\" font-family=\"Arial\" font-size=\"8.00\">nums[i] &gt; nums[j]: dp[i] = max(dp[i],</text>\n",
       "<text text-anchor=\"start\" x=\"217.5\" y=\"-269.6\" font-family=\"Arial\" font-size=\"8.00\">dp[j] + 1) max_length = max(dp) lis = []</text>\n",
       "<text text-anchor=\"start\" x=\"227\" y=\"-261.6\" font-family=\"Arial\" font-size=\"8.00\">for i in range(n &#45; 1, &#45;1, &#45;1): if dp[i]</text>\n",
       "<text text-anchor=\"start\" x=\"223.5\" y=\"-253.6\" font-family=\"Arial\" font-size=\"8.00\">== max_length: lis.append(nums[i])</text>\n",
       "<text text-anchor=\"start\" x=\"224\" y=\"-245.6\" font-family=\"Arial\" font-size=\"8.00\">max_length &#45;= 1 return len(lis[::&#45;1])</text>\n",
       "</g>\n",
       "<!-- 140403495364752&#45;&gt;140403494540304 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>140403495364752&#45;&gt;140403494540304</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M286,-235.7C286,-229.47 286,-222.87 286,-216.09\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"289.5,-216.03 286,-206.03 282.5,-216.03 289.5,-216.03\"/>\n",
       "</g>\n",
       "<!-- 140403123354192 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>140403123354192</title>\n",
       "<polygon fill=\"lavender\" stroke=\"black\" points=\"564.5,-318 387.5,-318 387.5,-248 564.5,-248 564.5,-318\"/>\n",
       "<text text-anchor=\"start\" x=\"401.5\" y=\"-305.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Role: </text>\n",
       "<text text-anchor=\"start\" x=\"425.5\" y=\"-305.6\" font-family=\"Arial\" font-size=\"8.00\"> System prompt to the loss function</text>\n",
       "<text text-anchor=\"start\" x=\"394.5\" y=\"-297.6\" font-family=\"Arial\" font-weight=\"bold\" font-size=\"8.00\" fill=\"darkblue\">Value: </text>\n",
       "<text text-anchor=\"start\" x=\"422.5\" y=\"-297.6\" font-family=\"Arial\" font-size=\"8.00\"> You are a smart language model that</text>\n",
       "<text text-anchor=\"start\" x=\"412\" y=\"-289.6\" font-family=\"Arial\" font-size=\"8.00\">evaluates code snippets. You do not</text>\n",
       "<text text-anchor=\"start\" x=\"410\" y=\"-281.6\" font-family=\"Arial\" font-size=\"8.00\">solve problems or propose new code</text>\n",
       "<text text-anchor=\"start\" x=\"421.5\" y=\"-273.6\" font-family=\"Arial\" font-size=\"8.00\">snippets, only evaluate existing</text>\n",
       "<text text-anchor=\"start\" x=\"421\" y=\"-265.6\" font-family=\"Arial\" font-size=\"8.00\">solutions critically and give very</text>\n",
       "<text text-anchor=\"start\" x=\"444\" y=\"-257.6\" font-family=\"Arial\" font-size=\"8.00\">concise feedback.</text>\n",
       "</g>\n",
       "<!-- 140403123354192&#45;&gt;140403494540304 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>140403123354192&#45;&gt;140403494540304</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M439.37,-247.68C422.75,-232.11 402.25,-212.91 381.7,-193.66\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"383.82,-190.85 374.13,-186.57 379.04,-195.96 383.82,-190.85\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fb2264b5e50>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's visualize our computation graph.\n",
    "loss.generate_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d63b1bbd-c638-4e16-8ffe-a4a5ad45b3b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{Variable(value=1. **Simplify the Return Statement:**\n",
      "   - The final return statement reconstructs the LIS sequence and then takes its length, which is unnecessary for finding the length of the LIS. Simplifying this to `return max(dp)` would reduce the complexity and improve readability.\n",
      "\n",
      "2. **Optimize Time Complexity:**\n",
      "   - The current time complexity is \\(O(n^2)\\). This can be improved to \\(O(n \\log n)\\) using a more efficient algorithm, such as binary search with a dynamic array (or a segment tree). This would significantly enhance performance for larger input sizes.\n",
      "\n",
      "3. **Remove Unnecessary List Construction:**\n",
      "   - The `lis` list and its reconstruction are not needed for the problem's requirement of finding the length of the LIS. Removing this part will save space and reduce unnecessary operations.\n",
      "\n",
      "4. **Improve Variable Naming:**\n",
      "   - Use more descriptive variable names to improve code readability. For example, `dp` could be renamed to `lis_lengths` to clearly indicate its purpose.\n",
      "\n",
      "5. **Add Edge Case Handling:**\n",
      "   - Consider adding a check for edge cases, such as an empty input list. This would make the function more robust.\n",
      "\n",
      "6. **Documentation and Comments:**\n",
      "   - Adding comments and documentation to explain the logic and steps of the algorithm would make the code more understandable for future readers.\n",
      "\n",
      "7. **Use Python Built-in Functions:**\n",
      "   - Utilize Python's built-in functions and libraries where applicable to make the code more concise and efficient. For example, using `max()` directly on the `dp` array instead of reconstructing the sequence.\n",
      "\n",
      "By addressing these points, the code can be optimized for both performance and readability, aligning with the objective of improving the given metric., role=feedback to code instance to optimize, grads=)}\n"
     ]
    }
   ],
   "source": [
    "# Let's look at the gradients!\n",
    "loss.backward()\n",
    "print(code.gradients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "caad85e4-b449-495b-a845-3b372e00dfd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's update the code\n",
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "77602cbf-7cd6-4f59-b69e-f7fe7d351731",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Longest Increasing Subsequence Length: 185\n",
      "Runtime: 4.28112 seconds\n",
      "All test cases passed!\n"
     ]
    }
   ],
   "source": [
    "# Hopefully, we should get much better runtime!\n",
    "longest_increasing_subsequence = run_function_in_interpreter(code.value)\n",
    "\n",
    "start_time = time.time()\n",
    "lis = longest_increasing_subsequence(nums)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f\"Longest Increasing Subsequence Length: {lis}\")\n",
    "print(f\"Runtime: {end_time - start_time:.5f} seconds\")\n",
    "\n",
    "test_longest_increasing_subsequence(longest_increasing_subsequence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ae584c91-5dc3-4b37-9b88-4a34f7b6f62a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's do one more iteration\n",
    "optimizer.zero_grad()\n",
    "loss = loss_fn(problem, code)\n",
    "loss.backward()\n",
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aeb54d4e-0d9f-4768-9b8b-06ea97657751",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Longest Increasing Subsequence Length: 185\n",
      "Runtime: 0.00383 seconds\n",
      "All test cases passed!\n"
     ]
    }
   ],
   "source": [
    "longest_increasing_subsequence = run_function_in_interpreter(code.value)\n",
    "\n",
    "start_time = time.time()\n",
    "lis = longest_increasing_subsequence(nums)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f\"Longest Increasing Subsequence Length: {lis}\")\n",
    "print(f\"Runtime: {end_time - start_time:.5f} seconds\")\n",
    "\n",
    "test_longest_increasing_subsequence(longest_increasing_subsequence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54efbf37-4c7a-486c-8112-966e6b34e23c",
   "metadata": {},
   "source": [
    "## Optimized code, much faster!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f6115288-cb26-4008-9882-a5af60e49c3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import bisect\n",
      "\n",
      "def longest_increasing_subsequence(nums):\n",
      "    if not nums:\n",
      "        return 0\n",
      "    \n",
      "    tails = []\n",
      "    \n",
      "    for num in nums:\n",
      "        pos = bisect.bisect_left(tails, num)\n",
      "        if pos == len(tails):\n",
      "            tails.append(num)\n",
      "        else:\n",
      "            tails[pos] = num\n",
      "    \n",
      "    return len(tails)\n"
     ]
    }
   ],
   "source": [
    "print(code.value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "202c93c1-e7b6-4b89-9851-8d5ee07527d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
