{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3715f062",
   "metadata": {},
   "outputs": [],
   "source": [
    "    def create_test_prompt(self, task_descr_str, task_spec, use_example=True, bulk=True):\n",
    "        \"\"\"\n",
    "        Generates a prompt (or list of prompts) for test case generation based on task specifications.\n",
    "        \n",
    "        Parameters:\n",
    "        - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.\n",
    "        - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.\n",
    "        \n",
    "        Returns:\n",
    "        - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).\n",
    "        \"\"\"\n",
    "        \n",
    "        # Helper function to generate the prompt text from a modified task specification\n",
    "        def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = \"\"):\n",
    "            advisory_list = \"\\n\".join([f\"- {advise}\" for advise in advisories])\n",
    "            prompt = f\"\"\"You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:\n",
    "\n",
    "### Input Specifications:\n",
    "- **Task Description**:\n",
    "{task_descr_str}\n",
    "- **Input Format**: \n",
    "{input_descr_str}\n",
    "- **Output Format**: \n",
    "{output_descr_str}\n",
    "- **Components Used**: {components_str}\n",
    "- **Plan**: \n",
    "{plan_str}\n",
    "- **Test Case Advise**: \n",
    "{advisory_list}\n",
    "\n",
    "### Requirements:\n",
    "1. **Test Function Structure**:\n",
    "   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):`).\n",
    "   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.\n",
    "   - Include input generation, runtime checks, code inspection, or result validation within the function.\n",
    "\n",
    "2. **Test Types** (use one of these for `test_type`):\n",
    "   - `correctness`: Validate output against expected results for specific inputs.\n",
    "   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.\n",
    "   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).\n",
    "   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).\n",
    "   - `error_handling`: Check if errors are raised for invalid inputs.\n",
    "\n",
    "3. **Test Case Diversity**:\n",
    "   - Cover all provided advisories.\n",
    "   - Include at least one test per advisory and one for each test type where applicable.\n",
    "\n",
    "### Output Format:\n",
    "Return a JSON dictionary with test cases in this structure:\n",
    "{{\n",
    "  \"test_case_1\": {{\n",
    "    \"purpose\": \"Briefly describe the test's purpose...\",\n",
    "    \"test_function\": \"Create your test function code here...\",\n",
    "    \"test_type\": \"correctness|edge_case|runtime|component_check|error_handling\"\n",
    "  }},\n",
    "  ...\n",
    "}}\n",
    "{example_text}\n",
    "Generate test cases that rigorously validate the function's behavior, code structure, and performance.\n",
    "You MUST strictly follow the output format and structure. The output function in test_function MUST be a runnable function that use another python function as its parameter.\"\"\"\n",
    "            return prompt\n",
    "\n",
    "        if use_example:\n",
    "            examples_text = \"\"\"\n",
    "### Example:\n",
    "{{\n",
    "  \"test_case_1\": {{\n",
    "    \"purpose\": \"Test with arrays of different sizes.\",\n",
    "    \"test_function\": \"def test_case(func):\\\\n    arr1 = [1, 3, 5]\\\\n    arr2 = [2, 4]\\\\n    merged = sorted(arr1 + arr2)\\\\n    expected = (merged[2] + merged[1]) / 2\\\\n    return func(arr1, arr2) == expected\",\n",
    "    \"test_type\": \"correctness\"\n",
    "  }},\n",
    "  \"test_case_2\": {{\n",
    "    \"purpose\": \"Check if 'merge_arrays' component is used.\",\n",
    "    \"test_function\": \"def test_case(func):\\\\n    import inspect\\\\n    source = inspect.getsource(func)\\\\n    return 'merge_arrays(' in source\",\n",
    "    \"test_type\": \"component_check\"\n",
    "  }}\n",
    "}}\n",
    "\"\"\"\n",
    "        else:\n",
    "            examples_text = \"\"\n",
    "\n",
    "        # Process input_format into a descriptive string\n",
    "        input_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_descr.append(f\"- Argument {idx}: {dtype} ({shape_info})\")\n",
    "        input_descr_str = \"\\n\".join(input_descr)\n",
    "\n",
    "        # Process output_format into a descriptive string\n",
    "        output_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_descr.append(f\"- Output {idx}: {dtype} ({shape_info})\")\n",
    "        output_descr_str = \"\\n\".join(output_descr)\n",
    "\n",
    "        # Process components and plan\n",
    "        components_str = \", \".join(task_spec['components'])\n",
    "        plan_str = \"\\n\".join(task_spec['plan'])\n",
    "\n",
    "        if bulk:\n",
    "            # Generate a single prompt with all advisories\n",
    "            advisories = task_spec['test_case_generation_advise']\n",
    "            return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)\n",
    "        else:\n",
    "            # Generate a list of prompts, each with a single advisory\n",
    "            prompts = []\n",
    "            for advise in task_spec['test_case_generation_advise']:\n",
    "                single_advisory = [advise]\n",
    "                prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)\n",
    "                prompts.append(prompt)\n",
    "            return prompts\n",
    "        \n",
    "    def extract_test_cases(self, llm_output):\n",
    "        \"\"\"\n",
    "        Extract test cases from LLM's JSON output.\n",
    "        \n",
    "        Args:\n",
    "            llm_output (str/dict): Raw text output from LLM containing JSON, or a dictionary\n",
    "        \n",
    "        Returns:\n",
    "            dict: Parsed test cases in dictionary format\n",
    "        \"\"\"\n",
    "        # Handle if input is already a dictionary\n",
    "        if isinstance(llm_output, dict):\n",
    "            return llm_output\n",
    "        \n",
    "        # Normalize JSON formatting\n",
    "        cleaned_output = llm_output.strip()\n",
    "        \n",
    "        # Handle code block formatting\n",
    "        if cleaned_output.startswith(\"```json\"):\n",
    "            cleaned_output = re.sub(r'^```json\\s*|\\s*```$', '', cleaned_output, flags=re.MULTILINE)\n",
    "        elif cleaned_output.startswith(\"```\"):\n",
    "            cleaned_output = re.sub(r'^```\\s*|\\s*```$', '', cleaned_output, flags=re.MULTILINE)\n",
    "        \n",
    "        # Parse JSON\n",
    "        try:\n",
    "            test_cases = json.loads(cleaned_output)\n",
    "        except Exception as e:\n",
    "            logging.error(f\"Failed to parse JSON: {e}\")\n",
    "            return False\n",
    "        \n",
    "        \n",
    "        # Validate structure\n",
    "        for key, value in test_cases.items():\n",
    "            if not all(k in value for k in ('purpose', 'test_function', 'test_type')):\n",
    "                logging.error(f\"Invalid test case structure in key: {key}, llm_output: {llm_output}\")\n",
    "                return False\n",
    "                \n",
    "        return test_cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1c1da1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "    def create_test_prompt(self, task_descr_str, task_spec, use_example=True, bulk=True):\n",
    "        \"\"\"\n",
    "        Generates a prompt (or list of prompts) for test case generation based on task specifications.\n",
    "        \n",
    "        Parameters:\n",
    "        - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.\n",
    "        - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.\n",
    "        \n",
    "        Returns:\n",
    "        - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).\n",
    "        \"\"\"\n",
    "        \n",
    "        # Helper function to generate the prompt text from a modified task specification\n",
    "        def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = \"\"):\n",
    "            advisory_list = \"\\n\".join([f\"- {advise}\" for advise in advisories])\n",
    "            prompt = f\"\"\"You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:\n",
    "\n",
    "### Input Specifications:\n",
    "- **Task Description**:\n",
    "{task_descr_str}\n",
    "- **Input Format**: \n",
    "{input_descr_str}\n",
    "- **Output Format**: \n",
    "{output_descr_str}\n",
    "- **Components Used**: {components_str}\n",
    "- **Plan**: \n",
    "{plan_str}\n",
    "- **Test Case Advise**: \n",
    "{advisory_list}\n",
    "\n",
    "### Requirements:\n",
    "1. **Test Function Structure**:\n",
    "   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):...`).\n",
    "   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.\n",
    "   - Include input generation, runtime checks, code inspection, or result validation within the function.\n",
    "\n",
    "2. **Test Types** (use one of these for indicating the test_type):\n",
    "   - `correctness`: Validate output against expected results for specific inputs.\n",
    "   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.\n",
    "   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).\n",
    "   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).\n",
    "   - `error_handling`: Check if errors are raised for invalid inputs.\n",
    "\n",
    "3. **Test Case Diversity**:\n",
    "   - Cover all provided advisories.\n",
    "   - Include at least one test per advisory and one for each test type where applicable.\n",
    "\n",
    "### Output Format:\n",
    "For each test case, you need to firstly define the Test Types to indicate what type of test case you are going to create and then give the reasoning and explanation of the test case. After that, generate the test function based on the your reasoning.\n",
    "\n",
    "For each test function, return with following structure:\n",
    "\n",
    "<test type>\n",
    "Pick one of correctness|edge_case|runtime|component_check|error_handling\n",
    "</test type>\n",
    "\n",
    "<reasoning>\n",
    "Specify the purpose of the test function and the reasoning behind it. Explain step by step why your test case is correct and what is the expected output.\n",
    "</reasoning>\n",
    "\n",
    "<test function>\n",
    "def test_case(func):\n",
    "    # Your test function code here\n",
    "</test function>\n",
    "\n",
    "If you are going to create multiple test cases, please separate them with <separator> tag.\n",
    "\n",
    "{example_text}\n",
    "Generate test cases that rigorously validate the function's behavior, code structure, and performance.\n",
    "You MUST strictly follow the output format and structure. The generated test functions MUST be runnable function that use another python function as its parameter.\"\"\"\n",
    "            return prompt\n",
    "\n",
    "        if use_example:\n",
    "            examples_text = \"\"\n",
    "        else:\n",
    "            examples_text = \"\"\n",
    "\n",
    "        # Process input_format into a descriptive string\n",
    "        input_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            input_descr.append(f\"- Argument {idx}: {dtype} ({shape_info})\")\n",
    "        input_descr_str = \"\\n\".join(input_descr)\n",
    "\n",
    "        # Process output_format into a descriptive string\n",
    "        output_descr = []\n",
    "        for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):\n",
    "            shape_info = f\"shape {shape}\" if shape is not None else \"no fixed shape\"\n",
    "            output_descr.append(f\"- Output {idx}: {dtype} ({shape_info})\")\n",
    "        output_descr_str = \"\\n\".join(output_descr)\n",
    "\n",
    "        # Process components and plan\n",
    "        components_str = \", \".join(task_spec['components'])\n",
    "        plan_str = \"\\n\".join(task_spec['plan'])\n",
    "\n",
    "        if bulk:\n",
    "            # Generate a single prompt with all advisories\n",
    "            advisories = task_spec['test_case_generation_advise']\n",
    "            return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)\n",
    "        else:\n",
    "            # Generate a list of prompts, each with a single advisory\n",
    "            prompts = []\n",
    "            for advise in task_spec['test_case_generation_advise']:\n",
    "                single_advisory = [advise]\n",
    "                prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)\n",
    "                prompts.append(prompt)\n",
    "            return prompts\n",
    "        \n",
    "def extract_test_cases(self, output_text):\n",
    "    \"\"\"\n",
    "    Extracts test cases from the generated output text into a structured dictionary.\n",
    "    \n",
    "    Parameters:\n",
    "    - output_text (str): The raw text output from the LLM containing test cases.\n",
    "    \n",
    "    Returns:\n",
    "    - dict: A dictionary where each key is 'test_case_N' and the value is another dictionary\n",
    "            with 'test_type', 'purpose', and 'test_function' keys.\n",
    "    \"\"\"\n",
    "    test_cases = {}\n",
    "    # Split the output into individual test case blocks using the separator\n",
    "    test_case_blocks = re.split(r'<separator>', output_text, flags=re.IGNORECASE)\n",
    "    \n",
    "    for idx, block in enumerate(test_case_blocks, 1):\n",
    "        # Extract test type using regex, allowing for any whitespace\n",
    "        test_type_match = re.search(r'<test[_\\s]*type>\\s*(.*?)\\s*</test[_\\s]*type>', block, re.DOTALL)\n",
    "        test_type = test_type_match.group(1).strip() if test_type_match else None\n",
    "        \n",
    "        # Extract reasoning (purpose) from the reasoning tags\n",
    "        reasoning_match = re.search(r'<reasoning>\\s*(.*?)\\s*</reasoning>', block, re.DOTALL)\n",
    "        reasoning = reasoning_match.group(1).strip() if reasoning_match else \"\"\n",
    "        \n",
    "        # Extract test function, checking both <test_function> tags and Python code blocks\n",
    "        test_func = None\n",
    "        # First check within <test_function> tags\n",
    "        test_func_match = re.search(r'<test[_\\s]*function>\\s*(.*?)\\s*</test[_\\s]*function>', block, re.DOTALL)\n",
    "        if test_func_match:\n",
    "            inner_content = test_func_match.group(1).strip()\n",
    "            # Check if the inner content is a Python code block\n",
    "            code_block_match = re.search(r'```python\\s*(.*?)\\s*```', inner_content, re.DOTALL)\n",
    "            if code_block_match:\n",
    "                test_func = code_block_match.group(1).strip()\n",
    "            else:\n",
    "                test_func = inner_content\n",
    "        else:\n",
    "            # If no tags, check for a standalone Python code block\n",
    "            code_block_match = re.search(r'```python\\s*(.*?)\\s*```', block, re.DOTALL)\n",
    "            if code_block_match:\n",
    "                test_func = code_block_match.group(1).strip()\n",
    "        \n",
    "        # Only add to test cases if both test_type and test_func are present\n",
    "        if test_type and test_func:\n",
    "            case_key = f'test_case_{idx}'\n",
    "            test_cases[case_key] = {\n",
    "                'test_type': test_type,\n",
    "                'purpose': reasoning,\n",
    "                'test_function': test_func\n",
    "            }\n",
    "    \n",
    "    return test_cases"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
