import re
import logging
import json
from tqdm.asyncio import tqdm_asyncio

class LLMTM():
    def __init__(self, task_description, LLM_model):
        self.LLM_model = LLM_model
        self.task_description = task_description

    def create_task_refinement_prompt(self, original_task_description, extra_information, feature_analysis, advice, use_llm=True):
        if not use_llm:
            parts = []
            # parts.append("=== Task description ===\n")
            parts.append(original_task_description)
            parts.append("\n=== Extra information ===\n")
            parts.append(extra_information)
            parts.append("\n=== Feature analysis ===\n")
            parts.append(feature_analysis)
            parts.append("\n=== Extra Advice ===\n")
            parts.append(advice)
            # print(parts)
            return "\n".join(parts)

        refinement_prompt = f"""
I have an initial task description and some analysis regarding the data and approach.
I need you to synthesize all this information into a single, clear, and more detailed refined task description.

The refined task description should incorporate insights from the feature analysis and the advice provided, making it more actionable and comprehensive for a code generation system. It should clearly state the goal, the expected inputs (data/features), and any key steps or considerations mentioned in the advice.

Here is the information:

<original_task_description_start>
{original_task_description}
<original_task_description_end>

<extra_information_start>
{extra_information}
<extra_information_end>

<feature_analysis_start>
{feature_analysis}
<feature_analysis_end>

<advice_start>
{advice}
<advice_end>

Now, please provide the refined task description. Output only the refined task description itself, without any extra conversational text or tags.
Refined Task Description:"""
        return refinement_prompt

    def create_data_prompt(self, task_description):
        data_prompt = f"""I am working on a code generation task:
<task_description_start>
{task_description}
<task_description_end>

I need you to analyze the data I will be working with and add extra information about the task escription. Please structure your response as follows:

<extra_information>
Based on the task description, provide any additional information or context that might be relevant to the task. For example, adding mathematical formulas, domain-specific knowledge, or any other relevant information that can aid in the code generation process.
</extra_information>

<feature_analysis>
Provide a brief introduction to the features present in the data.
If specific feature information is not available, please describe the overall data structure you would expect or require for this task.
</feature_analysis>

<advice>
Based on the task description, extra knowledge and the data features (or expected data structure),
provide advice on how these features can be effectively used to accomplish the project, and how to incorporate the additional information into the project.
Suggest potential data transformations, feature engineering steps, or specific ways to leverage features and extra domain-specific knowledge in the code generation process.
</advice>
    """
        return data_prompt

    # def extract_data_analysis_result(self, llm_output):
    #     extra_information_match = re.search(r"<extra_information>(.*?)</extra_information>", llm_output, re.DOTALL)
    #     analysis_match = re.search(r"<feature_analysis>(.*?)</feature_analysis>", llm_output, re.DOTALL)
    #     advice_match = re.search(r"<advice>(.*?)</advice>", llm_output, re.DOTALL)

    #     extracted_data = {
    #         "extra_information": extra_information_match.group(1).strip() if extra_information_match else None,
    #         "feature_analysis": analysis_match.group(1).strip() if analysis_match else None,
    #         "advice": advice_match.group(1).strip() if advice_match else None,
    #     }
    #     return extracted_data
    
    def extract_data_analysis_result(self, llm_output):
        extra_information_regex = r"<extra_information>(.*?)(?:</extra_information>|(?=<feature_analysis>)|(?=<advice>)|$)"

        analysis_regex = r"<feature_analysis>(.*?)(?:</feature_analysis>|(?=<extra_information>)|(?=<advice>)|$)"

        advice_regex = r"<advice>(.*?)(?:</advice>|(?=<extra_information>)|(?=<feature_analysis>)|$)"

        extra_information_match = re.search(extra_information_regex, llm_output, re.DOTALL)
        analysis_match = re.search(analysis_regex, llm_output, re.DOTALL)
        advice_match = re.search(advice_regex, llm_output, re.DOTALL)

        extracted_data = {
            "extra_information": extra_information_match.group(1).strip() if extra_information_match else None,
            "feature_analysis": analysis_match.group(1).strip() if analysis_match else None,
            "advice": advice_match.group(1).strip() if advice_match else None,
        }
        return extracted_data
    
    def create_plan_prompt(self, task_description=None):

        if task_description is None:
            task_description = self.task_description

        task_decompose_prompt = """You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks and then give a detailed overall plan based on your defined subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.

Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single component.

Your output must strictly follow the format below:

<components>
{
  "component_1": {
    "step_task_description": str,
    "input_format": [[type, shape or null]],
    "output_format": [[type, shape or null]],
    "work_flow": [str],
    "test_case_generation_advise": [str]
  },
  "component_2": {
    "step_task_description": str,
    "input_format": [["type", shape or null]],
    "output_format": [["type", shape or null]],
    "work_flow": [str],
    "test_case_generation_advise": [str]
  },
  ...
}
</components>

<overall_plan>
{
  "input_format": [["type", shape or null]],
  "output_format": [["type", shape or null]],
  "components": [str],
  "plan": [str],
  "test_case_generation_advise": [str]
}
</overall_plan>

Here are additional detailed explanations of each field:

For <components>:
- **component_X**: The key represents the subtask name, it should be replaced by the actual class/function name of the component (e.g., "merge_arrays", "calculate_median").
- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).
- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:
  - The first element indicates the data type (e.g., "list", "dict", NumPy array, torch.Tensor). DO make sure the data type is a string.
  - The second element indicates the fixed shape if applicable; otherwise, it is null.
- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`, note that it has to be a list of lists.
- **work_flow**: Provide a detailed step-by-step plan that outlines the workflow of how the component functions to achieve the subtask.
- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.

For <overall_plan>:
- **input_format**: Describe the format of the input arguments required for the overall task. It follows the same structure as `input_format` in the component section.
- **output_format**: Describe the format of the output arguments generated by the overall task. It follows the same structure as `output_format` in the component section.
- **components**: List the components in the order.
- **plan**: Provide a detailed step-by-step plan that outlines the workflow of how the components interact with each other to achieve the overall task. This should be a high-level description of the process.
- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases for the overall task, explicitly mentioning potential edge cases and critical scenarios that need coverage.

Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:

"{{TASK_DESCRIPTION}}"

Use <> to indicate both start and end of the component part and the overall plan. Ensure that the components and the overall plan are clearly separated.

Please provide your structured decomposition according to the instructions above.
"""
        task_decompose_prompt = task_decompose_prompt.replace("{{TASK_DESCRIPTION}}", task_description)
        return task_decompose_prompt
    
    def create_plan_refinement_prompt(self, user_feedback, previous_output, task_description=None):

        if task_description is None:
            task_description = self.task_description

        plan_refinement_prompt = """You are an expert agent specialized in refining and improving code generation plans through iterative feedback. Given a task description, previous decomposition output, and user feedback, your job is to critically analyze the existing plan and modify it accordingly while maintaining the required output format.

Carefully review the previous components and overall plan, then:
1. Preserve correct/valid elements that don't conflict with the feedback
2. Make targeted modifications based on the user's specific advice
3. Ensure consistency between components and overall plan
4. Verify input/output formats and workflow logic
5. Check for any introduced errors during modification

The input consists of three elements:
- Original Task Description: "{{TASK_DESCRIPTION}}"
- Previous Decomposition Output: 
{{PREVIOUS_OUTPUT}}
- User Feedback: "{{USER_ADVICE}}"

Your output must STRICTLY follow the original format with these sections:
<components>...</components>
<overall_plan>...</overall_plan>

Follow these guidelines:
- Explicitly address all points in the user feedback
- Clearly document any changes made from previous version
- Preserve JSON structure and formatting requirements
- If feedback contradicts original requirements, prioritize feedback

Again, user feedback is: "{{USER_ADVICE}}"

Provide your refined decomposition with clear explanations of changes in the component descriptions and overall plan. Ensure modularity, testability, and coverage of edge cases mentioned in feedback."""

        plan_refinement_prompt = plan_refinement_prompt.replace("{{TASK_DESCRIPTION}}", task_description)
        plan_refinement_prompt = plan_refinement_prompt.replace("{{USER_ADVICE}}", user_feedback)
        plan_refinement_prompt = plan_refinement_prompt.replace("{{PREVIOUS_OUTPUT}}", previous_output)
        return plan_refinement_prompt

    def extract_plan(self, input_str):

        input_str = input_str.replace("\\n", "\n")

        # Updated regex pattern to match <tag>...</tag> format
        pattern = r'<(components|overall_plan)>(.*?)</\1>'
        
        # Find all matches, allowing multiline content
        matches = re.findall(pattern, input_str, re.DOTALL)
        
        result = {}
        for block_name, content in matches:
            try:
                # Strip whitespace
                cleaned_content = content.strip()
                
                # Fix trailing commas
                cleaned_content = re.sub(r',\s*}', '}', cleaned_content)
                cleaned_content = re.sub(r',\s*\]', ']', cleaned_content)
                
                # Parse JSON
                parsed_data = json.loads(cleaned_content)
                result[block_name] = parsed_data
            except json.JSONDecodeError as e:
                logging.warning(f"JSON解析错误: {block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}")
                # print(f"解析错误：{block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}")
                result[block_name] = None
                return False
        return result

    def refine_plan_by_data_analysis(self, task_description=None, gen_kwargs={}, use_llm=True):
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_data_prompt(task_description)
        llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
        # print(llm_output)
        extract_data = self.extract_data_analysis_result(llm_output)
        # print(extract_data)
        refine_prompt = self.create_task_refinement_prompt(task_description, extract_data["extra_information"], extract_data["feature_analysis"], extract_data["advice"], use_llm=use_llm)
        if use_llm:
            llm_output = self.LLM_model.LLM_response(refine_prompt, gen_kwargs)
        else:
            return refine_prompt
        return llm_output

    def get_plan(self, task_description=None, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_prompt(task_description)
        while retry_num <= max_retry:
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            extract_plan = self.extract_plan(llm_output)
            if extract_plan:
                break
            else:
                retry_num += 1
                # print(f"Failed to extract plan, retrying ({retry_num})...")
                logging.warning(f"Failed to extract plan, retrying ({retry_num})...")
        if extract_plan is False:
            # print("Failed to extract plan, current llm_output:\n", llm_output)
            logging.error(f"Failed to extract plan, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract plan, current llm_output:\n", llm_output)
        return extract_plan, llm_output
    
    async def get_plan_async(self, num_plan, task_description=None, gen_kwargs={}):
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_prompt(task_description)
        # get multiple plans
        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]
        
        llm_output = await tqdm_asyncio.gather(*task_list)
        return [self.extract_plan(output) for output in llm_output]
    
    def refine_plan(self, user_feedback, previous_output, task_description=None, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)
        while retry_num <= max_retry:
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            extract_plan = self.extract_plan(llm_output)
            if extract_plan:
                break
            retry_num += 1
        if extract_plan is False:
            raise ValueError("Failed to extract plan, current llm_output:\n", llm_output)
        return extract_plan, llm_output
    
    async def refine_multi_plan(self, num_plan, user_feedback, previous_output, task_description=None, gen_kwargs={}):
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)
        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]
        llm_output = await tqdm_asyncio.gather(*task_list)
        return [self.extract_plan(output) for output in llm_output]

    def create_test_prompt(self, task_descr_str, task_spec, use_example=True, bulk=True):
        """
        Generates a prompt (or list of prompts) for test case generation based on task specifications.
        
        Parameters:
        - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.
        - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.
        
        Returns:
        - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).
        """
        
        # Helper function to generate the prompt text from a modified task specification
        def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = ""):
            advisory_list = "\n".join([f"- {advise}" for advise in advisories])
            prompt = f"""You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:

### Input Specifications:
- **Task Description**:
{task_descr_str}
- **Input Format**: 
{input_descr_str}
- **Output Format**: 
{output_descr_str}
- **Components Used**: {components_str}
- **Plan**: 
{plan_str}
- **Test Case Advise**: 
{advisory_list}

### Requirements:
1. **Test Function Structure**:
   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):...`).
   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.
   - Include input generation, runtime checks, code inspection, or result validation within the function.

2. **Test Types** (use one of these for indicating the test_type):
   - `correctness`: Validate output against expected results for specific inputs.
   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.
   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).
   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).
   - `error_handling`: Check if errors are raised for invalid inputs.

3. **Test Case Diversity**:
   - Cover all provided advisories.
   - Include at least one test per advisory and one for each test type where applicable.

### Output Format:
For each test case, you need to firstly define the Test Types to indicate what type of test case you are going to create and then give the reasoning and explanation of the test case. After that, generate the test function based on the your reasoning.

For each test function, return with following structure:

<Type>
Pick one of correctness|edge_case|runtime|component_check|error_handling
</Type>
<Planning>
Introduce how would you design the test function. Specify the purpose of the test function and the reasoning behind it. Explain step by step why your test case is correct and what is the expected output.
</Planning>
<Code>
def test_case(func):
    # Your test function code here
    return True or False as test result, and a message
</Code>

If you are going to create multiple test cases, please separate them with <separator> tag.

{example_text}
Generate test cases that rigorously validate the function's behavior, code structure, and performance.
You MUST strictly follow the output format and structure. The generated test functions MUST be runnable function that use another python function as its parameter and it should output both the Test result(True or False) and a message to give extra information about the test result.(For example, f"Test failed: expected X but got Y" or "Test failed: output with shape [x1, y1] but got [x2, y2]", where the X, Y and shapes need to be replaced by the actual output and expected output in test function)."""
            return prompt

        if use_example:
            examples_text = ""
        else:
            examples_text = ""

        # Process input_format into a descriptive string
        input_descr = []
        for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):
            shape_info = f"shape {shape}" if shape is not None else "no fixed shape"
            input_descr.append(f"- Argument {idx}: {dtype} ({shape_info})")
        input_descr_str = "\n".join(input_descr)

        # Process output_format into a descriptive string
        output_descr = []
        for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):
            shape_info = f"shape {shape}" if shape is not None else "no fixed shape"
            output_descr.append(f"- Output {idx}: {dtype} ({shape_info})")
        output_descr_str = "\n".join(output_descr)

        # Process components and plan
        components_str = ", ".join(task_spec['components'])
        if isinstance(task_spec['plan'], str):
            plan_str = task_spec['plan']
        elif isinstance(task_spec['plan'], list):
            plan_str = "\n".join(task_spec['plan'])
        else:
            plan_str = str(task_spec['plan'])

        if bulk:
            # Generate a single prompt with all advisories
            advisories = task_spec['test_case_generation_advise']
            return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)
        else:
            # Generate a list of prompts, each with a single advisory
            prompts = []
            for advise in task_spec['test_case_generation_advise']:
                single_advisory = [advise]
                prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)
                prompts.append(prompt)
            return prompts

    def extract_test_cases(self, output_text):
        """
        Extracts test cases from LLM output text with flexible tag handling.
        Supports case-insensitive tags, missing <Type> tags, and multi-separators.
        """
        import re
        test_cases = {}

        output_text = output_text.replace("\\n", "\n")

        tags_to_protect_newlines_in = ['type', 'planning', 'reasoning', 'code', 'test_function']
        def preprocess_text(text):
            placeholder = "###NL###"
            
            # This inner function will be used by re.sub to process the content of matched tags
            def replace_newlines_in_content(match_obj):
                opening_tag = match_obj.group(1)
                content = match_obj.group(2)
                closing_tag = match_obj.group(3)
                
                processed_content = content.replace("\n", placeholder)
                return f"{opening_tag}{processed_content}{closing_tag}"

            # Iteratively protect newlines within the specified XML/HTML-like tags
            for tag_name_to_protect in tags_to_protect_newlines_in:
                tag_pattern = re.compile(
                    r'(<\s*' + re.escape(tag_name_to_protect) + r'\b[^>]*>)(.*?)(<\s*/\s*' + re.escape(tag_name_to_protect) + r'\s*>)',
                    flags=re.IGNORECASE | re.DOTALL
                )
                text = tag_pattern.sub(replace_newlines_in_content, text)
            
            # Also protect newlines within ```python ... ``` markdown code blocks (as before)
            def replace_newlines_in_markdown_code(match_obj):
                block = match_obj.group(0) # The whole matched block ```python ... ```
                return block.replace("\n", placeholder)
            
            markdown_code_pattern = re.compile(r'(```python.*?```)', flags=re.IGNORECASE | re.DOTALL)
            text = markdown_code_pattern.sub(replace_newlines_in_markdown_code, text)
            
            return text, placeholder

        # 预处理：隐藏代码块内的换行符
        modified_text, placeholder = preprocess_text(output_text)
        
        # 分块：使用<separator>标签 或 连续空行分块
        split_pattern = r'(?:<\s*/\s*separator\s*>|<\s*separator\s*>|<\s*separator\s*/>|\n\s*\n\s*)'
        test_case_blocks = re.split(split_pattern, modified_text, flags=re.IGNORECASE)
        test_case_blocks = [b.strip() for b in test_case_blocks if b.strip()]
        
        # 还原各块内被隐藏的换行符
        test_case_blocks = [b.replace(placeholder, "\n") for b in test_case_blocks]

        # print(f"共分出 {len(test_case_blocks)} 个块")
        for idx, block in enumerate(test_case_blocks, 1):
            # 1. 提取 test_type
            test_type = None
            
            # Case 1：通过 <type>value</type>
            type_match = re.search(
                r'<\s*type\s*>(.*?)<\s*/\s*type\s*>', 
                block, 
                re.IGNORECASE | re.DOTALL
            )
            if type_match:
                test_type = type_match.group(1).strip()
            else:
                # Case 2：判断是否有其他非已知标签标记的类型
                known_tags = {'type', 'planning', 'code', 'reasoning', 'test_function', 'separator'}
                for tag_match in re.finditer(r'<\s*([^\s>/]+)\s*.*?>', block, re.IGNORECASE):
                    tag_name = tag_match.group(1).lower()
                    if tag_name not in known_tags:
                        test_type = tag_name
                        break  # 取第一个不在已知标签中的
                
            if not test_type:  # 若无 test_type 则跳过该块
                continue
            
            # 2. 提取 reasoning（支持 <planning> 和 <reasoning>）
            reasoning_match = re.search(
                r'<\s*(?:reasoning|planning)\s*>(.*?)<\s*/\s*(?:reasoning|planning)\s*>',
                block, 
                re.IGNORECASE | re.DOTALL
            )
            reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
            
            # 3. 提取 test_function（优先顺序：test_function 标签 > code 标签 > 独立代码块）
            test_func = None
            
            # 检查 <test_function> 标签
            test_func_match = re.search(
                r'<\s*test_function\s*>(.*?)<\s*/\s*test_function\s*>',
                block, 
                re.IGNORECASE | re.DOTALL
            )
            if test_func_match:
                content = test_func_match.group(1).strip()
                code_block = re.search(r'```python\s*(.*?)\s*```', content, re.DOTALL)
                test_func = code_block.group(1).strip() if code_block else content
            else:
                # 检查 <code> 标签
                code_match = re.search(
                    r'<\s*code\s*>(.*?)<\s*/\s*code\s*>',
                    block,
                    re.IGNORECASE | re.DOTALL
                )
                if code_match:
                    content = code_match.group(1).strip()
                    code_block = re.search(r'```python\s*(.*?)\s*```', content, re.DOTALL)
                    test_func = code_block.group(1).strip() if code_block else content
                else:
                    # 检查独立代码块 (```python ... ```)
                    code_block = re.search(r'```python\s*(.*?)\s*```', block, re.DOTALL)
                    if code_block:
                        test_func = code_block.group(1).strip()
            
            if test_type and test_func:
                test_cases[f'test_case_{idx}'] = {
                    'test_type': test_type,
                    'purpose': reasoning,
                    'test_function': test_func,
                    'all_pass_times': 0,
                    'all_fail_times': 0,
                }
        
        if not test_cases:
            # 如果没有提取到测试用例，则返回 False
            return False

        return test_cases
    
    # def get_test_cases(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):
    #     retry_num=0
    #     if task_description is None:
    #         task_description = self.task_description
    #     prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)
    #     while retry_num <= max_retry:
    #         llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
    #         test_cases = self.extract_test_cases(llm_output)
    #         if test_cases:
    #             break
    #         else:
    #             retry_num += 1
    #             # print(f"Failed to extract test cases, retrying ({retry_num})...")
    #             logging.warning(f"Failed to extract test cases, retrying ({retry_num})...")
    #     if test_cases is False:
    #         logging.error(f"Failed to extract test cases, current llm_output:\n{llm_output}")
    #         raise ValueError("Failed to extract test cases, current llm_output:\n", llm_output)
    #     return test_cases

    def _filter_test_cases(self, dataset):
        # print(dataset)
        runnable_entries = {}
        for code_id, attributes in dataset.items():
            test_code = attributes.get("test_function", "")
            try:
                # Attempt to compile the code string to check for syntax errors.
                compile(test_code, "<string>", "exec")
                # If no exception is raised, consider the code as runnable.
                runnable_entries[code_id] = attributes
            except Exception as error:
                # If an exception is raised, skip this entry.
                continue
        return runnable_entries

    async def get_test_cases_async(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)

        while retry_num <= max_retry:
            # debug
            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)
            test_cases = self.extract_test_cases(llm_output)
            if test_cases:
                # DEBUG
                test_cases = self._filter_test_cases(test_cases)
                if test_cases == {}:
                    logging.warning("No runnable test cases found, retrying...")
                    test_cases = False
                    retry_num += 1
                else:
                    break
            else:
                retry_num += 1
                # print(f"Failed to extract test cases, retrying ({retry_num})...")
                prompt_text = f"""###################################################
{prompt}
###################################################"""
                logging.warning(f"Failed to extract test cases, retrying ({retry_num})..., current prompt:\n{prompt_text},\ncurrent llm_output:\n{llm_output}")
        if test_cases is False:
            logging.error(f"Failed to extract test cases, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract test cases, current llm_output:\n", llm_output)
        return test_cases
    
    def get_test_cases(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)
        # debug
        # print("###############################################################")
        # print(prompt)
        while retry_num <= max_retry:
            # debug
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            test_cases = self.extract_test_cases(llm_output)
            if test_cases:
                # DEBUG
                test_cases = self._filter_test_cases(test_cases)
                if test_cases == {}:
                    logging.warning("No runnable test cases found, retrying...")
                    test_cases = False
                    retry_num += 1
                else:
                    break
            else:
                retry_num += 1
                # print(f"Failed to extract test cases, retrying ({retry_num})...")
                prompt_text = f"""###################################################
{prompt}
###################################################"""
                logging.warning(f"Failed to extract test cases, retrying ({retry_num})..., current prompt:\n{prompt_text},\ncurrent llm_output:\n{llm_output}")
        if test_cases is False:
            logging.error(f"Failed to extract test cases, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract test cases, current llm_output:\n", llm_output)
        
        # print("###############################################################")
        # print("llm output:")
        # print(llm_output)
        return test_cases