{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### import and network test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3.1+cu118\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import ast\n",
    "import torch\n",
    "print(torch.__version__)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:\n",
      "Compose a poem that explains the concept of recursion in programming.\n",
      "output:\n",
      "In the realm of code, where logic reigns supreme,\n",
      "There lies a concept, not always what it seems.\n",
      "Recursion it's called, a powerful tool,\n",
      "A loop that spins endlessly, breaking all rules.\n",
      "\n",
      "Like mirrors reflecting, each other's reflection,\n",
      "A function calls itself, without hesitation.\n",
      "It dives deeper and deeper, into the abyss,\n",
      "Until a base case finally brings it bliss.\n",
      "\n",
      "An elegant solution, to problems complex,\n",
      "Recursion unravels, with finesse and finesse.\n",
      "Through layers of repetition, it paves a way,\n",
      "Repeating its steps, until it can say:\n",
      "\n",
      "\"I've reached the end, my journey complete,\n",
      "Recursion, my friend, you've never tasted defeat.\"\n",
      "So let us embrace, this recursive dance,\n",
      "And marvel at the beauty, of its infinite expanse.\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "from openai import OpenAI\n",
    "from openai import AsyncOpenAI\n",
    "\n",
    "client = OpenAI(api_key='sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg')\n",
    "\n",
    "client_async = AsyncOpenAI(api_key='sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg')\n",
    "\n",
    "def LLM_response(prompt, model, max_tokens=100):\n",
    "    if type(prompt) == str:\n",
    "      prompt = \" \".join(prompt)\n",
    "      input_messages = [\n",
    "          {\"role\": \"user\", \"content\": prompt}\n",
    "      ]\n",
    "    elif type(prompt) == list:\n",
    "      input_messages = prompt\n",
    "    else:\n",
    "      raise ValueError(\"prompt must be a string or a list of strings\")\n",
    "    \n",
    "    completion = client.chat.completions.create(\n",
    "       model=model,\n",
    "       messages=input_messages\n",
    "    )\n",
    "    return completion.choices[0].message.content\n",
    "\n",
    "messages_test_1=[\n",
    "  {\"role\": \"system\", \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},\n",
    "  {\"role\": \"user\", \"content\": \"Compose a poem that explains the concept of recursion in programming.\"}\n",
    "]\n",
    "messages_test_2 = \"Compose a poem that explains the concept of recursion in programming.\"\n",
    "\n",
    "print(\"input:\")\n",
    "print(messages_test_2)\n",
    "print(\"output:\")\n",
    "print(LLM_response(messages_test_2, \"gpt-3.5-turbo\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prompt test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### LLM-TM prompt test (task decompose version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:\n",
      "You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.\n",
      "\n",
      "Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single subtask.\n",
      "\n",
      "For each decomposed subtask, your output must strictly follow the format below:\n",
      "\n",
      "{\n",
      "  \"step_1\": {\n",
      "    \"step_task_description\": str,\n",
      "    \"input_format\": [[type, shape or null]],\n",
      "    \"output_format\": [[type, shape or null]],\n",
      "    \"test_case_generation_advise\": [str]\n",
      "  },\n",
      "  \"step_2\": {\n",
      "    \"step_task_description\": str,\n",
      "    \"input_format\": [[type, shape or null]],\n",
      "    \"output_format\": [[type, shape or null]],\n",
      "    \"test_case_generation_advise\": [str]\n",
      "  },\n",
      "  ...\n",
      "}\n",
      "\n",
      "Here are additional detailed explanations of each field:\n",
      "\n",
      "- **step_X**: The key represents the subtask name, it should be replaced by the actual name of the subtask (e.g., \"merge_arrays\", \"calculate_median\").\n",
      "- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).\n",
      "- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:\n",
      "  - The first element indicates the data type (e.g., list, dict, NumPy array, torch.Tensor).\n",
      "  - The second element indicates the fixed shape if applicable; otherwise, it is null.\n",
      "- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`.\n",
      "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
      "\n",
      "Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:\n",
      "\n",
      "\"Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
      "\n",
      "The overall run time complexity should be O(log (m+n)).\n",
      "\n",
      "Example 1:\n",
      "\n",
      "Input: nums1 = [1,3], nums2 = [2]\n",
      "Output: 2.00000\n",
      "Explanation: merged array = [1,2,3] and median is 2.\n",
      "Example 2:\n",
      "\n",
      "Input: nums1 = [1,2], nums2 = [3,4]\n",
      "Output: 2.50000\n",
      "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
      " \n",
      "\n",
      "Constraints:\n",
      "\n",
      "nums1.length == m\n",
      "nums2.length == n\n",
      "0 <= m <= 1000\n",
      "0 <= n <= 1000\n",
      "1 <= m + n <= 2000\n",
      "-106 <= nums1[i], nums2[i] <= 106\"\n",
      "\n",
      "Please provide your structured decomposition according to the instructions above.\n",
      "\n",
      "output:\n",
      "{\n",
      "    \"merge_arrays\": {\n",
      "        \"step_task_description\": \"Merge two sorted arrays nums1 and nums2 and return the median of the merged array.\",\n",
      "        \"input_format\": [[\"list\", null], [\"list\", null]],\n",
      "        \"output_format\": [[\"float\", null]],\n",
      "        \"test_case_generation_advise\": [\"Test edge cases such as empty arrays, arrays with equal length, arrays with different lengths, etc.\"]\n",
      "    }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# task decompose prompt test\n",
    "\n",
    "TASK_DESCRIPTION = \"\"\"Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
    "\n",
    "The overall run time complexity should be O(log (m+n)).\n",
    "\n",
    "Example 1:\n",
    "\n",
    "Input: nums1 = [1,3], nums2 = [2]\n",
    "Output: 2.00000\n",
    "Explanation: merged array = [1,2,3] and median is 2.\n",
    "Example 2:\n",
    "\n",
    "Input: nums1 = [1,2], nums2 = [3,4]\n",
    "Output: 2.50000\n",
    "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
    " \n",
    "\n",
    "Constraints:\n",
    "\n",
    "nums1.length == m\n",
    "nums2.length == n\n",
    "0 <= m <= 1000\n",
    "0 <= n <= 1000\n",
    "1 <= m + n <= 2000\n",
    "-106 <= nums1[i], nums2[i] <= 106\"\"\"\n",
    "\n",
    "task_decompose_system_prompt = \"You are an expert code architect specializing in decomposing complex programming tasks into modular components. Given a code generation task description, your role is to systematically break it down into atomic sub-tasks, each representing a distinct function/class that contributes to solving the global task. Ensure each component is self-contained, testable, and adheres to strict input/output specifications.\"\n",
    "\n",
    "task_decompose_prompt = \"\"\"You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.\n",
    "\n",
    "Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single subtask.\n",
    "\n",
    "For each decomposed subtask, your output must strictly follow the format below:\n",
    "\n",
    "{\n",
    "  \"step_1\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[type, shape or null]],\n",
    "    \"output_format\": [[type, shape or null]],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  \"step_2\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[type, shape or null]],\n",
    "    \"output_format\": [[type, shape or null]],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  ...\n",
    "}\n",
    "\n",
    "Here are additional detailed explanations of each field:\n",
    "\n",
    "- **step_X**: The key represents the subtask name, it should be replaced by the actual name of the subtask (e.g., \"merge_arrays\", \"calculate_median\").\n",
    "- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).\n",
    "- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:\n",
    "  - The first element indicates the data type (e.g., list, dict, NumPy array, torch.Tensor).\n",
    "  - The second element indicates the fixed shape if applicable; otherwise, it is null.\n",
    "- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:\n",
    "\n",
    "\"{{TASK_DESCRIPTION}}\"\n",
    "\n",
    "Please provide your structured decomposition according to the instructions above.\n",
    "\"\"\"\n",
    "\n",
    "task_decompose_prompt = task_decompose_prompt.replace(\"{{TASK_DESCRIPTION}}\", TASK_DESCRIPTION)\n",
    "LLM_output = LLM_response(task_decompose_prompt, \"gpt-3.5-turbo\")\n",
    "print(\"input:\")\n",
    "print(task_decompose_prompt)\n",
    "print(\"output:\")\n",
    "print(LLM_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"merge_arrays\": {\n",
      "        \"step_task_description\": \"Merge two sorted arrays nums1 and nums2 and return the median of the merged array.\",\n",
      "        \"input_format\": [[\"list\", null], [\"list\", null]],\n",
      "        \"output_format\": [[\"float\", null]],\n",
      "        \"test_case_generation_advise\": [\"Test edge cases such as empty arrays, arrays with equal length, arrays with different lengths, etc.\"]\n",
      "    }\n",
      "}\n",
      "{'merge_arrays': {'step_task_description': 'Merge two sorted arrays nums1 and nums2 and return the median of the merged array.', 'input_format': [['list', None], ['list', None]], 'output_format': [['float', None]], 'test_case_generation_advise': ['Test edge cases such as empty arrays, arrays with equal length, arrays with different lengths, etc.']}}\n"
     ]
    }
   ],
   "source": [
    "# task decompose extract test\n",
    "\n",
    "import json\n",
    "\n",
    "def extract_json_data(json_string):\n",
    "    try:\n",
    "        data = json.loads(json_string)\n",
    "        return data\n",
    "    except json.JSONDecodeError as e:\n",
    "        print(f\"JSON解析错误：{e}\")\n",
    "        return None\n",
    "    \n",
    "# test_temp = \"\"\"{\n",
    "#     \"step_1\": {\n",
    "#         \"step_task_description\": \"Merge two sorted arrays nums1 and nums2\",\n",
    "#         \"input_format\": [[\"List\", null], [\"List\", null]],\n",
    "#         \"output_format\": [[\"List\", null]],\n",
    "#         \"test_case_generation_advise\": [\"Include test cases where arrays have different lengths and values\"]\n",
    "#     },\n",
    "#     \"step_2\": {\n",
    "#         \"step_task_description\": \"Calculate the median of the merged array\",\n",
    "#         \"input_format\": [[\"List\", null]],\n",
    "#         \"output_format\": [[\"float\", null]],\n",
    "#         \"test_case_generation_advise\": [\"Consider edge cases like empty arrays or arrays with even length\"]\n",
    "#     }\n",
    "# }\"\"\"\n",
    "\n",
    "print(LLM_output)\n",
    "extract_test = extract_json_data(LLM_output)\n",
    "print(extract_test)\n",
    "# print(extract_test[\"step_1\"][\"test_case_generation_advise\"])  # 输出对应字段"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### LLM-TM prompt test (with overall plan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:\n",
      "You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks and then give a detailed overall plan based on your defined subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.\n",
      "\n",
      "Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single component.\n",
      "\n",
      "Your output must strictly follow the format below:\n",
      "\n",
      "<components>\n",
      "{\n",
      "  \"component_1\": {\n",
      "    \"step_task_description\": str,\n",
      "    \"input_format\": [[type, shape or null]],\n",
      "    \"output_format\": [[type, shape or null]],\n",
      "    \"work_flow\": [str],\n",
      "    \"test_case_generation_advise\": [str]\n",
      "  },\n",
      "  \"component_2\": {\n",
      "    \"step_task_description\": str,\n",
      "    \"input_format\": [[type, shape or null]],\n",
      "    \"output_format\": [[type, shape or null]],\n",
      "    \"work_flow\": [str],\n",
      "    \"test_case_generation_advise\": [str]\n",
      "  },\n",
      "  ...\n",
      "}\n",
      "<End>\n",
      "\n",
      "<overall_plan>\n",
      "{\n",
      "  \"input_format\": [[type, shape or null]],\n",
      "  \"output_format\": [[type, shape or null]],\n",
      "  \"components\": [str],\n",
      "  \"plan\": [str],\n",
      "  \"test_case_generation_advise\": [str]\n",
      "}\n",
      "<End>\n",
      "\n",
      "Here are additional detailed explanations of each field:\n",
      "\n",
      "For <components>:\n",
      "- **component_X**: The key represents the subtask name, it should be replaced by the actual class/function name of the component (e.g., \"merge_arrays\", \"calculate_median\").\n",
      "- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).\n",
      "- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:\n",
      "  - The first element indicates the data type (e.g., list, dict, NumPy array, torch.Tensor).\n",
      "  - The second element indicates the fixed shape if applicable; otherwise, it is null.\n",
      "- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`.\n",
      "- **work_flow**: Provide a detailed step-by-step plan that outlines the workflow of how the component functions to achieve the subtask.\n",
      "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
      "\n",
      "For <overall_plan>:\n",
      "- **input_format**: Describe the format of the input arguments required for the overall task. It follows the same structure as `input_format` in the component section.\n",
      "- **output_format**: Describe the format of the output arguments generated by the overall task. It follows the same structure as `output_format` in the component section.\n",
      "- **components**: List the components in the order.\n",
      "- **plan**: Provide a detailed step-by-step plan that outlines the workflow of how the components interact with each other to achieve the overall task. This should be a high-level description of the process.\n",
      "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases for the overall task, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
      "\n",
      "Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:\n",
      "\n",
      "\"Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
      "\n",
      "The overall run time complexity should be O(log (m+n)).\n",
      "\n",
      "Example 1:\n",
      "\n",
      "Input: nums1 = [1,3], nums2 = [2]\n",
      "Output: 2.00000\n",
      "Explanation: merged array = [1,2,3] and median is 2.\n",
      "Example 2:\n",
      "\n",
      "Input: nums1 = [1,2], nums2 = [3,4]\n",
      "Output: 2.50000\n",
      "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
      " \n",
      "\n",
      "Constraints:\n",
      "\n",
      "nums1.length == m\n",
      "nums2.length == n\n",
      "0 <= m <= 1000\n",
      "0 <= n <= 1000\n",
      "1 <= m + n <= 2000\n",
      "-106 <= nums1[i], nums2[i] <= 106\"\n",
      "\n",
      "Use <> to indicate both start and end of the component part and the overall plan. Ensure that the components and the overall plan are clearly separated.\n",
      "\n",
      "Please provide your structured decomposition according to the instructions above.\n",
      "\n",
      "output:\n",
      "<components>\n",
      "{\n",
      "    \"merge_arrays\": {\n",
      "        \"step_task_description\": \"Merge two sorted arrays and find the median\",\n",
      "        \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Merge the two input arrays into a single sorted array\",\n",
      "                      \"2. Find the median of the merged array\"],\n",
      "        \"test_case_generation_advise\": [\"Generate test cases where the arrays have different sizes\",\n",
      "                                        \"Include edge cases where the median is at the beginning, middle, or end of the merged array\"]\n",
      "    },\n",
      "    \"calculate_median\": {\n",
      "        \"step_task_description\": \"Calculate the median of a sorted array\",\n",
      "        \"input_format\": [[\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Determine the size of the array\",\n",
      "                      \"2. Check if the size is even or odd\",\n",
      "                      \"3. Calculate the median based on the size\"],\n",
      "        \"test_case_generation_advise\": [\"Include test cases for arrays with even and odd lengths\",\n",
      "                                        \"Cover edge cases where the median is a single value or an average of two values\"]\n",
      "    }\n",
      "}\n",
      "<End>\n",
      "\n",
      "<overall_plan>\n",
      "{\n",
      "    \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "    \"output_format\": [[\"Float\", null]],\n",
      "    \"components\": [\"merge_arrays\", \"calculate_median\"],\n",
      "    \"plan\": [\"1. Merge the two arrays using component 'merge_arrays'\",\n",
      "             \"2. Calculate the median of the merged array using component 'calculate_median'\"],\n",
      "    \"test_case_generation_advise\": [\"Combine test cases for merging and calculating median\",\n",
      "                                    \"Ensure test coverage for different array sizes and median positions\"]\n",
      "}\n",
      "<End>\n"
     ]
    }
   ],
   "source": [
    "# task decompose prompt test\n",
    "\n",
    "TASK_DESCRIPTION = \"\"\"Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
    "\n",
    "The overall run time complexity should be O(log (m+n)).\n",
    "\n",
    "Example 1:\n",
    "\n",
    "Input: nums1 = [1,3], nums2 = [2]\n",
    "Output: 2.00000\n",
    "Explanation: merged array = [1,2,3] and median is 2.\n",
    "Example 2:\n",
    "\n",
    "Input: nums1 = [1,2], nums2 = [3,4]\n",
    "Output: 2.50000\n",
    "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
    " \n",
    "\n",
    "Constraints:\n",
    "\n",
    "nums1.length == m\n",
    "nums2.length == n\n",
    "0 <= m <= 1000\n",
    "0 <= n <= 1000\n",
    "1 <= m + n <= 2000\n",
    "-106 <= nums1[i], nums2[i] <= 106\"\"\"\n",
    "\n",
    "task_decompose_prompt = \"\"\"You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks and then give a detailed overall plan based on your defined subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.\n",
    "\n",
    "Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single component.\n",
    "\n",
    "Your output must strictly follow the format below:\n",
    "\n",
    "<components>\n",
    "{\n",
    "  \"component_1\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[type, shape or null]],\n",
    "    \"output_format\": [[type, shape or null]],\n",
    "    \"work_flow\": [str],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  \"component_2\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[type, shape or null]],\n",
    "    \"output_format\": [[type, shape or null]],\n",
    "    \"work_flow\": [str],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  ...\n",
    "}\n",
    "<End>\n",
    "\n",
    "<overall_plan>\n",
    "{\n",
    "  \"input_format\": [[type, shape or null]],\n",
    "  \"output_format\": [[type, shape or null]],\n",
    "  \"components\": [str],\n",
    "  \"plan\": [str],\n",
    "  \"test_case_generation_advise\": [str]\n",
    "}\n",
    "<End>\n",
    "\n",
    "Here are additional detailed explanations of each field:\n",
    "\n",
    "For <components>:\n",
    "- **component_X**: The key represents the subtask name, it should be replaced by the actual class/function name of the component (e.g., \"merge_arrays\", \"calculate_median\").\n",
    "- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).\n",
    "- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:\n",
    "  - The first element indicates the data type (e.g., list, dict, NumPy array, torch.Tensor).\n",
    "  - The second element indicates the fixed shape if applicable; otherwise, it is null.\n",
    "- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`.\n",
    "- **work_flow**: Provide a detailed step-by-step plan that outlines the workflow of how the component functions to achieve the subtask.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "For <overall_plan>:\n",
    "- **input_format**: Describe the format of the input arguments required for the overall task. It follows the same structure as `input_format` in the component section.\n",
    "- **output_format**: Describe the format of the output arguments generated by the overall task. It follows the same structure as `output_format` in the component section.\n",
    "- **components**: List the components in the order.\n",
    "- **plan**: Provide a detailed step-by-step plan that outlines the workflow of how the components interact with each other to achieve the overall task. This should be a high-level description of the process.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases for the overall task, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:\n",
    "\n",
    "\"{{TASK_DESCRIPTION}}\"\n",
    "\n",
    "Use <> to indicate both start and end of the component part and the overall plan. Ensure that the components and the overall plan are clearly separated.\n",
    "\n",
    "Please provide your structured decomposition according to the instructions above.\n",
    "\"\"\"\n",
    "\n",
    "task_decompose_prompt = task_decompose_prompt.replace(\"{{TASK_DESCRIPTION}}\", TASK_DESCRIPTION)\n",
    "LLM_output = LLM_response(task_decompose_prompt, \"gpt-3.5-turbo\")\n",
    "print(\"input:\")\n",
    "print(task_decompose_prompt)\n",
    "print(\"output:\")\n",
    "print(LLM_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<components>\n",
      "{\n",
      "    \"merge_arrays\": {\n",
      "        \"step_task_description\": \"Merge two sorted arrays and find the median\",\n",
      "        \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Merge the two input arrays into a single sorted array\",\n",
      "                      \"2. Find the median of the merged array\"],\n",
      "        \"test_case_generation_advise\": [\"Generate test cases where the arrays have different sizes\",\n",
      "                                        \"Include edge cases where the median is at the beginning, middle, or end of the merged array\"]\n",
      "    },\n",
      "    \"calculate_median\": {\n",
      "        \"step_task_description\": \"Calculate the median of a sorted array\",\n",
      "        \"input_format\": [[\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Determine the size of the array\",\n",
      "                      \"2. Check if the size is even or odd\",\n",
      "                      \"3. Calculate the median based on the size\"],\n",
      "        \"test_case_generation_advise\": [\"Include test cases for arrays with even and odd lengths\",\n",
      "                                        \"Cover edge cases where the median is a single value or an average of two values\"]\n",
      "    }\n",
      "}\n",
      "<End>\n",
      "\n",
      "<overall_plan>\n",
      "{\n",
      "    \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "    \"output_format\": [[\"Float\", null]],\n",
      "    \"components\": [\"merge_arrays\", \"calculate_median\"],\n",
      "    \"plan\": [\"1. Merge the two arrays using component 'merge_arrays'\",\n",
      "             \"2. Calculate the median of the merged array using component 'calculate_median'\"],\n",
      "    \"test_case_generation_advise\": [\"Combine test cases for merging and calculating median\",\n",
      "                                    \"Ensure test coverage for different array sizes and median positions\"]\n",
      "}\n",
      "<End>\n",
      "extract results:\n",
      "{'components': {'merge_arrays': {'step_task_description': 'Merge two sorted arrays and find the median', 'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Merge the two input arrays into a single sorted array', '2. Find the median of the merged array'], 'test_case_generation_advise': ['Generate test cases where the arrays have different sizes', 'Include edge cases where the median is at the beginning, middle, or end of the merged array']}, 'calculate_median': {'step_task_description': 'Calculate the median of a sorted array', 'input_format': [['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Determine the size of the array', '2. Check if the size is even or odd', '3. Calculate the median based on the size'], 'test_case_generation_advise': ['Include test cases for arrays with even and odd lengths', 'Cover edge cases where the median is a single value or an average of two values']}}, 'overall_plan': {'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'components': ['merge_arrays', 'calculate_median'], 'plan': [\"1. Merge the two arrays using component 'merge_arrays'\", \"2. Calculate the median of the merged array using component 'calculate_median'\"], 'test_case_generation_advise': ['Combine test cases for merging and calculating median', 'Ensure test coverage for different array sizes and median positions']}}\n",
      "dict_keys(['components', 'overall_plan'])\n",
      "components:\n",
      "{'merge_arrays': {'step_task_description': 'Merge two sorted arrays and find the median', 'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Merge the two input arrays into a single sorted array', '2. Find the median of the merged array'], 'test_case_generation_advise': ['Generate test cases where the arrays have different sizes', 'Include edge cases where the median is at the beginning, middle, or end of the merged array']}, 'calculate_median': {'step_task_description': 'Calculate the median of a sorted array', 'input_format': [['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Determine the size of the array', '2. Check if the size is even or odd', '3. Calculate the median based on the size'], 'test_case_generation_advise': ['Include test cases for arrays with even and odd lengths', 'Cover edge cases where the median is a single value or an average of two values']}}\n",
      "overall_plan:\n",
      "{'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'components': ['merge_arrays', 'calculate_median'], 'plan': [\"1. Merge the two arrays using component 'merge_arrays'\", \"2. Calculate the median of the merged array using component 'calculate_median'\"], 'test_case_generation_advise': ['Combine test cases for merging and calculating median', 'Ensure test coverage for different array sizes and median positions']}\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "import json\n",
    "\n",
    "def extract_multiple_jsons(input_str):\n",
    "    # 定义匹配JSON块的正则表达式（惰性匹配）\n",
    "    pattern = r'<(components|overall_plan)>(.*?)<End>'\n",
    "    \n",
    "    # 查找所有匹配的块\n",
    "    matches = re.findall(pattern, input_str, re.DOTALL)\n",
    "    \n",
    "    result = {}\n",
    "    for block_name, content in matches:\n",
    "        try:\n",
    "            # 清理内容：移除前后空格和换行符\n",
    "            cleaned_content = content.strip()\n",
    "            \n",
    "            # 修复常见JSON格式错误（如末尾逗号）\n",
    "            cleaned_content = re.sub(r',\\s*}', '}', cleaned_content)\n",
    "            cleaned_content = re.sub(r',\\s*\\]', ']', cleaned_content)\n",
    "            \n",
    "            # 解析JSON\n",
    "            parsed_data = json.loads(cleaned_content)\n",
    "            result[block_name] = parsed_data\n",
    "        except json.JSONDecodeError as e:\n",
    "            print(f\"解析错误：{block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}\")\n",
    "            result[block_name] = None\n",
    "    return result\n",
    "\n",
    "print(LLM_output)\n",
    "plan_extract_test = extract_multiple_jsons(LLM_output)\n",
    "print(\"extract results:\")\n",
    "print(plan_extract_test)\n",
    "print(plan_extract_test.keys())\n",
    "print(\"components:\")\n",
    "print(plan_extract_test[\"components\"])\n",
    "print(\"overall_plan:\")\n",
    "print(plan_extract_test[\"overall_plan\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### LLM-TM prompt test (edition)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:\n",
      "You are an expert agent specialized in refining and improving code generation plans through iterative feedback. Given a task description, previous decomposition output, and user feedback, your job is to critically analyze the existing plan and modify it accordingly while maintaining the required output format.\n",
      "\n",
      "Carefully review the previous components and overall plan, then:\n",
      "1. Preserve correct/valid elements that don't conflict with the feedback\n",
      "2. Make targeted modifications based on the user's specific advice\n",
      "3. Ensure consistency between components and overall plan\n",
      "4. Verify input/output formats and workflow logic\n",
      "5. Check for any introduced errors during modification\n",
      "\n",
      "The input consists of three elements:\n",
      "- Original Task Description: \"Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
      "\n",
      "The overall run time complexity should be O(log (m+n)).\n",
      "\n",
      "Example 1:\n",
      "\n",
      "Input: nums1 = [1,3], nums2 = [2]\n",
      "Output: 2.00000\n",
      "Explanation: merged array = [1,2,3] and median is 2.\n",
      "Example 2:\n",
      "\n",
      "Input: nums1 = [1,2], nums2 = [3,4]\n",
      "Output: 2.50000\n",
      "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
      " \n",
      "\n",
      "Constraints:\n",
      "\n",
      "nums1.length == m\n",
      "nums2.length == n\n",
      "0 <= m <= 1000\n",
      "0 <= n <= 1000\n",
      "1 <= m + n <= 2000\n",
      "-106 <= nums1[i], nums2[i] <= 106\"\n",
      "- Previous Decomposition Output: \n",
      "<components>\n",
      "{\n",
      "    \"merge_arrays\": {\n",
      "        \"step_task_description\": \"Merge two sorted arrays and find the median\",\n",
      "        \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Merge the two input arrays into a single sorted array\",\n",
      "                      \"2. Find the median of the merged array\"],\n",
      "        \"test_case_generation_advise\": [\"Generate test cases where the arrays have different sizes\",\n",
      "                                        \"Include edge cases where the median is at the beginning, middle, or end of the merged array\"]\n",
      "    },\n",
      "    \"calculate_median\": {\n",
      "        \"step_task_description\": \"Calculate the median of a sorted array\",\n",
      "        \"input_format\": [[\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Determine the size of the array\",\n",
      "                      \"2. Check if the size is even or odd\",\n",
      "                      \"3. Calculate the median based on the size\"],\n",
      "        \"test_case_generation_advise\": [\"Include test cases for arrays with even and odd lengths\",\n",
      "                                        \"Cover edge cases where the median is a single value or an average of two values\"]\n",
      "    }\n",
      "}\n",
      "<End>\n",
      "\n",
      "<overall_plan>\n",
      "{\n",
      "    \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "    \"output_format\": [[\"Float\", null]],\n",
      "    \"components\": [\"merge_arrays\", \"calculate_median\"],\n",
      "    \"plan\": [\"1. Merge the two arrays using component 'merge_arrays'\",\n",
      "             \"2. Calculate the median of the merged array using component 'calculate_median'\"],\n",
      "    \"test_case_generation_advise\": [\"Combine test cases for merging and calculating median\",\n",
      "                                    \"Ensure test coverage for different array sizes and median positions\"]\n",
      "}\n",
      "<End>\n",
      "- User Feedback: \"From now on, I want to get the largest number in the merged array\"\n",
      "\n",
      "Your output must STRICTLY follow the original format with these sections:\n",
      "<components>...<End>\n",
      "<overall_plan>...<End>\n",
      "\n",
      "Follow these guidelines:\n",
      "- Explicitly address all points in the user feedback\n",
      "- Clearly document any changes made from previous version\n",
      "- Preserve JSON structure and formatting requirements\n",
      "- If feedback contradicts original requirements, prioritize feedback\n",
      "\n",
      "Again, user feedback is: \"From now on, I want to get the largest number in the merged array\"\n",
      "\n",
      "Provide your refined decomposition with clear explanations of changes in the component descriptions and overall plan. Ensure modularity, testability, and coverage of edge cases mentioned in feedback.\n",
      "output:\n",
      "<components>\n",
      "{\n",
      "    \"merge_arrays\": {\n",
      "        \"step_task_description\": \"Merge two sorted arrays and find the largest number\",\n",
      "        \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Merge the two input arrays into a single sorted array\",\n",
      "                      \"2. Find the largest number in the merged array\"],\n",
      "        \"test_case_generation_advise\": [\"Generate test cases where the arrays have different sizes\",\n",
      "                                       \"Include edge cases where the largest number is at the beginning, middle, or end of the merged array\"]\n",
      "    },\n",
      "    \"calculate_largest_number\": {\n",
      "        \"step_task_description\": \"Calculate the largest number in an array\",\n",
      "        \"input_format\": [[\"List\", null]],\n",
      "        \"output_format\": [[\"Float\", null]],\n",
      "        \"work_flow\": [\"1. Determine the size of the array\",\n",
      "                      \"2. Find the largest number in the array\"],\n",
      "        \"test_case_generation_advise\": [\"Include test cases for arrays with different sizes\",\n",
      "                                       \"Cover edge cases where the largest number is a single value or the average of two values\"]\n",
      "    }\n",
      "}\n",
      "<End>\n",
      "\n",
      "<overall_plan>\n",
      "{\n",
      "    \"input_format\": [[\"List\", null], [\"List\", null]],\n",
      "    \"output_format\": [[\"Float\", null]],\n",
      "    \"components\": [\"merge_arrays\", \"calculate_largest_number\"],\n",
      "    \"plan\": [\"1. Merge the two arrays using component 'merge_arrays'\",\n",
      "               \"2. Calculate the largest number in the merged array using component 'calculate_largest_number'\"],\n",
      "    \"test_case_generation_advise\": [\"Combine test cases for merging and calculating the largest number\",\n",
      "                                   \"Ensure test coverage for different array sizes and positions of the largest number\"]\n",
      "}\n",
      "<End>\n",
      "extract results:\n",
      "{'components': {'merge_arrays': {'step_task_description': 'Merge two sorted arrays and find the largest number', 'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Merge the two input arrays into a single sorted array', '2. Find the largest number in the merged array'], 'test_case_generation_advise': ['Generate test cases where the arrays have different sizes', 'Include edge cases where the largest number is at the beginning, middle, or end of the merged array']}, 'calculate_largest_number': {'step_task_description': 'Calculate the largest number in an array', 'input_format': [['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Determine the size of the array', '2. Find the largest number in the array'], 'test_case_generation_advise': ['Include test cases for arrays with different sizes', 'Cover edge cases where the largest number is a single value or the average of two values']}}, 'overall_plan': {'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'components': ['merge_arrays', 'calculate_largest_number'], 'plan': [\"1. Merge the two arrays using component 'merge_arrays'\", \"2. Calculate the largest number in the merged array using component 'calculate_largest_number'\"], 'test_case_generation_advise': ['Combine test cases for merging and calculating the largest number', 'Ensure test coverage for different array sizes and positions of the largest number']}}\n",
      "dict_keys(['components', 'overall_plan'])\n",
      "components:\n",
      "{'merge_arrays': {'step_task_description': 'Merge two sorted arrays and find the largest number', 'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Merge the two input arrays into a single sorted array', '2. Find the largest number in the merged array'], 'test_case_generation_advise': ['Generate test cases where the arrays have different sizes', 'Include edge cases where the largest number is at the beginning, middle, or end of the merged array']}, 'calculate_largest_number': {'step_task_description': 'Calculate the largest number in an array', 'input_format': [['List', None]], 'output_format': [['Float', None]], 'work_flow': ['1. Determine the size of the array', '2. Find the largest number in the array'], 'test_case_generation_advise': ['Include test cases for arrays with different sizes', 'Cover edge cases where the largest number is a single value or the average of two values']}}\n",
      "overall_plan:\n",
      "{'input_format': [['List', None], ['List', None]], 'output_format': [['Float', None]], 'components': ['merge_arrays', 'calculate_largest_number'], 'plan': [\"1. Merge the two arrays using component 'merge_arrays'\", \"2. Calculate the largest number in the merged array using component 'calculate_largest_number'\"], 'test_case_generation_advise': ['Combine test cases for merging and calculating the largest number', 'Ensure test coverage for different array sizes and positions of the largest number']}\n"
     ]
    }
   ],
   "source": [
    "plan_refinement_prompt = \"\"\"You are an expert agent specialized in refining and improving code generation plans through iterative feedback. Given a task description, previous decomposition output, and user feedback, your job is to critically analyze the existing plan and modify it accordingly while maintaining the required output format.\n",
    "\n",
    "Carefully review the previous components and overall plan, then:\n",
    "1. Preserve correct/valid elements that don't conflict with the feedback\n",
    "2. Make targeted modifications based on the user's specific advice\n",
    "3. Ensure consistency between components and overall plan\n",
    "4. Verify input/output formats and workflow logic\n",
    "5. Check for any introduced errors during modification\n",
    "\n",
    "The input consists of three elements:\n",
    "- Original Task Description: \"{{TASK_DESCRIPTION}}\"\n",
    "- Previous Decomposition Output: \n",
    "{{PREVIOUS_OUTPUT}}\n",
    "- User Feedback: \"{{USER_ADVICE}}\"\n",
    "\n",
    "Your output must STRICTLY follow the original format with these sections:\n",
    "<components>...<End>\n",
    "<overall_plan>...<End>\n",
    "\n",
    "Follow these guidelines:\n",
    "- Explicitly address all points in the user feedback\n",
    "- Clearly document any changes made from previous version\n",
    "- Preserve JSON structure and formatting requirements\n",
    "- If feedback contradicts original requirements, prioritize feedback\n",
    "\n",
    "Again, user feedback is: \"{{USER_ADVICE}}\"\n",
    "\n",
    "Provide your refined decomposition with clear explanations of changes in the component descriptions and overall plan. Ensure modularity, testability, and coverage of edge cases mentioned in feedback.\"\"\"\n",
    "\n",
    "previous_LLM_output = \"\"\"<components>\n",
    "{\n",
    "    \"merge_arrays\": {\n",
    "        \"step_task_description\": \"Merge two sorted arrays and find the median\",\n",
    "        \"input_format\": [[\"List\", null], [\"List\", null]],\n",
    "        \"output_format\": [[\"Float\", null]],\n",
    "        \"work_flow\": [\"1. Merge the two input arrays into a single sorted array\",\n",
    "                      \"2. Find the median of the merged array\"],\n",
    "        \"test_case_generation_advise\": [\"Generate test cases where the arrays have different sizes\",\n",
    "                                        \"Include edge cases where the median is at the beginning, middle, or end of the merged array\"]\n",
    "    },\n",
    "    \"calculate_median\": {\n",
    "        \"step_task_description\": \"Calculate the median of a sorted array\",\n",
    "        \"input_format\": [[\"List\", null]],\n",
    "        \"output_format\": [[\"Float\", null]],\n",
    "        \"work_flow\": [\"1. Determine the size of the array\",\n",
    "                      \"2. Check if the size is even or odd\",\n",
    "                      \"3. Calculate the median based on the size\"],\n",
    "        \"test_case_generation_advise\": [\"Include test cases for arrays with even and odd lengths\",\n",
    "                                        \"Cover edge cases where the median is a single value or an average of two values\"]\n",
    "    }\n",
    "}\n",
    "<End>\n",
    "\n",
    "<overall_plan>\n",
    "{\n",
    "    \"input_format\": [[\"List\", null], [\"List\", null]],\n",
    "    \"output_format\": [[\"Float\", null]],\n",
    "    \"components\": [\"merge_arrays\", \"calculate_median\"],\n",
    "    \"plan\": [\"1. Merge the two arrays using component 'merge_arrays'\",\n",
    "             \"2. Calculate the median of the merged array using component 'calculate_median'\"],\n",
    "    \"test_case_generation_advise\": [\"Combine test cases for merging and calculating median\",\n",
    "                                    \"Ensure test coverage for different array sizes and median positions\"]\n",
    "}\n",
    "<End>\"\"\"\n",
    "\n",
    "user_advice = \"From now on, I want to get the largest number in the merged array\"\n",
    "\n",
    "plan_refinement_prompt = plan_refinement_prompt.replace(\"{{TASK_DESCRIPTION}}\", TASK_DESCRIPTION)\n",
    "plan_refinement_prompt = plan_refinement_prompt.replace(\"{{PREVIOUS_OUTPUT}}\", previous_LLM_output)\n",
    "plan_refinement_prompt = plan_refinement_prompt.replace(\"{{USER_ADVICE}}\", user_advice)\n",
    "\n",
    "LLM_output = LLM_response(plan_refinement_prompt, \"gpt-3.5-turbo\")\n",
    "print(\"input:\")\n",
    "print(plan_refinement_prompt)\n",
    "print(\"output:\")\n",
    "print(LLM_output)\n",
    "\n",
    "plan_refinement_extract_test = extract_multiple_jsons(LLM_output)\n",
    "print(\"extract results:\")\n",
    "print(plan_refinement_extract_test)\n",
    "print(plan_refinement_extract_test.keys())\n",
    "print(\"components:\")\n",
    "print(plan_refinement_extract_test[\"components\"])\n",
    "print(\"overall_plan:\")\n",
    "print(plan_refinement_extract_test[\"overall_plan\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Test case generation test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:\n",
      "You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:\n",
      "\n",
      "### Input Specifications:\n",
      "- **Task Description**:\n",
      "Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
      "\n",
      "The overall run time complexity should be O(log (m+n)).\n",
      "\n",
      "Example 1:\n",
      "\n",
      "Input: nums1 = [1,3], nums2 = [2]\n",
      "Output: 2.00000\n",
      "Explanation: merged array = [1,2,3] and median is 2.\n",
      "Example 2:\n",
      "\n",
      "Input: nums1 = [1,2], nums2 = [3,4]\n",
      "Output: 2.50000\n",
      "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
      " \n",
      "\n",
      "Constraints:\n",
      "\n",
      "nums1.length == m\n",
      "nums2.length == n\n",
      "0 <= m <= 1000\n",
      "0 <= n <= 1000\n",
      "1 <= m + n <= 2000\n",
      "-106 <= nums1[i], nums2[i] <= 106\n",
      "- **Input Format**: \n",
      "- Argument 1: list (no fixed shape)\n",
      "- Argument 2: list (no fixed shape)\n",
      "- **Output Format**: \n",
      "- Output 1: float (no fixed shape)\n",
      "- **Components Used**: merge_arrays\n",
      "- **Plan**: \n",
      "1. Merge the two sorted arrays using the 'merge_arrays' component.\n",
      "2. Calculate the median of the merged sorted array.\n",
      "3. Return the median value as output.\n",
      "- **Test Case Advise**: \n",
      "- Test with arrays of different sizes.\n",
      "- Test with negative values in arrays.\n",
      "- Test with arrays containing only one element.\n",
      "\n",
      "### Requirements:\n",
      "1. **Test Function Structure**:\n",
      "   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):`).\n",
      "   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.\n",
      "   - Include input generation, runtime checks, code inspection, or result validation within the function.\n",
      "\n",
      "2. **Test Types** (use one of these for `test_type`):\n",
      "   - `correctness`: Validate output against expected results for specific inputs.\n",
      "   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.\n",
      "   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).\n",
      "   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).\n",
      "   - `error_handling`: Check if errors are raised for invalid inputs.\n",
      "\n",
      "3. **Test Case Diversity**:\n",
      "   - Cover all provided advisories.\n",
      "   - Include at least one test per advisory and one for each test type where applicable.\n",
      "\n",
      "### Output Format:\n",
      "Return a JSON dictionary with test cases in this structure:\n",
      "{\n",
      "  \"test_case_1\": {\n",
      "    \"purpose\": \"Briefly describe the test's purpose...\",\n",
      "    \"test_function\": \"Test function string...\",\n",
      "    \"test_type\": \"correctness|edge_case|runtime|component_check|error_handling\"\n",
      "  },\n",
      "  ...\n",
      "}\n",
      "\n",
      "Generate test cases that rigorously validate the function's behavior, code structure, and performance.\n"
     ]
    }
   ],
   "source": [
    "def create_test_prompt(task_descr_str, task_spec, use_example=True, bulk=True):\n",
    "    \"\"\"\n",
    "    Generates a prompt (or list of prompts) for test case generation based on task specifications.\n",
    "    \n",
    "    Parameters:\n",
    "    - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.\n",
    "    - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.\n",
    "    \n",
    "    Returns:\n",
    "    - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).\n",
    "    \"\"\"\n",
    "    \n",
    "    # Helper function to generate the prompt text from a modified task specification\n",
    "    def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = \"\"):\n",
    "        advisory_list = \"\\n\".join([f\"- {advise}\" for advise in advisories])\n",
    "        prompt = f\"\"\"You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:\n",
    "\n",
    "### Input Specifications:\n",
    "- **Task Description**:\n",
    "{task_descr_str}\n",
    "- **Input Format**: \n",
    "{input_descr_str}\n",
    "- **Output Format**: \n",
    "{output_descr_str}\n",
    "- **Components Used**: {components_str}\n",
    "- **Plan**: \n",
    "{plan_str}\n",
    "- **Test Case Advise**: \n",
    "{advisory_list}\n",
    "\n",
    "### Requirements:\n",
    "1. **Test Function Structure**:\n",
    "   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):`).\n",
    "   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.\n",
    "   - Include input generation, runtime checks, code inspection, or result validation within the function.\n",
    "\n",
    "2. **Test Types** (use one of these for `test_type`):\n",
    "   - `correctness`: Validate output against expected results for specific inputs.\n",
    "   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.\n",
    "   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).\n",
    "   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).\n",
    "   - `error_handling`: Check if errors are raised for invalid inputs.\n",
    "\n",
    "3. **Test Case Diversity**:\n",
    "   - Cover all provided advisories.\n",
    "   - Include at least one test per advisory and one for each test type where applicable.\n",
    "\n",
    "### Output Format:\n",
    "Return a JSON dictionary with test cases in this structure:\n",
    "{{\n",
    "  \"test_case_1\": {{\n",
    "    \"purpose\": \"Briefly describe the test's purpose...\",\n",
    "    \"test_function\": \"Test function string...\",\n",
    "    \"test_type\": \"correctness|edge_case|runtime|component_check|error_handling\"\n",
    "  }},\n",
    "  ...\n",
    "}}\n",
    "{example_text}\n",
    "Generate test cases that rigorously validate the function's behavior, code structure, and performance.\"\"\"\n",
    "        return prompt\n",
    "\n",
    "    if use_example:\n",
    "      examples_text = \"\"\"\n",
    "### Example:\n",
    "{{\n",
    "  \"test_case_1\": {{\n",
    "    \"purpose\": \"Test with arrays of different sizes.\",\n",
    "    \"test_function\": \"def test_case(func):\\\\n    arr1 = [1, 3, 5]\\\\n    arr2 = [2, 4]\\\\n    merged = sorted(arr1 + arr2)\\\\n    expected = (merged[2] + merged[1]) / 2\\\\n    return func(arr1, arr2) == expected\",\n",
    "    \"test_type\": \"correctness\"\n",
    "  }},\n",
    "  \"test_case_2\": {{\n",
    "    \"purpose\": \"Check if 'merge_arrays' component is used.\",\n",
    "    \"test_function\": \"def test_case(func):\\\\n    import inspect\\\\n    source = inspect.getsource(func)\\\\n    return 'merge_arrays(' in source\",\n",
    "    \"test_type\": \"component_check\"\n",
    "  }}\n",
    "}}\n",
    "\"\"\"\n",
    "    else:\n",
    "      examples_text = \"\"\n",
    "\n",
    "    # Process input_format into a descriptive string\n",
    "    input_descr = []\n",
    "    for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):\n",
    "        shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "        input_descr.append(f\"- Argument {idx}: {dtype} ({shape_info})\")\n",
    "    input_descr_str = \"\\n\".join(input_descr)\n",
    "\n",
    "    # Process output_format into a descriptive string\n",
    "    output_descr = []\n",
    "    for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):\n",
    "        shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "        output_descr.append(f\"- Output {idx}: {dtype} ({shape_info})\")\n",
    "    output_descr_str = \"\\n\".join(output_descr)\n",
    "\n",
    "    # Process components and plan\n",
    "    components_str = \", \".join(task_spec['components'])\n",
    "    plan_str = \"\\n\".join(task_spec['plan'])\n",
    "\n",
    "    if bulk:\n",
    "        # Generate a single prompt with all advisories\n",
    "        advisories = task_spec['test_case_generation_advise']\n",
    "        return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)\n",
    "    else:\n",
    "        # Generate a list of prompts, each with a single advisory\n",
    "        prompts = []\n",
    "        for advise in task_spec['test_case_generation_advise']:\n",
    "            single_advisory = [advise]\n",
    "            prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)\n",
    "            prompts.append(prompt)\n",
    "        return prompts\n",
    "    \n",
    "test_plan_dict = {'input_format': [['list', None], ['list', None]], 'output_format': [['float', None]], 'components': ['merge_arrays'], 'plan': [\"1. Merge the two sorted arrays using the 'merge_arrays' component.\", '2. Calculate the median of the merged sorted array.', '3. Return the median value as output.'], 'test_case_generation_advise': ['Test with arrays of different sizes.', 'Test with negative values in arrays.', 'Test with arrays containing only one element.']}\n",
    "\n",
    "test_case_generation_prompt = create_test_prompt(TASK_DESCRIPTION, test_plan_dict, use_example=False, bulk=True)\n",
    "print(\"input:\")\n",
    "print(test_case_generation_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output:\n",
      "{\n",
      "    \"test_case_1\": {\n",
      "        \"purpose\": \"Test the function with input arrays of different sizes to validate the median calculation.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    nums1 = [1, 3, 5]\\\\n    nums2 = [2, 4, 6, 8]\\\\n    return func(nums1, nums2) == 4.0\",\n",
      "        \"test_type\": \"correctness\"\n",
      "    },\n",
      "    \"test_case_2\": {\n",
      "        \"purpose\": \"Test the function with negative values in input arrays to ensure correct calculation of the median.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    nums1 = [-3, -2, -1]\\\\n    nums2 = [-5, -4]\\\\n    return func(nums1, nums2) == -3.0\",\n",
      "        \"test_type\": \"correctness\"\n",
      "    },\n",
      "    \"test_case_3\": {\n",
      "        \"purpose\": \"Test the function with arrays containing only one element to verify handling of edge cases.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    nums1 = [2]\\\\n    nums2 = [3]\\\\n    return func(nums1, nums2) == 2.5\",\n",
      "        \"test_type\": \"edge_case\"\n",
      "    },\n",
      "    \"test_case_4\": {\n",
      "        \"purpose\": \"Test the function to ensure it meets the required time complexity.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    import time\\\\n    nums1 = [i for i in range(1000)]\\\\n    nums2 = [i for i in range(500, 1500)]\\\\n    start_time = time.time()\\\\n    func(nums1, nums2)\\\\n    end_time = time.time()\\\\n    return end_time - start_time < 0.001\",\n",
      "        \"test_type\": \"runtime\"\n",
      "    }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# test case generation extract test\n",
    "LLM_output = LLM_response(test_case_generation_prompt, \"gpt-3.5-turbo\")\n",
    "print(\"output:\")\n",
    "print(LLM_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"test_case_1\": {\n",
      "        \"purpose\": \"Test the function with input arrays of different sizes to validate the median calculation.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    nums1 = [1, 3, 5]\\\\n    nums2 = [2, 4, 6, 8]\\\\n    return func(nums1, nums2) == 4.0\",\n",
      "        \"test_type\": \"correctness\"\n",
      "    },\n",
      "    \"test_case_2\": {\n",
      "        \"purpose\": \"Test the function with negative values in input arrays to ensure correct calculation of the median.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    nums1 = [-3, -2, -1]\\\\n    nums2 = [-5, -4]\\\\n    return func(nums1, nums2) == -3.0\",\n",
      "        \"test_type\": \"correctness\"\n",
      "    },\n",
      "    \"test_case_3\": {\n",
      "        \"purpose\": \"Test the function with arrays containing only one element to verify handling of edge cases.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    nums1 = [2]\\\\n    nums2 = [3]\\\\n    return func(nums1, nums2) == 2.5\",\n",
      "        \"test_type\": \"edge_case\"\n",
      "    },\n",
      "    \"test_case_4\": {\n",
      "        \"purpose\": \"Test the function to ensure it meets the required time complexity.\",\n",
      "        \"test_function\": \"def test_case(func):\\\\n    import time\\\\n    nums1 = [i for i in range(1000)]\\\\n    nums2 = [i for i in range(500, 1500)]\\\\n    start_time = time.time()\\\\n    func(nums1, nums2)\\\\n    end_time = time.time()\\\\n    return end_time - start_time < 0.001\",\n",
      "        \"test_type\": \"runtime\"\n",
      "    }\n",
      "}\n",
      "\n",
      "extract results:\n",
      "{'test_case_1': {'purpose': 'Test the function with input arrays of different sizes to validate the median calculation.', 'test_function': 'def test_case(func):\\\\n    nums1 = [1, 3, 5]\\\\n    nums2 = [2, 4, 6, 8]\\\\n    return func(nums1, nums2) == 4.0', 'test_type': 'correctness'}, 'test_case_2': {'purpose': 'Test the function with negative values in input arrays to ensure correct calculation of the median.', 'test_function': 'def test_case(func):\\\\n    nums1 = [-3, -2, -1]\\\\n    nums2 = [-5, -4]\\\\n    return func(nums1, nums2) == -3.0', 'test_type': 'correctness'}, 'test_case_3': {'purpose': 'Test the function with arrays containing only one element to verify handling of edge cases.', 'test_function': 'def test_case(func):\\\\n    nums1 = [2]\\\\n    nums2 = [3]\\\\n    return func(nums1, nums2) == 2.5', 'test_type': 'edge_case'}, 'test_case_4': {'purpose': 'Test the function to ensure it meets the required time complexity.', 'test_function': 'def test_case(func):\\\\n    import time\\\\n    nums1 = [i for i in range(1000)]\\\\n    nums2 = [i for i in range(500, 1500)]\\\\n    start_time = time.time()\\\\n    func(nums1, nums2)\\\\n    end_time = time.time()\\\\n    return end_time - start_time < 0.001', 'test_type': 'runtime'}}\n",
      "\n",
      "test function examples:\n",
      "test_case_1: \n",
      "def test_case(func):\n",
      "    nums1 = [1, 3, 5]\n",
      "    nums2 = [2, 4, 6, 8]\n",
      "    return func(nums1, nums2) == 4.0\n",
      "test_case_2: \n",
      "def test_case(func):\n",
      "    nums1 = [-3, -2, -1]\n",
      "    nums2 = [-5, -4]\n",
      "    return func(nums1, nums2) == -3.0\n",
      "test_case_3: \n",
      "def test_case(func):\n",
      "    nums1 = [2]\n",
      "    nums2 = [3]\n",
      "    return func(nums1, nums2) == 2.5\n",
      "test_case_4: \n",
      "def test_case(func):\n",
      "    import time\n",
      "    nums1 = [i for i in range(1000)]\n",
      "    nums2 = [i for i in range(500, 1500)]\n",
      "    start_time = time.time()\n",
      "    func(nums1, nums2)\n",
      "    end_time = time.time()\n",
      "    return end_time - start_time < 0.001\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import re\n",
    "\n",
    "def extract_test_cases(llm_output):\n",
    "    \"\"\"\n",
    "    Extract test cases from LLM's JSON output.\n",
    "    \n",
    "    Args:\n",
    "        llm_output (str/dict): Raw text output from LLM containing JSON, or a dictionary\n",
    "    \n",
    "    Returns:\n",
    "        dict: Parsed test cases in dictionary format\n",
    "    \"\"\"\n",
    "    # Handle if input is already a dictionary\n",
    "    if isinstance(llm_output, dict):\n",
    "        return llm_output\n",
    "    \n",
    "    # Normalize JSON formatting\n",
    "    cleaned_output = llm_output.strip()\n",
    "    \n",
    "    # Handle code block formatting\n",
    "    if cleaned_output.startswith(\"```json\"):\n",
    "        cleaned_output = re.sub(r'^```json\\s*|\\s*```$', '', cleaned_output, flags=re.MULTILINE)\n",
    "    elif cleaned_output.startswith(\"```\"):\n",
    "        cleaned_output = re.sub(r'^```\\s*|\\s*```$', '', cleaned_output, flags=re.MULTILINE)\n",
    "    \n",
    "    # Parse JSON\n",
    "    try:\n",
    "        test_cases = json.loads(cleaned_output)\n",
    "    except json.JSONDecodeError as e:\n",
    "        raise ValueError(f\"Failed to parse JSON: {e}\") from e\n",
    "    \n",
    "    # Validate structure\n",
    "    for key, value in test_cases.items():\n",
    "        if not all(k in value for k in ('purpose', 'test_function', 'test_type')):\n",
    "            raise ValueError(f\"Invalid test case structure in key: {key}\")\n",
    "            \n",
    "    return test_cases\n",
    "\n",
    "print(LLM_output)\n",
    "test_case_extract_test = extract_test_cases(LLM_output)\n",
    "print(\"\\nextract results:\")\n",
    "print(test_case_extract_test)\n",
    "print(\"\\ntest function examples:\")\n",
    "for key, value in test_case_extract_test.items():\n",
    "    test_function_i = value['test_function'].replace(\"\\\\n\", \"\\n\")\n",
    "    print(f\"{key}: \\n{test_function_i}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Code generation prompt test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_code_generation_prompt(\n",
    "    components,\n",
    "    overall_plan,\n",
    "    task_description=None,\n",
    "    test_cases=None,\n",
    "    history=None,\n",
    "    next_code_line=False,\n",
    "    output_planning=False,\n",
    "    use_example=False,\n",
    "    use_task_description=False,\n",
    "    use_system_prompt=True,\n",
    "    more_comments=False,\n",
    "):\n",
    "    prompt_parts = []\n",
    "\n",
    "    system_prompt = \"You are a highly skilled coding assistant designed to generate clear, efficient, and correct code based on structured task descriptions and detailed plans provided by the user. Your responses must precisely follow the instructions, formats, and constraints given by the user, and you must strictly adhere to input-output formats, workflows, and specific guidelines outlined.\"\n",
    "\n",
    "    # Add System Prompt if enabled\n",
    "    if use_system_prompt:\n",
    "        prompt_parts.append(f\"=== Role ===\\n{system_prompt}\\n\")\n",
    "\n",
    "    # Add Task Description if enabled\n",
    "    if use_task_description and task_description:\n",
    "        prompt_parts.append(f\"=== Task Description ===\\n{task_description}\\n\")\n",
    "\n",
    "    # Add Components Section\n",
    "    if components:\n",
    "        prompt_parts.append(\"=== Components ===\")\n",
    "        for comp_name, comp_details in components.items():\n",
    "            # Process Input Format\n",
    "            input_fmt = comp_details[\"input_format\"]\n",
    "            input_lines = []\n",
    "            for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "                shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "            input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "            # Process Output Format\n",
    "            output_fmt = comp_details[\"output_format\"]\n",
    "            output_lines = []\n",
    "            for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "                shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "            output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "            # Build Component Details\n",
    "            component_part = [\n",
    "                f\"\\n**Component: {comp_name}**\",\n",
    "                f\"Step Task Description: {comp_details['step_task_description']}\",\n",
    "                input_section,\n",
    "                output_section,\n",
    "                \"Workflow Steps:\",\n",
    "                *[f\"- {step}\" for step in comp_details[\"work_flow\"]],\n",
    "                \"Test Case Generation Advice:\",\n",
    "                *[f\"- {advice}\" for advice in comp_details[\"test_case_generation_advise\"]],\n",
    "                \"\\n\",\n",
    "            ]\n",
    "            prompt_parts.extend(component_part)\n",
    "\n",
    "    # Add Overall Plan Section\n",
    "    if overall_plan:\n",
    "        prompt_parts.append(\"\\n=== Overall Plan ===\")\n",
    "        # Process Input Format\n",
    "        input_fmt = overall_plan[\"input_format\"]\n",
    "        input_lines = []\n",
    "        for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "            shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "        input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "        # Process Output Format\n",
    "        output_fmt = overall_plan[\"output_format\"]\n",
    "        output_lines = []\n",
    "        for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "            shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "        output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "        # Build Overall Plan Details\n",
    "        plan_part = [\n",
    "            input_section,\n",
    "            output_section,\n",
    "            f\"Components Order: {', '.join(overall_plan['components'])}\",\n",
    "            \"Plan Steps:\",\n",
    "            *[f\"- {step}\" for step in overall_plan[\"plan\"]],\n",
    "            \"Overall Test Case Advice:\",\n",
    "            *[f\"- {advice}\" for advice in overall_plan[\"test_case_generation_advise\"]],\n",
    "            \"\\n\",\n",
    "        ]\n",
    "        prompt_parts.extend(plan_part)\n",
    "\n",
    "    # Add Test Cases if enabled and available\n",
    "    if use_example and test_cases:\n",
    "        prompt_parts.append(\"\\n=== Test Cases ===\")\n",
    "        for case_name, case_details in test_cases.items():\n",
    "            case_part = [\n",
    "                f\"\\n**Test Case: {case_name}**\",\n",
    "                f\"Purpose: {case_details['purpose']}\",\n",
    "                f\"Type: {case_details['test_type']}\",\n",
    "                f\"Test Function:\\n{case_details['test_function']}\",\n",
    "                \"\\n\",\n",
    "            ]\n",
    "            prompt_parts.extend(case_part)\n",
    "\n",
    "    # Add History if available\n",
    "    if history:\n",
    "        prompt_parts.append(\"\\n=== Previous Generation Attempts ===\")\n",
    "        for gen_name, gen_details in history.items():\n",
    "            history_part = [\n",
    "                f\"\\n**Generation: {gen_name}**\",\n",
    "                f\"Score: {gen_details['score']}\",\n",
    "                \"Generated Code:\",\n",
    "                gen_details[\"generated_code\"],\n",
    "                \"Generation Plan:\",\n",
    "                *[f\"- {step}\" for step in gen_details[\"generation_plan\"]],\n",
    "                \"\\n\",\n",
    "            ]\n",
    "            prompt_parts.extend(history_part)\n",
    "\n",
    "    # Build Instructions\n",
    "    instructions = [\"\\n=== Instructions ===\"]\n",
    "    if next_code_line:\n",
    "        instructions.append(\"Generate ONLY the next line or a small code snippet required to proceed.\")\n",
    "    else:\n",
    "        instructions.append(\"Generate the COMPLETE code based on the components and plan above.\")\n",
    "\n",
    "    if more_comments:\n",
    "        instructions.append(\"Please add as much comments as possible to your code to explain the logic and any critical steps.\")\n",
    "\n",
    "    if output_planning:\n",
    "        instructions.append(\"Structure your response as follows:\")\n",
    "        instructions.append(\"<Code>\")\n",
    "        instructions.append(\"Your code here\")\n",
    "        instructions.append(\"<End>\")\n",
    "        instructions.append(\"<Planning>\")\n",
    "        if next_code_line:\n",
    "            instructions.append(\"A concise summary of what this specific code part accomplishes.\")\n",
    "        else:\n",
    "            instructions.append(\"A detailed step-by-step explanation of the code's workflow.\")\n",
    "        instructions.append(\"<End>\")\n",
    "        instructions.append(\"Provide the code with the same indicator and structure as shown in Instructions.\")\n",
    "    else:\n",
    "        instructions.append(\"Structure your response as follows:\")\n",
    "        instructions.append(\"<Code>\")\n",
    "        instructions.append(\"Your code here\")\n",
    "        instructions.append(\"<End>\")\n",
    "        instructions.append(\"Provide the code WITHOUT any additional explanations, and DO use the same indicator and structure as shown in Instructions.\")\n",
    "\n",
    "    prompt_parts.append(\"\\n\".join(instructions))\n",
    "\n",
    "    return \"\\n\".join(prompt_parts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:\n",
      "=== Role ===\n",
      "You are a highly skilled coding assistant designed to generate clear, efficient, and correct code based on structured task descriptions and detailed plans provided by the user. Your responses must precisely follow the instructions, formats, and constraints given by the user, and you must strictly adhere to input-output formats, workflows, and specific guidelines outlined.\n",
      "\n",
      "=== Task Description ===\n",
      "Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.\n",
      "\n",
      "The overall run time complexity should be O(log (m+n)).\n",
      "\n",
      "Example 1:\n",
      "\n",
      "Input: nums1 = [1,3], nums2 = [2]\n",
      "Output: 2.00000\n",
      "Explanation: merged array = [1,2,3] and median is 2.\n",
      "Example 2:\n",
      "\n",
      "Input: nums1 = [1,2], nums2 = [3,4]\n",
      "Output: 2.50000\n",
      "Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.\n",
      " \n",
      "\n",
      "Constraints:\n",
      "\n",
      "nums1.length == m\n",
      "nums2.length == n\n",
      "0 <= m <= 1000\n",
      "0 <= n <= 1000\n",
      "1 <= m + n <= 2000\n",
      "-106 <= nums1[i], nums2[i] <= 106\n",
      "\n",
      "=== Components ===\n",
      "\n",
      "**Component: merge_arrays**\n",
      "Step Task Description: Merge two sorted arrays nums1 and nums2 and return the median.\n",
      "Input Format:\n",
      "- Argument 1: list with no fixed shape\n",
      "- Argument 2: list with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: float with no fixed shape\n",
      "Workflow Steps:\n",
      "- 1. Merge the two arrays into a single sorted array.\n",
      "- 2. Calculate the median of the merged array.\n",
      "Test Case Generation Advice:\n",
      "- Include test cases for edge cases like empty arrays, arrays with same numbers, uneven number of elements in arrays.\n",
      "\n",
      "\n",
      "\n",
      "=== Overall Plan ===\n",
      "Input Format:\n",
      "- Argument 1: list with no fixed shape\n",
      "- Argument 2: list with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: float with no fixed shape\n",
      "Components Order: merge_arrays\n",
      "Plan Steps:\n",
      "- 1. Call the merge_arrays component with two input arrays nums1 and nums2.\n",
      "- 2. Receive the merged array containing all elements from both arrays.\n",
      "- 3. Calculate the median of the merged array.\n",
      "- 4. Return the calculated median as the output.\n",
      "Overall Test Case Advice:\n",
      "- Generate test cases covering arrays with different lengths, arrays with overlapping values, and extreme boundary values.\n",
      "\n",
      "\n",
      "\n",
      "=== Test Cases ===\n",
      "\n",
      "**Test Case: test_case_1**\n",
      "Purpose: Test the function with input arrays of different sizes to validate the median calculation.\n",
      "Type: correctness\n",
      "Test Function:\n",
      "def test_case(func):\\n    nums1 = [1, 3, 5]\\n    nums2 = [2, 4, 6, 8]\\n    return func(nums1, nums2) == 4.0\n",
      "\n",
      "\n",
      "\n",
      "**Test Case: test_case_2**\n",
      "Purpose: Test the function with negative values in input arrays to ensure correct calculation of the median.\n",
      "Type: correctness\n",
      "Test Function:\n",
      "def test_case(func):\\n    nums1 = [-3, -2, -1]\\n    nums2 = [-5, -4]\\n    return func(nums1, nums2) == -3.0\n",
      "\n",
      "\n",
      "\n",
      "**Test Case: test_case_3**\n",
      "Purpose: Test the function with arrays containing only one element to verify handling of edge cases.\n",
      "Type: edge_case\n",
      "Test Function:\n",
      "def test_case(func):\\n    nums1 = [2]\\n    nums2 = [3]\\n    return func(nums1, nums2) == 2.5\n",
      "\n",
      "\n",
      "\n",
      "**Test Case: test_case_4**\n",
      "Purpose: Test the function to ensure it meets the required time complexity.\n",
      "Type: runtime\n",
      "Test Function:\n",
      "def test_case(func):\\n    import time\\n    nums1 = [i for i in range(1000)]\\n    nums2 = [i for i in range(500, 1500)]\\n    start_time = time.time()\\n    func(nums1, nums2)\\n    end_time = time.time()\\n    return end_time - start_time < 0.001\n",
      "\n",
      "\n",
      "\n",
      "=== Instructions ===\n",
      "Generate the COMPLETE code based on the components and plan above.\n",
      "Please add as much comments as possible to your code to explain the logic and any critical steps.\n",
      "Structure your response as follows:\n",
      "<Code>\n",
      "Your code here\n",
      "<End>\n",
      "<Planning>\n",
      "A detailed step-by-step explanation of the code's workflow.\n",
      "<End>\n",
      "Provide the code with the same indicator and structure as shown in Instructions.\n"
     ]
    }
   ],
   "source": [
    "# code generation prompt test\n",
    "code_generation_prompt = create_code_generation_prompt(plan_extract_test[\"components\"], plan_extract_test[\"overall_plan\"], task_description=TASK_DESCRIPTION, test_cases=test_case_extract_test, next_code_line=False, output_planning=True,use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True)\n",
    "\n",
    "print(\"input:\")\n",
    "print(code_generation_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output:\n",
      "<Code>\n",
      "def merge_arrays(nums1, nums2):\n",
      "    # Merge the two sorted arrays into a single sorted array\n",
      "    merged_array = sorted(nums1 + nums2)\n",
      "    \n",
      "    # Calculate the median of the merged array\n",
      "    total_length = len(merged_array)\n",
      "    if total_length % 2 == 0:\n",
      "        median = (merged_array[total_length // 2 - 1] + merged_array[total_length // 2]) / 2\n",
      "    else:\n",
      "        median = merged_array[total_length // 2]\n",
      "    \n",
      "    return median\n",
      "\n",
      "<End>\n",
      "\n",
      "<Planning>\n",
      "The code above defines a function merge_arrays that takes two sorted arrays as input and returns the median of the merged sorted array.\n",
      "1. Merge the two input arrays nums1 and nums2 by concatenating them and sorting the resulting array.\n",
      "2. Calculate the median of the merged array by checking if the total length is even or odd:\n",
      "   - If it's even, calculate the average of the middle two elements.\n",
      "   - If it's odd, take the middle element as the median.\n",
      "3. Return the calculated median as the output of the function.\n",
      "</End>\n"
     ]
    }
   ],
   "source": [
    "# code generation extract test\n",
    "LLM_output = LLM_response(code_generation_prompt, \"gpt-3.5-turbo\")\n",
    "print(\"output:\")\n",
    "print(LLM_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "extract results:\n",
      "{'code': 'def merge_arrays(nums1, nums2):\\n    merged = sorted(nums1 + nums2)\\n    n = len(merged)\\n    if n % 2 == 0:\\n        return (merged[n//2 - 1] + merged[n//2]) / 2\\n    else:\\n        return merged[n//2]', 'plan': 'The code takes two sorted arrays, merges them into a single sorted array, and calculates the median of the merged array.\\n1. Merge the two arrays into a single sorted array.\\n2. Calculate the median of the merged array.\\n\\nExplanation:\\n- Merge the two sorted arrays [1, 2] and [3, 4] into [1, 2, 3, 4].\\n- Since the length of the merged array is even, the median is calculated as (2 + 3) / 2 = 2.5.'}\n",
      "\n",
      "code:\n",
      "def merge_arrays(nums1, nums2):\n",
      "    merged = sorted(nums1 + nums2)\n",
      "    n = len(merged)\n",
      "    if n % 2 == 0:\n",
      "        return (merged[n//2 - 1] + merged[n//2]) / 2\n",
      "    else:\n",
      "        return merged[n//2]\n",
      "\n",
      "plan:\n",
      "The code takes two sorted arrays, merges them into a single sorted array, and calculates the median of the merged array.\n",
      "1. Merge the two arrays into a single sorted array.\n",
      "2. Calculate the median of the merged array.\n",
      "\n",
      "Explanation:\n",
      "- Merge the two sorted arrays [1, 2] and [3, 4] into [1, 2, 3, 4].\n",
      "- Since the length of the merged array is even, the median is calculated as (2 + 3) / 2 = 2.5.\n"
     ]
    }
   ],
   "source": [
    "# code generation extract test\n",
    "import re\n",
    "\n",
    "def extract_code(llm_output):\n",
    "    \"\"\"Extracts code and planning sections from LLM output.\"\"\"\n",
    "    result = {\"code\": None, \"plan\": None}\n",
    "    \n",
    "    # Extract code section\n",
    "    code_match = re.search(r'<Code>(.*?)<End>', llm_output, re.DOTALL)\n",
    "    if code_match:\n",
    "        result[\"code\"] = code_match.group(1).strip()\n",
    "    \n",
    "    # Extract planning section\n",
    "    plan_match = re.search(r'<Planning>(.*?)<End>', llm_output, re.DOTALL)\n",
    "    if plan_match:\n",
    "        result[\"plan\"] = plan_match.group(1).strip()\n",
    "    \n",
    "    return result\n",
    "\n",
    "code_extract_test = extract_code(LLM_output)\n",
    "print(\"\\nextract results:\")\n",
    "print(code_extract_test)\n",
    "print(\"\\ncode:\")\n",
    "print(code_extract_test[\"code\"])\n",
    "print(\"\\nplan:\")\n",
    "print(code_extract_test[\"plan\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### code refinement test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_code_refinement_prompt(\n",
    "    components,\n",
    "    overall_plan,\n",
    "    task_description=None,\n",
    "    test_cases=None,\n",
    "    history=None,\n",
    "    next_code_line=False,\n",
    "    output_planning=False,\n",
    "    use_example=False,\n",
    "    use_task_description=False,\n",
    "    use_system_prompt=True,\n",
    "    more_comments=False,\n",
    "):\n",
    "    prompt_parts = []\n",
    "\n",
    "    system_prompt = (\"You are a code refinement specialist designed to improve existing implementations based on specific feedback. \"\n",
    "                    \"Analyze the provided feedback, identify areas for improvement, and modify the code while strictly maintaining \"\n",
    "                    \"the required input/output formats and component specifications.\")\n",
    "\n",
    "    if use_system_prompt:\n",
    "        prompt_parts.append(f\"=== Role ===\\n{system_prompt}\\n\")\n",
    "\n",
    "    # Add Feedback Section (using task_description field for feedback)\n",
    "    if use_task_description and task_description:\n",
    "        prompt_parts.append(f\"=== Feedback Analysis ===\\n{task_description}\\n\")\n",
    "\n",
    "    # Add Components (for reference during modification)\n",
    "    if components:\n",
    "        prompt_parts.append(\"=== Component Specifications ===\")\n",
    "        for comp_name, comp_details in components.items():\n",
    "            comp_spec = [\n",
    "                f\"\\n**Component: {comp_name}**\",\n",
    "                f\"Purpose: {comp_details['step_task_description']}\",\n",
    "                \"Input Requirements:\",\n",
    "                *[f\"- {dtype} with shape={shape if shape else 'any'}\" for dtype, shape in comp_details[\"input_format\"]],\n",
    "                \"Output Requirements:\",\n",
    "                *[f\"- {dtype} with shape={shape if shape else 'any'}\" for dtype, shape in comp_details[\"output_format\"]],\n",
    "                \"Key Workflow Steps:\",\n",
    "                *[f\"- {step}\" for step in comp_details[\"work_flow\"][:3]],  # Show top 3 critical steps\n",
    "                \"\\n\"\n",
    "            ]\n",
    "            prompt_parts.extend(comp_spec)\n",
    "\n",
    "    # Add Previous Implementations with Feedback\n",
    "    if history:\n",
    "        prompt_parts.append(\"\\n=== Implementation History ===\")\n",
    "        for version, details in history.items():\n",
    "            history_entry = [\n",
    "                f\"\\n**Version: {version}**\",\n",
    "                f\"Assessment: Score {details['score']}/10\",\n",
    "                \"Identified Issues:\",\n",
    "                *[f\"- {issue}\" for issue in details.get('feedback_issues', [])],\n",
    "                \"Code Snapshot:\",\n",
    "                f\"```python\\n{details['generated_code']}\\n```\",\n",
    "                \"\\n\"\n",
    "            ]\n",
    "            prompt_parts.extend(history_entry)\n",
    "\n",
    "    # Add Test Case Status (if available)\n",
    "    if use_example and test_cases:\n",
    "        prompt_parts.append(\"\\n=== Test Case Status ===\")\n",
    "        for case_name, case_details in test_cases.items():\n",
    "            test_status = [\n",
    "                f\"\\n**Test: {case_name}**\",\n",
    "                f\"Status: {'PASS' if case_details.get('passed', False) else 'FAIL'}\",\n",
    "                f\"Failure Context: {case_details.get('failure_reason', 'N/A')}\",\n",
    "                \"\\n\"\n",
    "            ]\n",
    "            prompt_parts.extend(test_status)\n",
    "\n",
    "    # Build Refinement Instructions\n",
    "    instructions = [\"\\n=== Refinement Requirements ===\"]\n",
    "    instructions.append(\"Generate a revised implementation that:\")\n",
    "    instructions.append(\"- Addresses all identified issues from the feedback analysis\")\n",
    "    instructions.append(\"- Maintains strict compliance with component specifications\")\n",
    "    instructions.append(\"- Preserves existing functionality that passed validation\")\n",
    "    \n",
    "    if next_code_line:\n",
    "        instructions.append(\"\\nFocus ONLY on the next critical modification needed in the code flow.\")\n",
    "    else:\n",
    "        instructions.append(\"\\nProvide a complete revised implementation with clear markers for changes.\")\n",
    "\n",
    "    if more_comments:\n",
    "        instructions.append(\"Include detailed comments explaining:\") \n",
    "        instructions.append(\"- How feedback items were addressed\")\n",
    "        instructions.append(\"- Any trade-offs considered in the modifications\")\n",
    "\n",
    "    if output_planning:\n",
    "        instructions.append(\"\\nStructure your response with:\")\n",
    "        instructions.append(\"<Analysis> - Summary of key changes from feedback\")\n",
    "        instructions.append(\"<Code> - Revised implementation with change markers\")\n",
    "        instructions.append(\"<Validation> - Test cases to verify fixes\")\n",
    "        instructions.append(\"<End>\")\n",
    "    else:\n",
    "        instructions.append(\"\\nStructure your response as:\")\n",
    "        instructions.append(\"<Code> - Revised implementation with brief change comments\")\n",
    "        instructions.append(\"<End>\")\n",
    "\n",
    "    prompt_parts.append(\"\\n\".join(instructions))\n",
    "    return \"\\n\".join(prompt_parts)\n",
    "\n",
    "# code refinement prompt test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### code testing framework"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running tests: 100%|██████████| 4/4 [00:00<00:00, 4002.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Results:\n",
      "{'code_1': {'test_case_1': True, 'test_case_3': True, 'test_case_2': True, 'test_case_4': True}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import concurrent.futures\n",
    "from tqdm import tqdm\n",
    "\n",
    "def run_test(func_obj, test_func):\n",
    "    \"\"\"\n",
    "    Runs a single test: executes test_func passing in func_obj.\n",
    "    Returns the result (True/False) or False if an exception occurs.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        return test_func(func_obj)\n",
    "    except Exception as e:\n",
    "        return False\n",
    "\n",
    "def compile_code(code_str):\n",
    "    \"\"\"\n",
    "    Compiles code from a string and returns the first callable object found.\n",
    "    If no callable is found or an error occurs, returns None.\n",
    "    \"\"\"\n",
    "    local_vars = {}\n",
    "    try:\n",
    "        exec(code_str, {}, local_vars)\n",
    "    except Exception as e:\n",
    "        return None\n",
    "    for obj in local_vars.values():\n",
    "        if callable(obj):\n",
    "            return obj\n",
    "    return None\n",
    "\n",
    "def run_all_tests(functions, test_cases, max_workers=10):\n",
    "    \"\"\"\n",
    "    Runs each test case against each function concurrently with an upper limit on concurrency.\n",
    "    \n",
    "    :param functions: dict of {function_id: function_code_string}\n",
    "    :param test_cases: dict of {test_case_id: test_code_string}\n",
    "    :param max_workers: maximum number of concurrent tests\n",
    "    :return: dict of {function_id: {test_case_id: test_result}}\n",
    "    \"\"\"\n",
    "    # Compile functions and test cases\n",
    "    compiled_functions = {fid: compile_code(code) for fid, code in functions.items()}\n",
    "    compiled_tests = {tid: compile_code(code) for tid, code in test_cases.items()}\n",
    "    \n",
    "    # Prepare the result dictionary\n",
    "    results = {fid: {} for fid in functions}\n",
    "    \n",
    "    total_tests = len(compiled_functions) * len(compiled_tests)\n",
    "    # Use ThreadPoolExecutor to run tests concurrently\n",
    "    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "        futures = {}\n",
    "        pbar = tqdm(total=total_tests, desc=\"Running tests\")\n",
    "        for func_id, func_obj in compiled_functions.items():\n",
    "            for test_id, test_func in compiled_tests.items():\n",
    "                # If either the function or test did not compile, mark test as failed\n",
    "                if func_obj is None or test_func is None:\n",
    "                    results[func_id][test_id] = False\n",
    "                    pbar.update(1)\n",
    "                    continue\n",
    "                # Submit the test for execution\n",
    "                future = executor.submit(run_test, func_obj, test_func)\n",
    "                futures[future] = (func_id, test_id)\n",
    "        \n",
    "        # As each future completes, record the result\n",
    "        for future in concurrent.futures.as_completed(futures):\n",
    "            func_id, test_id = futures[future]\n",
    "            try:\n",
    "                result = future.result()\n",
    "            except Exception:\n",
    "                result = False\n",
    "            results[func_id][test_id] = result\n",
    "            pbar.update(1)\n",
    "        pbar.close()\n",
    "    return results\n",
    "\n",
    "\n",
    "test_cases = {\n",
    "    \"test_case_1\": \"\"\"def test_case(func):\n",
    "nums1 = [1, 3, 5]\n",
    "nums2 = [2, 4, 6, 8]\n",
    "return func(nums1, nums2) == 4.0\"\"\",\n",
    "    \"test_case_2\": \"\"\"def test_case(func):\n",
    "nums1 = [-3, -2, -1]\n",
    "nums2 = [-5, -4]\n",
    "return func(nums1, nums2) == -3.0\"\"\",\n",
    "    \"test_case_3\": \"\"\"def test_case(func):\n",
    "nums1 = [2]\n",
    "nums2 = [3]\n",
    "return func(nums1, nums2) == 2.5\"\"\",\n",
    "    \"test_case_4\": \"\"\"def test_case(func):\n",
    "import time\n",
    "nums1 = [i for i in range(1000)]\n",
    "nums2 = [i for i in range(500, 1500)]\n",
    "start_time = time.time()\n",
    "func(nums1, nums2)\n",
    "end_time = time.time()\n",
    "return end_time - start_time < 0.001\"\"\",\n",
    "}\n",
    "\n",
    "functions_to_test = {\n",
    "    'code_1': \"\"\"def merge_arrays(nums1, nums2):\n",
    "merged = sorted(nums1 + nums2)\n",
    "n = len(merged)\n",
    "if n % 2 == 0:\n",
    "    return (merged[n//2 - 1] + merged[n//2]) / 2\n",
    "else:\n",
    "    return merged[n//2]\"\"\"\n",
    "}\n",
    "\n",
    "results = run_all_tests(functions_to_test, test_cases, max_workers=5)\n",
    "print(\"Test Results:\")\n",
    "print(results)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### pass rate predictor test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.data import Data, Batch\n",
    "from torch_geometric.nn import GATConv, global_add_pool\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from tree_sitter import Language, Parser\n",
    "\n",
    "# 1. 生产级配置（可通过YAML文件加载）\n",
    "class RewardConfig:\n",
    "    # 结构解析参数\n",
    "    ast_embed_dim = 128       # AST节点嵌入维度\n",
    "    max_ast_nodes = 512       # 最大AST节点数\n",
    "    \n",
    "    # 模型架构参数\n",
    "    gat_heads = 4            # GAT注意力头数\n",
    "    text_model = \"microsoft/codebert-base\"  # 预训练文本模型\n",
    "    fusion_dim = 256          # 多模态融合维度\n",
    "    \n",
    "    # 训练参数\n",
    "    batch_size = 16           # 考虑图数据的显存占用\n",
    "    grad_clip = 1.0           # 梯度裁剪阈值\n",
    "    mixed_precision = True    # 混合精度训练\n",
    "\n",
    "# 2. 增强型结构特征提取器\n",
    "class CodeStructureParser:\n",
    "    def __init__(self, lang='python'):\n",
    "        # 初始化多语言解析器\n",
    "        self.parser = Parser()\n",
    "        self.parser.set_language(Language('build/python.so', lang))\n",
    "        \n",
    "        # AST节点类型映射\n",
    "        self.node_types = {\n",
    "            'function_definition': 1,\n",
    "            'call': 2,\n",
    "            'identifier': 3,\n",
    "            # 扩展其他节点类型...\n",
    "        }\n",
    "\n",
    "    def parse_to_graph(self, code):\n",
    "        \"\"\"将代码解析为图结构数据\"\"\"\n",
    "        tree = self.parser.parse(bytes(code, 'utf-8'))\n",
    "        \n",
    "        # 提取AST节点特征和边关系\n",
    "        nodes, edges = [], []\n",
    "        self._traverse(tree.root_node, nodes, edges)\n",
    "        \n",
    "        # 构建PyG数据对象\n",
    "        node_tensor = torch.tensor(nodes, dtype=torch.long)\n",
    "        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "        \n",
    "        return Data(x=node_tensor, edge_index=edge_index)\n",
    "\n",
    "    def _traverse(self, node, nodes, edges, parent_id=-1):\n",
    "        \"\"\"递归遍历AST构建图结构\"\"\"\n",
    "        current_id = len(nodes)\n",
    "        nodes.append(self.node_types.get(node.type, 0))  # 节点类型编码\n",
    "        \n",
    "        if parent_id != -1:  # 添加父子边\n",
    "            edges.append([parent_id, current_id])\n",
    "        \n",
    "        for child in node.children:\n",
    "            self._traverse(child, nodes, edges, current_id)\n",
    "\n",
    "# 3. 双模态融合模型架构\n",
    "class CodeRewardModel(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        \n",
    "        # 文本语义分支（冻结底层）\n",
    "        self.text_encoder = AutoModel.from_pretrained(config.text_model)\n",
    "        for param in self.text_encoder.parameters():\n",
    "            param.requires_grad = False  # 冻结预训练层\n",
    "        self.text_proj = nn.Linear(768, 128)  # 降维适配\n",
    "        \n",
    "        # 结构特征分支\n",
    "        self.ast_embed = nn.Embedding(100, config.ast_embed_dim)\n",
    "        self.gat1 = GATConv(config.ast_embed_dim, 64, heads=config.gat_heads)\n",
    "        self.gat2 = GATConv(64*config.gat_heads, 32, heads=1)\n",
    "        \n",
    "        # 多模态融合\n",
    "        self.fusion = nn.Sequential(\n",
    "            nn.Linear(128+32, config.fusion_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.LayerNorm(config.fusion_dim)\n",
    "        )\n",
    "        \n",
    "        # 奖励预测头\n",
    "        self.scorer = nn.Sequential(\n",
    "            nn.Linear(config.fusion_dim, 64),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(64, 1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self, text_inputs, graph_data):\n",
    "        # 文本语义编码\n",
    "        text_feats = self.text_encoder(**text_inputs).last_hidden_state[:, 0, :]\n",
    "        text_feats = self.text_proj(text_feats)\n",
    "        \n",
    "        # 图结构编码\n",
    "        x = self.ast_embed(graph_data.x)\n",
    "        x = self.gat1(x, graph_data.edge_index)\n",
    "        x = F.leaky_relu(x)\n",
    "        x = self.gat2(x, graph_data.edge_index)\n",
    "        graph_feats = global_add_pool(x, graph_data.batch)  # 图级特征\n",
    "        \n",
    "        # 特征融合\n",
    "        fused = self.fusion(torch.cat([text_feats, graph_feats], dim=-1))\n",
    "        return self.scorer(fused)\n",
    "\n",
    "class RewardTrainingFramework:\n",
    "    def __init__(self, config=RewardConfig()):\n",
    "        self.config = config\n",
    "        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        \n",
    "        # 初始化组件\n",
    "        self.parser = CodeStructureParser()\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(config.text_model)\n",
    "        self.model = CodeRewardModel(config).to(self.device)\n",
    "        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-4)\n",
    "        self.scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision)\n",
    "\n",
    "    def process_batch(self, codes):\n",
    "        graph_list, text_inputs = [], []\n",
    "        \n",
    "        # 并行处理结构解析和文本编码\n",
    "        for code in codes:\n",
    "            # 结构解析\n",
    "            graph = self.parser.parse_to_graph(code)\n",
    "            graph_list.append(graph)\n",
    "            \n",
    "            # 文本处理\n",
    "            text = self.tokenizer(\n",
    "                code, \n",
    "                padding='max_length', \n",
    "                truncation=True, \n",
    "                max_length=512,\n",
    "                return_tensors='pt'\n",
    "            )\n",
    "            text_inputs.append(text)\n",
    "        \n",
    "        # 批量图数据组装\n",
    "        graph_batch = Batch.from_data_list(graph_list).to(self.device)\n",
    "        \n",
    "        # 文本数据组装\n",
    "        text_batch = {\n",
    "            'input_ids': torch.cat([x['input_ids'] for x in text_inputs], dim=0).to(self.device),\n",
    "            'attention_mask': torch.cat([x['attention_mask'] for x in text_inputs], dim=0).to(self.device)\n",
    "        }\n",
    "        \n",
    "        return text_batch, graph_batch\n",
    "\n",
    "    def train_step(self, codes, scores):\n",
    "        \"\"\"混合精度训练步骤\"\"\"\n",
    "        self.model.train()\n",
    "        text_batch, graph_batch = self.process_batch(codes)\n",
    "        scores = torch.tensor(scores).float().to(self.device)\n",
    "        \n",
    "        with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):\n",
    "            preds = self.model(text_batch, graph_batch).squeeze()\n",
    "            loss = F.mse_loss(preds, scores)\n",
    "        \n",
    "        self.optimizer.zero_grad()\n",
    "        self.scaler.scale(loss).backward()\n",
    "        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)\n",
    "        self.scaler.step(self.optimizer)\n",
    "        self.scaler.update()\n",
    "        \n",
    "        return loss.item()\n",
    "\n",
    "    def predict(self, codes):\n",
    "        self.model.eval()\n",
    "        text_batch, graph_batch = self.process_batch(codes)\n",
    "        \n",
    "        with torch.no_grad(), torch.cuda.amp.autocast():\n",
    "            return self.model(text_batch, graph_batch).cpu().numpy()\n",
    "\n",
    "# 5. 训练流程示例\n",
    "if __name__ == \"__main__\":\n",
    "    framework = RewardTrainingFramework()\n",
    "    \n",
    "    # 模拟训练数据（生产环境应从数据库加载）\n",
    "    train_data = [\n",
    "        (\"def add(a, b): return a + b\", 0.92),\n",
    "        (\"for i in range(10): print(i)\", 0.65),\n",
    "        # 更多样本...\n",
    "    ]\n",
    "    \n",
    "    # 生产级训练循环\n",
    "    for epoch in range(10):\n",
    "        total_loss = 0\n",
    "        codes, scores = zip(*train_data)  # 实际应分批次加载\n",
    "        \n",
    "        # 模拟数据分批次\n",
    "        for i in tqdm(range(0, len(codes), framework.config.batch_size)):\n",
    "            batch_codes = codes[i:i+framework.config.batch_size]\n",
    "            batch_scores = scores[i:i+framework.config.batch_size]\n",
    "            \n",
    "            loss = framework.train_step(batch_codes, batch_scores)\n",
    "            total_loss += loss\n",
    "        \n",
    "        print(f\"Epoch {epoch} Loss: {total_loss/len(train_data):.4f}\")\n",
    "    \n",
    "    # 保存生产模型\n",
    "    torch.save(framework.model.state_dict(), \"reward_model.pth\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
