{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 16:59:10,435 - asyncio - ERROR - Task exception was never retrieved\n",
      "future: <Task finished name='Task-9' coro=<LLMTM.get_test_cases_async() done, defined at C:\\Users\\Zihang Zeng\\AppData\\Local\\Temp\\ipykernel_85188\\3063914536.py:510> exception=ValueError('Failed to extract test cases, current llm_output:\\n', '<correctness>\\n<planning>\\nIn this test function, I will validate the output of the function that performs the operations on the input array `nums`. I will create a few test cases to ensure correctness, including sorted and unsorted arrays, to confirm that the function produces the expected number of operations correctly. The expected results will be determined based on the provided examples along with some additional cases that cover different scenarios.\\n</planning>\\n<code>\\ndef test_case(func):\\n    # Test case 1: Basic unsorted array\\n    nums1 = [5, 2, 3, 1]\\n    expected_output1 = 2  # Explanation from the prompt\\n    result1 = func(nums1)\\n    if result1 != expected_output1:\\n        return False\\n\\n    # Test case 2: Already sorted array\\n    nums2 = [1, 2, 2]\\n    expected_output2 = 0  # Already non-decreasing\\n    result2 = func(nums2)\\n    if result2 != expected_output2:\\n        return False\\n\\n    # Test case 3: Small unsorted array\\n    nums3 = [3, 1]\\n    expected_output3 = 1  # (3, 1) -> 4\\n    result3 = func(nums3)\\n    if result3 != expected_output3:\\n        return False\\n\\n    # Test case 4: All elements the same\\n    nums4 = [4, 4, 4, 4]\\n    expected_output4 = 0  # Already non-decreasing\\n    result4 = func(nums4)\\n    if result4 != expected_output4:\\n        return False\\n\\n    # Test case 5: Single element array\\n    nums5 = [42]\\n    expected_output5 = 0  # Already non-decreasing\\n    result5 = func(nums5)\\n    if result5 != expected_output5:\\n        return False\\n\\n    # Test case 6: Empty array\\n    nums6 = []\\n    expected_output6 = 0  # Already non-decreasing\\n    result6 = func(nums6)\\n    if result6 != expected_output6:\\n        return False\\n\\n    # Additional test case: two elements reversed\\n    nums7 = [2, 1]\\n    expected_output7 = 1  # (2, 1) -> 3\\n    result7 = func(nums7)\\n    if result7 != expected_output7:\\n        return False\\n\\n    return True\\n</code>')>\n",
      "Traceback (most recent call last):\n",
      "  File \"C:\\Users\\Zihang Zeng\\AppData\\Local\\Temp\\ipykernel_85188\\3063914536.py\", line 531, in get_test_cases_async\n",
      "    raise ValueError(\"Failed to extract test cases, current llm_output:\\n\", llm_output)\n",
      "ValueError: ('Failed to extract test cases, current llm_output:\\n', '<correctness>\\n<planning>\\nIn this test function, I will validate the output of the function that performs the operations on the input array `nums`. I will create a few test cases to ensure correctness, including sorted and unsorted arrays, to confirm that the function produces the expected number of operations correctly. The expected results will be determined based on the provided examples along with some additional cases that cover different scenarios.\\n</planning>\\n<code>\\ndef test_case(func):\\n    # Test case 1: Basic unsorted array\\n    nums1 = [5, 2, 3, 1]\\n    expected_output1 = 2  # Explanation from the prompt\\n    result1 = func(nums1)\\n    if result1 != expected_output1:\\n        return False\\n\\n    # Test case 2: Already sorted array\\n    nums2 = [1, 2, 2]\\n    expected_output2 = 0  # Already non-decreasing\\n    result2 = func(nums2)\\n    if result2 != expected_output2:\\n        return False\\n\\n    # Test case 3: Small unsorted array\\n    nums3 = [3, 1]\\n    expected_output3 = 1  # (3, 1) -> 4\\n    result3 = func(nums3)\\n    if result3 != expected_output3:\\n        return False\\n\\n    # Test case 4: All elements the same\\n    nums4 = [4, 4, 4, 4]\\n    expected_output4 = 0  # Already non-decreasing\\n    result4 = func(nums4)\\n    if result4 != expected_output4:\\n        return False\\n\\n    # Test case 5: Single element array\\n    nums5 = [42]\\n    expected_output5 = 0  # Already non-decreasing\\n    result5 = func(nums5)\\n    if result5 != expected_output5:\\n        return False\\n\\n    # Test case 6: Empty array\\n    nums6 = []\\n    expected_output6 = 0  # Already non-decreasing\\n    result6 = func(nums6)\\n    if result6 != expected_output6:\\n        return False\\n\\n    # Additional test case: two elements reversed\\n    nums7 = [2, 1]\\n    expected_output7 = 1  # (2, 1) -> 3\\n    result7 = func(nums7)\\n    if result7 != expected_output7:\\n        return False\\n\\n    return True\\n</code>')\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "import json\n",
    "import asyncio\n",
    "import random\n",
    "import os\n",
    "import logging\n",
    "import concurrent.futures\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "from tqdm.asyncio import tqdm_asyncio\n",
    "\n",
    "from openai import OpenAI\n",
    "from openai import AsyncOpenAI\n",
    "\n",
    "# 创建Logger实例\n",
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.DEBUG)\n",
    "\n",
    "# 定义日志格式\n",
    "formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "\n",
    "# 输出到文件的Handler\n",
    "log_dir = Path(\"log\")\n",
    "log_dir.mkdir(exist_ok=True)\n",
    "timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "log_filename = f\"{timestamp}.log\"\n",
    "log_path = log_dir / log_filename\n",
    "file_handler = logging.FileHandler(log_path)\n",
    "file_handler.setLevel(logging.DEBUG)\n",
    "file_handler.setFormatter(formatter)\n",
    "\n",
    "# 输出到控制台的Handler\n",
    "console_handler = logging.StreamHandler()\n",
    "console_handler.setLevel(logging.INFO)\n",
    "console_handler.setFormatter(formatter)\n",
    "\n",
    "# 将Handler添加到Logger\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
    "    handlers=[\n",
    "        file_handler,   # 文件输出\n",
    "        console_handler # 控制台输出\n",
    "    ]\n",
    ")\n",
    "\n",
    "class LLMModel():\n",
    "    def __init__(self, api_key, model=\"gpt-3.5-turbo\"):\n",
    "        if api_key is None:\n",
    "            self.api_key = \"sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg\"\n",
    "        else:\n",
    "            self.api_key = api_key\n",
    "        self.model = model\n",
    "        self.client = OpenAI(api_key=self.api_key)\n",
    "        self.client_async = AsyncOpenAI(api_key=self.api_key)\n",
    "    \n",
    "    def LLM_response(self, prompt, gen_kwargs={}, model=None):\n",
    "        if model is None:\n",
    "            model = self.model\n",
    "\n",
    "        if type(prompt) == str:\n",
    "            input_messages = [\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "                ]\n",
    "        elif type(prompt) == list:\n",
    "            input_messages = prompt\n",
    "        else:\n",
    "            logging.error(\"prompt must be a string or a list of messages, current type: \", type(prompt))\n",
    "            raise ValueError(\"prompt must be a string or a list of messages\")\n",
    "        \n",
    "        completion = self.client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=input_messages,\n",
    "            **gen_kwargs\n",
    "            )\n",
    "\n",
    "        return completion.choices[0].message.content\n",
    "    \n",
    "    async def LLM_response_async(self, prompt, gen_kwargs={}, model=None):\n",
    "        if model is None:\n",
    "            model = self.model\n",
    "\n",
    "        if type(prompt) == str:\n",
    "            prompt = \" \".join(prompt)\n",
    "            input_messages = [\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "                ]\n",
    "        elif type(prompt) == list:\n",
    "            input_messages = prompt\n",
    "        else:\n",
    "            logging.error(\"prompt must be a string or a list of messages, current type: \", type(prompt))\n",
    "            raise ValueError(\"prompt must be a string or a list of messages\")\n",
    "        \n",
    "        completion = await self.client_async.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=input_messages,\n",
    "            **gen_kwargs\n",
    "            )\n",
    "        return completion.choices[0].message.content\n",
    "\n",
    "class LLMTM():\n",
    "    def __init__(self, task_description, LLM_model):\n",
    "        self.LLM_model = LLM_model\n",
    "        self.task_description = task_description\n",
    "\n",
    "    def create_plan_prompt(self, task_description=None):\n",
    "\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "\n",
    "        task_decompose_prompt = \"\"\"You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks and then give a detailed overall plan based on your defined subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.\n",
    "\n",
    "Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single component.\n",
    "\n",
    "Your output must strictly follow the format below:\n",
    "\n",
    "<components>\n",
    "{\n",
    "  \"component_1\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[type, shape or null]],\n",
    "    \"output_format\": [[type, shape or null]],\n",
    "    \"work_flow\": [str],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  \"component_2\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[\"type\", shape or null]],\n",
    "    \"output_format\": [[\"type\", shape or null]],\n",
    "    \"work_flow\": [str],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  ...\n",
    "}\n",
    "</components>\n",
    "\n",
    "<overall_plan>\n",
    "{\n",
    "  \"input_format\": [[\"type\", shape or null]],\n",
    "  \"output_format\": [[\"type\", shape or null]],\n",
    "  \"components\": [str],\n",
    "  \"plan\": [str],\n",
    "  \"test_case_generation_advise\": [str]\n",
    "}\n",
    "</overall_plan>\n",
    "\n",
    "Here are additional detailed explanations of each field:\n",
    "\n",
    "For <components>:\n",
    "- **component_X**: The key represents the subtask name, it should be replaced by the actual class/function name of the component (e.g., \"merge_arrays\", \"calculate_median\").\n",
    "- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).\n",
    "- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:\n",
    "  - The first element indicates the data type (e.g., \"list\", \"dict\", NumPy array, torch.Tensor). DO make sure the data type is a string.\n",
    "  - The second element indicates the fixed shape if applicable; otherwise, it is null.\n",
    "- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`, note that it has to be a list of lists.\n",
    "- **work_flow**: Provide a detailed step-by-step plan that outlines the workflow of how the component functions to achieve the subtask.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "For <overall_plan>:\n",
    "- **input_format**: Describe the format of the input arguments required for the overall task. It follows the same structure as `input_format` in the component section.\n",
    "- **output_format**: Describe the format of the output arguments generated by the overall task. It follows the same structure as `output_format` in the component section.\n",
    "- **components**: List the components in the order.\n",
    "- **plan**: Provide a detailed step-by-step plan that outlines the workflow of how the components interact with each other to achieve the overall task. This should be a high-level description of the process.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases for the overall task, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:\n",
    "\n",
    "\"{{TASK_DESCRIPTION}}\"\n",
    "\n",
    "Use <> to indicate both start and end of the component part and the overall plan. Ensure that the components and the overall plan are clearly separated.\n",
    "\n",
    "Please provide your structured decomposition according to the instructions above.\n",
    "\"\"\"\n",
    "        task_decompose_prompt = task_decompose_prompt.replace(\"{{TASK_DESCRIPTION}}\", task_description)\n",
    "        return task_decompose_prompt\n",
    "    \n",
    "    def create_plan_refinement_prompt(self, user_feedback, previous_output, task_description=None):\n",
    "\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "\n",
    "        plan_refinement_prompt = \"\"\"You are an expert agent specialized in refining and improving code generation plans through iterative feedback. Given a task description, previous decomposition output, and user feedback, your job is to critically analyze the existing plan and modify it accordingly while maintaining the required output format.\n",
    "\n",
    "Carefully review the previous components and overall plan, then:\n",
    "1. Preserve correct/valid elements that don't conflict with the feedback\n",
    "2. Make targeted modifications based on the user's specific advice\n",
    "3. Ensure consistency between components and overall plan\n",
    "4. Verify input/output formats and workflow logic\n",
    "5. Check for any introduced errors during modification\n",
    "\n",
    "The input consists of three elements:\n",
    "- Original Task Description: \"{{TASK_DESCRIPTION}}\"\n",
    "- Previous Decomposition Output: \n",
    "{{PREVIOUS_OUTPUT}}\n",
    "- User Feedback: \"{{USER_ADVICE}}\"\n",
    "\n",
    "Your output must STRICTLY follow the original format with these sections:\n",
    "<components>...</components>\n",
    "<overall_plan>...</overall_plan>\n",
    "\n",
    "Follow these guidelines:\n",
    "- Explicitly address all points in the user feedback\n",
    "- Clearly document any changes made from previous version\n",
    "- Preserve JSON structure and formatting requirements\n",
    "- If feedback contradicts original requirements, prioritize feedback\n",
    "\n",
    "Again, user feedback is: \"{{USER_ADVICE}}\"\n",
    "\n",
    "Provide your refined decomposition with clear explanations of changes in the component descriptions and overall plan. Ensure modularity, testability, and coverage of edge cases mentioned in feedback.\"\"\"\n",
    "\n",
    "        plan_refinement_prompt = plan_refinement_prompt.replace(\"{{TASK_DESCRIPTION}}\", task_description)\n",
    "        plan_refinement_prompt = plan_refinement_prompt.replace(\"{{USER_ADVICE}}\", user_feedback)\n",
    "        plan_refinement_prompt = plan_refinement_prompt.replace(\"{{PREVIOUS_OUTPUT}}\", previous_output)\n",
    "        return plan_refinement_prompt\n",
    "\n",
    "    def extract_plan(self, input_str):\n",
    "        # Updated regex pattern to match <tag>...</tag> format\n",
    "        pattern = r'<(components|overall_plan)>(.*?)</\\1>'\n",
    "        \n",
    "        # Find all matches, allowing multiline content\n",
    "        matches = re.findall(pattern, input_str, re.DOTALL)\n",
    "        \n",
    "        result = {}\n",
    "        for block_name, content in matches:\n",
    "            try:\n",
    "                # Strip whitespace\n",
    "                cleaned_content = content.strip()\n",
    "                \n",
    "                # Fix trailing commas\n",
    "                cleaned_content = re.sub(r',\\s*}', '}', cleaned_content)\n",
    "                cleaned_content = re.sub(r',\\s*\\]', ']', cleaned_content)\n",
    "                \n",
    "                # Parse JSON\n",
    "                parsed_data = json.loads(cleaned_content)\n",
    "                result[block_name] = parsed_data\n",
    "            except json.JSONDecodeError as e:\n",
    "                logging.warning(f\"JSON解析错误: {block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}\")\n",
    "                # print(f\"解析错误：{block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}\")\n",
    "                result[block_name] = None\n",
    "                return False\n",
    "        return result\n",
    "\n",
    "    def get_plan(self, task_description=None, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_prompt(task_description)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            extract_plan = self.extract_plan(llm_output)\n",
    "            if extract_plan:\n",
    "                break\n",
    "            else:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract plan, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Failed to extract plan, retrying ({retry_num})...\")\n",
    "        if extract_plan is False:\n",
    "            # print(\"Failed to extract plan, current llm_output:\\n\", llm_output)\n",
    "            logging.error(f\"Failed to extract plan, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract plan, current llm_output:\\n\", llm_output)\n",
    "        return extract_plan, llm_output\n",
    "    \n",
    "    async def get_plan_async(self, num_plan, task_description=None, gen_kwargs={}):\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_prompt(task_description)\n",
    "        # get multiple plans\n",
    "        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]\n",
    "        \n",
    "        llm_output = await tqdm_asyncio.gather(*task_list)\n",
    "        return [self.extract_plan(output) for output in llm_output]\n",
    "    \n",
    "    def refine_plan(self, user_feedback, previous_output, task_description=None, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            extract_plan = self.extract_plan(llm_output)\n",
    "            if extract_plan:\n",
    "                break\n",
    "            retry_num += 1\n",
    "        if extract_plan is False:\n",
    "            raise ValueError(\"Failed to extract plan, current llm_output:\\n\", llm_output)\n",
    "        return extract_plan, llm_output\n",
    "    \n",
    "    async def refine_multi_plan(self, num_plan, user_feedback, previous_output, task_description=None, gen_kwargs={}):\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)\n",
    "        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]\n",
    "        llm_output = await tqdm_asyncio.gather(*task_list)\n",
    "        return [self.extract_plan(output) for output in llm_output]\n",
    "    \n",
    "    def create_test_prompt(self, task_descr_str, task_spec, use_example=True, bulk=True):\n",
    "        \"\"\"\n",
    "        Generates a prompt (or list of prompts) for test case generation based on task specifications.\n",
    "        \n",
    "        Parameters:\n",
    "        - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.\n",
    "        - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.\n",
    "        \n",
    "        Returns:\n",
    "        - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).\n",
    "        \"\"\"\n",
    "        \n",
    "        # Helper function to generate the prompt text from a modified task specification\n",
    "        def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = \"\"):\n",
    "            advisory_list = \"\\n\".join([f\"- {advise}\" for advise in advisories])\n",
    "            prompt = f\"\"\"You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:\n",
    "\n",
    "### Input Specifications:\n",
    "- **Task Description**:\n",
    "{task_descr_str}\n",
    "- **Input Format**: \n",
    "{input_descr_str}\n",
    "- **Output Format**: \n",
    "{output_descr_str}\n",
    "- **Components Used**: {components_str}\n",
    "- **Plan**: \n",
    "{plan_str}\n",
    "- **Test Case Advise**: \n",
    "{advisory_list}\n",
    "\n",
    "### Requirements:\n",
    "1. **Test Function Structure**:\n",
    "   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):...`).\n",
    "   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.\n",
    "   - Include input generation, runtime checks, code inspection, or result validation within the function.\n",
    "\n",
    "2. **Test Types** (use one of these for indicating the test_type):\n",
    "   - `correctness`: Validate output against expected results for specific inputs.\n",
    "   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.\n",
    "   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).\n",
    "   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).\n",
    "   - `error_handling`: Check if errors are raised for invalid inputs.\n",
    "\n",
    "3. **Test Case Diversity**:\n",
    "   - Cover all provided advisories.\n",
    "   - Include at least one test per advisory and one for each test type where applicable.\n",
    "\n",
    "### Output Format:\n",
    "For each test case, you need to firstly define the Test Types to indicate what type of test case you are going to create and then give the reasoning and explanation of the test case. After that, generate the test function based on the your reasoning.\n",
    "\n",
    "For each test function, return with following structure:\n",
    "\n",
    "<Type>\n",
    "Pick one of correctness|edge_case|runtime|component_check|error_handling\n",
    "</Type>\n",
    "<Planning>\n",
    "Introduce how would you design the test function. Specify the purpose of the test function and the reasoning behind it. Explain step by step why your test case is correct and what is the expected output.\n",
    "</Planning>\n",
    "<Code>\n",
    "def test_case(func):\n",
    "    # Your test function code here\n",
    "</Code>\n",
    "\n",
    "If you are going to create multiple test cases, please separate them with <separator> tag.\n",
    "\n",
    "{example_text}\n",
    "Generate test cases that rigorously validate the function's behavior, code structure, and performance.\n",
    "You MUST strictly follow the output format and structure. The generated test functions MUST be runnable function that use another python function as its parameter.\"\"\"\n",
    "            return prompt\n",
    "\n",
    "        if use_example:\n",
    "            examples_text = \"\"\n",
    "        else:\n",
    "            examples_text = \"\"\n",
    "\n",
    "        # Process input_format into a descriptive string\n",
    "        input_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_descr.append(f\"- Argument {idx}: {dtype} ({shape_info})\")\n",
    "        input_descr_str = \"\\n\".join(input_descr)\n",
    "\n",
    "        # Process output_format into a descriptive string\n",
    "        output_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_descr.append(f\"- Output {idx}: {dtype} ({shape_info})\")\n",
    "        output_descr_str = \"\\n\".join(output_descr)\n",
    "\n",
    "        # Process components and plan\n",
    "        components_str = \", \".join(task_spec['components'])\n",
    "        plan_str = \"\\n\".join(task_spec['plan'])\n",
    "\n",
    "        if bulk:\n",
    "            # Generate a single prompt with all advisories\n",
    "            advisories = task_spec['test_case_generation_advise']\n",
    "            return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)\n",
    "        else:\n",
    "            # Generate a list of prompts, each with a single advisory\n",
    "            prompts = []\n",
    "            for advise in task_spec['test_case_generation_advise']:\n",
    "                single_advisory = [advise]\n",
    "                prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)\n",
    "                prompts.append(prompt)\n",
    "            return prompts\n",
    "\n",
    "    def extract_test_cases(self, output_text):\n",
    "        \"\"\"\n",
    "        Extracts test cases from LLM output text with flexible tag handling.\n",
    "        Supports case-insensitive tags, missing <Type> tags, and multi-separators.\n",
    "        \"\"\"\n",
    "        import re\n",
    "        test_cases = {}\n",
    "\n",
    "        def preprocess_text(text):\n",
    "            # 定义一个占位符，避免选中正常文本的内容\n",
    "            placeholder = \"###NL###\"\n",
    "            \n",
    "            # 定义替换函数：将匹配到的代码块内的换行符替换为占位符\n",
    "            def repl_code(match):\n",
    "                block = match.group(0)\n",
    "                return block.replace(\"\\n\", placeholder)\n",
    "            \n",
    "            # 对 <code>...</code> 块进行替换（不区分大小写，多行匹配）\n",
    "            text = re.sub(r'(<\\s*code\\s*>.*?</\\s*code\\s*>)', repl_code, text, flags=re.IGNORECASE | re.DOTALL)\n",
    "            # 对 ```python ... ``` 块进行替换\n",
    "            text = re.sub(r'(```python.*?```)', repl_code, text, flags=re.IGNORECASE | re.DOTALL)\n",
    "            \n",
    "            # 如果还需要保护其他块，也可以在这里加上类似处理\n",
    "            return text, placeholder\n",
    "\n",
    "        # 预处理：隐藏代码块内的换行符\n",
    "        modified_text, placeholder = preprocess_text(output_text)\n",
    "        \n",
    "        # 分块：使用<separator>标签 或 连续空行分块\n",
    "        split_pattern = r'(?:<\\s*/\\s*separator\\s*>|<\\s*separator\\s*>|<\\s*separator\\s*/>|\\n\\s*\\n\\s*)'\n",
    "        test_case_blocks = re.split(split_pattern, modified_text, flags=re.IGNORECASE)\n",
    "        test_case_blocks = [b.strip() for b in test_case_blocks if b.strip()]\n",
    "        \n",
    "        # 还原各块内被隐藏的换行符\n",
    "        test_case_blocks = [b.replace(placeholder, \"\\n\") for b in test_case_blocks]\n",
    "\n",
    "        # print(f\"共分出 {len(test_case_blocks)} 个块\")\n",
    "        for idx, block in enumerate(test_case_blocks, 1):\n",
    "            # 1. 提取 test_type\n",
    "            test_type = None\n",
    "            \n",
    "            # Case 1：通过 <type>value</type>\n",
    "            type_match = re.search(\n",
    "                r'<\\s*type\\s*>(.*?)<\\s*/\\s*type\\s*>', \n",
    "                block, \n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            if type_match:\n",
    "                test_type = type_match.group(1).strip()\n",
    "            else:\n",
    "                # Case 2：判断是否有其他非已知标签标记的类型\n",
    "                known_tags = {'type', 'planning', 'code', 'reasoning', 'test_function', 'separator'}\n",
    "                for tag_match in re.finditer(r'<\\s*([^\\s>/]+)\\s*.*?>', block, re.IGNORECASE):\n",
    "                    tag_name = tag_match.group(1).lower()\n",
    "                    if tag_name not in known_tags:\n",
    "                        test_type = tag_name\n",
    "                        break  # 取第一个不在已知标签中的\n",
    "                \n",
    "            if not test_type:  # 若无 test_type 则跳过该块\n",
    "                continue\n",
    "            \n",
    "            # 2. 提取 reasoning（支持 <planning> 和 <reasoning>）\n",
    "            reasoning_match = re.search(\n",
    "                r'<\\s*(?:reasoning|planning)\\s*>(.*?)<\\s*/\\s*(?:reasoning|planning)\\s*>',\n",
    "                block, \n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            reasoning = reasoning_match.group(1).strip() if reasoning_match else \"\"\n",
    "            \n",
    "            # 3. 提取 test_function（优先顺序：test_function 标签 > code 标签 > 独立代码块）\n",
    "            test_func = None\n",
    "            \n",
    "            # 检查 <test_function> 标签\n",
    "            test_func_match = re.search(\n",
    "                r'<\\s*test_function\\s*>(.*?)<\\s*/\\s*test_function\\s*>',\n",
    "                block, \n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            if test_func_match:\n",
    "                content = test_func_match.group(1).strip()\n",
    "                code_block = re.search(r'```python\\s*(.*?)\\s*```', content, re.DOTALL)\n",
    "                test_func = code_block.group(1).strip() if code_block else content\n",
    "            else:\n",
    "                # 检查 <code> 标签\n",
    "                code_match = re.search(\n",
    "                    r'<\\s*code\\s*>(.*?)<\\s*/\\s*code\\s*>',\n",
    "                    block,\n",
    "                    re.IGNORECASE | re.DOTALL\n",
    "                )\n",
    "                if code_match:\n",
    "                    content = code_match.group(1).strip()\n",
    "                    code_block = re.search(r'```python\\s*(.*?)\\s*```', content, re.DOTALL)\n",
    "                    test_func = code_block.group(1).strip() if code_block else content\n",
    "                else:\n",
    "                    # 检查独立代码块 (```python ... ```)\n",
    "                    code_block = re.search(r'```python\\s*(.*?)\\s*```', block, re.DOTALL)\n",
    "                    if code_block:\n",
    "                        test_func = code_block.group(1).strip()\n",
    "            \n",
    "            if test_type and test_func:\n",
    "                test_cases[f'test_case_{idx}'] = {\n",
    "                    'test_type': test_type,\n",
    "                    'purpose': reasoning,\n",
    "                    'test_function': test_func\n",
    "                }\n",
    "        \n",
    "        if not test_cases:\n",
    "            # 如果没有提取到测试用例，则返回 False\n",
    "            return False\n",
    "\n",
    "        return test_cases\n",
    "    \n",
    "    def get_test_cases(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            test_cases = self.extract_test_cases(llm_output)\n",
    "            if test_cases:\n",
    "                break\n",
    "            else:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract test cases, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Failed to extract test cases, retrying ({retry_num})...\")\n",
    "        if test_cases is False:\n",
    "            logging.error(f\"Failed to extract test cases, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract test cases, current llm_output:\\n\", llm_output)\n",
    "        return test_cases\n",
    "    \n",
    "    async def get_test_cases_async(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)\n",
    "\n",
    "        # # debug\n",
    "        # print(prompt)\n",
    "\n",
    "        while retry_num <= max_retry:\n",
    "            # debug\n",
    "            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)\n",
    "            test_cases = self.extract_test_cases(llm_output)\n",
    "            if test_cases:\n",
    "                break\n",
    "            else:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract test cases, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Failed to extract test cases, retrying ({retry_num})..., current llm_output:\\n{llm_output}\")\n",
    "        if test_cases is False:\n",
    "            logging.error(f\"Failed to extract test cases, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract test cases, current llm_output:\\n\", llm_output)\n",
    "        return test_cases\n",
    "    \n",
    "class LLMCG():\n",
    "    def __init__(self, task_description, LLM_model):\n",
    "        self.task_description = task_description\n",
    "        self.LLM_model = LLM_model\n",
    "\n",
    "    def create_code_generation_prompt(\n",
    "        self,\n",
    "        extracted_plan,\n",
    "        user_feedback=None,\n",
    "        task_description=None,\n",
    "        test_cases=None,\n",
    "        history=None,\n",
    "        next_code_line=False,\n",
    "        output_planning=False,\n",
    "        use_example=False,\n",
    "        use_task_description=False,\n",
    "        use_system_prompt=True,\n",
    "        more_comments=False,\n",
    "        ):\n",
    "\n",
    "        components = extracted_plan[\"components\"]\n",
    "        overall_plan = extracted_plan[\"overall_plan\"]\n",
    "\n",
    "        prompt_parts = []\n",
    "\n",
    "        if user_feedback:\n",
    "            system_prompt = \"You are a code refinement specialist designed to improve existing implementations based on specific feedback. Analyze the provided feedback, identify areas for improvement, and modify the code while strictly maintaining the required input/output formats and component specifications.\"\n",
    "        else:\n",
    "            system_prompt = \"You are a highly skilled coding assistant designed to generate clear, efficient, and correct code based on structured task descriptions and detailed plans provided by the user. Your responses must precisely follow the instructions, formats, and constraints given by the user, and you must strictly adhere to input-output formats, workflows, and specific guidelines outlined.\"\n",
    "\n",
    "        # Add System Prompt if enabled\n",
    "        if use_system_prompt:\n",
    "            prompt_parts.append(f\"=== Role ===\\n{system_prompt}\\n\")\n",
    "\n",
    "        # Add Task Description if enabled\n",
    "        if not task_description:\n",
    "            task_description = self.task_description\n",
    "\n",
    "        if use_task_description:\n",
    "            prompt_parts.append(f\"=== Task Description ===\\n{task_description}\\n\")\n",
    "\n",
    "        # Add Components Section\n",
    "        if components:\n",
    "            prompt_parts.append(\"=== Components ===\")\n",
    "            for comp_name, comp_details in components.items():\n",
    "                # Process Input Format\n",
    "                input_fmt = comp_details[\"input_format\"]\n",
    "                input_lines = []\n",
    "                for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "                    shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                    input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "                input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "                # Process Output Format\n",
    "                output_fmt = comp_details[\"output_format\"]\n",
    "                output_lines = []\n",
    "                for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "                    shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                    output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "                output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "                # Build Component Details\n",
    "                component_part = [\n",
    "                    f\"\\n**Component: {comp_name}**\",\n",
    "                    f\"Step Task Description: {comp_details['step_task_description']}\",\n",
    "                    input_section,\n",
    "                    output_section,\n",
    "                    \"Workflow Steps:\",\n",
    "                    *[f\"- {step}\" for step in comp_details[\"work_flow\"]],\n",
    "                    \"Test Case Generation Advice:\",\n",
    "                    *[f\"- {advice}\" for advice in comp_details[\"test_case_generation_advise\"]],\n",
    "                    \"\\n\",\n",
    "                ]\n",
    "                prompt_parts.extend(component_part)\n",
    "\n",
    "        # Add Overall Plan Section\n",
    "        if overall_plan:\n",
    "            prompt_parts.append(\"\\n=== Overall Plan ===\")\n",
    "            # Process Input Format\n",
    "            input_fmt = overall_plan[\"input_format\"]\n",
    "            input_lines = []\n",
    "            for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "                shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "            input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "            # Process Output Format\n",
    "            output_fmt = overall_plan[\"output_format\"]\n",
    "            output_lines = []\n",
    "            for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "                shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "            output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "            # Build Overall Plan Details\n",
    "            plan_part = [\n",
    "                input_section,\n",
    "                output_section,\n",
    "                f\"Components Order: {', '.join(overall_plan['components'])}\",\n",
    "                \"Plan Steps:\",\n",
    "                *[f\"- {step}\" for step in overall_plan[\"plan\"]],\n",
    "                \"Overall Test Case Advice:\",\n",
    "                *[f\"- {advice}\" for advice in overall_plan[\"test_case_generation_advise\"]],\n",
    "                \"\\n\",\n",
    "            ]\n",
    "            prompt_parts.extend(plan_part)\n",
    "\n",
    "        if user_feedback:\n",
    "            prompt_parts.append(\"\\n=== User Feedback ===\")\n",
    "            prompt_parts.append(user_feedback)\n",
    "\n",
    "        # Add Test Cases if enabled and available\n",
    "        if use_example and test_cases:\n",
    "            prompt_parts.append(\"\\n=== Test Cases ===\")\n",
    "            example_num = 3\n",
    "            for case_name, case_details in test_cases.items():\n",
    "                case_part = [\n",
    "                    f\"\\n**Test Case: {case_name}**\",\n",
    "                    f\"Purpose: {case_details['purpose']}\",\n",
    "                    f\"Type: {case_details['test_type']}\",\n",
    "                    f\"Test Function:\\n{case_details['test_function']}\",\n",
    "                    \"\\n\",\n",
    "                ]\n",
    "                prompt_parts.extend(case_part)\n",
    "                example_num -= 1\n",
    "                if example_num == 0:\n",
    "                    break\n",
    "\n",
    "        # Add History if available\n",
    "        if history:\n",
    "            prompt_parts.append(\"\\n=== Previous Generation Attempts ===\")\n",
    "            for gen_name, gen_details in history.items():\n",
    "                history_part = [\n",
    "                    f\"\\n**Generation: {gen_name}**\",\n",
    "                    f\"Score: {gen_details['score']}\",\n",
    "                    \"Generated Code:\",\n",
    "                    gen_details[\"generated_code\"],\n",
    "                    \"Generation Plan:\",\n",
    "                    *[f\"- {step}\" for step in gen_details[\"generation_plan\"]],\n",
    "                    \"\\n\",\n",
    "                ]\n",
    "                prompt_parts.extend(history_part)\n",
    "\n",
    "        # Build Refinement Instructions\n",
    "        if user_feedback:\n",
    "            refine_instructions = [\"\\n=== Refinement Requirements ===\"]\n",
    "            refine_instructions.append(\"Generate a revised implementation that:\")\n",
    "            refine_instructions.append(\"- Addresses all identified issues from the feedback analysis\")\n",
    "            refine_instructions.append(\"- Maintains strict compliance with component specifications\")\n",
    "            refine_instructions.append(\"- Preserves existing functionality that passed validation\")\n",
    "            prompt_parts.append(\"\\n\".join(refine_instructions))\n",
    "\n",
    "        # Build Instructions\n",
    "        instructions = [\"\\n=== Instructions ===\"]\n",
    "        if next_code_line:\n",
    "            instructions.append(\"Generate ONLY the next line or a small code snippet required to proceed.\")\n",
    "        else:\n",
    "            instructions.append(\"Generate the COMPLETE code based on the components and plan above.\")\n",
    "        instructions.append(\"DO MAKE SURE the complete code is a runnable function, all components are correctly integrated with in this function.\")\n",
    "        instructions.append(\"The complete function should take the input arguments as specified in the overall plan and return the output as specified.\")\n",
    "\n",
    "        if more_comments:\n",
    "            instructions.append(\"Please add as much comments as possible to your code to explain the logic and any critical steps.\")\n",
    "\n",
    "        if output_planning:\n",
    "            instructions.append(\"Structure your response as follows:\")\n",
    "            instructions.append(\"<Code>\")\n",
    "            instructions.append(\"Your code here. DO make sure the output is a single function that integrates all components.\")\n",
    "            instructions.append(\"</Code>\")\n",
    "            instructions.append(\"<Planning>\")\n",
    "            if next_code_line:\n",
    "                instructions.append(\"A concise summary of what this specific code part accomplishes.\")\n",
    "            else:\n",
    "                instructions.append(\"A detailed step-by-step explanation of the code's workflow.\")\n",
    "            instructions.append(\"</Planning>\")\n",
    "            instructions.append(\"<Main Function Name>\")\n",
    "            instructions.append(\"The name of the main function that integrates all components.\")\n",
    "            instructions.append(\"</Main Function Name>\")\n",
    "            instructions.append(\"Provide the code with the same indicator and structure as shown in Instructions. DO NOT return any test cases or example usages in your code!\")\n",
    "        else:\n",
    "            instructions.append(\"Structure your response as follows:\")\n",
    "            instructions.append(\"<Code>\")\n",
    "            instructions.append(\"Your code here\")\n",
    "            instructions.append(\"</Code>\")\n",
    "            instructions.append(\"Provide the code WITHOUT any additional explanations, and DO use the same indicator and structure as shown in Instructions.\")\n",
    "\n",
    "        prompt_parts.append(\"\\n\".join(instructions))\n",
    "\n",
    "        return \"\\n\".join(prompt_parts)\n",
    "    \n",
    "    def extract_code(self, llm_output):\n",
    "        \"\"\"Extracts code and planning sections from LLM output.\"\"\"\n",
    "        result = {\"code\": None, \"plan\": None, \"main_function_name\": None}\n",
    "        \n",
    "        # Extract code section\n",
    "        code_match = re.search(r'<Code>(.*?)(?:</Code>|<End>)', llm_output, re.DOTALL)\n",
    "        if code_match:\n",
    "            result[\"code\"] = code_match.group(1).strip()\n",
    "        else:\n",
    "            # If not found, try to extract from ```python ... ```\n",
    "            code_block_match = re.search(r'```(?:python)?\\s*(.*?)```', llm_output, re.DOTALL)\n",
    "            if code_block_match:\n",
    "                result[\"code\"] = code_block_match.group(1).strip()\n",
    "        \n",
    "        # Extract planning section\n",
    "        plan_match = re.search(r'<Planning>(.*?)(?:</Planning>|<End>)', llm_output, re.DOTALL)\n",
    "        if plan_match:\n",
    "            result[\"plan\"] = plan_match.group(1).strip()\n",
    "\n",
    "        # Extract main function name\n",
    "        main_func_match = re.search(r'<Main Function Name>(.*?)(?:</Main Function Name>|<End>)', llm_output, re.DOTALL)\n",
    "        if main_func_match:\n",
    "            result[\"main_function_name\"] = main_func_match.group(1).strip()\n",
    "        \n",
    "        return result\n",
    "    \n",
    "    def get_code(self, extracted_plan, task_description=None, test_cases=None, history=None, next_code_line=False, output_planning=True, use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_code_generation_prompt(extracted_plan, extracted_plan.get('user_feedback'), task_description, test_cases, history, next_code_line, output_planning, use_example, use_task_description, use_system_prompt, more_comments)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            code_output = self.extract_code(llm_output)\n",
    "            if code_output[\"code\"] is None:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                # print(f\"Current llm_output:\\n{llm_output}\")\n",
    "                logging.warning(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Current llm_output:\\n{llm_output}\")\n",
    "            else:\n",
    "                break\n",
    "        if code_output[\"code\"] is None:\n",
    "            logging.error(f\"Failed to extract code, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract code, current llm_output:\\n\", llm_output)\n",
    "        return code_output\n",
    "    \n",
    "    async def get_code_async(self, extracted_plan, task_description=None, test_cases=None, history=None, next_code_line=False, output_planning=True, use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_code_generation_prompt(extracted_plan, extracted_plan.get('user_feedback'), task_description, test_cases, history, next_code_line, output_planning, use_example, use_task_description, use_system_prompt, more_comments)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)\n",
    "            code_output = self.extract_code(llm_output)\n",
    "            if code_output[\"code\"] is None:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                # print(f\"Current llm_output:\\n{llm_output}\")\n",
    "                logging.warning(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Current llm_output:\\n{llm_output}\")\n",
    "            else:\n",
    "                break\n",
    "        if code_output[\"code\"] is None:\n",
    "            logging.error(f\"Failed to extract code, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract code, current llm_output:\\n\", llm_output)\n",
    "        return code_output\n",
    "\n",
    "class CodeRunner():\n",
    "    def __init__(self, max_workers=5):\n",
    "        self.max_workers = max_workers\n",
    "\n",
    "    def run_test(self, func_obj, test_func):\n",
    "        try:\n",
    "            return test_func(func_obj)\n",
    "        except Exception as e:\n",
    "            return False\n",
    "    \n",
    "    def compile_code(self, code_str, main_function_name=None):\n",
    "        try:\n",
    "            local_vars = {}\n",
    "            exec(code_str, local_vars)  # Use one dict for globals and locals\n",
    "            if main_function_name is not None:\n",
    "                func = local_vars.get(main_function_name)\n",
    "                return func if callable(func) else None\n",
    "            return next((obj for obj in local_vars.values() if callable(obj)), None)\n",
    "        except Exception as e:\n",
    "            print(f\"Compilation Error: {str(e)}, code_str:\\n {code_str}\")\n",
    "            return None\n",
    "    \n",
    "    def run_all_tests(self, functions, test_cases, max_workers=5):\n",
    "        \"\"\"\n",
    "        Updated to handle new function structure with main_function_name\n",
    "        \"\"\"\n",
    "        # 编译函数（处理带主函数名称的情况）\n",
    "        compiled_functions = {\n",
    "            fid: self.compile_code(\n",
    "                code_info['code'],\n",
    "                main_function_name=code_info.get('main_function_name')\n",
    "            )\n",
    "            for fid, code_info in functions.items()\n",
    "        }\n",
    "        \n",
    "        # 编译测试用例（保持原有逻辑）\n",
    "        compiled_tests = {\n",
    "            tid: self.compile_code(code_info['test_function'])\n",
    "            for tid, code_info in test_cases.items()\n",
    "        }\n",
    "\n",
    "        # 准备结果字典\n",
    "        fun_results = {fid: {} for fid in functions}\n",
    "        test_results = {tid: {} for tid in test_cases}\n",
    "\n",
    "        total_tests = len(compiled_functions) * len(compiled_tests)\n",
    "        \n",
    "        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "            futures = {}\n",
    "            pbar = tqdm(total=total_tests, desc=\"Running tests\")\n",
    "            \n",
    "            # 提交测试任务\n",
    "            for func_id, func_obj in compiled_functions.items():\n",
    "                for test_id, test_func in compiled_tests.items():\n",
    "                    # 处理编译失败的情况\n",
    "                    if func_obj is None or test_func is None:\n",
    "                        fun_results[func_id][test_id] = False\n",
    "                        test_results[test_id][func_id] = False\n",
    "                        pbar.update(1)\n",
    "                        continue\n",
    "                    \n",
    "                    # 提交并发任务\n",
    "                    future = executor.submit(self.run_test, func_obj, test_func)\n",
    "                    futures[future] = (func_id, test_id)\n",
    "\n",
    "            # 处理测试结果\n",
    "            for future in concurrent.futures.as_completed(futures):\n",
    "                func_id, test_id = futures[future]\n",
    "                try:\n",
    "                    result = future.result()\n",
    "                except Exception:\n",
    "                    result = False\n",
    "                fun_results[func_id][test_id] = result\n",
    "                test_results[test_id][func_id] = result\n",
    "                pbar.update(1)\n",
    "            \n",
    "            pbar.close()\n",
    "        \n",
    "        return fun_results, test_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:11:13,697 - root - INFO - Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch_geometric.data import Data, Batch\n",
    "from torch_geometric.nn import GATConv, GraphConv, global_max_pool, global_mean_pool\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from torch.utils.data import Dataset, Subset\n",
    "from torch_geometric.loader import DataLoader\n",
    "import logging\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "logging.info(f'Using device: {device}')\n",
    "\n",
    "class CodeGraphDataset(Dataset):\n",
    "    def __init__(self, dataframe, scaler=None):\n",
    "        self.invalid_count = 0\n",
    "        self.dataframe = dataframe.reset_index(drop=True)\n",
    "        self.scaler = scaler if scaler else MinMaxScaler()\n",
    "        if scaler is None:  # 仅训练集拟合\n",
    "            self.scaler.fit(self.dataframe['score'].values.reshape(-1, 1))\n",
    "        logging.info('Score values scaled using MinMaxScaler.')\n",
    "        # Build a vocabulary for AST node types\n",
    "        self.node_type_vocab = self.build_node_type_vocab()\n",
    "        logging.info(f'Built node type vocabulary with size: {len(self.node_type_vocab)}')\n",
    "\n",
    "    def build_node_type_vocab(self):\n",
    "        node_types = set()\n",
    "        for idx, code in enumerate(self.dataframe['code']):\n",
    "            try:\n",
    "                tree = ast.parse(code)\n",
    "                for node in ast.walk(tree):\n",
    "                    node_types.add(type(node).__name__)\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"Error parsing code at index {idx}: {e}\")\n",
    "        node_type_to_id = {nt: idx for idx, nt in enumerate(sorted(node_types))}\n",
    "        return node_type_to_id\n",
    "\n",
    "    def ast_to_graph(self, code):\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "        except Exception as e:\n",
    "            logging.warning(f\"Error parsing code: {e}\")\n",
    "            return None\n",
    "\n",
    "        nodes = []\n",
    "        edges = []\n",
    "        node_features = []\n",
    "        node_id = 0\n",
    "        node_id_map = {}\n",
    "\n",
    "        def traverse(node, parent_id=None):\n",
    "            nonlocal node_id\n",
    "            current_id = node_id\n",
    "            node_id_map[id(node)] = current_id\n",
    "            nodes.append(current_id)\n",
    "            # Encode node type as integer\n",
    "            node_type = type(node).__name__\n",
    "            node_type_id = self.node_type_vocab.get(node_type, len(self.node_type_vocab))  # Handle unknown types\n",
    "            node_features.append([node_type_id])\n",
    "            node_id += 1\n",
    "\n",
    "            if parent_id is not None:\n",
    "                edges.append((parent_id, current_id))\n",
    "\n",
    "            for child in ast.iter_child_nodes(node):\n",
    "                traverse(child, current_id)\n",
    "\n",
    "        traverse(tree)\n",
    "\n",
    "        if not nodes:\n",
    "            return None\n",
    "\n",
    "        # Convert edges to a tensor\n",
    "        if edges:\n",
    "            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "        else:\n",
    "            edge_index = torch.empty((2, 0), dtype=torch.long)\n",
    "\n",
    "        # Convert node features to a tensor\n",
    "        x = torch.tensor(node_features, dtype=torch.long)\n",
    "\n",
    "        # Create a Data object\n",
    "        data = Data(x=x, edge_index=edge_index)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataframe)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.dataframe.iloc[idx]\n",
    "        code = row['code']\n",
    "        score = row['score']\n",
    "\n",
    "        graph = self.ast_to_graph(code)\n",
    "        if graph is None:\n",
    "            # Skip samples with parsing errors by raising an exception\n",
    "            # Alternatively, implement a different handling strategy\n",
    "            logging.debug(f\"Skipping index {idx} due to parsing error.\")\n",
    "            raise ValueError(f\"Parsing failed for code at index {idx}.\")\n",
    "\n",
    "        if graph is None:\n",
    "            self.invalid_count += 1\n",
    "            logging.debug(f\"Skipping index {idx} due to parsing error.\")\n",
    "            return None\n",
    "\n",
    "        # Normalize score using the scaler\n",
    "        score_normalized = self.scaler.transform([[score]]).flatten()[0]\n",
    "\n",
    "        graph.y = torch.tensor([score_normalized], dtype=torch.float)\n",
    "        return graph\n",
    "    \n",
    "class GNNModel(nn.Module):\n",
    "    def __init__(self, num_node_types, embed_dim=64, hidden_dim=128, scaler=None):\n",
    "        super(GNNModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(num_node_types + 1, embed_dim)\n",
    "        self.conv1 = GATConv(embed_dim, hidden_dim)\n",
    "        self.conv2 = GATConv(hidden_dim, hidden_dim)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  # 假设拼接了池化特征\n",
    "        self.fc2 = nn.Linear(hidden_dim, 1)\n",
    "        self.scaler = scaler\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, batch = data.x, data.edge_index, data.batch\n",
    "        x = self.embedding(x.squeeze())\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = self.dropout(F.relu(x))\n",
    "        x = self.conv2(x, edge_index)\n",
    "        x = self.dropout(F.relu(x))\n",
    "        x = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x.squeeze()\n",
    "\n",
    "class PassRatePredictor():\n",
    "    def __init__(self, ini_data=None, model=None):\n",
    "        self.model = model\n",
    "        if ini_data is None:\n",
    "            # 初始化数据集为空\n",
    "            self.data = pd.DataFrame(columns=[\"code\", \"score\"])\n",
    "        self.scaler = MinMaxScaler()\n",
    "        self.trained = False\n",
    "\n",
    "    def add_data(self, new_data, use_pass_rate=False):\n",
    "        if isinstance(new_data, dict):\n",
    "            new_data = pd.DataFrame.from_dict(new_data, orient='index').reset_index(drop=True)\n",
    "            # 仅保留 'code' 和 'score' 列\n",
    "            if use_pass_rate:\n",
    "                new_data = new_data[['code', 'pass_rate']].rename(columns={'pass_rate': 'score'})\n",
    "            else:\n",
    "                new_data = new_data[['code', 'score']]\n",
    "\n",
    "        if self.data is None:\n",
    "            self.data = new_data\n",
    "        else:\n",
    "            # 过滤重复数据\n",
    "            new_data = new_data[~new_data['code'].isin(self.data['code'])]\n",
    "            self.data = pd.concat([self.data, new_data], ignore_index=True)\n",
    "\n",
    "    def predict_score(self, new_code_samples, model=None, scaler=None):\n",
    "        if model is None:\n",
    "            model = self.model\n",
    "        if scaler is None:\n",
    "            scaler = self.scaler\n",
    "\n",
    "        # 将新数据包装为DataFrame\n",
    "        new_df = pd.DataFrame({\n",
    "            \"code\": new_code_samples,\n",
    "            \"score\": [\"0s\"] * len(new_code_samples)  # 占位值\n",
    "        })\n",
    "        \n",
    "        df_clean, _ = self.clean_score_data(new_df)\n",
    "\n",
    "        # 创建数据集\n",
    "        dataset = CodeGraphDataset(df_clean, scaler=scaler)\n",
    "        loader = DataLoader(\n",
    "            [data for data in dataset if data is not None],\n",
    "            batch_size=32\n",
    "        )\n",
    "        \n",
    "        # 预测\n",
    "        model.eval()\n",
    "        preds = []\n",
    "        with torch.no_grad():\n",
    "            for batch in loader:\n",
    "                pred = model(batch)\n",
    "                preds.extend(pred.cpu().numpy())\n",
    "        \n",
    "        # 反归一化\n",
    "        pred_score = scaler.inverse_transform(np.array(preds).reshape(-1, 1)).flatten()\n",
    "        return pred_score\n",
    "    \n",
    "    def test_model(self, model, dataframe, train_scaler=None):\n",
    "        # 使用训练集的scaler（假设已经通过train_model传递）\n",
    "        df_clean, _ = self.clean_score_data(dataframe)\n",
    "        if train_scaler is None:\n",
    "            train_scaler = MinMaxScaler().fit(dataframe['score'].values.reshape(-1, 1))\n",
    "        test_dataset = CodeGraphDataset(df_clean, scaler=train_scaler)\n",
    "        test_loader = DataLoader(\n",
    "            [data for data in test_dataset if data is not None],\n",
    "            batch_size=32\n",
    "        )\n",
    "        \n",
    "        criterion = torch.nn.MSELoss()\n",
    "        model.eval()\n",
    "        test_loss = []\n",
    "        all_preds = []\n",
    "        all_labels = []\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for batch in test_loader:\n",
    "                pred = model(batch)\n",
    "                loss = criterion(pred, batch.y)\n",
    "                test_loss.append(loss.item())\n",
    "                all_preds.extend(pred.cpu().numpy())\n",
    "                all_labels.extend(batch.y.cpu().numpy())\n",
    "        \n",
    "        # 反归一化预测值和真实值\n",
    "        preds = test_dataset.scaler.inverse_transform(np.array(all_preds).reshape(-1, 1)).flatten()\n",
    "        labels = test_dataset.scaler.inverse_transform(np.array(all_labels).reshape(-1, 1)).flatten()\n",
    "        \n",
    "        # 计算指标\n",
    "        mae = np.mean(np.abs(preds - labels))\n",
    "        rmse = np.sqrt(np.mean((preds - labels)**2))\n",
    "        print(f\"Test MAE: {mae:.4f}, Test RMSE: {rmse:.4f}\")\n",
    "        return {\"mae\": mae, \"rmse\": rmse}\n",
    "    \n",
    "    def train_model(self, dataframe=None, epochs=50, batch_size=32, lr=0.001):\n",
    "        if dataframe is None:\n",
    "            dataframe = self.data\n",
    "\n",
    "        # 清洗数据\n",
    "        df_preprocessed = self.preprocess_data(dataframe)\n",
    "        if df_preprocessed.empty:\n",
    "            raise ValueError(\"无有效数据可供训练\")\n",
    "\n",
    "        # 划分训练集和验证集\n",
    "        train_df, val_df = train_test_split(df_preprocessed, test_size=0.2, random_state=42)\n",
    "        \n",
    "        # 初始化数据集和DataLoader（训练集拟合scaler）\n",
    "        train_scaler = MinMaxScaler().fit(train_df['score'].values.reshape(-1, 1))\n",
    "        self.scaler = train_scaler\n",
    "        train_dataset = CodeGraphDataset(train_df, scaler=train_scaler)\n",
    "        val_dataset = CodeGraphDataset(val_df, scaler=train_scaler)  # 使用训练集的scaler\n",
    "        \n",
    "        # 过滤无效样本并创建DataLoader\n",
    "        train_loader = DataLoader(\n",
    "            [data for data in train_dataset if data is not None],\n",
    "            batch_size=batch_size,\n",
    "            shuffle=True\n",
    "        )\n",
    "        val_loader = DataLoader(\n",
    "            [data for data in val_dataset if data is not None],\n",
    "            batch_size=batch_size\n",
    "        )\n",
    "        \n",
    "        # 初始化模型和优化器\n",
    "        model = GNNModel(\n",
    "            num_node_types=len(train_dataset.node_type_vocab) + 1,  # +1 for unknown\n",
    "            embed_dim=64,\n",
    "            hidden_dim=128,\n",
    "            scaler=train_scaler\n",
    "        )\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "        criterion = torch.nn.MSELoss()  # 均方误差损失\n",
    "        \n",
    "        # 训练循环\n",
    "        best_val_loss = float('inf')\n",
    "        for epoch in range(epochs):\n",
    "            model.train()\n",
    "            train_loss = []\n",
    "            for batch in train_loader:\n",
    "                optimizer.zero_grad()\n",
    "                pred = model(batch)\n",
    "                loss = criterion(pred, batch.y)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                train_loss.append(loss.item())\n",
    "            \n",
    "            # 验证集评估\n",
    "            model.eval()\n",
    "            val_loss = []\n",
    "            with torch.no_grad():\n",
    "                for batch in val_loader:\n",
    "                    pred = model(batch)\n",
    "                    loss = criterion(pred, batch.y)\n",
    "                    val_loss.append(loss.item())\n",
    "            \n",
    "            # 打印日志\n",
    "            avg_train_loss = np.mean(train_loss)\n",
    "            avg_val_loss = np.mean(val_loss)\n",
    "            print(f\"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}\")\n",
    "            \n",
    "            # 保存最佳模型\n",
    "            if avg_val_loss < best_val_loss:\n",
    "                best_val_loss = avg_val_loss\n",
    "                torch.save(model.state_dict(), \"best_gnn_model.pth\")\n",
    "        \n",
    "        self.model = model\n",
    "        return model\n",
    "    \n",
    "    def filter_invalid_ast(self, df):\n",
    "        valid_indices = []\n",
    "        invalid_indices = []\n",
    "        \n",
    "        for idx, code in enumerate(df['code']):\n",
    "            try:\n",
    "                ast.parse(code)\n",
    "                valid_indices.append(idx)\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"索引 {idx} 的代码无法解析AST: {e}\")\n",
    "                invalid_indices.append(idx)\n",
    "        \n",
    "        # 保留有效样本\n",
    "        df_valid = df.iloc[valid_indices].reset_index(drop=True)\n",
    "        return df_valid, invalid_indices\n",
    "\n",
    "    def clean_score_data(self, df):\n",
    "        cleaned_scores = []\n",
    "        invalid_indices = []\n",
    "        \n",
    "        for idx, row in df.iterrows():\n",
    "            value = row['score']\n",
    "            try:\n",
    "                if isinstance(value, str):\n",
    "                    # 移除空格，转换为小写\n",
    "                    cleaned_str = value.strip().lower()\n",
    "                    # 提取数值和单位（正则匹配数值部分）\n",
    "                    num_match = re.match(r\"^(\\d+\\.?\\d*)\\s*([a-z]*)?\", cleaned_str)\n",
    "                    if not num_match:\n",
    "                        raise ValueError(\"无法提取数值\")\n",
    "                    num = float(num_match.group(1))\n",
    "                    unit = num_match.group(2) or 's'  # 默认单位是秒\n",
    "                    # 根据单位转换为秒\n",
    "                    if unit in {'s', 'sec', 'second', ''}:\n",
    "                        converted = num\n",
    "                    elif unit in {'ms', 'msec', 'millisecond'}:\n",
    "                        converted = num / 1000\n",
    "                    elif unit in {'m', 'min', 'minute'}:\n",
    "                        converted = num * 60\n",
    "                    elif unit in {'h', 'hour'}:\n",
    "                        converted = num * 3600\n",
    "                    else:\n",
    "                        logging.warning(f\"索引 {idx} 的未知单位 '{unit}'，假设为秒\")\n",
    "                        converted = num\n",
    "                    cleaned_scores.append(converted)\n",
    "                else:\n",
    "                    # 处理数值类型（int/float）\n",
    "                    cleaned_scores.append(float(value))\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"索引 {idx} 的score值 '{value}' 处理失败: {e}\")\n",
    "                invalid_indices.append(idx)\n",
    "                cleaned_scores.append(None)\n",
    "        \n",
    "        # 替换原列并删除无效行\n",
    "        df_clean = df.copy()\n",
    "        df_clean['score'] = cleaned_scores\n",
    "        df_clean = df_clean.dropna(subset=['score']).reset_index(drop=True)\n",
    "        return df_clean, invalid_indices\n",
    "\n",
    "    def preprocess_data(self, df):\n",
    "        # Step 1: 过滤无法解析AST的样本\n",
    "        df_ast_valid, ast_invalid = self.filter_invalid_ast(df)\n",
    "        logging.info(f\"过滤 {len(ast_invalid)} 个无效AST样本\")\n",
    "        \n",
    "        # Step 2: 清洗score字段\n",
    "        df_clean, score_invalid = self.clean_score_data(df_ast_valid)\n",
    "        logging.info(f\"过滤 {len(score_invalid)} 个无效score样本\")\n",
    "        \n",
    "        return df_clean\n",
    "\n",
    "###############################################################\n",
    "# # debug\n",
    "\n",
    "# # 1. 加载数据\n",
    "# df = pd.read_csv(r\"E:\\python_project_new\\AI4SLCDP\\leetcode_data\\leetcode Median of Two Sorted Arrays.csv\")\n",
    "\n",
    "# # 将\"runtime\" 列改为\"score\"\n",
    "# df.rename(columns={\"runtime\": \"score\"}, inplace=True)\n",
    "# print(df.head())\n",
    "\n",
    "# pass_rate_predictor = PassRatePredictor()\n",
    "# pass_rate_predictor.add_data(df)\n",
    "\n",
    "# model = pass_rate_predictor.train_model(epochs=100)\n",
    "\n",
    "# # 3. 测试模型\n",
    "# # test_df = pd.read_csv(\"test_data.csv\")\n",
    "# # test_metrics = test_model(model, test_df)\n",
    "\n",
    "# # 4. 预测新样本\n",
    "# new_samples = [\n",
    "#     \"def square(x):\\n    return x ** 2\",\n",
    "#     \"def div(a, b):\\n    return a / b\"\n",
    "# ]\n",
    "\n",
    "# predictions = pass_rate_predictor.predict_score(new_samples)\n",
    "# print(f\"Predicted scores: {predictions}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import re\n",
    "import tempfile\n",
    "import os\n",
    "import json\n",
    "\n",
    "def pylint_code_score(code):\n",
    "    try:\n",
    "        # 创建临时文件保存代码\n",
    "        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "            tmp.write(code)\n",
    "            tmp_path = tmp.name\n",
    "        \n",
    "        # 执行 Pylint 分析\n",
    "        result = subprocess.run(\n",
    "            [\"pylint\", \"--output-format=text\", tmp_path],\n",
    "            capture_output=True,\n",
    "            text=True,\n",
    "            check=False\n",
    "        )\n",
    "        output = result.stdout\n",
    "        # print(output)\n",
    "        # 删除临时文件\n",
    "        os.unlink(tmp_path)\n",
    "        \n",
    "        # 提取评分（如 \"rated at 7.50/10\"）\n",
    "        match = re.search(r\"rated at (\\d+\\.?\\d*)/10\", output)\n",
    "        return float(match.group(1)) if match else -1\n",
    "    \n",
    "    except Exception as e:\n",
    "        print(f\"Pylint 分析失败: {e}\")\n",
    "        return -1\n",
    "\n",
    "def radon_mi_code_score(code: str) -> float:\n",
    "    try:\n",
    "        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "            tmp.write(code)\n",
    "            tmp_path = tmp.name\n",
    "        \n",
    "        result = subprocess.run(\n",
    "            [\"radon\", \"mi\", \"--json\", tmp_path],\n",
    "            capture_output=True,\n",
    "            text=True,\n",
    "            check=False\n",
    "        )\n",
    "        data = json.loads(result.stdout)\n",
    "        os.unlink(tmp_path)\n",
    "        \n",
    "        if data and isinstance(data, dict):\n",
    "            file_key = list(data.keys())[0]  # 获取临时文件的键名\n",
    "            return data[file_key][\"mi\"] / 10\n",
    "        return -1\n",
    "    except Exception as e:\n",
    "        print(f\"Radon 分析失败: {e}\")\n",
    "        return -1\n",
    "    \n",
    "# 示例：直接分析代码字符串\n",
    "code = \"\"\"\n",
    "import subprocess\n",
    "import re\n",
    "import tempfile\n",
    "import os\n",
    "import json\n",
    "def radon_mi_code_score(code: str) -> float:\n",
    "    try:\n",
    "        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "            tmp.write(code)\n",
    "            tmp_path = tmp.name\n",
    "        \n",
    "        result = subprocess.run(\n",
    "            [\"radon\", \"mi\", \"--json\", tmp_path],\n",
    "            capture_output=True,\n",
    "            text=True,\n",
    "            check=False\n",
    "        )\n",
    "        data = json.loads(result.stdout)\n",
    "        os.unlink(tmp_path)\n",
    "        \n",
    "        if data and isinstance(data, dict):\n",
    "            file_key = list(data.keys())[0]  # 获取临时文件的键名\n",
    "            return data[file_key][\"mi\"]\n",
    "        return -1\n",
    "    except Exception as e:\n",
    "        print(f\"Radon 分析失败: {e}\")\n",
    "        return -1\n",
    "\"\"\"\n",
    "\n",
    "## debug\n",
    "\n",
    "# pylint_score = pylint_code_score(code)\n",
    "# radon_score = radon_mi_code_score(code)\n",
    "\n",
    "# print(f\"Pylint 质量评分: {pylint_score}/10\")\n",
    "# print(f\"Radon 维护指数: {radon_score}/10\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import json\n",
    "import copy\n",
    "import tempfile\n",
    "import subprocess\n",
    "import concurrent.futures\n",
    "from tqdm import tqdm\n",
    "\n",
    "class Evaluator():\n",
    "    def __init__(self, pass_rate_predictor=None):\n",
    "        self.pass_rate_predictor = pass_rate_predictor\n",
    "\n",
    "    def calculate_pass_rate_score(self, test_results, test_weights):\n",
    "        total_weight = sum(test_weights.values())\n",
    "        if total_weight == 0:\n",
    "            return 0.0\n",
    "        \n",
    "        passed_weight = sum(weight for test_id, weight in test_weights.items() \n",
    "                           if test_results.get(test_id, False))\n",
    "        return passed_weight / total_weight\n",
    "\n",
    "    def calculate_batch_scores(self, code_data):\n",
    "        items = list(code_data.items())\n",
    "        code_ids = [k for k, _ in items]\n",
    "        code_entries = [v for _, v in items]\n",
    "        full_score_dict = {}\n",
    "\n",
    "        # 计算pass_rate_score（快速计算，无需并行）\n",
    "        pass_rate_scores = {\n",
    "            code_id: self.calculate_pass_rate_score(entry[\"test_results\"], entry[\"test_weights\"])\n",
    "            for code_id, entry in code_data.items()\n",
    "        }\n",
    "\n",
    "        # 批量预测score\n",
    "        code_strs = [entry[\"code\"] for entry in code_entries]\n",
    "        prediction_scores = [0.0] * len(code_strs)\n",
    "        if self.pass_rate_predictor is not None and self.pass_rate_predictor.model is not None:\n",
    "            try:\n",
    "                prediction_scores = self.pass_rate_predictor.predict_score(code_strs)\n",
    "                print(\"###############################################################\")\n",
    "                print(f\"Prediction scores: {prediction_scores}\")\n",
    "                print(\"###############################################################\")\n",
    "            except Exception as e:\n",
    "                print(code_strs)\n",
    "                raise e\n",
    "\n",
    "        # 并行计算静态分析分数\n",
    "        with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "            static_scores = list(tqdm(\n",
    "                executor.map(self._compute_static_scores, code_strs),\n",
    "                total=len(code_strs),\n",
    "                desc=\"Analyzing codes\"\n",
    "            ))\n",
    "\n",
    "        # 组合最终分数\n",
    "        final_scores = {}\n",
    "        for i, code_id in enumerate(code_ids):\n",
    "            final_scores[code_id] = (\n",
    "                0.7 * pass_rate_scores[code_id] +\n",
    "                0.1 * prediction_scores[i] +\n",
    "                0.1 * static_scores[i][0] +\n",
    "                0.1 * static_scores[i][1]\n",
    "            )\n",
    "            full_score_dict[code_id] = {\n",
    "                \"pass_rate_score\": pass_rate_scores[code_id],\n",
    "                \"prediction_score\": prediction_scores[i],\n",
    "                \"pylint_score\": static_scores[i][0],\n",
    "                \"radon_score\": static_scores[i][1]\n",
    "            }\n",
    "\n",
    "        return final_scores, full_score_dict\n",
    "\n",
    "    def _compute_static_scores(self, code_str):\n",
    "        return (\n",
    "            self.pylint_code_score(code_str),\n",
    "            self.radon_mi_code_score(code_str)\n",
    "        )\n",
    "\n",
    "    def pylint_code_score(self, code):\n",
    "        try:\n",
    "            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "                tmp.write(code)\n",
    "                tmp_path = tmp.name\n",
    "            \n",
    "            result = subprocess.run(\n",
    "                [\"pylint\", \"--output-format=text\", tmp_path],\n",
    "                capture_output=True,\n",
    "                text=True,\n",
    "                check=False\n",
    "            )\n",
    "            os.unlink(tmp_path)\n",
    "            \n",
    "            match = re.search(r\"rated at (\\d+\\.?\\d*)/10\", result.stdout)\n",
    "            return float(match.group(1)) if match else -1\n",
    "        \n",
    "        except Exception as e:\n",
    "            print(f\"Pylint analysis failed: {e}\")\n",
    "            return -1\n",
    "\n",
    "    def radon_mi_code_score(self, code):\n",
    "        try:\n",
    "            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "                tmp.write(code)\n",
    "                tmp_path = tmp.name\n",
    "            \n",
    "            result = subprocess.run(\n",
    "                [\"radon\", \"mi\", \"--json\", tmp_path],\n",
    "                capture_output=True,\n",
    "                text=True,\n",
    "                check=False\n",
    "            )\n",
    "            os.unlink(tmp_path)\n",
    "            \n",
    "            data = json.loads(result.stdout)\n",
    "            if data and isinstance(data, dict):\n",
    "                return list(data.values())[0][\"mi\"] / 10\n",
    "            return -1\n",
    "        except Exception as e:\n",
    "            print(f\"Radon analysis failed: {e}\")\n",
    "            return -1\n",
    "\n",
    "class LCDP():\n",
    "    def __init__(self, api_key, model=\"gpt-3.5-turbo\"):\n",
    "        self.llm_model = LLMModel(api_key, model)\n",
    "        self.code_runner = CodeRunner()\n",
    "        self.pass_rate_predictor = PassRatePredictor()\n",
    "        self.evaluator = Evaluator(self.pass_rate_predictor)\n",
    "        self.task_description = None\n",
    "        self.current_plan = None\n",
    "        self.test_weights = {}\n",
    "        self.test_cases = {}\n",
    "\n",
    "    async def run(self, task_description, max_iterations=3, example_dataset=None,\n",
    "                 num_plans=3, num_tests=5, num_codes=5, refine_rounds=3, use_pass_rate_for_train=False):\n",
    "        self.task_description = task_description\n",
    "        \n",
    "        # Initialize LLM Task Manager\n",
    "        self.llmtm = LLMTM(task_description, self.llm_model)\n",
    "        self.llmcg = LLMCG(task_description, self.llm_model)\n",
    "        \n",
    "        # Phase 1: Plan Generation and Refinement\n",
    "        # print(\"########################################################################\")\n",
    "        # print(\"### Phase 1: Plan Generation and Refinement\")\n",
    "        logging.info(\"########################################################################\")\n",
    "        logging.info(\"### Phase 1: Plan Generation and Refinement\")\n",
    "        plan, plan_raw = self.llmtm.get_plan()\n",
    "        self.current_plan = await self._plan_refinement_loop(self.llmtm, plan_raw, refine_rounds)\n",
    "        self.current_plan = self._plan_format_refinement(self.current_plan)\n",
    "        \n",
    "        # Phase 2: Test Case Generation and Weighting\n",
    "        # print(\"\\n########################################################################\")\n",
    "        # print(\"### Phase 2: Test Case Generation and Weighting\")\n",
    "        logging.info(\"\\n########################################################################\")\n",
    "        logging.info(\"### Phase 2: Test Case Generation and Weighting\")\n",
    "        # self.test_cases = await self._generate_tests(self.llmtm, num_tests)\n",
    "        # debug\n",
    "        self.test_cases = await self._generate_tests_async(self.llmtm, num_tests, use_example=False)\n",
    "        self.test_cases = self._filter_test_cases(self.test_cases)\n",
    "\n",
    "        # print(\"Calculating test weights...\")\n",
    "        logging.info(\"Calculating test weights...\")\n",
    "        self.test_weights = self._calculate_test_weights(self.test_cases, example_dataset)\n",
    "        \n",
    "        # Phase 3: Iterative Code Generation\n",
    "        # print(\"\\n########################################################################\")\n",
    "        # print(\"### Phase 3: Iterative Code Generation\")\n",
    "        logging.info(\"\\n########################################################################\")\n",
    "        logging.info(\"### Phase 3: Iterative Code Generation\")\n",
    "        best_codes = {}\n",
    "        for iteration in range(max_iterations):\n",
    "            # print(f\"\\n=== Iteration {iteration+1}/{max_iterations} ===\")\n",
    "            logging.info(f\"\\n=== Iteration {iteration+1}/{max_iterations} ===\")\n",
    "            \n",
    "            # Generate new codes\n",
    "            # new_codes = await self._generate_codes(num_codes)\n",
    "            new_codes = await self._generate_codes_async(num_codes)\n",
    "            \n",
    "            # Evaluate codes\n",
    "            logging.info(\"Evaluating codes...\")\n",
    "            scored_codes, filtered_test_result = self._evaluate_codes(new_codes)\n",
    "            # remove the test cases that are not in the filtered_test_result\n",
    "            self.test_cases = {k: v for k, v in self.test_cases.items() if k in list(filtered_test_result.keys())}\n",
    "\n",
    "            logging.info(\"training pass_rate_predictor...\")\n",
    "            self.pass_rate_predictor.add_data(scored_codes, use_pass_rate=use_pass_rate_for_train)\n",
    "            self.pass_rate_predictor.train_model(epochs=50, batch_size=32, lr=0.001)\n",
    "            \n",
    "            # Update best codes\n",
    "            best_codes.update(self._select_top_codes(scored_codes, top_k=3))\n",
    "            \n",
    "            # User feedback\n",
    "            if not await self._get_user_feedback(best_codes):\n",
    "                self.current_plan['user_feedback'] = \"Based on previous outputs, please improve the code quality.\"\n",
    "        \n",
    "        return best_codes\n",
    "\n",
    "    async def _plan_refinement_loop(self, llmtm, initial_plan_raw, max_rounds):\n",
    "        current_plan_raw = initial_plan_raw\n",
    "        current_plan = llmtm.extract_plan(current_plan_raw)\n",
    "        for _ in range(max_rounds):\n",
    "            # Show current plan\n",
    "            # print(\"Current Plan:\\n\", self.plan_json_to_str(current_plan[\"overall_plan\"]))\n",
    "            logging.info(\"Current Plan:\\n\" + self.plan_json_to_str(current_plan[\"overall_plan\"]))\n",
    "            \n",
    "            # Get user feedback\n",
    "            if input(\"Refine plan? (y/n): \").lower() != 'y':\n",
    "                logging.info(\"Skipping plan refinement.\")\n",
    "                break\n",
    "            \n",
    "            feedback = input(\"Enter refinement feedback: \")\n",
    "            logging.info(f\"User feedback: {feedback}\")\n",
    "            current_plan, current_plan_raw = llmtm.refine_plan(feedback, current_plan_raw)\n",
    "        \n",
    "        return llmtm.extract_plan(current_plan_raw)\n",
    "\n",
    "    def plan_json_to_str(self, plan):\n",
    "        # Process Input Format\n",
    "        input_fmt = plan[\"input_format\"]\n",
    "        input_lines = []\n",
    "        for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "            shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "        input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "        # Process Output Format\n",
    "        output_fmt = plan[\"output_format\"]\n",
    "        output_lines = []\n",
    "        for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "            shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "        output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "        # Build Overall Plan Details\n",
    "        plan_part = [\n",
    "            \"=== Current Plan ===\",\n",
    "            input_section,\n",
    "            output_section,\n",
    "            f\"Components Order: {', '.join(plan['components'])}\",\n",
    "            \"Plan Steps:\",\n",
    "            *[f\"- {step}\" for step in plan[\"plan\"]],\n",
    "            \"Overall Test Case Advice:\",\n",
    "            *[f\"- {advice}\" for advice in plan[\"test_case_generation_advise\"]],\n",
    "            \"\\n\",\n",
    "        ]\n",
    "\n",
    "        return \"\\n\".join(plan_part)\n",
    "\n",
    "    def _plan_format_refinement(self, plan_dict):\n",
    "        \"\"\"Refines the input and output formats in the plan to be lists of lists.\"\"\"\n",
    "        \n",
    "        # Create a deep copy to avoid modifying the original input\n",
    "        refined_plan = copy.deepcopy(plan_dict)\n",
    "        \n",
    "        def refine_format(formats):\n",
    "            \"\"\"Ensure each format field is a list of lists.\"\"\"\n",
    "            if isinstance(formats, list):\n",
    "                # Check if all elements are lists\n",
    "                if not all(isinstance(elem, list) for elem in formats):\n",
    "                    return [formats]\n",
    "            else:\n",
    "                # If it's not a list, wrap it into a list (though input is expected to be a list)\n",
    "                return [formats]\n",
    "            return formats\n",
    "        \n",
    "        # Process each component in 'components'\n",
    "        for component in refined_plan[\"components\"].values():\n",
    "            for key in [\"input_format\", \"output_format\"]:\n",
    "                if key in component:\n",
    "                    component[key] = refine_format(component[key])\n",
    "        \n",
    "        # Process 'overall_plan'\n",
    "        overall_plan = refined_plan.get(\"overall_plan\")\n",
    "        if overall_plan:\n",
    "            for key in [\"input_format\", \"output_format\"]:\n",
    "                if key in overall_plan:\n",
    "                    overall_plan[key] = refine_format(overall_plan[key])\n",
    "        \n",
    "        return refined_plan\n",
    "\n",
    "    async def _generate_tests(self, llmtm, num_tests):\n",
    "        test_cases = {}\n",
    "        for _ in range(num_tests):\n",
    "            test = llmtm.get_test_cases(self.current_plan['overall_plan'])\n",
    "            test_cases.update(test)\n",
    "        return test_cases\n",
    "    \n",
    "    async def _generate_tests_async(self, llmtm, num_tests, use_example=True):\n",
    "        test_cases = {}\n",
    "        task_list = [llmtm.get_test_cases_async(self.current_plan['overall_plan'], use_example=use_example) for _ in range(num_tests)]\n",
    "        \n",
    "        for task in tqdm_asyncio.as_completed(task_list, total=num_tests, desc=\"Generating async tests\"):\n",
    "            test = await task\n",
    "            for key, value in test.items():\n",
    "                # 生成唯一键名逻辑\n",
    "                new_key = key\n",
    "                suffix = 1\n",
    "                while new_key in test_cases:\n",
    "                    new_key = f\"{key}_{suffix}\"\n",
    "                    suffix += 1\n",
    "                test_cases[new_key] = value\n",
    "                \n",
    "        return test_cases\n",
    "\n",
    "    def _filter_test_cases(self, dataset):\n",
    "        print(dataset)\n",
    "        runnable_entries = {}\n",
    "        for code_id, attributes in dataset.items():\n",
    "            test_code = attributes.get(\"test_function\", \"\")\n",
    "            try:\n",
    "                # Attempt to compile the code string to check for syntax errors.\n",
    "                compile(test_code, \"<string>\", \"exec\")\n",
    "                # If no exception is raised, consider the code as runnable.\n",
    "                runnable_entries[code_id] = attributes\n",
    "            except Exception as error:\n",
    "                # If an exception is raised, skip this entry.\n",
    "                continue\n",
    "        return runnable_entries\n",
    "            \n",
    "\n",
    "    def _calculate_test_weights(self, test_cases, example_dataset):\n",
    "        if not example_dataset:\n",
    "            return {tid: 1.0 for tid in test_cases}\n",
    "        \n",
    "        # Run example dataset through tests\n",
    "        _, test_results = self.code_runner.run_all_tests(example_dataset, test_cases)\n",
    "        \n",
    "        # Calculate weights\n",
    "        weights = {}\n",
    "        for tid, results in test_results.items():\n",
    "            pass_rate = sum(results.values()) / len(results)\n",
    "            weights[tid] = 1 - abs(pass_rate - 0.5)  # Weight tests that discriminate\n",
    "        return weights\n",
    "\n",
    "    async def _generate_codes(self, num_codes):\n",
    "        codes = {}\n",
    "        for _ in range(num_codes):\n",
    "            code = self.llmcg.get_code(\n",
    "                extracted_plan=self.current_plan,\n",
    "                test_cases=self.test_cases,\n",
    "            )\n",
    "            codes[f\"code_{len(codes)}\"] = code\n",
    "        return codes\n",
    "    \n",
    "    async def _generate_codes_async(self, num_codes):\n",
    "        codes = {}\n",
    "        task_list = [self.llmcg.get_code_async(extracted_plan=self.current_plan,\n",
    "                                               test_cases=self.test_cases) for _ in range(num_codes)]\n",
    "        for task in tqdm_asyncio.as_completed(task_list, total=num_codes, desc=\"Generating async codes\"):\n",
    "            code = await task\n",
    "            codes[f\"code_{len(codes)}\"] = code\n",
    "        return codes\n",
    "\n",
    "    def transform_test_perspective(self, test_results):\n",
    "        transformed = {}\n",
    "        for test_case_id, code_results in test_results.items():\n",
    "            for code_id, result in code_results.items():\n",
    "                if code_id not in transformed:\n",
    "                    transformed[code_id] = {}\n",
    "                transformed[code_id][test_case_id] = result\n",
    "        return transformed\n",
    "\n",
    "    def _filter_test_cases_by_pass_rate(self, test_results, threshold=0.05):\n",
    "        filtered_test_case_list = []\n",
    "        filtered_test_results = {}\n",
    "        test_case_length = len(test_results)\n",
    "        for test_case_id, results in test_results.items():\n",
    "            total = len(results)\n",
    "            if total == 0:\n",
    "                continue\n",
    "            passed = sum(results.values())\n",
    "            pass_rate = passed / total\n",
    "            if pass_rate > threshold:\n",
    "                filtered_test_case_list.append(test_case_id)\n",
    "                filtered_test_results[test_case_id] = results\n",
    "        self.test_cases = {k: v for k, v in self.test_cases.items() if k in filtered_test_case_list}\n",
    "        logging.info(f\"Filtered test cases: {len(self.test_cases)} out of {test_case_length}\")\n",
    "\n",
    "        filtered_fun_results = self.transform_test_perspective(filtered_test_results)\n",
    "\n",
    "        return filtered_fun_results, filtered_test_results\n",
    "\n",
    "    def _evaluate_codes(self, codes):\n",
    "        fun_results, test_results = self.code_runner.run_all_tests(codes, self.test_cases)\n",
    "\n",
    "        filtered_fun_results, filtered_test_results = self._filter_test_cases_by_pass_rate(test_results, threshold=0.05)\n",
    "\n",
    "        input_data = {}\n",
    "        for code_id, results in filtered_fun_results.items():\n",
    "            input_data[code_id] = {\n",
    "                'code': codes[code_id]['code'],\n",
    "                'test_results': results,\n",
    "                'test_weights': self.test_weights\n",
    "            }\n",
    "        # Calculate scores\n",
    "        output_scores, full_score_dict = self.evaluator.calculate_batch_scores(input_data)\n",
    "        # Combine scores with code data\n",
    "        output_results = {\n",
    "            code_id: {\n",
    "                'code': codes[code_id]['code'],\n",
    "                'plan': codes[code_id]['plan'],\n",
    "                'main_function_name': codes[code_id]['main_function_name'],\n",
    "                'score': output_scores[code_id],\n",
    "                'pass_rate_score': full_score_dict[code_id]['pass_rate_score'],\n",
    "                'prediction_score': full_score_dict[code_id]['prediction_score'],\n",
    "                'pylint_score': full_score_dict[code_id]['pylint_score'],\n",
    "                'radon_score': full_score_dict[code_id]['radon_score'],\n",
    "                'test_case_results': filtered_fun_results[code_id],\n",
    "            }\n",
    "            for code_id in codes.keys()\n",
    "        }\n",
    "        return output_results, filtered_test_results\n",
    "        # return {\n",
    "        #     code_id: {\n",
    "        #         'code': codes[code_id]['code'],\n",
    "        #         'plan':codes[code_id]['plan'],\n",
    "        #         'main_function_name':codes[code_id]['main_function_name'],\n",
    "        #         'score': self.evaluator.calculate_score(codes[code_id]['code'] ,results, self.test_weights)\n",
    "        #     }\n",
    "        #     for code_id, results in fun_results.items()\n",
    "        # }\n",
    "\n",
    "    def _select_top_codes(self, scored_codes, top_k=3):\n",
    "        return dict(sorted(scored_codes.items(), \n",
    "                          key=lambda x: x[1]['score'], \n",
    "                          reverse=True)[:top_k])\n",
    "\n",
    "    async def _get_user_feedback(self, top_codes):\n",
    "\n",
    "        logging.info(\"\\nTop Performing Codes:\")\n",
    "        for cid, data in top_codes.items():\n",
    "            logging.info(f\"{cid} [Score: {data['score']:.2f}]:\")\n",
    "            logging.info(\"Code workflow:\")\n",
    "            logging.info(data['plan'])\n",
    "            logging.info(\"Partial Code:\")\n",
    "            logging.info(data['code'][:500] + \"...\\n\")\n",
    "        \n",
    "        if input(\"Provide feedback? (y/n): \").lower() == 'y':\n",
    "            feedback = input(\"Enter your feedback: \")\n",
    "            logging.info(f\"User feedback: {feedback}\")\n",
    "            # Store feedback for next generation cycle\n",
    "            self.current_plan['user_feedback'] = feedback\n",
    "            return True\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:16:14,612 - root - INFO - ########################################################################\n",
      "2025-04-08 17:16:14,613 - root - INFO - ### Phase 1: Plan Generation and Refinement\n",
      "2025-04-08 17:16:35,141 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "2025-04-08 17:16:35,173 - root - INFO - Current Plan:\n",
      "=== Current Plan ===\n",
      "Input Format:\n",
      "- Argument 1: list with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: int with no fixed shape\n",
      "Components Order: find_min_adjacent_pair, replace_adjacent_pair_with_sum, count_operations_to_make_non_decreasing\n",
      "Plan Steps:\n",
      "- Accept an input list of integers.\n",
      "- Invoke count_operations_to_make_non_decreasing to initiate the process of finding pairs and replacing them until the array is in non-decreasing order.\n",
      "- During each iteration within count_operations_to_make_non_decreasing, use find_min_adjacent_pair to locate the leftmost adjacent pair with the minimum sum.\n",
      "- Invoke replace_adjacent_pair_with_sum to replace the identified pair in the array.\n",
      "- Continue this process while counting the number of operations until the array fulfills the non-decreasing condition.\n",
      "- Return the count of operations as the final output.\n",
      "Overall Test Case Advice:\n",
      "- Develop test cases that include arrays of varied sizes, both sorted and unsorted, to verify correct operation counts.\n",
      "- Include edge cases emphasizing negative numbers, duplicates, and already sorted arrays.\n",
      "- Performance tests should be conducted with large input arrays to ensure efficiency.\n",
      "- Test scenarios where no operations are needed to confirm the function can handle this gracefully.\n",
      "\n",
      "\n",
      "2025-04-08 17:16:47,286 - root - INFO - Skipping plan refinement.\n",
      "2025-04-08 17:16:47,287 - root - INFO - \n",
      "########################################################################\n",
      "2025-04-08 17:16:47,288 - root - INFO - ### Phase 2: Test Case Generation and Weighting\n",
      "Generating async tests:   0%|          | 0/5 [00:00<?, ?it/s]2025-04-08 17:16:55,953 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  20%|██        | 1/5 [00:08<00:34,  8.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共分出 5 个块\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:16:58,495 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  40%|████      | 2/5 [00:11<00:15,  5.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共分出 5 个块\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:16:59,053 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  60%|██████    | 3/5 [00:11<00:06,  3.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共分出 6 个块\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:17:01,241 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  80%|████████  | 4/5 [00:13<00:02,  2.69s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共分出 5 个块\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:17:01,684 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests: 100%|██████████| 5/5 [00:14<00:00,  2.89s/it]\n",
      "2025-04-08 17:17:01,749 - root - INFO - Calculating test weights...\n",
      "2025-04-08 17:17:01,749 - root - INFO - \n",
      "########################################################################\n",
      "2025-04-08 17:17:01,750 - root - INFO - ### Phase 3: Iterative Code Generation\n",
      "2025-04-08 17:17:01,751 - root - INFO - \n",
      "=== Iteration 1/2 ===\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "共分出 5 个块\n",
      "{'test_case_1': {'test_type': 'correctness', 'purpose': \"To design this test case, I'll use an input that consists of an already sorted array. The function should recognize that no operations are needed to make this array non-decreasing. The expected output is 0 because the array [1, 2, 2] already meets the non-decreasing condition.\", 'test_function': 'def test_case(func):\\n    result = func([1, 2, 2])\\n    return result == 0'}, 'test_case_2': {'test_type': 'edge_case', 'purpose': 'This case will test the function with an empty list. Testing an empty array is essential to check whether the function can handle and return appropriate values with no input data to process. The expected output is 0 since no operations are needed, as there are no elements to compare.', 'test_function': 'def test_case(func):\\n    result = func([])\\n    return result == 0'}, 'test_case_3': {'test_type': 'runtime', 'purpose': 'In this test case, I will provide a large input array with random integers to check the performance of the function. The main goal is to ensure that the function runs within a reasonable time frame, so I will measure the execution time of the function for this input. The expected output will not be checked, but rather we will ensure that it finishes running without errors within a predefined time limit.', 'test_function': 'import time\\n\\ndef test_case(func):\\n    input_data = list(range(10000, 0, -1))  # descend from 10000 to 1\\n    start_time = time.time()\\n    func(input_data)\\n    elapsed_time = time.time() - start_time\\n    return elapsed_time < 1  # should run in under 1 second'}, 'test_case_4': {'test_type': 'component_check', 'purpose': \"In this case, I want to verify if the function makes use of the specified components. I will ensure that it checks for adjacent pairs while implementing the logic for the main task. I'll be looking at the structure of the function to see if it uses the correct methods for finding and replacing adjacent pairs. The expected output will relate to whether the code logically uses these components meaningfully.\", 'test_function': 'def test_case(func):\\n    code_str = inspect.getsource(func)\\n    return \"find_min_adjacent_pair\" in code_str and \"replace_adjacent_pair_with_sum\" in code_str'}, 'test_case_5': {'test_type': 'error_handling', 'purpose': 'This test case is designed to verify that the function can handle invalid inputs gracefully. Specifically, I will pass a non-list input such as a string and check for a raised exception. The expected outcome is that the function should raise an appropriate error without crashing.', 'test_function': 'def test_case(func):\\n    try:\\n        func(\"invalid input\")\\n        return False  # exception was not raised\\n    except TypeError:\\n        return True  # correctly raised TypeError'}, 'test_case_1_1': {'test_type': 'correctness', 'purpose': 'The purpose of this test function is to validate the correctness of the implementation of the counting operation needed to make an array non-decreasing. I will test the function with an unsorted array as input to check if the function correctly counts the number of replacement operations required. In this case, I expect the output to match the predetermined correct value based on the rules described in the task. For the input [5, 2, 3, 1], the expected output is 2 because it requires two operations to reach a non-decreasing state.', 'test_function': 'def test_case(func):\\n    # Given input\\n    nums = [5, 2, 3, 1]\\n    # Expected output\\n    expected_output = 2\\n    # Run the function\\n    result = func(nums)\\n    # Check if the output matches the expected output\\n    return result == expected_output'}, 'test_case_2_1': {'test_type': 'edge_case', 'purpose': \"This test function is designed to validate the function's behavior with an edge case where the input array is already non-decreasing. Since the input is [1, 2, 2], I expect the output to be 0 because no operations are needed to make the array non-decreasing.\", 'test_function': 'def test_case(func):\\n    # Given input\\n    nums = [1, 2, 2]\\n    # Expected output\\n    expected_output = 0\\n    # Run the function\\n    result = func(nums)\\n    # Check if the output matches the expected output\\n    return result == expected_output'}, 'test_case_3_1': {'test_type': 'runtime', 'purpose': 'In this test function, I will evaluate the runtime performance of the function when provided with a large array. The purpose of this test is to ensure that the function can handle large inputs efficiently. I will measure the execution time and assert that it is below a predefined threshold (for example, 1 second).', 'test_function': 'import time\\n\\ndef test_case(func):\\n    import random\\n    # Generate a large input list\\n    nums = random.sample(range(-10000, 10000), 10000)  # random unique integers\\n    start_time = time.time()  # Start the timer\\n    func(nums)  # Run the function\\n    runtime = time.time() - start_time  # Calculate elapsed time\\n    return runtime < 1  # Ensure runtime is below 1 second'}, 'test_case_4_1': {'test_type': 'component_check', 'purpose': \"This test is to ensure that the function uses the specified components correctly. I will utilize string inspection to check if the function makes use of a given component name such as 'find_min_adjacent_pair', as this indicates proper functionality according to the specification.\", 'test_function': \"def test_case(func):\\n    # Retrieve the source code of the function\\n    import inspect\\n    source_code = inspect.getsource(func)\\n    # Verify if certain components are mentioned in the source code\\n    return 'find_min_adjacent_pair' in source_code and 'replace_adjacent_pair_with_sum' in source_code\"}, 'test_case_5_1': {'test_type': 'error_handling', 'purpose': 'This test will check if the function handles invalid inputs gracefully. Specifically, I will pass an empty list as input and verify that an exception is raised, indicating that the function cannot process such input.', 'test_function': 'def test_case(func):\\n    try:\\n        func([])  # Passing an empty list\\n        return False  # If no exception is raised, the test fails\\n    except Exception:\\n        return True  # An exception is expected here, indicating proper error handling'}, 'test_case_1_2': {'test_type': 'correctness', 'purpose': 'I will create a test case to validate the correctness of the function that counts the number of operations needed to make the array non-decreasing. The input for this test case will be an array that is unsorted and requires multiple operations to convert it into a non-decreasing sequence. I will specifically use the input `[5, 2, 3, 1]` as the example provided in the task description. According to the explanation, it should require 2 operations to achieve the desired output. The expected output is `2`, and I will run the function with this input to assert that it returns the expected result.', 'test_function': 'def test_case(func):  \\n    input_data = [5, 2, 3, 1]  \\n    expected_output = 2  \\n    result = func(input_data)  \\n    return result == expected_output'}, 'test_case_2_2': {'test_type': 'edge_case', 'purpose': 'In this test case, I will check how the function handles an already sorted array. The input will be `[1, 2, 2]`, which does not require any operations to become non-decreasing. The expected output is `0`, which means the function should handle this edge case correctly by returning zero operations needed.', 'test_function': 'def test_case(func):  \\n    input_data = [1, 2, 2]  \\n    expected_output = 0  \\n    result = func(input_data)  \\n    return result == expected_output'}, 'test_case_3_2': {'test_type': 'correctness', 'purpose': 'I will create a test case to verify that the function can handle an array with negative numbers. The input will be `[-1, -2, -3]`, which should require multiple operations to make it non-decreasing. The expected number of operations in this case is `3`, as each operation will combine two adjacent negative numbers. This verifies if the function correctly processes negative values.', 'test_function': 'def test_case(func):  \\n    input_data = [-1, -2, -3]  \\n    expected_output = 3  \\n    result = func(input_data)  \\n    return result == expected_output'}, 'test_case_6': {'test_type': 'error_handling', 'purpose': 'I will create a test case to check how the function handles invalid inputs, specifically an empty list. The expected outcome for this case is that it should raise an exception or handle it gracefully without breaking the program flow. I will implement this test and check for exceptions during execution.', 'test_function': 'def test_case(func):  \\n    input_data = []  \\n    try:  \\n        func(input_data)  \\n        return False  # If no exception is raised, return False  \\n    except Exception:  \\n        return True  # If an exception is raised, return True'}, 'test_case_1_3': {'test_type': 'correctness', 'purpose': 'We will create a test case that verifies the correctness of the function when given an array that requires multiple operations to reach a non-decreasing order. Specifically, the input array `[5, 2, 3, 1]` is chosen because it has several elements that need to be combined to form larger numbers in order to create a sorted array. The expected output is `2`, as explained in the original task description, making this a solid case for correctness testing.', 'test_function': \"def test_case(func):\\n    # The input array that needs to be tested\\n    nums = [5, 2, 3, 1]\\n    # The expected output based on the problem's description\\n    expected_output = 2\\n    \\n    # Call the function with the test input\\n    result = func(nums)\\n    \\n    # Check if the result matches the expected output\\n    return result == expected_output\"}, 'test_case_2_3': {'test_type': 'edge_case', 'purpose': 'This test case will focus on the edge case of an already sorted array. We will input the array `[1, 2, 2]`, which is already in non-decreasing order. As such, we expect the function to return `0`, indicating no operations are needed. This test checks whether the function can correctly handle cases where the input does not require any modifications.', 'test_function': 'def test_case(func):\\n    # The input array, which is already sorted\\n    nums = [1, 2, 2]\\n    # The expected output since no operations are needed\\n    expected_output = 0\\n    \\n    # Call the function with the test input\\n    result = func(nums)\\n    \\n    # Check if the result matches the expected output\\n    return result == expected_output'}, 'test_case_3_3': {'test_type': 'runtime', 'purpose': 'This test function aims to evaluate the performance of the function when given a large array. By creating a large array with a pattern that guarantees it is unsorted, we can check if the function executes efficiently within a reasonable time frame. We will measure the execution time and ensure it does not exceed a defined threshold, such as 1 second.', 'test_function': \"import time\\n\\ndef test_case(func):\\n    # Generate a large test array with a length of 10000\\n    nums = [i for i in range(10000, 0, -1)]  # Creates a descending array\\n    # Start the timer\\n    start_time = time.time()\\n    \\n    # Call the function with the test input\\n    result = func(nums)\\n    \\n    # Check the execution time\\n    execution_time = time.time() - start_time\\n    \\n    # Set a threshold for execution time (1 second)\\n    threshold = 1.0\\n    \\n    # We don't check the output but ensure it runs within the time limit\\n    return execution_time < threshold\"}, 'test_case_4_2': {'test_type': 'component_check', 'purpose': 'This test case will check whether the function utilizes the required components (`find_min_adjacent_pair`, `replace_adjacent_pair_with_sum`, `count_operations_to_make_non_decreasing`). We will inspect the function as a string and confirm that these specific component names are used within the implementation.', 'test_function': 'def test_case(func):\\n    # Convert the function to a string for inspection\\n    func_source = inspect.getsource(func)\\n    \\n    # Check for the presence of required component names\\n    components = [\\n        \"find_min_adjacent_pair\",\\n        \"replace_adjacent_pair_with_sum\",\\n        \"count_operations_to_make_non_decreasing\"\\n    ]\\n    \\n    # Check if all components are used\\n    return all(component in func_source for component in components)'}, 'test_case_5_2': {'test_type': 'error_handling', 'purpose': 'In this test case, we will evaluate how the function handles invalid input, such as passing a non-list type (e.g., an integer). We expect the function to raise a TypeError or some form of exception, indicating that the input type is incorrect. This verifies that error handling is working as expected for invalid inputs.', 'test_function': 'def test_case(func):\\n    # Use an invalid input type (integer instead of list)\\n    invalid_input = 123\\n    \\n    try:\\n        # Attempt to call the function with invalid input\\n        func(invalid_input)\\n        return False  # If no error is raised, the test fails\\n    except TypeError:\\n        return True  # If a TypeError is raised, the test passes\\n    except Exception:\\n        return False  # If any other error occurs, the test also fails'}, 'test_case_1_4': {'test_type': 'correctness', 'purpose': 'I will design a test function to validate the correctness of the function that counts the number of operations needed to make an array non-decreasing. The function will take a specific input array and compare the output from the target function with the expected output. For instance, given the input [5, 2, 3, 1], the expected output is 2 since two operations are required to make the array non-decreasing. The test will confirm that the output matches the expected result.', 'test_function': 'def test_case(func):\\n    # Define the input for the test\\n    nums = [5, 2, 3, 1]\\n    # Expected output is 2 for this input\\n    expected_output = 2\\n    # Call the function with the test input\\n    result = func(nums)\\n    # Return whether the test passed or failed\\n    return result == expected_output'}, 'test_case_2_4': {'test_type': 'edge_case', 'purpose': 'I will create a test function to verify how the target function behaves with edge cases, including an empty array. An empty list should return 0 since no operations are needed to make it non-decreasing as there are no elements present. This test will ensure that the function handles empty input gracefully.', 'test_function': 'def test_case(func):\\n    # Test with an empty array\\n    nums = []\\n    # Expected output is 0 for an empty array\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_3_4': {'test_type': 'runtime', 'purpose': 'I will create a test function that measures the execution time of the target function for a large input array to ensure it performs efficiently. For this test, I will use a large sorted array of size 10000. The expected output should reflect a scenario where no operations are needed, thus the output should be 0. This test ensures that the function meets performance requirements.', 'test_function': 'import time\\n\\ndef test_case(func):\\n    # Create a large sorted array\\n    nums = list(range(10000))  # Sorted array of size 10000\\n    # Expected output is 0, as the array is already non-decreasing\\n    expected_output = 0\\n    \\n    # Measure the execution time\\n    start_time = time.time()\\n    result = func(nums)\\n    end_time = time.time()\\n\\n    execution_time = end_time - start_time\\n    # Return whether the test passed and execution time is below a threshold\\n    return result == expected_output and execution_time < 1  # threshold of 1 second'}, 'test_case_4_3': {'test_type': 'component_check', 'purpose': \"I will design a test function to check if the target function uses the necessary components such as finding the minimum adjacent pair and replacing it. This will be done by inspecting the function's bytecode to see if specific operations are present. This test ensures that the function is implemented according to the specifications provided.\", 'test_function': 'import dis\\n\\ndef test_case(func):\\n    # Get the bytecode of the function\\n    bytecode = dis.Bytecode(func)\\n    # Check for the required components in the bytecode\\n    min_pair_found = any(\"find_min_adjacent_pair\" in instruction.opname for instruction in bytecode)\\n    replace_called = any(\"replace_adjacent_pair_with_sum\" in instruction.opname for instruction in bytecode)\\n    \\n    # Ensure both required functions are called\\n    return min_pair_found and replace_called'}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating async codes:   0%|          | 0/5 [00:00<?, ?it/s]2025-04-08 17:17:09,201 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "2025-04-08 17:17:09,260 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  20%|██        | 1/5 [00:07<00:30,  7.53s/it]2025-04-08 17:17:10,183 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  60%|██████    | 3/5 [00:08<00:04,  2.31s/it]2025-04-08 17:17:11,191 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  80%|████████  | 4/5 [00:09<00:01,  1.86s/it]2025-04-08 17:17:16,174 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes: 100%|██████████| 5/5 [00:14<00:00,  2.89s/it]\n",
      "2025-04-08 17:17:16,210 - root - INFO - Evaluating codes...\n",
      "Running tests: 100%|██████████| 115/115 [00:43<00:00,  2.66it/s]\n",
      "2025-04-08 17:17:59,469 - root - INFO - Filtered test cases: 18 out of 23\n",
      "Analyzing codes: 100%|██████████| 5/5 [00:05<00:00,  1.15s/it]\n",
      "2025-04-08 17:18:05,219 - root - INFO - training pass_rate_predictor...\n",
      "C:\\Users\\Zihang Zeng\\AppData\\Local\\Temp\\ipykernel_85188\\3803597484.py:166: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
      "  self.data = pd.concat([self.data, new_data], ignore_index=True)\n",
      "2025-04-08 17:18:05,224 - root - INFO - 过滤 0 个无效AST样本\n",
      "2025-04-08 17:18:05,227 - root - INFO - 过滤 0 个无效score样本\n",
      "2025-04-08 17:18:05,230 - root - INFO - Score values scaled using MinMaxScaler.\n",
      "2025-04-08 17:18:05,233 - root - INFO - Built node type vocabulary with size: 30\n",
      "2025-04-08 17:18:05,234 - root - INFO - Score values scaled using MinMaxScaler.\n",
      "2025-04-08 17:18:05,235 - root - INFO - Built node type vocabulary with size: 29\n",
      "d:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\loss.py:535: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 | Train Loss: 0.3211 | Val Loss: 1.1932\n",
      "Epoch 2/50 | Train Loss: 0.5186 | Val Loss: 1.0525\n",
      "Epoch 3/50 | Train Loss: 0.2629 | Val Loss: 0.8407\n",
      "Epoch 4/50 | Train Loss: 0.1515 | Val Loss: 0.6955\n",
      "Epoch 5/50 | Train Loss: 0.2626 | Val Loss: 0.6534\n",
      "Epoch 6/50 | Train Loss: 0.2452 | Val Loss: 0.6859\n",
      "Epoch 7/50 | Train Loss: 0.2297 | Val Loss: 0.7550\n",
      "Epoch 8/50 | Train Loss: 0.1570 | Val Loss: 0.8366\n",
      "Epoch 9/50 | Train Loss: 0.1908 | Val Loss: 0.9071\n",
      "Epoch 10/50 | Train Loss: 0.1811 | Val Loss: 0.9406\n",
      "Epoch 11/50 | Train Loss: 0.2532 | Val Loss: 0.9348\n",
      "Epoch 12/50 | Train Loss: 0.2731 | Val Loss: 0.8847\n",
      "Epoch 13/50 | Train Loss: 0.1629 | Val Loss: 0.8202\n",
      "Epoch 14/50 | Train Loss: 0.1501 | Val Loss: 0.7706\n",
      "Epoch 15/50 | Train Loss: 0.1516 | Val Loss: 0.7424\n",
      "Epoch 16/50 | Train Loss: 0.1803 | Val Loss: 0.7368\n",
      "Epoch 17/50 | Train Loss: 0.1569 | Val Loss: 0.7511\n",
      "Epoch 18/50 | Train Loss: 0.1778 | Val Loss: 0.7802\n",
      "Epoch 19/50 | Train Loss: 0.1632 | Val Loss: 0.8147\n",
      "Epoch 20/50 | Train Loss: 0.1402 | Val Loss: 0.8540\n",
      "Epoch 21/50 | Train Loss: 0.1283 | Val Loss: 0.8912\n",
      "Epoch 22/50 | Train Loss: 0.1900 | Val Loss: 0.9019\n",
      "Epoch 23/50 | Train Loss: 0.1511 | Val Loss: 0.8890\n",
      "Epoch 24/50 | Train Loss: 0.1446 | Val Loss: 0.8598\n",
      "Epoch 25/50 | Train Loss: 0.0974 | Val Loss: 0.8381\n",
      "Epoch 26/50 | Train Loss: 0.1183 | Val Loss: 0.8236\n",
      "Epoch 27/50 | Train Loss: 0.1373 | Val Loss: 0.8188\n",
      "Epoch 28/50 | Train Loss: 0.1407 | Val Loss: 0.8299\n",
      "Epoch 29/50 | Train Loss: 0.1106 | Val Loss: 0.8563\n",
      "Epoch 30/50 | Train Loss: 0.1331 | Val Loss: 0.8727\n",
      "Epoch 31/50 | Train Loss: 0.1067 | Val Loss: 0.8888\n",
      "Epoch 32/50 | Train Loss: 0.0794 | Val Loss: 0.9141\n",
      "Epoch 33/50 | Train Loss: 0.1446 | Val Loss: 0.9204\n",
      "Epoch 34/50 | Train Loss: 0.0845 | Val Loss: 0.9163\n",
      "Epoch 35/50 | Train Loss: 0.0834 | Val Loss: 0.9206\n",
      "Epoch 36/50 | Train Loss: 0.0859 | Val Loss: 0.9333\n",
      "Epoch 37/50 | Train Loss: 0.0889 | Val Loss: 0.9480\n",
      "Epoch 38/50 | Train Loss: 0.0644 | Val Loss: 0.9631\n",
      "Epoch 39/50 | Train Loss: 0.0679 | Val Loss: 0.9678\n",
      "Epoch 40/50 | Train Loss: 0.0853 | Val Loss: 0.9657\n",
      "Epoch 41/50 | Train Loss: 0.0638 | Val Loss: 0.9660\n",
      "Epoch 42/50 | Train Loss: 0.0924 | Val Loss: 0.9448\n",
      "Epoch 43/50 | Train Loss: 0.0947 | Val Loss: 0.9469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:18:06,267 - root - INFO - \n",
      "Top Performing Codes:\n",
      "2025-04-08 17:18:06,268 - root - INFO - code_0 [Score: 1.92]:\n",
      "2025-04-08 17:18:06,269 - root - INFO - Code workflow:\n",
      "2025-04-08 17:18:06,269 - root - INFO - None\n",
      "2025-04-08 17:18:06,269 - root - INFO - Partial Code:\n",
      "2025-04-08 17:18:06,269 - root - INFO - def count_operations_to_make_non_decreasing(nums):\n",
      "    def find_min_adjacent_pair(nums):\n",
      "        min_sum = float('inf')\n",
      "        min_pair_index = -1\n",
      "        \n",
      "        # Loop through the array to find the adjacent pair with the minimum sum\n",
      "        for i in range(len(nums) - 1):\n",
      "            current_sum = nums[i] + nums[i + 1]\n",
      "            if current_sum < min_sum:\n",
      "                min_sum = current_sum\n",
      "                min_pair_index = i\n",
      "        \n",
      "        # Return the index of the pair along with the mi...\n",
      "\n",
      "2025-04-08 17:18:06,270 - root - INFO - code_1 [Score: 1.92]:\n",
      "2025-04-08 17:18:06,270 - root - INFO - Code workflow:\n",
      "2025-04-08 17:18:06,271 - root - INFO - None\n",
      "2025-04-08 17:18:06,271 - root - INFO - Partial Code:\n",
      "2025-04-08 17:18:06,272 - root - INFO - def count_operations_to_make_non_decreasing(nums):\n",
      "    # Function to find the minimum adjacent pair in the array\n",
      "    def find_min_adjacent_pair(nums):\n",
      "        min_sum = float('inf')\n",
      "        min_index = -1\n",
      "        for i in range(len(nums) - 1):\n",
      "            current_sum = nums[i] + nums[i + 1]\n",
      "            # Identify the pair with the minimum sum\n",
      "            if current_sum < min_sum:\n",
      "                min_sum = current_sum\n",
      "                min_index = i\n",
      "        return (min_index, min_sum)\n",
      "\n",
      "    # Functi...\n",
      "\n",
      "2025-04-08 17:18:06,272 - root - INFO - code_4 [Score: 1.62]:\n",
      "2025-04-08 17:18:06,273 - root - INFO - Code workflow:\n",
      "2025-04-08 17:18:06,273 - root - INFO - The code implements three main helper functions: `find_min_adjacent_pair`, `replace_adjacent_pair_with_sum`, and `count_operations_to_make_non_decreasing`. \n",
      "\n",
      "1. **find_min_adjacent_pair**: This function iterates through the list, calculates sums of adjacent elements, and identifies the pair with the smallest sum. It returns the index and the sum of that pair.\n",
      "  \n",
      "2. **replace_adjacent_pair_with_sum**: This function takes the original list and the identified pair to create a new list where the pair is replaced by their sum. It constructs a new list by incorporating elements before, the sum of the found pair, and elements after the pair.\n",
      "\n",
      "3. **count_operations_to_make_non_decreasing**: This function checks whether the list is non-decreasing. If any adjacent elements are found in decreasing order, it finds the minimal pair, replaces them, and increments the operation counter. This loop continues until the entire list is non-decreasing.\n",
      "\n",
      "4. **main**: The main function integrates these components and returns the count of operations necessary to make the array non-decreasing.\n",
      "\n",
      "This organized and clear approach helps ensure that the code follows the given specifications correctly, making it maintainable and easy to understand.\n",
      "2025-04-08 17:18:06,274 - root - INFO - Partial Code:\n",
      "2025-04-08 17:18:06,274 - root - INFO - def find_min_adjacent_pair(nums):\n",
      "    \"\"\"\n",
      "    Identifies the adjacent pair of elements in the array that has the minimum sum.\n",
      "    \n",
      "    Args:\n",
      "    nums (list): A list of integers with no fixed shape.\n",
      "    \n",
      "    Returns:\n",
      "    tuple: A tuple containing the index of the pair with the minimum sum and the sum itself.\n",
      "    \"\"\"\n",
      "    min_sum = float('inf')  # Initialize with infinity\n",
      "    min_index = -1  # Initialize index of the pair\n",
      "    \n",
      "    # Loop through the array up to the second last element\n",
      "    for i in ...\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44/50 | Train Loss: 0.1119 | Val Loss: 0.9596\n",
      "Epoch 45/50 | Train Loss: 0.0605 | Val Loss: 0.9329\n",
      "Epoch 46/50 | Train Loss: 0.0647 | Val Loss: 0.8699\n",
      "Epoch 47/50 | Train Loss: 0.0513 | Val Loss: 0.8521\n",
      "Epoch 48/50 | Train Loss: 0.0257 | Val Loss: 0.8711\n",
      "Epoch 49/50 | Train Loss: 0.0544 | Val Loss: 0.9230\n",
      "Epoch 50/50 | Train Loss: 0.0271 | Val Loss: 1.0057\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-08 17:18:09,158 - root - INFO - \n",
      "=== Iteration 2/2 ===\n",
      "Generating async codes:   0%|          | 0/5 [00:00<?, ?it/s]2025-04-08 17:18:21,093 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  20%|██        | 1/5 [00:12<00:48, 12.00s/it]2025-04-08 17:18:21,230 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  40%|████      | 2/5 [00:12<00:15,  5.02s/it]2025-04-08 17:18:21,494 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  60%|██████    | 3/5 [00:12<00:05,  2.82s/it]2025-04-08 17:18:25,346 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  80%|████████  | 4/5 [00:16<00:03,  3.25s/it]2025-04-08 17:18:26,072 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes: 100%|██████████| 5/5 [00:16<00:00,  3.39s/it]\n",
      "2025-04-08 17:18:26,134 - root - INFO - Evaluating codes...\n",
      "Running tests: 100%|██████████| 90/90 [00:14<00:00,  6.31it/s]\n",
      "2025-04-08 17:18:40,412 - root - INFO - Filtered test cases: 16 out of 18\n",
      "2025-04-08 17:18:40,415 - root - INFO - Score values scaled using MinMaxScaler.\n",
      "2025-04-08 17:18:40,418 - root - INFO - Built node type vocabulary with size: 34\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\"def count_operations_to_make_non_decreasing(nums):\\n    # Helper function to find the pair with the minimum sum\\n    def find_min_adjacent_pair(arr):\\n        min_sum = float('inf')\\n        min_index = -1\\n        \\n        for i in range(len(arr) - 1):\\n            current_sum = arr[i] + arr[i + 1]\\n            if current_sum < min_sum:\\n                min_sum = current_sum\\n                min_index = i\\n        \\n        return (min_index, min_sum)\\n\\n    # Helper function to replace the specified pair with their sum\\n    def replace_adjacent_pair_with_sum(arr, index_val_tuple):\\n        index, value = index_val_tuple\\n        new_value = arr[index] + arr[index + 1]\\n        return arr[:index] + [new_value] + arr[index + 2:]\\n\\n    # Helper function to check if the array is non-decreasing\\n    def is_non_decreasing(arr):\\n        for i in range(1, len(arr)):\\n            if arr[i] < arr[i - 1]:\\n                return False\\n        return True\\n    \\n    operation_count = 0\\n\\n    # Continue processing until the array is non-decreasing\\n    while not is_non_decreasing(nums):\\n        # Find the leftmost minimum adjacent pair\\n        min_pair_index = find_min_adjacent_pair(nums)\\n        # Replace the adjacent pair with their sum\\n        nums = replace_adjacent_pair_with_sum(nums, min_pair_index)\\n        # Increment the operation count\\n        operation_count += 1\\n\\n    return operation_count\", \"def find_min_adjacent_pair(nums):\\n    # Initialize variables to store the minimum sum and the index of the pair\\n    min_sum = float('inf')\\n    index = -1\\n    \\n    # Loop through the array to find the adjacent pair with the minimum sum\\n    for i in range(len(nums) - 1):\\n        current_sum = nums[i] + nums[i + 1]\\n        if current_sum < min_sum:\\n            min_sum = current_sum\\n            index = i\\n            \\n    # Return the index and the minimum sum as a tuple\\n    return (index, min_sum)\\n\\ndef replace_adjacent_pair_with_sum(nums, pair):\\n    index, _ = pair  # Unpack index and value from tuple\\n    # Compute the sum of the two adjacent elements\\n    new_value = nums[index] + nums[index + 1]\\n    # Create a new list to store modified elements\\n    new_list = nums[:index] + [new_value] + nums[index + 2:]\\n    return new_list\\n\\ndef count_operations_to_make_non_decreasing(nums):\\n    operations_count = 0\\n    \\n    while not all(nums[i] <= nums[i + 1] for i in range(len(nums) - 1)):\\n        # Find the leftmost adjacent pair with the minimum sum\\n        pair = find_min_adjacent_pair(nums)\\n        # Replace the identified pair with their sum\\n        nums = replace_adjacent_pair_with_sum(nums, pair)\\n        operations_count += 1  # Increment the operations count\\n    \\n    return operations_count\\n\\ndef reduce_to_non_decreasing(nums):\\n    # Call the counts operation function with the given nums\\n    return count_operations_to_make_non_decreasing(nums)\", \"def find_min_adjacent_pair(nums):\\n    min_sum = float('inf')\\n    min_index = -1\\n    \\n    # Loop through the array to find adjacent pairs\\n    for i in range(len(nums) - 1):\\n        current_sum = nums[i] + nums[i + 1]\\n        # Update min_sum and min_index if a new minimum is found\\n        if current_sum < min_sum:\\n            min_sum = current_sum\\n            min_index = i\\n\\n    # Return the index of the pair and the minimum sum as a tuple\\n    return (min_index, min_sum)\\n\\ndef replace_adjacent_pair_with_sum(nums, pair):\\n    index, _ = pair\\n    # Calculate the sum of the adjacent pair\\n    sum_value = nums[index] + nums[index + 1]\\n    \\n    # Create a new list replacing the pair with their sum\\n    new_list = nums[:index] + [sum_value] + nums[index + 2:]\\n    \\n    return new_list\\n\\ndef count_operations_to_make_non_decreasing(nums):\\n    operation_count = 0\\n    \\n    # Check if the array is not non-decreasing\\n    while any(nums[i] > nums[i + 1] for i in range(len(nums) - 1)):\\n        # Find the minimum adjacent pair\\n        min_pair = find_min_adjacent_pair(nums)\\n        # Replace the adjacent pair with their sum\\n        nums = replace_adjacent_pair_with_sum(nums, min_pair)\\n        # Increment the operation count\\n        operation_count += 1\\n\\n    return operation_count\\n\\ndef make_array_non_decreasing(nums):\\n    # Main function to initiate the counting of operations\\n    return count_operations_to_make_non_decreasing(nums)\", \"def find_min_adjacent_pair(nums):\\n    # Initialize variable to store the minimum sum found and the index of the pair\\n    min_sum = float('inf')\\n    pair_index = -1\\n    \\n    # Loop through the array from the beginning to the second last element\\n    for i in range(len(nums) - 1):\\n        # Calculate the sum of the current element and the next element\\n        current_sum = nums[i] + nums[i + 1]\\n        \\n        # Check if the current sum is less than the minimum sum found\\n        if current_sum < min_sum:\\n            min_sum = current_sum\\n            pair_index = i  # Store the index of the pair\\n    \\n    # Return the index of the pair with the minimum sum and the sum itself as a tuple\\n    return (pair_index, min_sum)\\n\\ndef replace_adjacent_pair_with_sum(nums, pair_info):\\n    # Extract the index and value (sum) of the adjacent pair from the input tuple\\n    index, value = pair_info\\n    \\n    # Compute the sum of the two adjacent elements at the specified index\\n    new_value = nums[index] + nums[index + 1]\\n    \\n    # Create a new list to hold the elements after replacement\\n    new_list = nums[:index] + [new_value] + nums[index + 2:]\\n    \\n    return new_list\\n\\ndef count_operations_to_make_non_decreasing(nums):\\n    # Initialize an operation counter\\n    operation_count = 0\\n    \\n    # Continue performing operations until the array is non-decreasing\\n    while True:\\n        # Check if the array is already non-decreasing\\n        if all(nums[i] <= nums[i + 1] for i in range(len(nums) - 1)):\\n            break\\n        \\n        # Find the pair with the minimum sum\\n        pair_info = find_min_adjacent_pair(nums)\\n        \\n        # Replace the identified adjacent pair with their sum\\n        nums = replace_adjacent_pair_with_sum(nums, pair_info)\\n        \\n        # Increment the operation count\\n        operation_count += 1\\n    \\n    return operation_count\\n\\ndef main_function(nums):\\n    # Accept an input list of integers and initiate the operation count process\\n    return count_operations_to_make_non_decreasing(nums)\", 'def find_min_adjacent_pair(nums):\\n    \"\"\"\\n    Identify the adjacent pair of elements in the array that has the minimum sum.\\n    \\n    Args:\\n    nums (list): A list of integers.\\n    \\n    Returns:\\n    tuple: A tuple containing the index of the pair and the minimum sum.\\n    \"\"\"\\n    # Initialize the minimum sum and index\\n    min_sum = float(\\'inf\\')\\n    index = -1\\n\\n    # Loop through the array to find the minimum sum of adjacent pairs\\n    for i in range(len(nums) - 1):\\n        current_sum = nums[i] + nums[i + 1]\\n        if current_sum < min_sum:\\n            min_sum = current_sum\\n            index = i\\n\\n    return (index, min_sum)\\n\\ndef replace_adjacent_pair_with_sum(nums, adjacent_pair):\\n    \"\"\"\\n    Replace the identified adjacent pair in the array with their sum.\\n    \\n    Args:\\n    nums (list): A list of integers.\\n    adjacent_pair (tuple): A tuple containing the index of the pair and the corresponding sum.\\n    \\n    Returns:\\n    list: Updated list with the pair replaced by their sum.\\n    \"\"\"\\n    index, _ = adjacent_pair\\n    new_list = nums[:index]  # Elements before the pair\\n    new_element = nums[index] + nums[index + 1]  # The sum of the pair\\n    new_list.append(new_element)  # Add the sum\\n    new_list.extend(nums[index + 2:])  # Elements after the pair\\n    return new_list\\n\\ndef count_operations_to_make_non_decreasing(nums):\\n    \"\"\"\\n    Perform iterations of finding and replacing adjacent pairs until the array is non-decreasing,\\n    counting the number of operations needed.\\n\\n    Args:\\n    nums (list): A list of integers.\\n    \\n    Returns:\\n    int: The number of operations performed.\\n    \"\"\"\\n    operations_count = 0\\n\\n    # Check if the array is already non-decreasing\\n    while any(nums[i] > nums[i + 1] for i in range(len(nums) - 1)):\\n        # Find the minimum adjacent pair\\n        adjacent_pair = find_min_adjacent_pair(nums)\\n        # Replace the adjacent pair with their sum\\n        nums = replace_adjacent_pair_with_sum(nums, adjacent_pair)\\n        operations_count += 1\\n\\n    return operations_count\\n\\ndef make_array_non_decreasing(nums):\\n    \"\"\"\\n    Main function to initiate the process of making the array non-decreasing.\\n\\n    Args:\\n    nums (list): A list of integers.\\n\\n    Returns:\\n    int: The total count of operations performed.\\n    \"\"\"\\n    return count_operations_to_make_non_decreasing(nums)']\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "index out of range in self",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[32], line 34\u001b[0m\n\u001b[0;32m      1\u001b[0m test_task_description \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\"\"\u001b[39m\u001b[38;5;124mGiven an array nums, you can perform the following operation any number of times:\u001b[39m\n\u001b[0;32m      2\u001b[0m \n\u001b[0;32m      3\u001b[0m \u001b[38;5;124mSelect the adjacent pair with the minimum sum in nums. If multiple such pairs exist, choose the leftmost one.\u001b[39m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     30\u001b[0m \n\u001b[0;32m     31\u001b[0m \u001b[38;5;124mThe array nums is already sorted.\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[0;32m     33\u001b[0m lcdp \u001b[38;5;241m=\u001b[39m LCDP(api_key\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg\u001b[39m\u001b[38;5;124m\"\u001b[39m, model\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpt-4o-mini\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 34\u001b[0m best_codes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m lcdp\u001b[38;5;241m.\u001b[39mrun(\n\u001b[0;32m     35\u001b[0m     task_description\u001b[38;5;241m=\u001b[39mtest_task_description,\n\u001b[0;32m     36\u001b[0m     max_iterations\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m,\n\u001b[0;32m     37\u001b[0m     num_plans\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m,\n\u001b[0;32m     38\u001b[0m     num_tests\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[0;32m     39\u001b[0m     num_codes\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[0;32m     40\u001b[0m     refine_rounds\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m,\n\u001b[0;32m     41\u001b[0m     use_pass_rate_for_train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m     42\u001b[0m     \u001b[38;5;66;03m# use_example=True,\u001b[39;00m\n\u001b[0;32m     43\u001b[0m     \u001b[38;5;66;03m# example_dataset=example_codes,\u001b[39;00m\n\u001b[0;32m     44\u001b[0m )\n\u001b[0;32m     45\u001b[0m \u001b[38;5;28mprint\u001b[39m(best_codes)\n",
      "Cell \u001b[1;32mIn[31], line 181\u001b[0m, in \u001b[0;36mLCDP.run\u001b[1;34m(self, task_description, max_iterations, example_dataset, num_plans, num_tests, num_codes, refine_rounds, use_pass_rate_for_train)\u001b[0m\n\u001b[0;32m    179\u001b[0m \u001b[38;5;66;03m# Evaluate codes\u001b[39;00m\n\u001b[0;32m    180\u001b[0m logging\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEvaluating codes...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 181\u001b[0m scored_codes, filtered_test_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluate_codes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_codes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    182\u001b[0m \u001b[38;5;66;03m# remove the test cases that are not in the filtered_test_result\u001b[39;00m\n\u001b[0;32m    183\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_cases \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_cases\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(filtered_test_result\u001b[38;5;241m.\u001b[39mkeys())}\n",
      "Cell \u001b[1;32mIn[31], line 396\u001b[0m, in \u001b[0;36mLCDP._evaluate_codes\u001b[1;34m(self, codes)\u001b[0m\n\u001b[0;32m    390\u001b[0m     input_data[code_id] \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m    391\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m: codes[code_id][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[0;32m    392\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_results\u001b[39m\u001b[38;5;124m'\u001b[39m: results,\n\u001b[0;32m    393\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_weights\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_weights\n\u001b[0;32m    394\u001b[0m     }\n\u001b[0;32m    395\u001b[0m \u001b[38;5;66;03m# Calculate scores\u001b[39;00m\n\u001b[1;32m--> 396\u001b[0m output_scores, full_score_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcalculate_batch_scores\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    397\u001b[0m \u001b[38;5;66;03m# Combine scores with code data\u001b[39;00m\n\u001b[0;32m    398\u001b[0m output_results \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m    399\u001b[0m     code_id: {\n\u001b[0;32m    400\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m: codes[code_id][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    410\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m code_id \u001b[38;5;129;01min\u001b[39;00m codes\u001b[38;5;241m.\u001b[39mkeys()\n\u001b[0;32m    411\u001b[0m }\n",
      "Cell \u001b[1;32mIn[31], line 46\u001b[0m, in \u001b[0;36mEvaluator.calculate_batch_scores\u001b[1;34m(self, code_data)\u001b[0m\n\u001b[0;32m     44\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m     45\u001b[0m         \u001b[38;5;28mprint\u001b[39m(code_strs)\n\u001b[1;32m---> 46\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[0;32m     48\u001b[0m \u001b[38;5;66;03m# 并行计算静态分析分数\u001b[39;00m\n\u001b[0;32m     49\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m concurrent\u001b[38;5;241m.\u001b[39mfutures\u001b[38;5;241m.\u001b[39mThreadPoolExecutor() \u001b[38;5;28;01mas\u001b[39;00m executor:\n",
      "Cell \u001b[1;32mIn[31], line 40\u001b[0m, in \u001b[0;36mEvaluator.calculate_batch_scores\u001b[1;34m(self, code_data)\u001b[0m\n\u001b[0;32m     38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpass_rate_predictor \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpass_rate_predictor\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m     39\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 40\u001b[0m         prediction_scores \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpass_rate_predictor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict_score\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcode_strs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     41\u001b[0m         \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m###############################################################\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     42\u001b[0m         \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrediction scores: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprediction_scores\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "Cell \u001b[1;32mIn[30], line 194\u001b[0m, in \u001b[0;36mPassRatePredictor.predict_score\u001b[1;34m(self, new_code_samples, model, scaler)\u001b[0m\n\u001b[0;32m    192\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m    193\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m loader:\n\u001b[1;32m--> 194\u001b[0m         pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    195\u001b[0m         preds\u001b[38;5;241m.\u001b[39mextend(pred\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy())\n\u001b[0;32m    197\u001b[0m \u001b[38;5;66;03m# 反归一化\u001b[39;00m\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[1;32mIn[30], line 132\u001b[0m, in \u001b[0;36mGNNModel.forward\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m    130\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, data):\n\u001b[0;32m    131\u001b[0m     x, edge_index, batch \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mx, data\u001b[38;5;241m.\u001b[39medge_index, data\u001b[38;5;241m.\u001b[39mbatch\n\u001b[1;32m--> 132\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    133\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x, edge_index)\n\u001b[0;32m    134\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(F\u001b[38;5;241m.\u001b[39mrelu(x))\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\sparse.py:163\u001b[0m, in \u001b[0;36mEmbedding.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    162\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 163\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    164\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    165\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\functional.py:2264\u001b[0m, in \u001b[0;36membedding\u001b[1;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[0;32m   2258\u001b[0m     \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[0;32m   2259\u001b[0m     \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[0;32m   2260\u001b[0m     \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m   2261\u001b[0m     \u001b[38;5;66;03m#   torch.embedding_renorm_\u001b[39;00m\n\u001b[0;32m   2262\u001b[0m     \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[0;32m   2263\u001b[0m     _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[1;32m-> 2264\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mIndexError\u001b[0m: index out of range in self"
     ]
    }
   ],
   "source": [
    "test_task_description = \"\"\"Given an array nums, you can perform the following operation any number of times:\n",
    "\n",
    "Select the adjacent pair with the minimum sum in nums. If multiple such pairs exist, choose the leftmost one.\n",
    "Replace the pair with their sum.\n",
    "Return the minimum number of operations needed to make the array non-decreasing.\n",
    "\n",
    "An array is said to be non-decreasing if each element is greater than or equal to its previous element (if it exists).\n",
    "\n",
    " \n",
    "\n",
    "Example 1:\n",
    "\n",
    "Input: nums = [5,2,3,1]\n",
    "\n",
    "Output: 2\n",
    "\n",
    "Explanation:\n",
    "\n",
    "The pair (3,1) has the minimum sum of 4. After replacement, nums = [5,2,4].\n",
    "The pair (2,4) has the minimum sum of 6. After replacement, nums = [5,6].\n",
    "The array nums became non-decreasing in two operations.\n",
    "\n",
    "Example 2:\n",
    "\n",
    "Input: nums = [1,2,2]\n",
    "\n",
    "Output: 0\n",
    "\n",
    "Explanation:\n",
    "\n",
    "The array nums is already sorted.\"\"\"\n",
    "\n",
    "lcdp = LCDP(api_key=\"sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg\", model=\"gpt-4o-mini\")\n",
    "best_codes = await lcdp.run(\n",
    "    task_description=test_task_description,\n",
    "    max_iterations=2,\n",
    "    num_plans=3,\n",
    "    num_tests=5,\n",
    "    num_codes=5,\n",
    "    refine_rounds=3,\n",
    "    use_pass_rate_for_train=False,\n",
    "    # use_example=True,\n",
    "    # example_dataset=example_codes,\n",
    ")\n",
    "print(best_codes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_case_1_1': {'test_type': 'correctness', 'purpose': 'We will create a test function to verify the correctness of the implementation that counts the number of operations required to make the array non-decreasing. We will use an input `[5, 2, 3, 1]` which is not sorted. The expected output for this input is `2` as demonstrated in the example. In the first operation, the pair `(3, 1)` will be replaced with `4`, resulting in `[5, 2, 4]`. The second operation will replace the pair `(2, 4)` with `6`, resulting in `[5, 6]`. After two operations, the array becomes non-decreasing.', 'test_function': 'def test_case(func):\\n    input_data = [5, 2, 3, 1]\\n    expected_output = 2\\n    result = func(input_data)\\n    return result == expected_output'}, 'test_case_2': {'test_type': 'edge_case', 'purpose': 'This test case will check how the function handles an empty array. An empty array is considered non-decreasing by definition because there are no elements to violate the rule. We expect the function to return 0 operations as no action is needed.', 'test_function': 'def test_case(func):\\n    nums = []\\n    expected_output = 0\\n    output = func(nums)\\n    return output == expected_output'}, 'test_case_3': {'test_type': 'correctness', 'purpose': 'This test case ensures that the function can handle an array that is already non-decreasing. Therefore, we will provide an array that is already sorted, such as [1, 2, 2, 3]. We expect the output of the function to be 0 since no operations are needed to make it non-decreasing.', 'test_function': 'def test_case(func):\\n    nums = [1, 2, 2, 3]\\n    expected_output = 0\\n    output = func(nums)\\n    return output == expected_output'}, 'test_case_1_3': {'test_type': 'correctness', 'purpose': 'To test the correctness of the function that counts the operations needed to make an array non-decreasing, I will create a test case where the input array is not sorted. I will use the input array [5, 2, 3, 1], which requires several operations to become non-decreasing. I expect the output to be 2, since the pairs (3, 1) and (2, 4) will be merged in two operations.', 'test_function': 'def test_case(func):\\n    input_data = [5, 2, 3, 1]\\n    expected_output = 2\\n    result = func(input_data)\\n    return result == expected_output'}, 'test_case_2_1': {'test_type': 'edge_case', 'purpose': 'For edge cases, I will test an empty array and a single-element array. An empty array should return 0 because no operations are needed. A single-element array (e.g., [1]) should also return 0 since it is trivially non-decreasing. This will ensure that the function can handle inputs at the bounds of the specified requirements.', 'test_function': 'def test_case(func):\\n    input_data_empty = []\\n    expected_output_empty = 0\\n    result_empty = func(input_data_empty)\\n    if result_empty != expected_output_empty:\\n        return False\\n    \\n    input_data_single = [1]\\n    expected_output_single = 0\\n    result_single = func(input_data_single)\\n    \\n    return result_single == expected_output_single'}, 'test_case_3_1': {'test_type': 'runtime', 'purpose': 'To assess runtime performance, I will generate a large random array of integers (e.g., 10,000 elements) and measure the time it takes for the function to execute. The expectation is that the function should complete in a reasonable time frame (e.g., under 1 second). This test will determine if the function maintains efficiency with larger inputs.', 'test_function': 'import time\\nimport random\\n\\ndef test_case(func):\\n    input_data = [random.randint(-100, 100) for _ in range(10000)]\\n    start_time = time.time()\\n    func(input_data)\\n    end_time = time.time()\\n    execution_time = end_time - start_time\\n    return execution_time < 1  # Expecting it to run under 1 second'}, 'test_case_1_4': {'test_type': 'correctness', 'purpose': 'In this test case, I will validate the function with an already sorted array. The input will be [1, 2, 2], which is already non-decreasing. Since no operations are needed, the expected output should be 0. This test case will confirm that the function handles sorted inputs correctly and identifies that no changes are necessary.', 'test_function': 'def test_case(func):\\n    # Test input that is already sorted\\n    nums = [1, 2, 2]\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_2_2': {'test_type': 'correctness', 'purpose': 'In this test case, I will validate the function using an array that is in reverse order ([5, 4, 3, 2, 1]). This array requires complete reconstruction to become non-decreasing. The expected output is 4, as four operations are required: (5,4), (5,3), (5,2), and (5,1) will all be summed to form [5, 10, 15]. This test checks if the function can transform a completely descending list into a non-decreasing one efficiently.', 'test_function': 'def test_case(func):\\n    # Test reverse sorted input\\n    nums = [5, 4, 3, 2, 1]\\n    expected_output = 4\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_3_2': {'test_type': 'correctness', 'purpose': 'This test case will examine how the function handles an array with repeated elements. The input will be [2, 2, 1, 3]. Starting with this array, the expected output should be 1 because the pair (2,1) will be summed to form [2,3,3]. This case ensures that the function can handle duplicates appropriately while still maintaining order.', 'test_function': 'def test_case(func):\\n    # Test input with repeated elements\\n    nums = [2, 2, 1, 3]\\n    expected_output = 1\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_4_1': {'test_type': 'edge_case', 'purpose': 'In this test case, I will validate the function with an empty list. The expected output should be 0 since there are no elements to operate on, so the requirement for a non-decreasing order is trivially satisfied. This checks the edge case where no data is present.', 'test_function': 'def test_case(func):\\n    # Test empty input\\n    nums = []\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_5_1': {'test_type': 'edge_case', 'purpose': 'This test case will focus on a single-element array, e.g., [7]. The array is already non-decreasing, and thus the expected output should be 0. This checks the functionality when the input list contains the minimum number of elements.', 'test_function': 'def test_case(func):\\n    # Test single element input\\n    nums = [7]\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}}\n",
      "Test ID: test_case_1_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    input_data = [5, 2, 3, 1]\n",
      "    expected_output = 2\n",
      "    result = func(input_data)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_2\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    nums = []\n",
      "    expected_output = 0\n",
      "    output = func(nums)\n",
      "    return output == expected_output\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n",
      "Test ID: test_case_3\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    nums = [1, 2, 2, 3]\n",
      "    expected_output = 0\n",
      "    output = func(nums)\n",
      "    return output == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_1_3\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    input_data = [5, 2, 3, 1]\n",
      "    expected_output = 2\n",
      "    result = func(input_data)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_2_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    input_data_empty = []\n",
      "    expected_output_empty = 0\n",
      "    result_empty = func(input_data_empty)\n",
      "    if result_empty != expected_output_empty:\n",
      "        return False\n",
      "    \n",
      "    input_data_single = [1]\n",
      "    expected_output_single = 0\n",
      "    result_single = func(input_data_single)\n",
      "    \n",
      "    return result_single == expected_output_single\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n",
      "Test ID: test_case_3_1\n",
      "Test Case:\n",
      "import time\n",
      "import random\n",
      "\n",
      "def test_case(func):\n",
      "    input_data = [random.randint(-100, 100) for _ in range(10000)]\n",
      "    start_time = time.time()\n",
      "    func(input_data)\n",
      "    end_time = time.time()\n",
      "    execution_time = end_time - start_time\n",
      "    return execution_time < 1  # Expecting it to run under 1 second\n",
      "Test Type: runtime\n",
      "----------------------------------------\n",
      "Test ID: test_case_1_4\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test input that is already sorted\n",
      "    nums = [1, 2, 2]\n",
      "    expected_output = 0\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_2_2\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test reverse sorted input\n",
      "    nums = [5, 4, 3, 2, 1]\n",
      "    expected_output = 4\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_3_2\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test input with repeated elements\n",
      "    nums = [2, 2, 1, 3]\n",
      "    expected_output = 1\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_4_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test empty input\n",
      "    nums = []\n",
      "    expected_output = 0\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n",
      "Test ID: test_case_5_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test single element input\n",
      "    nums = [7]\n",
      "    expected_output = 0\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "test_cases = lcdp.test_cases\n",
    "print(test_cases)\n",
    "for test_id, test_case_dict in test_cases.items():\n",
    "    print(f\"Test ID: {test_id}\")\n",
    "    print(f\"Test Case:\\n{test_case_dict['test_function']}\")\n",
    "    print(f\"Test Type: {test_case_dict['test_type']}\")\n",
    "    print(\"-\" * 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['code_0', 'code_1', 'code_4', 'code_3'])\n",
      "dict_keys(['code', 'plan', 'main_function_name', 'score', 'test_case_results'])\n"
     ]
    }
   ],
   "source": [
    "checking_code = \"code_0\"\n",
    "print(best_codes.keys())\n",
    "print(best_codes[checking_code].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_case_1_1': True, 'test_case_2': True, 'test_case_3': True, 'test_case_1_3': True, 'test_case_2_1': True, 'test_case_3_1': False, 'test_case_1_4': True, 'test_case_2_2': True, 'test_case_3_2': True, 'test_case_4_1': True, 'test_case_5_1': True}\n",
      "10\n",
      "11\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"test_case_results\"])\n",
    "# get the total number of True in test_case_results\n",
    "print(sum([v for v in best_codes[checking_code][\"test_case_results\"].values()]))\n",
    "print(len(best_codes[checking_code][\"test_case_results\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def count_operations_to_sort(nums):\n",
      "    def find_min_adjacent_pair(nums):\n",
      "        min_sum = float('inf')\n",
      "        index = -1\n",
      "        \n",
      "        # Iterate through the array to find the minimum sum of adjacent pairs\n",
      "        for i in range(len(nums) - 1):\n",
      "            current_sum = nums[i] + nums[i + 1]\n",
      "            if current_sum < min_sum:\n",
      "                min_sum = current_sum\n",
      "                index = i\n",
      "        \n",
      "        return index\n",
      "\n",
      "    def replace_pair_with_sum(nums, index):\n",
      "        # Calculate the sum of the identified pair\n",
      "        sum_pair = nums[index] + nums[index + 1]\n",
      "        # Create a new list with the pair replaced by their sum\n",
      "        new_nums = nums[:index] + [sum_pair] + nums[index + 2:]\n",
      "        return new_nums\n",
      "\n",
      "    operation_count = 0\n",
      "\n",
      "    # Keep performing operations until the array is non-decreasing\n",
      "    while any(nums[i] > nums[i + 1] for i in range(len(nums) - 1)):\n",
      "        index = find_min_adjacent_pair(nums)\n",
      "        nums = replace_pair_with_sum(nums, index)\n",
      "        operation_count += 1\n",
      "\n",
      "    return operation_count\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"code\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"main_function_name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.9685026303719275\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"score\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
