{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "import asyncio\n",
    "import random\n",
    "import os\n",
    "import logging\n",
    "import concurrent.futures\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "from tqdm.asyncio import tqdm_asyncio\n",
    "\n",
    "from openai import OpenAI\n",
    "from openai import AsyncOpenAI\n",
    "\n",
    "# 创建Logger实例\n",
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.DEBUG)\n",
    "\n",
    "# 定义日志格式\n",
    "formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "\n",
    "# 输出到文件的Handler\n",
    "log_dir = Path(\"log\")\n",
    "log_dir.mkdir(exist_ok=True)\n",
    "timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "log_filename = f\"{timestamp}.log\"\n",
    "log_path = log_dir / log_filename\n",
    "file_handler = logging.FileHandler(log_path)\n",
    "file_handler.setLevel(logging.DEBUG)\n",
    "file_handler.setFormatter(formatter)\n",
    "\n",
    "# 输出到控制台的Handler\n",
    "console_handler = logging.StreamHandler()\n",
    "console_handler.setLevel(logging.INFO)\n",
    "console_handler.setFormatter(formatter)\n",
    "\n",
    "# 将Handler添加到Logger\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
    "    handlers=[\n",
    "        file_handler,   # 文件输出\n",
    "        console_handler # 控制台输出\n",
    "    ]\n",
    ")\n",
    "\n",
    "class LLMModel():\n",
    "    def __init__(self, api_key, model=\"gpt-3.5-turbo\"):\n",
    "        if api_key is None:\n",
    "            self.api_key = \"sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg\"\n",
    "        else:\n",
    "            self.api_key = api_key\n",
    "        self.model = model\n",
    "        self.client = OpenAI(api_key=self.api_key)\n",
    "        self.client_async = AsyncOpenAI(api_key=self.api_key)\n",
    "    \n",
    "    def LLM_response(self, prompt, gen_kwargs={}, model=None):\n",
    "        if model is None:\n",
    "            model = self.model\n",
    "\n",
    "        if type(prompt) == str:\n",
    "            input_messages = [\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "                ]\n",
    "        elif type(prompt) == list:\n",
    "            input_messages = prompt\n",
    "        else:\n",
    "            logging.error(\"prompt must be a string or a list of messages, current type: \", type(prompt))\n",
    "            raise ValueError(\"prompt must be a string or a list of messages\")\n",
    "        \n",
    "        completion = self.client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=input_messages,\n",
    "            **gen_kwargs\n",
    "            )\n",
    "\n",
    "        return completion.choices[0].message.content\n",
    "    \n",
    "    async def LLM_response_async(self, prompt, gen_kwargs={}, model=None):\n",
    "        if model is None:\n",
    "            model = self.model\n",
    "\n",
    "        if type(prompt) == str:\n",
    "            prompt = \" \".join(prompt)\n",
    "            input_messages = [\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "                ]\n",
    "        elif type(prompt) == list:\n",
    "            input_messages = prompt\n",
    "        else:\n",
    "            logging.error(\"prompt must be a string or a list of messages, current type: \", type(prompt))\n",
    "            raise ValueError(\"prompt must be a string or a list of messages\")\n",
    "        \n",
    "        completion = await self.client_async.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=input_messages,\n",
    "            **gen_kwargs\n",
    "            )\n",
    "        return completion.choices[0].message.content\n",
    "\n",
    "class LLMTM():\n",
    "    def __init__(self, task_description, LLM_model):\n",
    "        self.LLM_model = LLM_model\n",
    "        self.task_description = task_description\n",
    "\n",
    "    def create_plan_prompt(self, task_description=None):\n",
    "\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "\n",
    "        task_decompose_prompt = \"\"\"You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks and then give a detailed overall plan based on your defined subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.\n",
    "\n",
    "Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single component.\n",
    "\n",
    "Your output must strictly follow the format below:\n",
    "\n",
    "<components>\n",
    "{\n",
    "  \"component_1\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[type, shape or null]],\n",
    "    \"output_format\": [[type, shape or null]],\n",
    "    \"work_flow\": [str],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  \"component_2\": {\n",
    "    \"step_task_description\": str,\n",
    "    \"input_format\": [[\"type\", shape or null]],\n",
    "    \"output_format\": [[\"type\", shape or null]],\n",
    "    \"work_flow\": [str],\n",
    "    \"test_case_generation_advise\": [str]\n",
    "  },\n",
    "  ...\n",
    "}\n",
    "</components>\n",
    "\n",
    "<overall_plan>\n",
    "{\n",
    "  \"input_format\": [[\"type\", shape or null]],\n",
    "  \"output_format\": [[\"type\", shape or null]],\n",
    "  \"components\": [str],\n",
    "  \"plan\": [str],\n",
    "  \"test_case_generation_advise\": [str]\n",
    "}\n",
    "</overall_plan>\n",
    "\n",
    "Here are additional detailed explanations of each field:\n",
    "\n",
    "For <components>:\n",
    "- **component_X**: The key represents the subtask name, it should be replaced by the actual class/function name of the component (e.g., \"merge_arrays\", \"calculate_median\").\n",
    "- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).\n",
    "- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:\n",
    "  - The first element indicates the data type (e.g., \"list\", \"dict\", NumPy array, torch.Tensor). DO make sure the data type is a string.\n",
    "  - The second element indicates the fixed shape if applicable; otherwise, it is null.\n",
    "- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`, note that it has to be a list of lists.\n",
    "- **work_flow**: Provide a detailed step-by-step plan that outlines the workflow of how the component functions to achieve the subtask.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "For <overall_plan>:\n",
    "- **input_format**: Describe the format of the input arguments required for the overall task. It follows the same structure as `input_format` in the component section.\n",
    "- **output_format**: Describe the format of the output arguments generated by the overall task. It follows the same structure as `output_format` in the component section.\n",
    "- **components**: List the components in the order.\n",
    "- **plan**: Provide a detailed step-by-step plan that outlines the workflow of how the components interact with each other to achieve the overall task. This should be a high-level description of the process.\n",
    "- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases for the overall task, explicitly mentioning potential edge cases and critical scenarios that need coverage.\n",
    "\n",
    "Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:\n",
    "\n",
    "\"{{TASK_DESCRIPTION}}\"\n",
    "\n",
    "Use <> to indicate both start and end of the component part and the overall plan. Ensure that the components and the overall plan are clearly separated.\n",
    "\n",
    "Please provide your structured decomposition according to the instructions above.\n",
    "\"\"\"\n",
    "        task_decompose_prompt = task_decompose_prompt.replace(\"{{TASK_DESCRIPTION}}\", task_description)\n",
    "        return task_decompose_prompt\n",
    "    \n",
    "    def create_plan_refinement_prompt(self, user_feedback, previous_output, task_description=None):\n",
    "\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "\n",
    "        plan_refinement_prompt = \"\"\"You are an expert agent specialized in refining and improving code generation plans through iterative feedback. Given a task description, previous decomposition output, and user feedback, your job is to critically analyze the existing plan and modify it accordingly while maintaining the required output format.\n",
    "\n",
    "Carefully review the previous components and overall plan, then:\n",
    "1. Preserve correct/valid elements that don't conflict with the feedback\n",
    "2. Make targeted modifications based on the user's specific advice\n",
    "3. Ensure consistency between components and overall plan\n",
    "4. Verify input/output formats and workflow logic\n",
    "5. Check for any introduced errors during modification\n",
    "\n",
    "The input consists of three elements:\n",
    "- Original Task Description: \"{{TASK_DESCRIPTION}}\"\n",
    "- Previous Decomposition Output: \n",
    "{{PREVIOUS_OUTPUT}}\n",
    "- User Feedback: \"{{USER_ADVICE}}\"\n",
    "\n",
    "Your output must STRICTLY follow the original format with these sections:\n",
    "<components>...</components>\n",
    "<overall_plan>...</overall_plan>\n",
    "\n",
    "Follow these guidelines:\n",
    "- Explicitly address all points in the user feedback\n",
    "- Clearly document any changes made from previous version\n",
    "- Preserve JSON structure and formatting requirements\n",
    "- If feedback contradicts original requirements, prioritize feedback\n",
    "\n",
    "Again, user feedback is: \"{{USER_ADVICE}}\"\n",
    "\n",
    "Provide your refined decomposition with clear explanations of changes in the component descriptions and overall plan. Ensure modularity, testability, and coverage of edge cases mentioned in feedback.\"\"\"\n",
    "\n",
    "        plan_refinement_prompt = plan_refinement_prompt.replace(\"{{TASK_DESCRIPTION}}\", task_description)\n",
    "        plan_refinement_prompt = plan_refinement_prompt.replace(\"{{USER_ADVICE}}\", user_feedback)\n",
    "        plan_refinement_prompt = plan_refinement_prompt.replace(\"{{PREVIOUS_OUTPUT}}\", previous_output)\n",
    "        return plan_refinement_prompt\n",
    "\n",
    "    def extract_plan(self, input_str):\n",
    "        # Updated regex pattern to match <tag>...</tag> format\n",
    "        pattern = r'<(components|overall_plan)>(.*?)</\\1>'\n",
    "        \n",
    "        # Find all matches, allowing multiline content\n",
    "        matches = re.findall(pattern, input_str, re.DOTALL)\n",
    "        \n",
    "        result = {}\n",
    "        for block_name, content in matches:\n",
    "            try:\n",
    "                # Strip whitespace\n",
    "                cleaned_content = content.strip()\n",
    "                \n",
    "                # Fix trailing commas\n",
    "                cleaned_content = re.sub(r',\\s*}', '}', cleaned_content)\n",
    "                cleaned_content = re.sub(r',\\s*\\]', ']', cleaned_content)\n",
    "                \n",
    "                # Parse JSON\n",
    "                parsed_data = json.loads(cleaned_content)\n",
    "                result[block_name] = parsed_data\n",
    "            except json.JSONDecodeError as e:\n",
    "                logging.warning(f\"JSON解析错误: {block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}\")\n",
    "                # print(f\"解析错误：{block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}\")\n",
    "                result[block_name] = None\n",
    "                return False\n",
    "        return result\n",
    "\n",
    "    def get_plan(self, task_description=None, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_prompt(task_description)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            extract_plan = self.extract_plan(llm_output)\n",
    "            if extract_plan:\n",
    "                break\n",
    "            else:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract plan, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Failed to extract plan, retrying ({retry_num})...\")\n",
    "        if extract_plan is False:\n",
    "            # print(\"Failed to extract plan, current llm_output:\\n\", llm_output)\n",
    "            logging.error(f\"Failed to extract plan, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract plan, current llm_output:\\n\", llm_output)\n",
    "        return extract_plan, llm_output\n",
    "    \n",
    "    async def get_plan_async(self, num_plan, task_description=None, gen_kwargs={}):\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_prompt(task_description)\n",
    "        # get multiple plans\n",
    "        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]\n",
    "        \n",
    "        llm_output = await tqdm_asyncio.gather(*task_list)\n",
    "        return [self.extract_plan(output) for output in llm_output]\n",
    "    \n",
    "    def refine_plan(self, user_feedback, previous_output, task_description=None, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            extract_plan = self.extract_plan(llm_output)\n",
    "            if extract_plan:\n",
    "                break\n",
    "            retry_num += 1\n",
    "        if extract_plan is False:\n",
    "            raise ValueError(\"Failed to extract plan, current llm_output:\\n\", llm_output)\n",
    "        return extract_plan, llm_output\n",
    "    \n",
    "    async def refine_multi_plan(self, num_plan, user_feedback, previous_output, task_description=None, gen_kwargs={}):\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)\n",
    "        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]\n",
    "        llm_output = await tqdm_asyncio.gather(*task_list)\n",
    "        return [self.extract_plan(output) for output in llm_output]\n",
    "    \n",
    "    def create_test_prompt(self, task_descr_str, task_spec, use_example=True, bulk=True):\n",
    "        \"\"\"\n",
    "        Generates a prompt (or list of prompts) for test case generation based on task specifications.\n",
    "        \n",
    "        Parameters:\n",
    "        - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.\n",
    "        - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.\n",
    "        \n",
    "        Returns:\n",
    "        - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).\n",
    "        \"\"\"\n",
    "        \n",
    "        # Helper function to generate the prompt text from a modified task specification\n",
    "        def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = \"\"):\n",
    "            advisory_list = \"\\n\".join([f\"- {advise}\" for advise in advisories])\n",
    "            prompt = f\"\"\"You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:\n",
    "\n",
    "### Input Specifications:\n",
    "- **Task Description**:\n",
    "{task_descr_str}\n",
    "- **Input Format**: \n",
    "{input_descr_str}\n",
    "- **Output Format**: \n",
    "{output_descr_str}\n",
    "- **Components Used**: {components_str}\n",
    "- **Plan**: \n",
    "{plan_str}\n",
    "- **Test Case Advise**: \n",
    "{advisory_list}\n",
    "\n",
    "### Requirements:\n",
    "1. **Test Function Structure**:\n",
    "   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):...`).\n",
    "   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.\n",
    "   - Include input generation, runtime checks, code inspection, or result validation within the function.\n",
    "\n",
    "2. **Test Types** (use one of these for indicating the test_type):\n",
    "   - `correctness`: Validate output against expected results for specific inputs.\n",
    "   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.\n",
    "   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).\n",
    "   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).\n",
    "   - `error_handling`: Check if errors are raised for invalid inputs.\n",
    "\n",
    "3. **Test Case Diversity**:\n",
    "   - Cover all provided advisories.\n",
    "   - Include at least one test per advisory and one for each test type where applicable.\n",
    "\n",
    "### Output Format:\n",
    "For each test case, you need to firstly define the Test Types to indicate what type of test case you are going to create and then give the reasoning and explanation of the test case. After that, generate the test function based on the your reasoning.\n",
    "\n",
    "For each test function, return with following structure:\n",
    "\n",
    "<Type>\n",
    "Pick one of correctness|edge_case|runtime|component_check|error_handling\n",
    "</Type>\n",
    "<Planning>\n",
    "Introduce how would you design the test function. Specify the purpose of the test function and the reasoning behind it. Explain step by step why your test case is correct and what is the expected output.\n",
    "</Planning>\n",
    "<Code>\n",
    "def test_case(func):\n",
    "    # Your test function code here\n",
    "</Code>\n",
    "\n",
    "If you are going to create multiple test cases, please separate them with <separator> tag.\n",
    "\n",
    "{example_text}\n",
    "Generate test cases that rigorously validate the function's behavior, code structure, and performance.\n",
    "You MUST strictly follow the output format and structure. The generated test functions MUST be runnable function that use another python function as its parameter.\"\"\"\n",
    "            return prompt\n",
    "\n",
    "        if use_example:\n",
    "            examples_text = \"\"\n",
    "        else:\n",
    "            examples_text = \"\"\n",
    "\n",
    "        # Process input_format into a descriptive string\n",
    "        input_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_descr.append(f\"- Argument {idx}: {dtype} ({shape_info})\")\n",
    "        input_descr_str = \"\\n\".join(input_descr)\n",
    "\n",
    "        # Process output_format into a descriptive string\n",
    "        output_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_descr.append(f\"- Output {idx}: {dtype} ({shape_info})\")\n",
    "        output_descr_str = \"\\n\".join(output_descr)\n",
    "\n",
    "        # Process components and plan\n",
    "        components_str = \", \".join(task_spec['components'])\n",
    "        plan_str = \"\\n\".join(task_spec['plan'])\n",
    "\n",
    "        if bulk:\n",
    "            # Generate a single prompt with all advisories\n",
    "            advisories = task_spec['test_case_generation_advise']\n",
    "            return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)\n",
    "        else:\n",
    "            # Generate a list of prompts, each with a single advisory\n",
    "            prompts = []\n",
    "            for advise in task_spec['test_case_generation_advise']:\n",
    "                single_advisory = [advise]\n",
    "                prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)\n",
    "                prompts.append(prompt)\n",
    "            return prompts\n",
    "\n",
    "    def extract_test_cases(self, output_text):\n",
    "        \"\"\"\n",
    "        Extracts test cases from LLM output text with flexible tag handling.\n",
    "        Supports case-insensitive tags, missing <Type> tags, and multi-separators.\n",
    "        \"\"\"\n",
    "        import re\n",
    "        test_cases = {}\n",
    "\n",
    "        def preprocess_text(text):\n",
    "            # 定义一个占位符，避免选中正常文本的内容\n",
    "            placeholder = \"###NL###\"\n",
    "            \n",
    "            # 定义替换函数：将匹配到的代码块内的换行符替换为占位符\n",
    "            def repl_code(match):\n",
    "                block = match.group(0)\n",
    "                return block.replace(\"\\n\", placeholder)\n",
    "            \n",
    "            # 对 <code>...</code> 块进行替换（不区分大小写，多行匹配）\n",
    "            text = re.sub(r'(<\\s*code\\s*>.*?</\\s*code\\s*>)', repl_code, text, flags=re.IGNORECASE | re.DOTALL)\n",
    "            # 对 ```python ... ``` 块进行替换\n",
    "            text = re.sub(r'(```python.*?```)', repl_code, text, flags=re.IGNORECASE | re.DOTALL)\n",
    "            \n",
    "            # 如果还需要保护其他块，也可以在这里加上类似处理\n",
    "            return text, placeholder\n",
    "\n",
    "        # 预处理：隐藏代码块内的换行符\n",
    "        modified_text, placeholder = preprocess_text(output_text)\n",
    "        \n",
    "        # 分块：使用<separator>标签 或 连续空行分块\n",
    "        split_pattern = r'(?:<\\s*/\\s*separator\\s*>|<\\s*separator\\s*>|<\\s*separator\\s*/>|\\n\\s*\\n\\s*)'\n",
    "        test_case_blocks = re.split(split_pattern, modified_text, flags=re.IGNORECASE)\n",
    "        test_case_blocks = [b.strip() for b in test_case_blocks if b.strip()]\n",
    "        \n",
    "        # 还原各块内被隐藏的换行符\n",
    "        test_case_blocks = [b.replace(placeholder, \"\\n\") for b in test_case_blocks]\n",
    "\n",
    "        # print(f\"共分出 {len(test_case_blocks)} 个块\")\n",
    "        for idx, block in enumerate(test_case_blocks, 1):\n",
    "            # 1. 提取 test_type\n",
    "            test_type = None\n",
    "            \n",
    "            # Case 1：通过 <type>value</type>\n",
    "            type_match = re.search(\n",
    "                r'<\\s*type\\s*>(.*?)<\\s*/\\s*type\\s*>', \n",
    "                block, \n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            if type_match:\n",
    "                test_type = type_match.group(1).strip()\n",
    "            else:\n",
    "                # Case 2：判断是否有其他非已知标签标记的类型\n",
    "                known_tags = {'type', 'planning', 'code', 'reasoning', 'test_function', 'separator'}\n",
    "                for tag_match in re.finditer(r'<\\s*([^\\s>/]+)\\s*.*?>', block, re.IGNORECASE):\n",
    "                    tag_name = tag_match.group(1).lower()\n",
    "                    if tag_name not in known_tags:\n",
    "                        test_type = tag_name\n",
    "                        break  # 取第一个不在已知标签中的\n",
    "                \n",
    "            if not test_type:  # 若无 test_type 则跳过该块\n",
    "                continue\n",
    "            \n",
    "            # 2. 提取 reasoning（支持 <planning> 和 <reasoning>）\n",
    "            reasoning_match = re.search(\n",
    "                r'<\\s*(?:reasoning|planning)\\s*>(.*?)<\\s*/\\s*(?:reasoning|planning)\\s*>',\n",
    "                block, \n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            reasoning = reasoning_match.group(1).strip() if reasoning_match else \"\"\n",
    "            \n",
    "            # 3. 提取 test_function（优先顺序：test_function 标签 > code 标签 > 独立代码块）\n",
    "            test_func = None\n",
    "            \n",
    "            # 检查 <test_function> 标签\n",
    "            test_func_match = re.search(\n",
    "                r'<\\s*test_function\\s*>(.*?)<\\s*/\\s*test_function\\s*>',\n",
    "                block, \n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            if test_func_match:\n",
    "                content = test_func_match.group(1).strip()\n",
    "                code_block = re.search(r'```python\\s*(.*?)\\s*```', content, re.DOTALL)\n",
    "                test_func = code_block.group(1).strip() if code_block else content\n",
    "            else:\n",
    "                # 检查 <code> 标签\n",
    "                code_match = re.search(\n",
    "                    r'<\\s*code\\s*>(.*?)<\\s*/\\s*code\\s*>',\n",
    "                    block,\n",
    "                    re.IGNORECASE | re.DOTALL\n",
    "                )\n",
    "                if code_match:\n",
    "                    content = code_match.group(1).strip()\n",
    "                    code_block = re.search(r'```python\\s*(.*?)\\s*```', content, re.DOTALL)\n",
    "                    test_func = code_block.group(1).strip() if code_block else content\n",
    "                else:\n",
    "                    # 检查独立代码块 (```python ... ```)\n",
    "                    code_block = re.search(r'```python\\s*(.*?)\\s*```', block, re.DOTALL)\n",
    "                    if code_block:\n",
    "                        test_func = code_block.group(1).strip()\n",
    "            \n",
    "            if test_type and test_func:\n",
    "                test_cases[f'test_case_{idx}'] = {\n",
    "                    'test_type': test_type,\n",
    "                    'purpose': reasoning,\n",
    "                    'test_function': test_func\n",
    "                }\n",
    "        \n",
    "        if not test_cases:\n",
    "            # 如果没有提取到测试用例，则返回 False\n",
    "            return False\n",
    "\n",
    "        return test_cases\n",
    "    \n",
    "    def get_test_cases(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            test_cases = self.extract_test_cases(llm_output)\n",
    "            if test_cases:\n",
    "                break\n",
    "            else:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract test cases, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Failed to extract test cases, retrying ({retry_num})...\")\n",
    "        if test_cases is False:\n",
    "            logging.error(f\"Failed to extract test cases, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract test cases, current llm_output:\\n\", llm_output)\n",
    "        return test_cases\n",
    "    \n",
    "    async def get_test_cases_async(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)\n",
    "\n",
    "        # # debug\n",
    "        # print(prompt)\n",
    "\n",
    "        while retry_num <= max_retry:\n",
    "            # debug\n",
    "            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)\n",
    "            test_cases = self.extract_test_cases(llm_output)\n",
    "            if test_cases:\n",
    "                break\n",
    "            else:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract test cases, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Failed to extract test cases, retrying ({retry_num})..., current llm_output:\\n{llm_output}\")\n",
    "        if test_cases is False:\n",
    "            logging.error(f\"Failed to extract test cases, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract test cases, current llm_output:\\n\", llm_output)\n",
    "        return test_cases\n",
    "    \n",
    "class LLMCG():\n",
    "    def __init__(self, task_description, LLM_model):\n",
    "        self.task_description = task_description\n",
    "        self.LLM_model = LLM_model\n",
    "\n",
    "    def create_code_generation_prompt(\n",
    "        self,\n",
    "        extracted_plan,\n",
    "        user_feedback=None,\n",
    "        task_description=None,\n",
    "        test_cases=None,\n",
    "        history=None,\n",
    "        next_code_line=False,\n",
    "        output_planning=False,\n",
    "        use_example=False,\n",
    "        use_task_description=False,\n",
    "        use_system_prompt=True,\n",
    "        more_comments=False,\n",
    "        ):\n",
    "\n",
    "        components = extracted_plan[\"components\"]\n",
    "        overall_plan = extracted_plan[\"overall_plan\"]\n",
    "\n",
    "        prompt_parts = []\n",
    "\n",
    "        if user_feedback:\n",
    "            system_prompt = \"You are a code refinement specialist designed to improve existing implementations based on specific feedback. Analyze the provided feedback, identify areas for improvement, and modify the code while strictly maintaining the required input/output formats and component specifications.\"\n",
    "        else:\n",
    "            system_prompt = \"You are a highly skilled coding assistant designed to generate clear, efficient, and correct code based on structured task descriptions and detailed plans provided by the user. Your responses must precisely follow the instructions, formats, and constraints given by the user, and you must strictly adhere to input-output formats, workflows, and specific guidelines outlined.\"\n",
    "\n",
    "        # Add System Prompt if enabled\n",
    "        if use_system_prompt:\n",
    "            prompt_parts.append(f\"=== Role ===\\n{system_prompt}\\n\")\n",
    "\n",
    "        # Add Task Description if enabled\n",
    "        if not task_description:\n",
    "            task_description = self.task_description\n",
    "\n",
    "        if use_task_description:\n",
    "            prompt_parts.append(f\"=== Task Description ===\\n{task_description}\\n\")\n",
    "\n",
    "        # Add Components Section\n",
    "        if components:\n",
    "            prompt_parts.append(\"=== Components ===\")\n",
    "            for comp_name, comp_details in components.items():\n",
    "                # Process Input Format\n",
    "                input_fmt = comp_details[\"input_format\"]\n",
    "                input_lines = []\n",
    "                for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "                    shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                    input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "                input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "                # Process Output Format\n",
    "                output_fmt = comp_details[\"output_format\"]\n",
    "                output_lines = []\n",
    "                for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "                    shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                    output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "                output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "                # Build Component Details\n",
    "                component_part = [\n",
    "                    f\"\\n**Component: {comp_name}**\",\n",
    "                    f\"Step Task Description: {comp_details['step_task_description']}\",\n",
    "                    input_section,\n",
    "                    output_section,\n",
    "                    \"Workflow Steps:\",\n",
    "                    *[f\"- {step}\" for step in comp_details[\"work_flow\"]],\n",
    "                    \"Test Case Generation Advice:\",\n",
    "                    *[f\"- {advice}\" for advice in comp_details[\"test_case_generation_advise\"]],\n",
    "                    \"\\n\",\n",
    "                ]\n",
    "                prompt_parts.extend(component_part)\n",
    "\n",
    "        # Add Overall Plan Section\n",
    "        if overall_plan:\n",
    "            prompt_parts.append(\"\\n=== Overall Plan ===\")\n",
    "            # Process Input Format\n",
    "            input_fmt = overall_plan[\"input_format\"]\n",
    "            input_lines = []\n",
    "            for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "                shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "            input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "            # Process Output Format\n",
    "            output_fmt = overall_plan[\"output_format\"]\n",
    "            output_lines = []\n",
    "            for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "                shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "                output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "            output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "            # Build Overall Plan Details\n",
    "            plan_part = [\n",
    "                input_section,\n",
    "                output_section,\n",
    "                f\"Components Order: {', '.join(overall_plan['components'])}\",\n",
    "                \"Plan Steps:\",\n",
    "                *[f\"- {step}\" for step in overall_plan[\"plan\"]],\n",
    "                \"Overall Test Case Advice:\",\n",
    "                *[f\"- {advice}\" for advice in overall_plan[\"test_case_generation_advise\"]],\n",
    "                \"\\n\",\n",
    "            ]\n",
    "            prompt_parts.extend(plan_part)\n",
    "\n",
    "        if user_feedback:\n",
    "            prompt_parts.append(\"\\n=== User Feedback ===\")\n",
    "            prompt_parts.append(user_feedback)\n",
    "\n",
    "        # Add Test Cases if enabled and available\n",
    "        if use_example and test_cases:\n",
    "            prompt_parts.append(\"\\n=== Test Cases ===\")\n",
    "            example_num = 3\n",
    "            for case_name, case_details in test_cases.items():\n",
    "                case_part = [\n",
    "                    f\"\\n**Test Case: {case_name}**\",\n",
    "                    f\"Purpose: {case_details['purpose']}\",\n",
    "                    f\"Type: {case_details['test_type']}\",\n",
    "                    f\"Test Function:\\n{case_details['test_function']}\",\n",
    "                    \"\\n\",\n",
    "                ]\n",
    "                prompt_parts.extend(case_part)\n",
    "                example_num -= 1\n",
    "                if example_num == 0:\n",
    "                    break\n",
    "\n",
    "        # Add History if available\n",
    "        if history:\n",
    "            prompt_parts.append(\"\\n=== Previous Generation Attempts ===\")\n",
    "            for gen_name, gen_details in history.items():\n",
    "                history_part = [\n",
    "                    f\"\\n**Generation: {gen_name}**\",\n",
    "                    f\"Score: {gen_details['score']}\",\n",
    "                    \"Generated Code:\",\n",
    "                    gen_details[\"generated_code\"],\n",
    "                    \"Generation Plan:\",\n",
    "                    *[f\"- {step}\" for step in gen_details[\"generation_plan\"]],\n",
    "                    \"\\n\",\n",
    "                ]\n",
    "                prompt_parts.extend(history_part)\n",
    "\n",
    "        # Build Refinement Instructions\n",
    "        if user_feedback:\n",
    "            refine_instructions = [\"\\n=== Refinement Requirements ===\"]\n",
    "            refine_instructions.append(\"Generate a revised implementation that:\")\n",
    "            refine_instructions.append(\"- Addresses all identified issues from the feedback analysis\")\n",
    "            refine_instructions.append(\"- Maintains strict compliance with component specifications\")\n",
    "            refine_instructions.append(\"- Preserves existing functionality that passed validation\")\n",
    "            prompt_parts.append(\"\\n\".join(refine_instructions))\n",
    "\n",
    "        # Build Instructions\n",
    "        instructions = [\"\\n=== Instructions ===\"]\n",
    "        if next_code_line:\n",
    "            instructions.append(\"Generate ONLY the next line or a small code snippet required to proceed.\")\n",
    "        else:\n",
    "            instructions.append(\"Generate the COMPLETE code based on the components and plan above.\")\n",
    "        instructions.append(\"DO MAKE SURE the complete code is a runnable function, all components are correctly integrated with in this function.\")\n",
    "        instructions.append(\"The complete function should take the input arguments as specified in the overall plan and return the output as specified.\")\n",
    "\n",
    "        if more_comments:\n",
    "            instructions.append(\"Please add as much comments as possible to your code to explain the logic and any critical steps.\")\n",
    "\n",
    "        if output_planning:\n",
    "            instructions.append(\"Structure your response as follows:\")\n",
    "            instructions.append(\"<Code>\")\n",
    "            instructions.append(\"Your code here. DO make sure the output is a single function that integrates all components.\")\n",
    "            instructions.append(\"</Code>\")\n",
    "            instructions.append(\"<Planning>\")\n",
    "            if next_code_line:\n",
    "                instructions.append(\"A concise summary of what this specific code part accomplishes.\")\n",
    "            else:\n",
    "                instructions.append(\"A detailed step-by-step explanation of the code's workflow.\")\n",
    "            instructions.append(\"</Planning>\")\n",
    "            instructions.append(\"<Main Function Name>\")\n",
    "            instructions.append(\"The name of the main function that integrates all components.\")\n",
    "            instructions.append(\"</Main Function Name>\")\n",
    "            instructions.append(\"Provide the code with the same indicator and structure as shown in Instructions. DO NOT return any test cases or example usages in your code!\")\n",
    "        else:\n",
    "            instructions.append(\"Structure your response as follows:\")\n",
    "            instructions.append(\"<Code>\")\n",
    "            instructions.append(\"Your code here\")\n",
    "            instructions.append(\"</Code>\")\n",
    "            instructions.append(\"Provide the code WITHOUT any additional explanations, and DO use the same indicator and structure as shown in Instructions.\")\n",
    "\n",
    "        prompt_parts.append(\"\\n\".join(instructions))\n",
    "\n",
    "        return \"\\n\".join(prompt_parts)\n",
    "    \n",
    "    def extract_code(self, llm_output):\n",
    "        \"\"\"Extracts code and planning sections from LLM output.\"\"\"\n",
    "        result = {\"code\": None, \"plan\": None, \"main_function_name\": None}\n",
    "        \n",
    "        # Extract code section\n",
    "        code_match = re.search(r'<Code>(.*?)(?:</Code>|<End>)', llm_output, re.DOTALL)\n",
    "        if code_match:\n",
    "            result[\"code\"] = code_match.group(1).strip()\n",
    "        else:\n",
    "            # If not found, try to extract from ```python ... ```\n",
    "            code_block_match = re.search(r'```(?:python)?\\s*(.*?)```', llm_output, re.DOTALL)\n",
    "            if code_block_match:\n",
    "                result[\"code\"] = code_block_match.group(1).strip()\n",
    "        \n",
    "        # Extract planning section\n",
    "        plan_match = re.search(r'<Planning>(.*?)(?:</Planning>|<End>)', llm_output, re.DOTALL)\n",
    "        if plan_match:\n",
    "            result[\"plan\"] = plan_match.group(1).strip()\n",
    "\n",
    "        # Extract main function name\n",
    "        main_func_match = re.search(r'<Main Function Name>(.*?)(?:</Main Function Name>|<End>)', llm_output, re.DOTALL)\n",
    "        if main_func_match:\n",
    "            result[\"main_function_name\"] = main_func_match.group(1).strip()\n",
    "        \n",
    "        return result\n",
    "    \n",
    "    def get_code(self, extracted_plan, task_description=None, test_cases=None, history=None, next_code_line=False, output_planning=True, use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_code_generation_prompt(extracted_plan, extracted_plan.get('user_feedback'), task_description, test_cases, history, next_code_line, output_planning, use_example, use_task_description, use_system_prompt, more_comments)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)\n",
    "            code_output = self.extract_code(llm_output)\n",
    "            if code_output[\"code\"] is None:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                # print(f\"Current llm_output:\\n{llm_output}\")\n",
    "                logging.warning(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Current llm_output:\\n{llm_output}\")\n",
    "            else:\n",
    "                break\n",
    "        if code_output[\"code\"] is None:\n",
    "            logging.error(f\"Failed to extract code, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract code, current llm_output:\\n\", llm_output)\n",
    "        return code_output\n",
    "    \n",
    "    async def get_code_async(self, extracted_plan, task_description=None, test_cases=None, history=None, next_code_line=False, output_planning=True, use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True, gen_kwargs={}, max_retry=3):\n",
    "        retry_num=0\n",
    "        if task_description is None:\n",
    "            task_description = self.task_description\n",
    "        prompt = self.create_code_generation_prompt(extracted_plan, extracted_plan.get('user_feedback'), task_description, test_cases, history, next_code_line, output_planning, use_example, use_task_description, use_system_prompt, more_comments)\n",
    "        while retry_num <= max_retry:\n",
    "            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)\n",
    "            code_output = self.extract_code(llm_output)\n",
    "            if code_output[\"code\"] is None:\n",
    "                retry_num += 1\n",
    "                # print(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                # print(f\"Current llm_output:\\n{llm_output}\")\n",
    "                logging.warning(f\"Failed to extract code, retrying ({retry_num})...\")\n",
    "                logging.warning(f\"Current llm_output:\\n{llm_output}\")\n",
    "            else:\n",
    "                break\n",
    "        if code_output[\"code\"] is None:\n",
    "            logging.error(f\"Failed to extract code, current llm_output:\\n{llm_output}\")\n",
    "            raise ValueError(\"Failed to extract code, current llm_output:\\n\", llm_output)\n",
    "        return code_output\n",
    "\n",
    "class CodeRunner():\n",
    "    def __init__(self, max_workers=5):\n",
    "        self.max_workers = max_workers\n",
    "\n",
    "    def run_test(self, func_obj, test_func):\n",
    "        try:\n",
    "            return test_func(func_obj)\n",
    "        except Exception as e:\n",
    "            return False\n",
    "    \n",
    "    def compile_code(self, code_str, main_function_name=None):\n",
    "        try:\n",
    "            local_vars = {}\n",
    "            exec(code_str, local_vars)  # Use one dict for globals and locals\n",
    "            if main_function_name is not None:\n",
    "                func = local_vars.get(main_function_name)\n",
    "                return func if callable(func) else None\n",
    "            return next((obj for obj in local_vars.values() if callable(obj)), None)\n",
    "        except Exception as e:\n",
    "            print(f\"Compilation Error: {str(e)}, code_str:\\n {code_str}\")\n",
    "            return None\n",
    "    \n",
    "    def run_all_tests(self, functions, test_cases, max_workers=5):\n",
    "        \"\"\"\n",
    "        Updated to handle new function structure with main_function_name\n",
    "        \"\"\"\n",
    "        # 编译函数（处理带主函数名称的情况）\n",
    "        compiled_functions = {\n",
    "            fid: self.compile_code(\n",
    "                code_info['code'],\n",
    "                main_function_name=code_info.get('main_function_name')\n",
    "            )\n",
    "            for fid, code_info in functions.items()\n",
    "        }\n",
    "        \n",
    "        # 编译测试用例（保持原有逻辑）\n",
    "        compiled_tests = {\n",
    "            tid: self.compile_code(code_info['test_function'])\n",
    "            for tid, code_info in test_cases.items()\n",
    "        }\n",
    "\n",
    "        # 准备结果字典\n",
    "        fun_results = {fid: {} for fid in functions}\n",
    "        test_results = {tid: {} for tid in test_cases}\n",
    "\n",
    "        total_tests = len(compiled_functions) * len(compiled_tests)\n",
    "        \n",
    "        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "            futures = {}\n",
    "            pbar = tqdm(total=total_tests, desc=\"Running tests\")\n",
    "            \n",
    "            # 提交测试任务\n",
    "            for func_id, func_obj in compiled_functions.items():\n",
    "                for test_id, test_func in compiled_tests.items():\n",
    "                    # 处理编译失败的情况\n",
    "                    if func_obj is None or test_func is None:\n",
    "                        fun_results[func_id][test_id] = False\n",
    "                        test_results[test_id][func_id] = False\n",
    "                        pbar.update(1)\n",
    "                        continue\n",
    "                    \n",
    "                    # 提交并发任务\n",
    "                    future = executor.submit(self.run_test, func_obj, test_func)\n",
    "                    futures[future] = (func_id, test_id)\n",
    "\n",
    "            # 处理测试结果\n",
    "            for future in concurrent.futures.as_completed(futures):\n",
    "                func_id, test_id = futures[future]\n",
    "                try:\n",
    "                    result = future.result()\n",
    "                except Exception:\n",
    "                    result = False\n",
    "                fun_results[func_id][test_id] = result\n",
    "                test_results[test_id][func_id] = result\n",
    "                pbar.update(1)\n",
    "            \n",
    "            pbar.close()\n",
    "        \n",
    "        return fun_results, test_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-10 11:03:57,520 - numexpr.utils - INFO - Note: NumExpr detected 16 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
      "2025-04-10 11:03:57,520 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.\n",
      "2025-04-10 11:04:10,165 - root - INFO - Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch_geometric.data import Data, Batch\n",
    "from torch_geometric.nn import GATConv, GraphConv, global_max_pool, global_mean_pool\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from torch.utils.data import Dataset, Subset\n",
    "from torch_geometric.loader import DataLoader\n",
    "import logging\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "logging.info(f'Using device: {device}')\n",
    "\n",
    "class CodeGraphDataset(Dataset):\n",
    "    def __init__(self, dataframe, scaler=None, node_type_vocab=None):\n",
    "        self.invalid_count = 0\n",
    "        self.dataframe = dataframe.reset_index(drop=True)\n",
    "        self.scaler = scaler if scaler else MinMaxScaler()\n",
    "        if scaler is None:  # 仅训练集拟合\n",
    "            self.scaler.fit(self.dataframe['score'].values.reshape(-1, 1))\n",
    "        logging.info('Score values scaled using MinMaxScaler.')\n",
    "        # Build a vocabulary for AST node types\n",
    "        if node_type_vocab is None:\n",
    "            self.node_type_vocab = self.build_node_type_vocab()\n",
    "        else:\n",
    "            self.node_type_vocab = node_type_vocab\n",
    "        logging.info(f'Built node type vocabulary with size: {len(self.node_type_vocab)}')\n",
    "\n",
    "    def build_node_type_vocab(self):\n",
    "        node_types = set()\n",
    "        for idx, code in enumerate(self.dataframe['code']):\n",
    "            try:\n",
    "                tree = ast.parse(code)\n",
    "                for node in ast.walk(tree):\n",
    "                    node_types.add(type(node).__name__)\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"Error parsing code at index {idx}: {e}\")\n",
    "        node_type_to_id = {\"UNK\": 0}\n",
    "        for idx, nt in enumerate(sorted(node_types), start=1):\n",
    "            node_type_to_id[nt] = idx\n",
    "        return node_type_to_id\n",
    "\n",
    "    def ast_to_graph(self, code):\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "        except Exception as e:\n",
    "            logging.warning(f\"Error parsing code: {e}\")\n",
    "            return None\n",
    "\n",
    "        nodes = []\n",
    "        edges = []\n",
    "        node_features = []\n",
    "        node_id = 0\n",
    "        node_id_map = {}\n",
    "\n",
    "        def traverse(node, parent_id=None):\n",
    "            nonlocal node_id\n",
    "            current_id = node_id\n",
    "            node_id_map[id(node)] = current_id\n",
    "            nodes.append(current_id)\n",
    "            # Encode node type as integer\n",
    "            node_type = type(node).__name__\n",
    "            node_type_id = self.node_type_vocab.get(node_type, 0)  # Handle unknown types\n",
    "            node_features.append([node_type_id])\n",
    "            node_id += 1\n",
    "\n",
    "            if parent_id is not None:\n",
    "                edges.append((parent_id, current_id))\n",
    "\n",
    "            for child in ast.iter_child_nodes(node):\n",
    "                traverse(child, current_id)\n",
    "\n",
    "        traverse(tree)\n",
    "\n",
    "        if not nodes:\n",
    "            return None\n",
    "\n",
    "        # Convert edges to a tensor\n",
    "        if edges:\n",
    "            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "        else:\n",
    "            edge_index = torch.empty((2, 0), dtype=torch.long)\n",
    "\n",
    "        # Convert node features to a tensor\n",
    "        x = torch.tensor(node_features, dtype=torch.long)\n",
    "\n",
    "        # Create a Data object\n",
    "        data = Data(x=x, edge_index=edge_index)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataframe)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.dataframe.iloc[idx]\n",
    "        code = row['code']\n",
    "        score = row['score']\n",
    "\n",
    "        graph = self.ast_to_graph(code)\n",
    "        if graph is None:\n",
    "            # Skip samples with parsing errors by raising an exception\n",
    "            # Alternatively, implement a different handling strategy\n",
    "            logging.debug(f\"Skipping index {idx} due to parsing error.\")\n",
    "            raise ValueError(f\"Parsing failed for code at index {idx}.\")\n",
    "\n",
    "        if graph is None:\n",
    "            self.invalid_count += 1\n",
    "            logging.debug(f\"Skipping index {idx} due to parsing error.\")\n",
    "            return None\n",
    "\n",
    "        # Normalize score using the scaler\n",
    "        score_normalized = self.scaler.transform([[score]]).flatten()[0]\n",
    "\n",
    "        graph.y = torch.tensor([score_normalized], dtype=torch.float)\n",
    "        return graph\n",
    "    \n",
    "class GNNModel(nn.Module):\n",
    "    def __init__(self, num_node_types, embed_dim=64, hidden_dim=128, scaler=None):\n",
    "        super(GNNModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(num_node_types, embed_dim)\n",
    "        self.conv1 = GATConv(embed_dim, hidden_dim)\n",
    "        self.conv2 = GATConv(hidden_dim, hidden_dim)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  # 假设拼接了池化特征\n",
    "        self.fc2 = nn.Linear(hidden_dim, 1)\n",
    "        self.scaler = scaler\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, batch = data.x, data.edge_index, data.batch\n",
    "        x = self.embedding(x.squeeze())\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = self.dropout(F.relu(x))\n",
    "        x = self.conv2(x, edge_index)\n",
    "        x = self.dropout(F.relu(x))\n",
    "        x = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x.squeeze()\n",
    "\n",
    "class PassRatePredictor():\n",
    "    def __init__(self, ini_data=None, model=None):\n",
    "        self.model = model\n",
    "        if ini_data is None:\n",
    "            # 初始化数据集为空\n",
    "            self.data = pd.DataFrame(columns=[\"code\", \"score\"])\n",
    "        self.scaler = MinMaxScaler()\n",
    "        self.trained = False\n",
    "        self.node_type_vocab = None\n",
    "\n",
    "    def add_data(self, new_data, use_pass_rate=False):\n",
    "        if isinstance(new_data, dict):\n",
    "            new_data = pd.DataFrame.from_dict(new_data, orient='index').reset_index(drop=True)\n",
    "            # 仅保留 'code' 和 'score' 列\n",
    "            if use_pass_rate:\n",
    "                new_data = new_data[['code', 'pass_rate']].rename(columns={'pass_rate': 'score'})\n",
    "            else:\n",
    "                new_data = new_data[['code', 'score']]\n",
    "\n",
    "        if self.data is None:\n",
    "            self.data = new_data\n",
    "        else:\n",
    "            # 过滤重复数据\n",
    "            new_data = new_data[~new_data['code'].isin(self.data['code'])]\n",
    "            self.data = pd.concat([self.data, new_data], ignore_index=True)\n",
    "\n",
    "    def predict_score(self, new_code_samples, model=None, scaler=None):\n",
    "        if model is None:\n",
    "            model = self.model\n",
    "        if scaler is None:\n",
    "            scaler = self.scaler\n",
    "\n",
    "        # 将新数据包装为DataFrame\n",
    "        new_df = pd.DataFrame({\n",
    "            \"code\": new_code_samples,\n",
    "            \"score\": [\"0s\"] * len(new_code_samples)  # 占位值\n",
    "        })\n",
    "        \n",
    "        df_clean, _ = self.clean_score_data(new_df)\n",
    "\n",
    "        # 创建数据集\n",
    "        dataset = CodeGraphDataset(df_clean, scaler=scaler, node_type_vocab=self.node_type_vocab)\n",
    "        loader = DataLoader(\n",
    "            [data for data in dataset if data is not None],\n",
    "            batch_size=32\n",
    "        )\n",
    "        \n",
    "        # 预测\n",
    "        model.eval()\n",
    "        preds = []\n",
    "        with torch.no_grad():\n",
    "            for batch in loader:\n",
    "                pred = model(batch)\n",
    "                preds.extend(pred.cpu().numpy())\n",
    "        \n",
    "        # 反归一化\n",
    "        pred_score = scaler.inverse_transform(np.array(preds).reshape(-1, 1)).flatten()\n",
    "        return pred_score\n",
    "    \n",
    "    def test_model(self, model, dataframe, train_scaler=None):\n",
    "        # 使用训练集的scaler（假设已经通过train_model传递）\n",
    "        df_clean, _ = self.clean_score_data(dataframe)\n",
    "        if train_scaler is None:\n",
    "            train_scaler = MinMaxScaler().fit(dataframe['score'].values.reshape(-1, 1))\n",
    "        test_dataset = CodeGraphDataset(df_clean, scaler=train_scaler)\n",
    "        test_loader = DataLoader(\n",
    "            [data for data in test_dataset if data is not None],\n",
    "            batch_size=32\n",
    "        )\n",
    "        \n",
    "        criterion = torch.nn.MSELoss()\n",
    "        model.eval()\n",
    "        test_loss = []\n",
    "        all_preds = []\n",
    "        all_labels = []\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for batch in test_loader:\n",
    "                pred = model(batch)\n",
    "                loss = criterion(pred, batch.y)\n",
    "                test_loss.append(loss.item())\n",
    "                all_preds.extend(pred.cpu().numpy())\n",
    "                all_labels.extend(batch.y.cpu().numpy())\n",
    "        \n",
    "        # 反归一化预测值和真实值\n",
    "        preds = test_dataset.scaler.inverse_transform(np.array(all_preds).reshape(-1, 1)).flatten()\n",
    "        labels = test_dataset.scaler.inverse_transform(np.array(all_labels).reshape(-1, 1)).flatten()\n",
    "        \n",
    "        # 计算指标\n",
    "        mae = np.mean(np.abs(preds - labels))\n",
    "        rmse = np.sqrt(np.mean((preds - labels)**2))\n",
    "        print(f\"Test MAE: {mae:.4f}, Test RMSE: {rmse:.4f}\")\n",
    "        return {\"mae\": mae, \"rmse\": rmse}\n",
    "    \n",
    "    def train_model(self, dataframe=None, epochs=50, batch_size=32, lr=0.001):\n",
    "        if dataframe is None:\n",
    "            dataframe = self.data\n",
    "\n",
    "        # 清洗数据\n",
    "        df_preprocessed = self.preprocess_data(dataframe)\n",
    "        if df_preprocessed.empty:\n",
    "            raise ValueError(\"无有效数据可供训练\")\n",
    "\n",
    "        # 划分训练集和验证集\n",
    "        train_df, val_df = train_test_split(df_preprocessed, test_size=0.2, random_state=42)\n",
    "        \n",
    "        # 初始化数据集和DataLoader（训练集拟合scaler）\n",
    "        train_scaler = MinMaxScaler().fit(train_df['score'].values.reshape(-1, 1))\n",
    "        self.scaler = train_scaler\n",
    "        train_dataset = CodeGraphDataset(train_df, scaler=train_scaler)\n",
    "        val_dataset = CodeGraphDataset(val_df, scaler=train_scaler)  # 使用训练集的scaler\n",
    "        \n",
    "        # 过滤无效样本并创建DataLoader\n",
    "        train_loader = DataLoader(\n",
    "            [data for data in train_dataset if data is not None],\n",
    "            batch_size=batch_size,\n",
    "            shuffle=True\n",
    "        )\n",
    "        val_loader = DataLoader(\n",
    "            [data for data in val_dataset if data is not None],\n",
    "            batch_size=batch_size\n",
    "        )\n",
    "        \n",
    "        # 初始化模型和优化器\n",
    "        model = GNNModel(\n",
    "            num_node_types=len(train_dataset.node_type_vocab),\n",
    "            embed_dim=64,\n",
    "            hidden_dim=128,\n",
    "            scaler=train_scaler\n",
    "        )\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "        criterion = torch.nn.MSELoss()  # 均方误差损失\n",
    "        \n",
    "        # 训练循环\n",
    "        best_val_loss = float('inf')\n",
    "        for epoch in range(epochs):\n",
    "            model.train()\n",
    "            train_loss = []\n",
    "            for batch in train_loader:\n",
    "                optimizer.zero_grad()\n",
    "                pred = model(batch)\n",
    "                loss = criterion(pred, batch.y)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                train_loss.append(loss.item())\n",
    "            \n",
    "            # 验证集评估\n",
    "            model.eval()\n",
    "            val_loss = []\n",
    "            with torch.no_grad():\n",
    "                for batch in val_loader:\n",
    "                    pred = model(batch)\n",
    "                    loss = criterion(pred, batch.y)\n",
    "                    val_loss.append(loss.item())\n",
    "            \n",
    "            # 打印日志\n",
    "            avg_train_loss = np.mean(train_loss)\n",
    "            avg_val_loss = np.mean(val_loss)\n",
    "            print(f\"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}\")\n",
    "            \n",
    "            # 保存最佳模型\n",
    "            if avg_val_loss < best_val_loss:\n",
    "                best_val_loss = avg_val_loss\n",
    "                torch.save(model.state_dict(), \"best_gnn_model.pth\")\n",
    "        \n",
    "        self.model = model\n",
    "        self.node_type_vocab = train_dataset.node_type_vocab\n",
    "        return model\n",
    "    \n",
    "    def filter_invalid_ast(self, df):\n",
    "        valid_indices = []\n",
    "        invalid_indices = []\n",
    "        \n",
    "        for idx, code in enumerate(df['code']):\n",
    "            try:\n",
    "                ast.parse(code)\n",
    "                valid_indices.append(idx)\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"索引 {idx} 的代码无法解析AST: {e}\")\n",
    "                invalid_indices.append(idx)\n",
    "        \n",
    "        # 保留有效样本\n",
    "        df_valid = df.iloc[valid_indices].reset_index(drop=True)\n",
    "        return df_valid, invalid_indices\n",
    "\n",
    "    def clean_score_data(self, df):\n",
    "        cleaned_scores = []\n",
    "        invalid_indices = []\n",
    "        \n",
    "        for idx, row in df.iterrows():\n",
    "            value = row['score']\n",
    "            try:\n",
    "                if isinstance(value, str):\n",
    "                    # 移除空格，转换为小写\n",
    "                    cleaned_str = value.strip().lower()\n",
    "                    # 提取数值和单位（正则匹配数值部分）\n",
    "                    num_match = re.match(r\"^(\\d+\\.?\\d*)\\s*([a-z]*)?\", cleaned_str)\n",
    "                    if not num_match:\n",
    "                        raise ValueError(\"无法提取数值\")\n",
    "                    num = float(num_match.group(1))\n",
    "                    unit = num_match.group(2) or 's'  # 默认单位是秒\n",
    "                    # 根据单位转换为秒\n",
    "                    if unit in {'s', 'sec', 'second', ''}:\n",
    "                        converted = num\n",
    "                    elif unit in {'ms', 'msec', 'millisecond'}:\n",
    "                        converted = num / 1000\n",
    "                    elif unit in {'m', 'min', 'minute'}:\n",
    "                        converted = num * 60\n",
    "                    elif unit in {'h', 'hour'}:\n",
    "                        converted = num * 3600\n",
    "                    else:\n",
    "                        logging.warning(f\"索引 {idx} 的未知单位 '{unit}'，假设为秒\")\n",
    "                        converted = num\n",
    "                    cleaned_scores.append(converted)\n",
    "                else:\n",
    "                    # 处理数值类型（int/float）\n",
    "                    cleaned_scores.append(float(value))\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"索引 {idx} 的score值 '{value}' 处理失败: {e}\")\n",
    "                invalid_indices.append(idx)\n",
    "                cleaned_scores.append(None)\n",
    "        \n",
    "        # 替换原列并删除无效行\n",
    "        df_clean = df.copy()\n",
    "        df_clean['score'] = cleaned_scores\n",
    "        df_clean = df_clean.dropna(subset=['score']).reset_index(drop=True)\n",
    "        return df_clean, invalid_indices\n",
    "\n",
    "    def preprocess_data(self, df):\n",
    "        # Step 1: 过滤无法解析AST的样本\n",
    "        df_ast_valid, ast_invalid = self.filter_invalid_ast(df)\n",
    "        logging.info(f\"过滤 {len(ast_invalid)} 个无效AST样本\")\n",
    "        \n",
    "        # Step 2: 清洗score字段\n",
    "        df_clean, score_invalid = self.clean_score_data(df_ast_valid)\n",
    "        logging.info(f\"过滤 {len(score_invalid)} 个无效score样本\")\n",
    "        \n",
    "        return df_clean\n",
    "\n",
    "###############################################################\n",
    "# # debug\n",
    "\n",
    "# # 1. 加载数据\n",
    "# df = pd.read_csv(r\"E:\\python_project_new\\AI4SLCDP\\leetcode_data\\leetcode Median of Two Sorted Arrays.csv\")\n",
    "\n",
    "# # 将\"runtime\" 列改为\"score\"\n",
    "# df.rename(columns={\"runtime\": \"score\"}, inplace=True)\n",
    "# print(df.head())\n",
    "\n",
    "# pass_rate_predictor = PassRatePredictor()\n",
    "# pass_rate_predictor.add_data(df)\n",
    "\n",
    "# model = pass_rate_predictor.train_model(epochs=100)\n",
    "\n",
    "# # 3. 测试模型\n",
    "# # test_df = pd.read_csv(\"test_data.csv\")\n",
    "# # test_metrics = test_model(model, test_df)\n",
    "\n",
    "# # 4. 预测新样本\n",
    "# new_samples = [\n",
    "#     \"def square(x):\\n    return x ** 2\",\n",
    "#     \"def div(a, b):\\n    return a / b\"\n",
    "# ]\n",
    "\n",
    "# predictions = pass_rate_predictor.predict_score(new_samples)\n",
    "# print(f\"Predicted scores: {predictions}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import re\n",
    "import tempfile\n",
    "import os\n",
    "import json\n",
    "\n",
    "def pylint_code_score(code):\n",
    "    try:\n",
    "        # 创建临时文件保存代码\n",
    "        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "            tmp.write(code)\n",
    "            tmp_path = tmp.name\n",
    "        \n",
    "        # 执行 Pylint 分析\n",
    "        result = subprocess.run(\n",
    "            [\"pylint\", \"--output-format=text\", tmp_path],\n",
    "            capture_output=True,\n",
    "            text=True,\n",
    "            check=False\n",
    "        )\n",
    "        output = result.stdout\n",
    "        # print(output)\n",
    "        # 删除临时文件\n",
    "        os.unlink(tmp_path)\n",
    "        \n",
    "        # 提取评分（如 \"rated at 7.50/10\"）\n",
    "        match = re.search(r\"rated at (\\d+\\.?\\d*)/10\", output)\n",
    "        return float(match.group(1)) if match else -1\n",
    "    \n",
    "    except Exception as e:\n",
    "        print(f\"Pylint 分析失败: {e}\")\n",
    "        return -1\n",
    "\n",
    "def radon_mi_code_score(code: str) -> float:\n",
    "    try:\n",
    "        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "            tmp.write(code)\n",
    "            tmp_path = tmp.name\n",
    "        \n",
    "        result = subprocess.run(\n",
    "            [\"radon\", \"mi\", \"--json\", tmp_path],\n",
    "            capture_output=True,\n",
    "            text=True,\n",
    "            check=False\n",
    "        )\n",
    "        data = json.loads(result.stdout)\n",
    "        os.unlink(tmp_path)\n",
    "        \n",
    "        if data and isinstance(data, dict):\n",
    "            file_key = list(data.keys())[0]  # 获取临时文件的键名\n",
    "            return data[file_key][\"mi\"] / 10\n",
    "        return -1\n",
    "    except Exception as e:\n",
    "        print(f\"Radon 分析失败: {e}\")\n",
    "        return -1\n",
    "    \n",
    "# 示例：直接分析代码字符串\n",
    "code = \"\"\"\n",
    "import subprocess\n",
    "import re\n",
    "import tempfile\n",
    "import os\n",
    "import json\n",
    "def radon_mi_code_score(code: str) -> float:\n",
    "    try:\n",
    "        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "            tmp.write(code)\n",
    "            tmp_path = tmp.name\n",
    "        \n",
    "        result = subprocess.run(\n",
    "            [\"radon\", \"mi\", \"--json\", tmp_path],\n",
    "            capture_output=True,\n",
    "            text=True,\n",
    "            check=False\n",
    "        )\n",
    "        data = json.loads(result.stdout)\n",
    "        os.unlink(tmp_path)\n",
    "        \n",
    "        if data and isinstance(data, dict):\n",
    "            file_key = list(data.keys())[0]  # 获取临时文件的键名\n",
    "            return data[file_key][\"mi\"]\n",
    "        return -1\n",
    "    except Exception as e:\n",
    "        print(f\"Radon 分析失败: {e}\")\n",
    "        return -1\n",
    "\"\"\"\n",
    "\n",
    "## debug\n",
    "\n",
    "# pylint_score = pylint_code_score(code)\n",
    "# radon_score = radon_mi_code_score(code)\n",
    "\n",
    "# print(f\"Pylint 质量评分: {pylint_score}/10\")\n",
    "# print(f\"Radon 维护指数: {radon_score}/10\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import json\n",
    "import copy\n",
    "import tempfile\n",
    "import subprocess\n",
    "import concurrent.futures\n",
    "from tqdm import tqdm\n",
    "\n",
    "class Evaluator():\n",
    "    def __init__(self, pass_rate_predictor=None):\n",
    "        self.pass_rate_predictor = pass_rate_predictor\n",
    "\n",
    "    def calculate_pass_rate_score(self, test_results, test_weights):\n",
    "        total_weight = sum(test_weights.values())\n",
    "        if total_weight == 0:\n",
    "            return 0.0\n",
    "        \n",
    "        passed_weight = sum(weight for test_id, weight in test_weights.items() \n",
    "                           if test_results.get(test_id, False))\n",
    "        return passed_weight / total_weight\n",
    "\n",
    "    def calculate_batch_scores(self, code_data):\n",
    "        items = list(code_data.items())\n",
    "        code_ids = [k for k, _ in items]\n",
    "        code_entries = [v for _, v in items]\n",
    "        full_score_dict = {}\n",
    "\n",
    "        # 计算pass_rate_score（快速计算，无需并行）\n",
    "        pass_rate_scores = {\n",
    "            code_id: self.calculate_pass_rate_score(entry[\"test_results\"], entry[\"test_weights\"])\n",
    "            for code_id, entry in code_data.items()\n",
    "        }\n",
    "\n",
    "        # 批量预测score\n",
    "        code_strs = [entry[\"code\"] for entry in code_entries]\n",
    "        prediction_scores = [0.0] * len(code_strs)\n",
    "        if self.pass_rate_predictor is not None and self.pass_rate_predictor.model is not None:\n",
    "            try:\n",
    "                prediction_scores = self.pass_rate_predictor.predict_score(code_strs)\n",
    "                print(\"###############################################################\")\n",
    "                print(f\"Prediction scores: {prediction_scores}\")\n",
    "                print(\"###############################################################\")\n",
    "            except Exception as e:\n",
    "                print(code_strs)\n",
    "                raise e\n",
    "\n",
    "        # 并行计算静态分析分数\n",
    "        with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "            static_scores = list(tqdm(\n",
    "                executor.map(self._compute_static_scores, code_strs),\n",
    "                total=len(code_strs),\n",
    "                desc=\"Analyzing codes\"\n",
    "            ))\n",
    "\n",
    "        # 组合最终分数\n",
    "        final_scores = {}\n",
    "        for i, code_id in enumerate(code_ids):\n",
    "            final_scores[code_id] = (\n",
    "                0.7 * pass_rate_scores[code_id] +\n",
    "                0.1 * prediction_scores[i] +\n",
    "                0.1 * static_scores[i][0] +\n",
    "                0.1 * static_scores[i][1]\n",
    "            )\n",
    "            full_score_dict[code_id] = {\n",
    "                \"pass_rate_score\": pass_rate_scores[code_id],\n",
    "                \"prediction_score\": prediction_scores[i],\n",
    "                \"pylint_score\": static_scores[i][0],\n",
    "                \"radon_score\": static_scores[i][1]\n",
    "            }\n",
    "\n",
    "        return final_scores, full_score_dict\n",
    "\n",
    "    def _compute_static_scores(self, code_str):\n",
    "        return (\n",
    "            self.pylint_code_score(code_str),\n",
    "            self.radon_mi_code_score(code_str)\n",
    "        )\n",
    "\n",
    "    def pylint_code_score(self, code):\n",
    "        try:\n",
    "            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "                tmp.write(code)\n",
    "                tmp_path = tmp.name\n",
    "            \n",
    "            result = subprocess.run(\n",
    "                [\"pylint\", \"--output-format=text\", tmp_path],\n",
    "                capture_output=True,\n",
    "                text=True,\n",
    "                check=False\n",
    "            )\n",
    "            os.unlink(tmp_path)\n",
    "            \n",
    "            match = re.search(r\"rated at (\\d+\\.?\\d*)/10\", result.stdout)\n",
    "            return float(match.group(1)) if match else -1\n",
    "        \n",
    "        except Exception as e:\n",
    "            print(f\"Pylint analysis failed: {e}\")\n",
    "            return -1\n",
    "\n",
    "    def radon_mi_code_score(self, code):\n",
    "        try:\n",
    "            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:\n",
    "                tmp.write(code)\n",
    "                tmp_path = tmp.name\n",
    "            \n",
    "            result = subprocess.run(\n",
    "                [\"radon\", \"mi\", \"--json\", tmp_path],\n",
    "                capture_output=True,\n",
    "                text=True,\n",
    "                check=False\n",
    "            )\n",
    "            os.unlink(tmp_path)\n",
    "            \n",
    "            data = json.loads(result.stdout)\n",
    "            if data and isinstance(data, dict):\n",
    "                return list(data.values())[0][\"mi\"] / 10\n",
    "            return -1\n",
    "        except Exception as e:\n",
    "            print(f\"Radon analysis failed: {e}\")\n",
    "            return -1\n",
    "\n",
    "class LCDP():\n",
    "    def __init__(self, api_key, model=\"gpt-3.5-turbo\"):\n",
    "        self.llm_model = LLMModel(api_key, model)\n",
    "        self.code_runner = CodeRunner()\n",
    "        self.pass_rate_predictor = PassRatePredictor()\n",
    "        self.evaluator = Evaluator(self.pass_rate_predictor)\n",
    "        self.task_description = None\n",
    "        self.current_plan = None\n",
    "        self.test_weights = {}\n",
    "        self.test_cases = {}\n",
    "\n",
    "    async def run(self, task_description, max_iterations=3, example_dataset=None,\n",
    "                 num_plans=3, num_tests=5, num_codes=5, refine_rounds=3, use_pass_rate_for_train=False):\n",
    "        self.task_description = task_description\n",
    "        \n",
    "        # Initialize LLM Task Manager\n",
    "        self.llmtm = LLMTM(task_description, self.llm_model)\n",
    "        self.llmcg = LLMCG(task_description, self.llm_model)\n",
    "        \n",
    "        # Phase 1: Plan Generation and Refinement\n",
    "        # print(\"########################################################################\")\n",
    "        # print(\"### Phase 1: Plan Generation and Refinement\")\n",
    "        logging.info(\"########################################################################\")\n",
    "        logging.info(\"### Phase 1: Plan Generation and Refinement\")\n",
    "        plan, plan_raw = self.llmtm.get_plan()\n",
    "        self.current_plan = await self._plan_refinement_loop(self.llmtm, plan_raw, refine_rounds)\n",
    "        self.current_plan = self._plan_format_refinement(self.current_plan)\n",
    "        \n",
    "        # Phase 2: Test Case Generation and Weighting\n",
    "        # print(\"\\n########################################################################\")\n",
    "        # print(\"### Phase 2: Test Case Generation and Weighting\")\n",
    "        logging.info(\"\\n########################################################################\")\n",
    "        logging.info(\"### Phase 2: Test Case Generation and Weighting\")\n",
    "        # self.test_cases = await self._generate_tests(self.llmtm, num_tests)\n",
    "        # debug\n",
    "        self.test_cases = await self._generate_tests_async(self.llmtm, num_tests, use_example=False)\n",
    "        self.test_cases = self._filter_test_cases(self.test_cases)\n",
    "\n",
    "        # print(\"Calculating test weights...\")\n",
    "        logging.info(\"Calculating test weights...\")\n",
    "        self.test_weights = self._calculate_test_weights(self.test_cases, example_dataset)\n",
    "        \n",
    "        # Phase 3: Iterative Code Generation\n",
    "        # print(\"\\n########################################################################\")\n",
    "        # print(\"### Phase 3: Iterative Code Generation\")\n",
    "        logging.info(\"\\n########################################################################\")\n",
    "        logging.info(\"### Phase 3: Iterative Code Generation\")\n",
    "        best_codes = {}\n",
    "        for iteration in range(max_iterations):\n",
    "            # print(f\"\\n=== Iteration {iteration+1}/{max_iterations} ===\")\n",
    "            logging.info(f\"\\n=== Iteration {iteration+1}/{max_iterations} ===\")\n",
    "            \n",
    "            # Generate new codes\n",
    "            # new_codes = await self._generate_codes(num_codes)\n",
    "            new_codes = await self._generate_codes_async(num_codes)\n",
    "            \n",
    "            # Evaluate codes\n",
    "            logging.info(\"Evaluating codes...\")\n",
    "            scored_codes, filtered_test_result = self._evaluate_codes(new_codes)\n",
    "            # remove the test cases that are not in the filtered_test_result\n",
    "            self.test_cases = {k: v for k, v in self.test_cases.items() if k in list(filtered_test_result.keys())}\n",
    "\n",
    "            logging.info(\"training pass_rate_predictor...\")\n",
    "            self.pass_rate_predictor.add_data(scored_codes, use_pass_rate=use_pass_rate_for_train)\n",
    "            self.pass_rate_predictor.train_model(epochs=50, batch_size=32, lr=0.001)\n",
    "            \n",
    "            # Update best codes\n",
    "            best_codes.update(self._select_top_codes(scored_codes, top_k=3))\n",
    "            \n",
    "            # User feedback\n",
    "            if not await self._get_user_feedback(best_codes):\n",
    "                self.current_plan['user_feedback'] = \"Based on previous outputs, please improve the code quality.\"\n",
    "        \n",
    "        return best_codes\n",
    "\n",
    "    async def _plan_refinement_loop(self, llmtm, initial_plan_raw, max_rounds):\n",
    "        current_plan_raw = initial_plan_raw\n",
    "        current_plan = llmtm.extract_plan(current_plan_raw)\n",
    "        for _ in range(max_rounds):\n",
    "            # Show current plan\n",
    "            # print(\"Current Plan:\\n\", self.plan_json_to_str(current_plan[\"overall_plan\"]))\n",
    "            logging.info(\"Current Plan:\\n\" + self.plan_json_to_str(current_plan[\"overall_plan\"]))\n",
    "            \n",
    "            # Get user feedback\n",
    "            if input(\"Refine plan? (y/n): \").lower() != 'y':\n",
    "                logging.info(\"Skipping plan refinement.\")\n",
    "                break\n",
    "            \n",
    "            feedback = input(\"Enter refinement feedback: \")\n",
    "            logging.info(f\"User feedback: {feedback}\")\n",
    "            current_plan, current_plan_raw = llmtm.refine_plan(feedback, current_plan_raw)\n",
    "        \n",
    "        return llmtm.extract_plan(current_plan_raw)\n",
    "\n",
    "    def plan_json_to_str(self, plan):\n",
    "        # Process Input Format\n",
    "        input_fmt = plan[\"input_format\"]\n",
    "        input_lines = []\n",
    "        for idx, (dtype, shape) in enumerate(input_fmt, 1):\n",
    "            shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_lines.append(f\"Argument {idx}: {dtype} with {shape_str}\")\n",
    "        input_section = \"Input Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in input_lines])\n",
    "\n",
    "        # Process Output Format\n",
    "        output_fmt = plan[\"output_format\"]\n",
    "        output_lines = []\n",
    "        for idx, (dtype, shape) in enumerate(output_fmt, 1):\n",
    "            shape_str = f\"shape={shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_lines.append(f\"Output {idx}: {dtype} with {shape_str}\")\n",
    "        output_section = \"Output Format:\\n\" + \"\\n\".join([f\"- {line}\" for line in output_lines])\n",
    "\n",
    "        # Build Overall Plan Details\n",
    "        plan_part = [\n",
    "            \"=== Current Plan ===\",\n",
    "            input_section,\n",
    "            output_section,\n",
    "            f\"Components Order: {', '.join(plan['components'])}\",\n",
    "            \"Plan Steps:\",\n",
    "            *[f\"- {step}\" for step in plan[\"plan\"]],\n",
    "            \"Overall Test Case Advice:\",\n",
    "            *[f\"- {advice}\" for advice in plan[\"test_case_generation_advise\"]],\n",
    "            \"\\n\",\n",
    "        ]\n",
    "\n",
    "        return \"\\n\".join(plan_part)\n",
    "\n",
    "    def _plan_format_refinement(self, plan_dict):\n",
    "        \"\"\"Refines the input and output formats in the plan to be lists of lists.\"\"\"\n",
    "        \n",
    "        # Create a deep copy to avoid modifying the original input\n",
    "        refined_plan = copy.deepcopy(plan_dict)\n",
    "        \n",
    "        def refine_format(formats):\n",
    "            \"\"\"Ensure each format field is a list of lists.\"\"\"\n",
    "            if isinstance(formats, list):\n",
    "                # Check if all elements are lists\n",
    "                if not all(isinstance(elem, list) for elem in formats):\n",
    "                    return [formats]\n",
    "            else:\n",
    "                # If it's not a list, wrap it into a list (though input is expected to be a list)\n",
    "                return [formats]\n",
    "            return formats\n",
    "        \n",
    "        # Process each component in 'components'\n",
    "        for component in refined_plan[\"components\"].values():\n",
    "            for key in [\"input_format\", \"output_format\"]:\n",
    "                if key in component:\n",
    "                    component[key] = refine_format(component[key])\n",
    "        \n",
    "        # Process 'overall_plan'\n",
    "        overall_plan = refined_plan.get(\"overall_plan\")\n",
    "        if overall_plan:\n",
    "            for key in [\"input_format\", \"output_format\"]:\n",
    "                if key in overall_plan:\n",
    "                    overall_plan[key] = refine_format(overall_plan[key])\n",
    "        \n",
    "        return refined_plan\n",
    "\n",
    "    async def _generate_tests(self, llmtm, num_tests):\n",
    "        test_cases = {}\n",
    "        for _ in range(num_tests):\n",
    "            test = llmtm.get_test_cases(self.current_plan['overall_plan'])\n",
    "            test_cases.update(test)\n",
    "        return test_cases\n",
    "    \n",
    "    async def _generate_tests_async(self, llmtm, num_tests, use_example=True):\n",
    "        test_cases = {}\n",
    "        task_list = [llmtm.get_test_cases_async(self.current_plan['overall_plan'], use_example=use_example) for _ in range(num_tests)]\n",
    "        \n",
    "        for task in tqdm_asyncio.as_completed(task_list, total=num_tests, desc=\"Generating async tests\"):\n",
    "            test = await task\n",
    "            for key, value in test.items():\n",
    "                # 生成唯一键名逻辑\n",
    "                new_key = key\n",
    "                suffix = 1\n",
    "                while new_key in test_cases:\n",
    "                    new_key = f\"{key}_{suffix}\"\n",
    "                    suffix += 1\n",
    "                test_cases[new_key] = value\n",
    "                \n",
    "        return test_cases\n",
    "\n",
    "    def _filter_test_cases(self, dataset):\n",
    "        print(dataset)\n",
    "        runnable_entries = {}\n",
    "        for code_id, attributes in dataset.items():\n",
    "            test_code = attributes.get(\"test_function\", \"\")\n",
    "            try:\n",
    "                # Attempt to compile the code string to check for syntax errors.\n",
    "                compile(test_code, \"<string>\", \"exec\")\n",
    "                # If no exception is raised, consider the code as runnable.\n",
    "                runnable_entries[code_id] = attributes\n",
    "            except Exception as error:\n",
    "                # If an exception is raised, skip this entry.\n",
    "                continue\n",
    "        return runnable_entries\n",
    "            \n",
    "\n",
    "    def _calculate_test_weights(self, test_cases, example_dataset):\n",
    "        if not example_dataset:\n",
    "            return {tid: 1.0 for tid in test_cases}\n",
    "        \n",
    "        # Run example dataset through tests\n",
    "        _, test_results = self.code_runner.run_all_tests(example_dataset, test_cases)\n",
    "        \n",
    "        # Calculate weights\n",
    "        weights = {}\n",
    "        for tid, results in test_results.items():\n",
    "            pass_rate = sum(results.values()) / len(results)\n",
    "            weights[tid] = 1 - abs(pass_rate - 0.5)  # Weight tests that discriminate\n",
    "        return weights\n",
    "\n",
    "    async def _generate_codes(self, num_codes):\n",
    "        codes = {}\n",
    "        for _ in range(num_codes):\n",
    "            code = self.llmcg.get_code(\n",
    "                extracted_plan=self.current_plan,\n",
    "                test_cases=self.test_cases,\n",
    "            )\n",
    "            codes[f\"code_{len(codes)}\"] = code\n",
    "        return codes\n",
    "    \n",
    "    async def _generate_codes_async(self, num_codes):\n",
    "        codes = {}\n",
    "        task_list = [self.llmcg.get_code_async(extracted_plan=self.current_plan,\n",
    "                                               test_cases=self.test_cases) for _ in range(num_codes)]\n",
    "        for task in tqdm_asyncio.as_completed(task_list, total=num_codes, desc=\"Generating async codes\"):\n",
    "            code = await task\n",
    "            codes[f\"code_{len(codes)}\"] = code\n",
    "        return codes\n",
    "\n",
    "    def transform_test_perspective(self, test_results):\n",
    "        transformed = {}\n",
    "        for test_case_id, code_results in test_results.items():\n",
    "            for code_id, result in code_results.items():\n",
    "                if code_id not in transformed:\n",
    "                    transformed[code_id] = {}\n",
    "                transformed[code_id][test_case_id] = result\n",
    "        return transformed\n",
    "\n",
    "    def _filter_test_cases_by_pass_rate(self, test_results, threshold=0.05):\n",
    "        filtered_test_case_list = []\n",
    "        filtered_test_results = {}\n",
    "        test_case_length = len(test_results)\n",
    "        for test_case_id, results in test_results.items():\n",
    "            total = len(results)\n",
    "            if total == 0:\n",
    "                continue\n",
    "            passed = sum(results.values())\n",
    "            pass_rate = passed / total\n",
    "            if pass_rate > threshold:\n",
    "                filtered_test_case_list.append(test_case_id)\n",
    "                filtered_test_results[test_case_id] = results\n",
    "        self.test_cases = {k: v for k, v in self.test_cases.items() if k in filtered_test_case_list}\n",
    "        logging.info(f\"Filtered test cases: {len(self.test_cases)} out of {test_case_length}\")\n",
    "\n",
    "        filtered_fun_results = self.transform_test_perspective(filtered_test_results)\n",
    "\n",
    "        return filtered_fun_results, filtered_test_results\n",
    "\n",
    "    def _evaluate_codes(self, codes):\n",
    "        fun_results, test_results = self.code_runner.run_all_tests(codes, self.test_cases)\n",
    "\n",
    "        filtered_fun_results, filtered_test_results = self._filter_test_cases_by_pass_rate(test_results, threshold=0.05)\n",
    "\n",
    "        input_data = {}\n",
    "        for code_id, results in filtered_fun_results.items():\n",
    "            input_data[code_id] = {\n",
    "                'code': codes[code_id]['code'],\n",
    "                'test_results': results,\n",
    "                'test_weights': self.test_weights\n",
    "            }\n",
    "        # Calculate scores\n",
    "        output_scores, full_score_dict = self.evaluator.calculate_batch_scores(input_data)\n",
    "        # Combine scores with code data\n",
    "        output_results = {\n",
    "            code_id: {\n",
    "                'code': codes[code_id]['code'],\n",
    "                'plan': codes[code_id]['plan'],\n",
    "                'main_function_name': codes[code_id]['main_function_name'],\n",
    "                'score': output_scores[code_id],\n",
    "                'pass_rate_score': full_score_dict[code_id]['pass_rate_score'],\n",
    "                'prediction_score': full_score_dict[code_id]['prediction_score'],\n",
    "                'pylint_score': full_score_dict[code_id]['pylint_score'],\n",
    "                'radon_score': full_score_dict[code_id]['radon_score'],\n",
    "                'test_case_results': filtered_fun_results[code_id],\n",
    "            }\n",
    "            for code_id in codes.keys()\n",
    "        }\n",
    "        return output_results, filtered_test_results\n",
    "        # return {\n",
    "        #     code_id: {\n",
    "        #         'code': codes[code_id]['code'],\n",
    "        #         'plan':codes[code_id]['plan'],\n",
    "        #         'main_function_name':codes[code_id]['main_function_name'],\n",
    "        #         'score': self.evaluator.calculate_score(codes[code_id]['code'] ,results, self.test_weights)\n",
    "        #     }\n",
    "        #     for code_id, results in fun_results.items()\n",
    "        # }\n",
    "\n",
    "    def _select_top_codes(self, scored_codes, top_k=3):\n",
    "        return dict(sorted(scored_codes.items(), \n",
    "                          key=lambda x: x[1]['score'], \n",
    "                          reverse=True)[:top_k])\n",
    "\n",
    "    async def _get_user_feedback(self, top_codes):\n",
    "\n",
    "        logging.info(\"\\nTop Performing Codes:\")\n",
    "        for cid, data in top_codes.items():\n",
    "            logging.info(f\"{cid} [Score: {data['score']:.2f}]:\")\n",
    "            logging.info(\"Code workflow:\")\n",
    "            logging.info(data['plan'])\n",
    "            logging.info(\"Partial Code:\")\n",
    "            logging.info(data['code'][:500] + \"...\\n\")\n",
    "        \n",
    "        if input(\"Provide feedback? (y/n): \").lower() == 'y':\n",
    "            feedback = input(\"Enter your feedback: \")\n",
    "            logging.info(f\"User feedback: {feedback}\")\n",
    "            # Store feedback for next generation cycle\n",
    "            self.current_plan['user_feedback'] = feedback\n",
    "            return True\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-10 11:04:11,235 - root - INFO - ########################################################################\n",
      "2025-04-10 11:04:11,236 - root - INFO - ### Phase 1: Plan Generation and Refinement\n",
      "2025-04-10 11:04:25,068 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "2025-04-10 11:04:25,102 - root - INFO - Current Plan:\n",
      "=== Current Plan ===\n",
      "Input Format:\n",
      "- Argument 1: list with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: int with no fixed shape\n",
      "Components Order: find_adjacent_min_pair, replace_pair_with_sum, count_operations_to_non_decreasing\n",
      "Plan Steps:\n",
      "- Begin with the input array 'nums'.\n",
      "- Use 'count_operations_to_non_decreasing' to find how many operations are needed to replace adjacent pairs to make the array non-decreasing.\n",
      "- Inside this method, repeatedly use 'find_adjacent_min_pair' to locate the pair with the minimum sum.\n",
      "- Upon finding the pair, use 'replace_pair_with_sum' to modify the array.\n",
      "- Continue this process until the array is non-decreasing.\n",
      "- Return the total number of operations performed.\n",
      "Overall Test Case Advice:\n",
      "- Test with the simplest input: an already sorted array.\n",
      "- Test with decreasing and random arrays to assess number of operations required.\n",
      "- Include edge cases such as arrays with duplicate values and empty arrays.\n",
      "- Consider performance testing with large input sizes.\n",
      "\n",
      "\n",
      "2025-04-10 11:04:50,718 - root - INFO - Skipping plan refinement.\n",
      "2025-04-10 11:04:50,719 - root - INFO - \n",
      "########################################################################\n",
      "2025-04-10 11:04:50,719 - root - INFO - ### Phase 2: Test Case Generation and Weighting\n",
      "Generating async tests:   0%|          | 0/5 [00:00<?, ?it/s]2025-04-10 11:05:00,625 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  20%|██        | 1/5 [00:09<00:39,  9.94s/it]2025-04-10 11:05:01,453 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  40%|████      | 2/5 [00:10<00:13,  4.58s/it]2025-04-10 11:05:02,797 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  60%|██████    | 3/5 [00:12<00:06,  3.10s/it]2025-04-10 11:05:11,717 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "2025-04-10 11:05:11,747 - root - WARNING - Failed to extract test cases, retrying (1)..., current llm_output:\n",
      "< correct_ness >\n",
      "< Planning >\n",
      "In this test case, I will check the function's correctness using a simple sorted array as input. The expected result is 0 since the array is already non-decreasing, meaning no operations are needed. This will confirm that the function correctly identifies a non-decreasing array as valid without performing operations.\n",
      "< / Planning >\n",
      "< Code >\n",
      "def test_case(func):\n",
      "    # Given sorted array\n",
      "    nums = [1, 2, 2]\n",
      "    expected_output = 0\n",
      "    \n",
      "    # Run the function and capture the result\n",
      "    result = func(nums)\n",
      "    \n",
      "    # Check if the result matches the expected output\n",
      "    return result == expected_output\n",
      "< / Code >\n",
      "\n",
      "< edge_case >\n",
      "< Planning >\n",
      "I will examine the behavior of the function when provided with an empty array as input. This is an edge case because the function must handle situations with no elements gracefully. The expected output is also 0 since an empty array is trivially non-decreasing as there are no elements to compare. This test checks if the function can return a valid result without runtime errors when faced with an empty input.\n",
      "< / Planning >\n",
      "< Code >\n",
      "def test_case(func):\n",
      "    # Given empty array\n",
      "    nums = []\n",
      "    expected_output = 0\n",
      "    \n",
      "    # Run the function and capture the result\n",
      "    result = func(nums)\n",
      "    \n",
      "    # Check if the result matches the expected output\n",
      "    return result == expected_output\n",
      "< / Code >\n",
      "\n",
      "< edge_case >\n",
      "< Planning >\n",
      "I will provide an array consisting of a single element. The expected output is 0 because a single-element array is always non-decreasing. This test ensures the function handles minimal input sizes effectively and returns the correct result without unnecessary operations.\n",
      "< / Planning >\n",
      "< Code >\n",
      "def test_case(func):\n",
      "    # Given single element array\n",
      "    nums = [5]\n",
      "    expected_output = 0\n",
      "    \n",
      "    # Run the function and capture the result\n",
      "    result = func(nums)\n",
      "    \n",
      "    # Check if the result matches the expected output\n",
      "    return result == expected_output\n",
      "< / Code >\n",
      "\n",
      "< correct_ness >\n",
      "< Planning >\n",
      "In this test case, I will provide a decreasing array to examine how many operations the function requires to convert it into a non-decreasing array. The input will be [5, 2, 3, 1], and according to the task description, the expected output is 2 as it requires two operations to make the array non-decreasing. This will verify if the function accurately computes the number of required operations.\n",
      "< / Planning >\n",
      "< Code >\n",
      "def test_case(func):\n",
      "    # Given decreasing array\n",
      "    nums = [5, 2, 3, 1]\n",
      "    expected_output = 2\n",
      "    \n",
      "    # Run the function and capture the result\n",
      "    result = func(nums)\n",
      "    \n",
      "    # Check if the result matches the expected output\n",
      "    return result == expected_output\n",
      "< / Code >\n",
      "\n",
      "< correct_ness >\n",
      "< Planning >\n",
      "In this case, I will test the function with an array that has duplicate values and requires multiple operations to become non-decreasing. The input will be [4, 3, 2, 3, 4], and this should test the function's ability to handle duplicates correctly while computing the expected output of 2 operations to achieve a non-decreasing state.\n",
      "< / Planning >\n",
      "< Code >\n",
      "def test_case(func):\n",
      "    # Given array with duplicates that needs operations\n",
      "    nums = [4, 3, 2, 3, 4]\n",
      "    expected_output = 2\n",
      "    \n",
      "    # Run the function and capture the result\n",
      "    result = func(nums)\n",
      "    \n",
      "    # Check if the result matches the expected output\n",
      "    return result == expected_output\n",
      "< / Code >\n",
      "\n",
      "< run_time >\n",
      "< Planning >\n",
      "I will test the function's execution time with a large array to ensure it operates efficiently. The input will be a list of size 10,000 with random values between 1 and 100. While it is difficult to define an exact expected output for a random array, I will measure the execution time to ensure it completes within a reasonable threshold (e.g., 1 second).\n",
      "< / Planning >\n",
      "< Code >\n",
      "import time\n",
      "import random\n",
      "\n",
      "def test_case(func):\n",
      "    # Generate a large random array\n",
      "    nums = [random.randint(1, 100) for _ in range(10000)]\n",
      "    \n",
      "    # Measure execution time\n",
      "    start_time = time.time()\n",
      "    func(nums)\n",
      "    execution_time = time.time() - start_time\n",
      "    \n",
      "    # Check if the execution time is within the threshold\n",
      "    return execution_time < 1  # Expected execution time should be less than 1 second\n",
      "< / Code >\n",
      "\n",
      "< component_check >\n",
      "< Planning >\n",
      "This test will confirm that the function correctly utilizes the specified components for finding adjacent pairs, replacing pairs with their sum, and counting the operations. I will inspect the function's code as a string and check for these specific function names to ensure it adheres to the expected structure.\n",
      "< / Planning >\n",
      "< Code >\n",
      "def test_case(func):\n",
      "    # Convert function to string to inspect contents\n",
      "    func_code = inspect.getsource(func)\n",
      "    \n",
      "    # Check for the required components in the code\n",
      "    contains_find_adjacent_min_pair = \"find_adjacent_min_pair\" in func_code\n",
      "    contains_replace_pair_with_sum = \"replace_pair_with_sum\" in func_code\n",
      "    contains_count_operations_to_non_decreasing = \"count_operations_to_non_decreasing\" in func_code\n",
      "    \n",
      "    return contains_find_adjacent_min_pair and contains_replace_pair_with_sum and contains_count_operations_to_non_decreasing\n",
      "< / Code >\n",
      "2025-04-10 11:05:16,822 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests:  80%|████████  | 4/5 [00:26<00:07,  7.41s/it]2025-04-10 11:05:21,967 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async tests: 100%|██████████| 5/5 [00:31<00:00,  6.26s/it]\n",
      "2025-04-10 11:05:22,004 - root - INFO - Calculating test weights...\n",
      "2025-04-10 11:05:22,005 - root - INFO - \n",
      "########################################################################\n",
      "2025-04-10 11:05:22,005 - root - INFO - ### Phase 3: Iterative Code Generation\n",
      "2025-04-10 11:05:22,006 - root - INFO - \n",
      "=== Iteration 1/2 ===\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_case_1': {'test_type': 'correctness', 'purpose': 'I will create a test function that validates the behavior of the function responsible for determining the number of operations needed to make an array non-decreasing. I will check with an array that requires operations to be made non-decreasing, ensuring that the output matches the expected result after performing the operations. This will confirm that the function behaves correctly under normal circumstances.', 'test_function': 'def test_case(func):\\n    nums = [5, 2, 3, 1]  # Input array that requires two operations\\n    expected_output = 2   # Two operations needed as explained in the example\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_2': {'test_type': 'correctness', 'purpose': 'I will test an already sorted array to ensure that no operations are performed and the function can correctly identify when the input is already in non-decreasing order. In this case, the output should be 0. This checks the ability of the function to handle cases that require no modifications.', 'test_function': 'def test_case(func):\\n    nums = [1, 2, 2]  # Already sorted array\\n    expected_output = 0  # No operations needed\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_3': {'test_type': 'edge_case', 'purpose': 'To validate how the function handles edge cases, I will create a test for an empty list. According to the problem description, an empty list should not require any operations; thus, the expected output should be 0. This tests the robustness of the implementation against invalid or special input cases.', 'test_function': 'def test_case(func):\\n    nums = []  # Empty array\\n    expected_output = 0  # No operations needed for an empty array\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_4': {'test_type': 'edge_case', 'purpose': \"I will also test with arrays containing duplicate values. It's essential for the function to correctly process lists where multiple elements are the same to confirm correctness under diverse inputs. The expectation would be that it still processes those cases without needing any operations if the array is non-decreasing.\", 'test_function': 'def test_case(func):\\n    nums = [2, 2, 2]  # Array with duplicates already in non-decreasing order\\n    expected_output = 0  # No operations needed\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_5': {'test_type': 'runtime', 'purpose': 'For a performance test, I will generate a large list with random integers and measure the execution time of the function. This will help to validate that the implementation performs efficiently, especially when faced with larger input sizes, ensuring that the execution is completed within a reasonable time threshold.', 'test_function': 'import time\\nimport random\\n\\ndef test_case(func):\\n    nums = [random.randint(0, 1000) for _ in range(100000)]  # Generating a large list\\n    start_time = time.time()\\n    result = func(nums)\\n    end_time = time.time()\\n    execution_time = end_time - start_time\\n    return execution_time < 1  # Assuming the function should complete within 1 second'}, 'test_case_1_1': {'test_type': 'correctness', 'purpose': 'I will design a test case to validate the correctness of the function. This function should return the number of operations needed to make the given array non-decreasing. I will use an array that is initially decreasing to observe how many operations it takes to convert it to non-decreasing order. The expected output can be calculated based on the operations applied to the pairs in the array.', 'test_function': 'def test_case(func): \\n    # Test input: decreasing array [5, 2, 3, 1]\\n    result = func([5, 2, 3, 1]) \\n    expected = 2 \\n    return result == expected'}, 'test_case_2_1': {'test_type': 'correctness', 'purpose': \"This test case will verify the function's correctness when given an already sorted array, which requires no operations. The output should simply be 0 in this scenario.\", 'test_function': 'def test_case(func): \\n    # Test input: already sorted array [1, 2, 2]\\n    result = func([1, 2, 2]) \\n    expected = 0 \\n    return result == expected'}, 'test_case_3_1': {'test_type': 'edge_case', 'purpose': 'I will test an edge case with an empty array. Since there are no elements to operate on, the expected output should be 0 for such an input. This will help verify the function handles this special case properly.', 'test_function': 'def test_case(func): \\n    # Test input: empty array []\\n    result = func([]) \\n    expected = 0 \\n    return result == expected'}, 'test_case_4_1': {'test_type': 'edge_case', 'purpose': 'In this case, I will test with an array containing duplicate values. An input like [3, 3, 2, 1] is chosen because it has duplicates and will require a few operations to become non-decreasing. The expected number of operations should be calculated manually to validate against the output.', 'test_function': 'def test_case(func): \\n    # Test input: array with duplicates [3, 3, 2, 1]\\n    result = func([3, 3, 2, 1]) \\n    expected = 2 \\n    return result == expected'}, 'test_case_5_1': {'test_type': 'runtime', 'purpose': \"The goal of this test case is to check the function's performance when handling large input sizes. A large random array will be generated, and the execution time will be measured to ensure it stays within a reasonable limit.\", 'test_function': 'import time \\ndef test_case(func): \\n    # Generate a large random input array \\n    import random \\n    large_input = [random.randint(0, 100) for _ in range(10000)] \\n    \\n    start_time = time.time() \\n    func(large_input) \\n    execution_time = time.time() - start_time \\n    \\n    return execution_time < 1.0  # Ensure it runs in less than 1 second'}, 'test_case_6': {'test_type': 'error_handling', 'purpose': 'In this case, I will check how the function handles invalid inputs, such as passing an input that is not a list. I expect it to raise an appropriate error.', 'test_function': 'def test_case(func): \\n    try: \\n        func(\"not a list\")  # Invalid input \\n        return False  # Should not reach this line \\n    except TypeError: \\n        return True  # Correctly handled invalid input'}, 'test_case_1_2': {'test_type': 'correctness', 'purpose': 'I will create a test function to validate the correctness of the implementation that counts the number of operations needed to make an array non-decreasing. I will use an input that requires multiple operations to transform the array and compare the output against the expected number of operations. This will ensure that the function correctly identifies adjacent pairs, performs replacements, and calculates the total number of operations needed.', 'test_function': 'def test_case(func):\\n    # Input array that requires operations\\n    nums = [5, 2, 3, 1]\\n    # Expected number of operations\\n    expected_output = 2\\n    # Call the provided function and get the actual output\\n    actual_output = func(nums)\\n    # Return True if the actual output matches the expected output\\n    return actual_output == expected_output'}, 'test_case_2_2': {'test_type': 'edge_case', 'purpose': 'I will create a test function to validate the behavior of the implementation when given an empty array input. Since there are no elements in the array, no operations should be needed, so the expected output will be 0. This edge case will help ensure that the function handles empty inputs properly without errors.', 'test_function': 'def test_case(func):\\n    # Input is an empty array\\n    nums = []\\n    # Expected number of operations\\n    expected_output = 0\\n    # Call the provided function and get the actual output\\n    actual_output = func(nums)\\n    # Return True if the actual output matches the expected output\\n    return actual_output == expected_output'}, 'test_case_3_2': {'test_type': 'runtime', 'purpose': 'I will create a test function to measure the runtime of the provided implementation using a large input array. This will help determine if the function performs efficiently for large datasets. I will record the execution time and ensure that it runs within a reasonable threshold, for example, less than 1 second.', 'test_function': 'import time\\n\\ndef test_case(func):\\n    nums = list(range(100000, 0, -1))  # Large input of decreasing numbers\\n    start_time = time.time()           # Start timing\\n    result = func(nums)                # Call the function\\n    execution_time = time.time() - start_time  # Calculate elapsed time\\n\\n    # Check if the function returns a valid count of operations\\n    operations_needed = result >= 0  # Result must be non-negative\\n    # Return True if execution time is below the threshold and function is valid\\n    return operations_needed and execution_time < 1'}, 'test_case_4_2': {'test_type': 'component_check', 'purpose': 'I will create a test function to ensure that the implementation utilizes the specified components, such as finding adjacent pairs and replacing them with their sum. This will involve checking the implementation for function calls that match the specified names. Verifying component usage can help ensure compliance with the coding requirements before runtime and correctness checks.', 'test_function': \"def test_case(func):\\n    import inspect\\n\\n    # Check if the required components are present in the function\\n    code = inspect.getsource(func)\\n    # Look for the specific components in the code\\n    has_find_adjacent = 'find_adjacent_min_pair' in code\\n    has_replace_pair = 'replace_pair_with_sum' in code\\n    # Return True if both components are found\\n    return has_find_adjacent and has_replace_pair\"}, 'test_case_5_2': {'test_type': 'error_handling', 'purpose': 'I will create a test function to check if the implementation properly raises exceptions for invalid inputs, such as a list containing non-integer types. This is crucial for ensuring that the function can handle erroneous inputs gracefully. The expected behavior is that the function should raise a ValueError or TypeError when given invalid input.', 'test_function': \"def test_case(func):\\n    try:\\n        # Input with non-integer values\\n        nums = [5, 'two', None]\\n        # Call the function with invalid input\\n        func(nums)\\n        # If no exception is raised, return False\\n        return False\\n    except (ValueError, TypeError):\\n        # Expected behavior; return True if an exception is raised\\n        return True\"}, 'test_case_1_3': {'test_type': 'correctness', 'purpose': 'The test function aims to validate the correctness of the implemented solution by passing a few sample arrays representing different scenarios. It will confirm that the output matches the expected number of operations needed to make the array non-decreasing. We will consider edge cases and basic functionality, ensuring the function correctly identifies already sorted arrays versus those requiring modifications.', 'test_function': 'def test_case(func):\\n    # Test Case 1: Basic functionality with a decreasing array\\n    input1 = [5, 2, 3, 1]\\n    expected_output1 = 2\\n    actual_output1 = func(input1)\\n    if actual_output1 != expected_output1:\\n        return False\\n\\n    # Test Case 2: Already sorted array\\n    input2 = [1, 2, 2]\\n    expected_output2 = 0\\n    actual_output2 = func(input2)\\n    if actual_output2 != expected_output2:\\n        return False\\n\\n    # Test Case 3: Edge case with empty array\\n    input3 = []\\n    expected_output3 = 0\\n    actual_output3 = func(input3)\\n    if actual_output3 != expected_output3:\\n        return False\\n\\n    # Test Case 4: Array of the same elements\\n    input4 = [4, 4, 4, 4]\\n    expected_output4 = 0\\n    actual_output4 = func(input4)\\n    if actual_output4 != expected_output4:\\n        return False\\n    \\n    # Test Case 5: Array with negative numbers\\n    input5 = [-1, -2, -3]\\n    expected_output5 = 2\\n    actual_output5 = func(input5)\\n    if actual_output5 != expected_output5:\\n        return False\\n\\n    # Test Case 6: Mixed positive and negative numbers\\n    input6 = [1, 3, 2, 4]\\n    expected_output6 = 1\\n    actual_output6 = func(input6)\\n    if actual_output6 != expected_output6:\\n        return False\\n\\n    return True'}, 'test_case_2_3': {'test_type': 'edge_case', 'purpose': 'This test function addresses special scenarios such as empty arrays and arrays with uniform elements. It checks that the function handles these cases gracefully without errors and returns the correct number of operations.', 'test_function': 'def test_case(func):\\n    # Test Case 1: Empty array\\n    input1 = []\\n    expected_output1 = 0\\n    actual_output1 = func(input1)\\n    if actual_output1 != expected_output1:\\n        return False\\n    \\n    # Test Case 2: Array with same values\\n    input2 = [7, 7, 7, 7]\\n    expected_output2 = 0\\n    actual_output2 = func(input2)\\n    if actual_output2 != expected_output2:\\n        return False\\n\\n    # Test Case 3: Single element array\\n    input3 = [42]\\n    expected_output3 = 0\\n    actual_output3 = func(input3)\\n    if actual_output3 != expected_output3:\\n        return False\\n\\n    # Test Case 4: Two element sorted array\\n    input4 = [1, 2]\\n    expected_output4 = 0\\n    actual_output4 = func(input4)\\n    if actual_output4 != expected_output4:\\n        return False\\n\\n    # Test Case 5: Two element unsorted array\\n    input5 = [2, 1]\\n    expected_output5 = 1\\n    actual_output5 = func(input5)\\n    if actual_output5 != expected_output5:\\n        return False\\n\\n    return True'}, 'test_case_3_3': {'test_type': 'runtime', 'purpose': 'We will measure the execution time for a large input array to ensure that the function performs efficiently under high loads. This test case will generate a large list of random integers and verify that the function completes within a reasonable time limit.', 'test_function': 'import random\\nimport time\\n\\ndef test_case(func):\\n    # Test Case 1: Performance testing with a large list\\n    input_data = [random.randint(1, 1000) for _ in range(10000)]  # Large array of 10,000 random elements\\n    start_time = time.time()\\n    actual_output = func(input_data)\\n    elapsed_time = time.time() - start_time\\n\\n    # Ensure the function completes within 1 second\\n    if elapsed_time > 1:\\n        return False\\n\\n    return True'}, 'test_case_4_3': {'test_type': 'component_check', 'purpose': 'This test case checks whether the function is structured correctly and uses necessary components, like finding minimum pairs and performing operations. It will do this by inspecting the presence of required functions or method calls in the implementation.', 'test_function': \"def test_case(func):\\n    # Check if the function uses needed components\\n    source_code = inspect.getsource(func)\\n    \\n    # Ensure that the necessary components are defined in the function source\\n    required_components = ['find_adjacent_min_pair', 'replace_pair_with_sum', 'count_operations_to_non_decreasing']\\n    \\n    for component in required_components:\\n        if component not in source_code:\\n            return False\\n\\n    return True\"}, 'test_case_5_3': {'test_type': 'error_handling', 'purpose': 'This test function checks how the implementation handles invalid input types, such as strings or other non-list data types. We expect the function to either raise a TypeError or handle it gracefully without crashing.', 'test_function': 'def test_case(func):\\n    # Test Case 1: Invalid input type (string)\\n    try:\\n        func(\"invalid input\")\\n        return False\\n    except TypeError:\\n        pass\\n    \\n    # Test Case 2: Invalid input type (number)\\n    try:\\n        func(123)\\n        return False\\n    except TypeError:\\n        pass\\n\\n    # Test Case 3: Invalid input type (None)\\n    try:\\n        func(None)\\n        return False\\n    except TypeError:\\n        pass\\n\\n    return True'}, 'test_case_1_4': {'test_type': 'correctness', 'purpose': 'I will create a test case to validate the correctness of the function by providing an input that requires multiple operations to make the array non-decreasing. The input array will contain numbers that are ordered in a decreasing manner, which guarantees that some replacement operations will be necessary. The expected output will be the total number of operations needed to achieve a non-decreasing order. Given the input [5, 2, 3, 1], the function should perform two operations to become non-decreasing, resulting in the output of 2.', 'test_function': 'def test_case(func):  \\n    input_data = [5, 2, 3, 1]  \\n    expected_output = 2  \\n    result = func(input_data)  \\n    return result == expected_output'}, 'test_case_2_4': {'test_type': 'edge_case', 'purpose': 'I will design an edge case to check how the function handles an empty input array. An empty array is a valid input, and the expected output should be 0 because no operations are needed to make an empty array non-decreasing. This will validate that the function correctly handles cases with no elements.', 'test_function': 'def test_case(func):  \\n    input_data = []  \\n    expected_output = 0  \\n    result = func(input_data)  \\n    return result == expected_output'}, 'test_case_3_4': {'test_type': 'runtime', 'purpose': 'I will test the performance of the function by measuring the execution time for a large input array. This is to ensure that the function can handle larger datasets efficiently. The test will involve creating an array with a significant size and running the function to confirm it executes within a reasonable time frame. I will measure the time taken and verify that it stays below a predefined threshold (e.g., 1 second).', 'test_function': 'def test_case(func):  \\n    import time  \\n    input_data = [i for i in range(10000, 0, -1)]  # Creating a large decreasing array  \\n    start_time = time.time()  \\n    func(input_data)  \\n    end_time = time.time()  \\n    return (end_time - start_time) < 1  # Test passes if execution time is under 1 second'}, 'test_case_4_4': {'test_type': 'component_check', 'purpose': 'This test will ensure that the function utilizes the specified components. I will check if the function contains calls to the helper functions that are required for processing with adjacency pairs, such as finding the adjacent minimum pair, replacing pairs, and counting operations. The verification will include examining the code structure for the relevant functions.', 'test_function': \"def test_case(func):  \\n    import inspect  \\n    source_code = inspect.getsource(func)  \\n    required_functions = ['find_adjacent_min_pair', 'replace_pair_with_sum', 'count_operations_to_non_decreasing']  \\n    return all(fn in source_code for fn in required_functions)\"}, 'test_case_5_4': {'test_type': 'error_handling', 'purpose': 'This test will check if the function appropriately raises errors for invalid input types, such as non-list inputs (like integers or strings). The function should either raise a TypeError or handle the situation gracefully. This will ensure that the function maintains robustness by rejecting inappropriate inputs.', 'test_function': 'def test_case(func):  \\n    try:  \\n        func(42)  # Passing an integer instead of a list  \\n        return False  # If no error is raised, the test fails  \\n    except TypeError:  \\n        return True  # If a TypeError is raised, the test passes'}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating async codes:   0%|          | 0/5 [00:00<?, ?it/s]2025-04-10 11:05:30,083 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "2025-04-10 11:05:30,106 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  20%|██        | 1/5 [00:08<00:32,  8.11s/it]2025-04-10 11:05:32,876 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  60%|██████    | 3/5 [00:10<00:06,  3.14s/it]2025-04-10 11:05:33,059 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes:  80%|████████  | 4/5 [00:11<00:02,  2.12s/it]2025-04-10 11:05:33,785 - httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "Generating async codes: 100%|██████████| 5/5 [00:11<00:00,  2.36s/it]\n",
      "2025-04-10 11:05:33,822 - root - INFO - Evaluating codes...\n",
      "Running tests:  98%|█████████▊| 128/130 [00:37<00:00, 10.65it/s]"
     ]
    }
   ],
   "source": [
    "test_task_description = \"\"\"Given an array nums, you can perform the following operation any number of times:\n",
    "\n",
    "Select the adjacent pair with the minimum sum in nums. If multiple such pairs exist, choose the leftmost one.\n",
    "Replace the pair with their sum.\n",
    "Return the minimum number of operations needed to make the array non-decreasing.\n",
    "\n",
    "An array is said to be non-decreasing if each element is greater than or equal to its previous element (if it exists).\n",
    "\n",
    " \n",
    "\n",
    "Example 1:\n",
    "\n",
    "Input: nums = [5,2,3,1]\n",
    "\n",
    "Output: 2\n",
    "\n",
    "Explanation:\n",
    "\n",
    "The pair (3,1) has the minimum sum of 4. After replacement, nums = [5,2,4].\n",
    "The pair (2,4) has the minimum sum of 6. After replacement, nums = [5,6].\n",
    "The array nums became non-decreasing in two operations.\n",
    "\n",
    "Example 2:\n",
    "\n",
    "Input: nums = [1,2,2]\n",
    "\n",
    "Output: 0\n",
    "\n",
    "Explanation:\n",
    "\n",
    "The array nums is already sorted.\"\"\"\n",
    "\n",
    "lcdp = LCDP(api_key=\"sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg\", model=\"gpt-4o-mini\")\n",
    "best_codes = await lcdp.run(\n",
    "    task_description=test_task_description,\n",
    "    max_iterations=2,\n",
    "    num_plans=3,\n",
    "    num_tests=5,\n",
    "    num_codes=5,\n",
    "    refine_rounds=3,\n",
    "    use_pass_rate_for_train=False,\n",
    "    # use_example=True,\n",
    "    # example_dataset=example_codes,\n",
    ")\n",
    "print(best_codes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_case_1_1': {'test_type': 'correctness', 'purpose': 'We will create a test function to verify the correctness of the implementation that counts the number of operations required to make the array non-decreasing. We will use an input `[5, 2, 3, 1]` which is not sorted. The expected output for this input is `2` as demonstrated in the example. In the first operation, the pair `(3, 1)` will be replaced with `4`, resulting in `[5, 2, 4]`. The second operation will replace the pair `(2, 4)` with `6`, resulting in `[5, 6]`. After two operations, the array becomes non-decreasing.', 'test_function': 'def test_case(func):\\n    input_data = [5, 2, 3, 1]\\n    expected_output = 2\\n    result = func(input_data)\\n    return result == expected_output'}, 'test_case_2': {'test_type': 'edge_case', 'purpose': 'This test case will check how the function handles an empty array. An empty array is considered non-decreasing by definition because there are no elements to violate the rule. We expect the function to return 0 operations as no action is needed.', 'test_function': 'def test_case(func):\\n    nums = []\\n    expected_output = 0\\n    output = func(nums)\\n    return output == expected_output'}, 'test_case_3': {'test_type': 'correctness', 'purpose': 'This test case ensures that the function can handle an array that is already non-decreasing. Therefore, we will provide an array that is already sorted, such as [1, 2, 2, 3]. We expect the output of the function to be 0 since no operations are needed to make it non-decreasing.', 'test_function': 'def test_case(func):\\n    nums = [1, 2, 2, 3]\\n    expected_output = 0\\n    output = func(nums)\\n    return output == expected_output'}, 'test_case_1_3': {'test_type': 'correctness', 'purpose': 'To test the correctness of the function that counts the operations needed to make an array non-decreasing, I will create a test case where the input array is not sorted. I will use the input array [5, 2, 3, 1], which requires several operations to become non-decreasing. I expect the output to be 2, since the pairs (3, 1) and (2, 4) will be merged in two operations.', 'test_function': 'def test_case(func):\\n    input_data = [5, 2, 3, 1]\\n    expected_output = 2\\n    result = func(input_data)\\n    return result == expected_output'}, 'test_case_2_1': {'test_type': 'edge_case', 'purpose': 'For edge cases, I will test an empty array and a single-element array. An empty array should return 0 because no operations are needed. A single-element array (e.g., [1]) should also return 0 since it is trivially non-decreasing. This will ensure that the function can handle inputs at the bounds of the specified requirements.', 'test_function': 'def test_case(func):\\n    input_data_empty = []\\n    expected_output_empty = 0\\n    result_empty = func(input_data_empty)\\n    if result_empty != expected_output_empty:\\n        return False\\n    \\n    input_data_single = [1]\\n    expected_output_single = 0\\n    result_single = func(input_data_single)\\n    \\n    return result_single == expected_output_single'}, 'test_case_3_1': {'test_type': 'runtime', 'purpose': 'To assess runtime performance, I will generate a large random array of integers (e.g., 10,000 elements) and measure the time it takes for the function to execute. The expectation is that the function should complete in a reasonable time frame (e.g., under 1 second). This test will determine if the function maintains efficiency with larger inputs.', 'test_function': 'import time\\nimport random\\n\\ndef test_case(func):\\n    input_data = [random.randint(-100, 100) for _ in range(10000)]\\n    start_time = time.time()\\n    func(input_data)\\n    end_time = time.time()\\n    execution_time = end_time - start_time\\n    return execution_time < 1  # Expecting it to run under 1 second'}, 'test_case_1_4': {'test_type': 'correctness', 'purpose': 'In this test case, I will validate the function with an already sorted array. The input will be [1, 2, 2], which is already non-decreasing. Since no operations are needed, the expected output should be 0. This test case will confirm that the function handles sorted inputs correctly and identifies that no changes are necessary.', 'test_function': 'def test_case(func):\\n    # Test input that is already sorted\\n    nums = [1, 2, 2]\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_2_2': {'test_type': 'correctness', 'purpose': 'In this test case, I will validate the function using an array that is in reverse order ([5, 4, 3, 2, 1]). This array requires complete reconstruction to become non-decreasing. The expected output is 4, as four operations are required: (5,4), (5,3), (5,2), and (5,1) will all be summed to form [5, 10, 15]. This test checks if the function can transform a completely descending list into a non-decreasing one efficiently.', 'test_function': 'def test_case(func):\\n    # Test reverse sorted input\\n    nums = [5, 4, 3, 2, 1]\\n    expected_output = 4\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_3_2': {'test_type': 'correctness', 'purpose': 'This test case will examine how the function handles an array with repeated elements. The input will be [2, 2, 1, 3]. Starting with this array, the expected output should be 1 because the pair (2,1) will be summed to form [2,3,3]. This case ensures that the function can handle duplicates appropriately while still maintaining order.', 'test_function': 'def test_case(func):\\n    # Test input with repeated elements\\n    nums = [2, 2, 1, 3]\\n    expected_output = 1\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_4_1': {'test_type': 'edge_case', 'purpose': 'In this test case, I will validate the function with an empty list. The expected output should be 0 since there are no elements to operate on, so the requirement for a non-decreasing order is trivially satisfied. This checks the edge case where no data is present.', 'test_function': 'def test_case(func):\\n    # Test empty input\\n    nums = []\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}, 'test_case_5_1': {'test_type': 'edge_case', 'purpose': 'This test case will focus on a single-element array, e.g., [7]. The array is already non-decreasing, and thus the expected output should be 0. This checks the functionality when the input list contains the minimum number of elements.', 'test_function': 'def test_case(func):\\n    # Test single element input\\n    nums = [7]\\n    expected_output = 0\\n    result = func(nums)\\n    return result == expected_output'}}\n",
      "Test ID: test_case_1_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    input_data = [5, 2, 3, 1]\n",
      "    expected_output = 2\n",
      "    result = func(input_data)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_2\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    nums = []\n",
      "    expected_output = 0\n",
      "    output = func(nums)\n",
      "    return output == expected_output\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n",
      "Test ID: test_case_3\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    nums = [1, 2, 2, 3]\n",
      "    expected_output = 0\n",
      "    output = func(nums)\n",
      "    return output == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_1_3\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    input_data = [5, 2, 3, 1]\n",
      "    expected_output = 2\n",
      "    result = func(input_data)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_2_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    input_data_empty = []\n",
      "    expected_output_empty = 0\n",
      "    result_empty = func(input_data_empty)\n",
      "    if result_empty != expected_output_empty:\n",
      "        return False\n",
      "    \n",
      "    input_data_single = [1]\n",
      "    expected_output_single = 0\n",
      "    result_single = func(input_data_single)\n",
      "    \n",
      "    return result_single == expected_output_single\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n",
      "Test ID: test_case_3_1\n",
      "Test Case:\n",
      "import time\n",
      "import random\n",
      "\n",
      "def test_case(func):\n",
      "    input_data = [random.randint(-100, 100) for _ in range(10000)]\n",
      "    start_time = time.time()\n",
      "    func(input_data)\n",
      "    end_time = time.time()\n",
      "    execution_time = end_time - start_time\n",
      "    return execution_time < 1  # Expecting it to run under 1 second\n",
      "Test Type: runtime\n",
      "----------------------------------------\n",
      "Test ID: test_case_1_4\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test input that is already sorted\n",
      "    nums = [1, 2, 2]\n",
      "    expected_output = 0\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_2_2\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test reverse sorted input\n",
      "    nums = [5, 4, 3, 2, 1]\n",
      "    expected_output = 4\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_3_2\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test input with repeated elements\n",
      "    nums = [2, 2, 1, 3]\n",
      "    expected_output = 1\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: correctness\n",
      "----------------------------------------\n",
      "Test ID: test_case_4_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test empty input\n",
      "    nums = []\n",
      "    expected_output = 0\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n",
      "Test ID: test_case_5_1\n",
      "Test Case:\n",
      "def test_case(func):\n",
      "    # Test single element input\n",
      "    nums = [7]\n",
      "    expected_output = 0\n",
      "    result = func(nums)\n",
      "    return result == expected_output\n",
      "Test Type: edge_case\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "test_cases = lcdp.test_cases\n",
    "print(test_cases)\n",
    "for test_id, test_case_dict in test_cases.items():\n",
    "    print(f\"Test ID: {test_id}\")\n",
    "    print(f\"Test Case:\\n{test_case_dict['test_function']}\")\n",
    "    print(f\"Test Type: {test_case_dict['test_type']}\")\n",
    "    print(\"-\" * 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['code_0', 'code_1', 'code_4', 'code_3'])\n",
      "dict_keys(['code', 'plan', 'main_function_name', 'score', 'test_case_results'])\n"
     ]
    }
   ],
   "source": [
    "checking_code = \"code_0\"\n",
    "print(best_codes.keys())\n",
    "print(best_codes[checking_code].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_case_1_1': True, 'test_case_2': True, 'test_case_3': True, 'test_case_1_3': True, 'test_case_2_1': True, 'test_case_3_1': False, 'test_case_1_4': True, 'test_case_2_2': True, 'test_case_3_2': True, 'test_case_4_1': True, 'test_case_5_1': True}\n",
      "10\n",
      "11\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"test_case_results\"])\n",
    "# get the total number of True in test_case_results\n",
    "print(sum([v for v in best_codes[checking_code][\"test_case_results\"].values()]))\n",
    "print(len(best_codes[checking_code][\"test_case_results\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def count_operations_to_sort(nums):\n",
      "    def find_min_adjacent_pair(nums):\n",
      "        min_sum = float('inf')\n",
      "        index = -1\n",
      "        \n",
      "        # Iterate through the array to find the minimum sum of adjacent pairs\n",
      "        for i in range(len(nums) - 1):\n",
      "            current_sum = nums[i] + nums[i + 1]\n",
      "            if current_sum < min_sum:\n",
      "                min_sum = current_sum\n",
      "                index = i\n",
      "        \n",
      "        return index\n",
      "\n",
      "    def replace_pair_with_sum(nums, index):\n",
      "        # Calculate the sum of the identified pair\n",
      "        sum_pair = nums[index] + nums[index + 1]\n",
      "        # Create a new list with the pair replaced by their sum\n",
      "        new_nums = nums[:index] + [sum_pair] + nums[index + 2:]\n",
      "        return new_nums\n",
      "\n",
      "    operation_count = 0\n",
      "\n",
      "    # Keep performing operations until the array is non-decreasing\n",
      "    while any(nums[i] > nums[i + 1] for i in range(len(nums) - 1)):\n",
      "        index = find_min_adjacent_pair(nums)\n",
      "        nums = replace_pair_with_sum(nums, index)\n",
      "        operation_count += 1\n",
      "\n",
      "    return operation_count\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"code\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"main_function_name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.9685026303719275\n"
     ]
    }
   ],
   "source": [
    "print(best_codes[checking_code][\"score\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
