import re
import json
import asyncio
import random
import os
import logging
import concurrent.futures
from concurrent.futures import TimeoutError
from datetime import datetime
from pathlib import Path
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from openai import OpenAI
from openai import AsyncOpenAI

# 创建Logger实例
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# 定义日志格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# 输出到文件的Handler
log_dir = Path("log")
log_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"{timestamp}.log"
log_path = log_dir / log_filename
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

# 输出到控制台的Handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

# 将Handler添加到Logger
logging.basicConfig(
    level=logging.DEBUG,
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        file_handler,   # 文件输出
        console_handler # 控制台输出
    ]
)

class LLMModel():
    def __init__(self, api_key, model="gpt-3.5-turbo"):
        if api_key is None:
            self.api_key = "sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg"
        else:
            self.api_key = api_key
        self.model = model
        self.client = OpenAI(api_key=self.api_key)
        self.client_async = AsyncOpenAI(api_key=self.api_key)
    
    def LLM_response(self, prompt, gen_kwargs={}, model=None):
        if model is None:
            model = self.model

        if type(prompt) == str:
            input_messages = [
                {"role": "user", "content": prompt}
                ]
        elif type(prompt) == list:
            input_messages = prompt
        else:
            logging.error("prompt must be a string or a list of messages, current type: ", type(prompt))
            raise ValueError("prompt must be a string or a list of messages")
        
        completion = self.client.chat.completions.create(
            model=model,
            messages=input_messages,
            **gen_kwargs
            )

        return completion.choices[0].message.content
    
    async def LLM_response_async(self, prompt, gen_kwargs={}, model=None):
        if model is None:
            model = self.model

        if type(prompt) == str:
            prompt = " ".join(prompt)
            input_messages = [
                {"role": "user", "content": prompt}
                ]
        elif type(prompt) == list:
            input_messages = prompt
        else:
            logging.error("prompt must be a string or a list of messages, current type: ", type(prompt))
            raise ValueError("prompt must be a string or a list of messages")
        
        completion = await self.client_async.chat.completions.create(
            model=model,
            messages=input_messages,
            **gen_kwargs
            )
        return completion.choices[0].message.content

class LLMTM():
    def __init__(self, task_description, LLM_model):
        self.LLM_model = LLM_model
        self.task_description = task_description

    def create_plan_prompt(self, task_description=None):

        if task_description is None:
            task_description = self.task_description

        task_decompose_prompt = """You are an expert agent specialized in decomposing code generation tasks into structured, detailed, and clear subtasks and then give a detailed overall plan based on your defined subtasks. Given a simple high-level task description, your job is to break it down into logical subtasks that clearly illustrate the workflow and ensure easy understanding and execution.

Each decomposed subtask should aim to create a function or class as a reusable component contributing to the overall task. If the provided task is too simple or atomic to require multiple components, your decomposition should only contain a single component.

Your output must strictly follow the format below:

<components>
{
  "component_1": {
    "step_task_description": str,
    "input_format": [[type, shape or null]],
    "output_format": [[type, shape or null]],
    "work_flow": [str],
    "test_case_generation_advise": [str]
  },
  "component_2": {
    "step_task_description": str,
    "input_format": [["type", shape or null]],
    "output_format": [["type", shape or null]],
    "work_flow": [str],
    "test_case_generation_advise": [str]
  },
  ...
}
</components>

<overall_plan>
{
  "input_format": [["type", shape or null]],
  "output_format": [["type", shape or null]],
  "components": [str],
  "plan": [str],
  "test_case_generation_advise": [str]
}
</overall_plan>

Here are additional detailed explanations of each field:

For <components>:
- **component_X**: The key represents the subtask name, it should be replaced by the actual class/function name of the component (e.g., "merge_arrays", "calculate_median").
- **step_task_description**: Provide a clear and concise description of exactly what this subtask aims to achieve, specifically mentioning the intended functionality or role of the created component (function/class).
- **input_format**: Describe the format of each input argument required for this subtask. It is a list of lists, where each inner list has two elements:
  - The first element indicates the data type (e.g., "list", "dict", NumPy array, torch.Tensor). DO make sure the data type is a string.
  - The second element indicates the fixed shape if applicable; otherwise, it is null.
- **output_format**: Describe the format of each output argument generated by this subtask. It follows the same list structure as `input_format`, note that it has to be a list of lists.
- **work_flow**: Provide a detailed step-by-step plan that outlines the workflow of how the component functions to achieve the subtask.
- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases, explicitly mentioning potential edge cases and critical scenarios that need coverage.

For <overall_plan>:
- **input_format**: Describe the format of the input arguments required for the overall task. It follows the same structure as `input_format` in the component section.
- **output_format**: Describe the format of the output arguments generated by the overall task. It follows the same structure as `output_format` in the component section.
- **components**: List the components in the order.
- **plan**: Provide a detailed step-by-step plan that outlines the workflow of how the components interact with each other to achieve the overall task. This should be a high-level description of the process.
- **test_case_generation_advise**: Provide a list of detailed guidelines or suggestions aimed at generating diverse and comprehensive test cases for the overall task, explicitly mentioning potential edge cases and critical scenarios that need coverage.

Your decomposition should strive for clarity, correctness, modularity, and ensure each step can be tested independently. Now, given the following simple task description:

"{{TASK_DESCRIPTION}}"

Use <> to indicate both start and end of the component part and the overall plan. Ensure that the components and the overall plan are clearly separated.

Please provide your structured decomposition according to the instructions above.
"""
        task_decompose_prompt = task_decompose_prompt.replace("{{TASK_DESCRIPTION}}", task_description)
        return task_decompose_prompt
    
    def create_plan_refinement_prompt(self, user_feedback, previous_output, task_description=None):

        if task_description is None:
            task_description = self.task_description

        plan_refinement_prompt = """You are an expert agent specialized in refining and improving code generation plans through iterative feedback. Given a task description, previous decomposition output, and user feedback, your job is to critically analyze the existing plan and modify it accordingly while maintaining the required output format.

Carefully review the previous components and overall plan, then:
1. Preserve correct/valid elements that don't conflict with the feedback
2. Make targeted modifications based on the user's specific advice
3. Ensure consistency between components and overall plan
4. Verify input/output formats and workflow logic
5. Check for any introduced errors during modification

The input consists of three elements:
- Original Task Description: "{{TASK_DESCRIPTION}}"
- Previous Decomposition Output: 
{{PREVIOUS_OUTPUT}}
- User Feedback: "{{USER_ADVICE}}"

Your output must STRICTLY follow the original format with these sections:
<components>...</components>
<overall_plan>...</overall_plan>

Follow these guidelines:
- Explicitly address all points in the user feedback
- Clearly document any changes made from previous version
- Preserve JSON structure and formatting requirements
- If feedback contradicts original requirements, prioritize feedback

Again, user feedback is: "{{USER_ADVICE}}"

Provide your refined decomposition with clear explanations of changes in the component descriptions and overall plan. Ensure modularity, testability, and coverage of edge cases mentioned in feedback."""

        plan_refinement_prompt = plan_refinement_prompt.replace("{{TASK_DESCRIPTION}}", task_description)
        plan_refinement_prompt = plan_refinement_prompt.replace("{{USER_ADVICE}}", user_feedback)
        plan_refinement_prompt = plan_refinement_prompt.replace("{{PREVIOUS_OUTPUT}}", previous_output)
        return plan_refinement_prompt

    def extract_plan(self, input_str):
        # Updated regex pattern to match <tag>...</tag> format
        pattern = r'<(components|overall_plan)>(.*?)</\1>'
        
        # Find all matches, allowing multiline content
        matches = re.findall(pattern, input_str, re.DOTALL)
        
        result = {}
        for block_name, content in matches:
            try:
                # Strip whitespace
                cleaned_content = content.strip()
                
                # Fix trailing commas
                cleaned_content = re.sub(r',\s*}', '}', cleaned_content)
                cleaned_content = re.sub(r',\s*\]', ']', cleaned_content)
                
                # Parse JSON
                parsed_data = json.loads(cleaned_content)
                result[block_name] = parsed_data
            except json.JSONDecodeError as e:
                logging.warning(f"JSON解析错误: {block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}")
                # print(f"解析错误：{block_name}块 | 错误位置：第{e.lineno}行第{e.colno}列 | 错误原因：{e.msg}")
                result[block_name] = None
                return False
        return result

    def get_plan(self, task_description=None, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_prompt(task_description)
        while retry_num <= max_retry:
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            extract_plan = self.extract_plan(llm_output)
            if extract_plan:
                break
            else:
                retry_num += 1
                # print(f"Failed to extract plan, retrying ({retry_num})...")
                logging.warning(f"Failed to extract plan, retrying ({retry_num})...")
        if extract_plan is False:
            # print("Failed to extract plan, current llm_output:\n", llm_output)
            logging.error(f"Failed to extract plan, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract plan, current llm_output:\n", llm_output)
        return extract_plan, llm_output
    
    async def get_plan_async(self, num_plan, task_description=None, gen_kwargs={}):
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_prompt(task_description)
        # get multiple plans
        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]
        
        llm_output = await tqdm_asyncio.gather(*task_list)
        return [self.extract_plan(output) for output in llm_output]
    
    def refine_plan(self, user_feedback, previous_output, task_description=None, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)
        while retry_num <= max_retry:
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            extract_plan = self.extract_plan(llm_output)
            if extract_plan:
                break
            retry_num += 1
        if extract_plan is False:
            raise ValueError("Failed to extract plan, current llm_output:\n", llm_output)
        return extract_plan, llm_output
    
    async def refine_multi_plan(self, num_plan, user_feedback, previous_output, task_description=None, gen_kwargs={}):
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_plan_refinement_prompt(user_feedback, previous_output, task_description)
        task_list = [self.LLM_model.LLM_response_async(prompt, gen_kwargs) for _ in range(num_plan)]
        llm_output = await tqdm_asyncio.gather(*task_list)
        return [self.extract_plan(output) for output in llm_output]
    
    def create_test_prompt(self, task_descr_str, task_spec, use_example=True, bulk=True):
        """
        Generates a prompt (or list of prompts) for test case generation based on task specifications.
        
        Parameters:
        - task_spec (dict): Dictionary containing input_format, output_format, components, plan, and test_case_generation_advise.
        - bulk (bool): If True, generate a single prompt with all advisories. If False, generate a list of prompts, each with a single advisory.
        
        Returns:
        - str or list: A single prompt string (if bulk=True) or a list of prompt strings (if bulk=False).
        """
        
        # Helper function to generate the prompt text from a modified task specification
        def generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, example_text = ""):
            advisory_list = "\n".join([f"- {advise}" for advise in advisories])
            prompt = f"""You are a test case generation agent. Your task is to create Python test functions to validate a code generation task based on the provided specifications. Follow these instructions carefully:

### Input Specifications:
- **Task Description**:
{task_descr_str}
- **Input Format**: 
{input_descr_str}
- **Output Format**: 
{output_descr_str}
- **Components Used**: {components_str}
- **Plan**: 
{plan_str}
- **Test Case Advise**: 
{advisory_list}

### Requirements:
1. **Test Function Structure**:
   - Each test function must accept **only the function under test** as its parameter (e.g., `def test_case(func):...`).
   - Return `True` if the test passes, `False` otherwise. Do not use assertions, please return a boolean value.
   - Include input generation, runtime checks, code inspection, or result validation within the function.

2. **Test Types** (use one of these for indicating the test_type):
   - `correctness`: Validate output against expected results for specific inputs.
   - `edge_case`: Test inputs like empty lists, extreme values, or invalid data.
   - `runtime`: Measure execution time (e.g., ensure it's below a threshold).
   - `component_check`: Verify the function's code uses specified components (e.g., via string inspection).
   - `error_handling`: Check if errors are raised for invalid inputs.

3. **Test Case Diversity**:
   - Cover all provided advisories.
   - Include at least one test per advisory and one for each test type where applicable.

### Output Format:
For each test case, you need to firstly define the Test Types to indicate what type of test case you are going to create and then give the reasoning and explanation of the test case. After that, generate the test function based on the your reasoning.

For each test function, return with following structure:

<Type>
Pick one of correctness|edge_case|runtime|component_check|error_handling
</Type>
<Planning>
Introduce how would you design the test function. Specify the purpose of the test function and the reasoning behind it. Explain step by step why your test case is correct and what is the expected output.
</Planning>
<Code>
def test_case(func):
    # Your test function code here
</Code>

If you are going to create multiple test cases, please separate them with <separator> tag.

{example_text}
Generate test cases that rigorously validate the function's behavior, code structure, and performance.
You MUST strictly follow the output format and structure. The generated test functions MUST be runnable function that use another python function as its parameter."""
            return prompt

        if use_example:
            examples_text = ""
        else:
            examples_text = ""

        # Process input_format into a descriptive string
        input_descr = []
        for idx, (dtype, shape) in enumerate(task_spec['input_format'], 1):
            shape_info = f"shape {shape}" if shape is not None else "no fixed shape"
            input_descr.append(f"- Argument {idx}: {dtype} ({shape_info})")
        input_descr_str = "\n".join(input_descr)

        # Process output_format into a descriptive string
        output_descr = []
        for idx, (dtype, shape) in enumerate(task_spec['output_format'], 1):
            shape_info = f"shape {shape}" if shape is not None else "no fixed shape"
            output_descr.append(f"- Output {idx}: {dtype} ({shape_info})")
        output_descr_str = "\n".join(output_descr)

        # Process components and plan
        components_str = ", ".join(task_spec['components'])
        plan_str = "\n".join(task_spec['plan'])

        if bulk:
            # Generate a single prompt with all advisories
            advisories = task_spec['test_case_generation_advise']
            return generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, advisories, examples_text)
        else:
            # Generate a list of prompts, each with a single advisory
            prompts = []
            for advise in task_spec['test_case_generation_advise']:
                single_advisory = [advise]
                prompt = generate_prompt(task_descr_str, input_descr_str, output_descr_str, components_str, plan_str, single_advisory, examples_text)
                prompts.append(prompt)
            return prompts

    def extract_test_cases(self, output_text):
        """
        Extracts test cases from LLM output text with flexible tag handling.
        Supports case-insensitive tags, missing <Type> tags, and multi-separators.
        """
        import re
        test_cases = {}

        def preprocess_text(text):
            # 定义一个占位符，避免选中正常文本的内容
            placeholder = "###NL###"
            
            # 定义替换函数：将匹配到的代码块内的换行符替换为占位符
            def repl_code(match):
                block = match.group(0)
                return block.replace("\n", placeholder)
            
            # 对 <code>...</code> 块进行替换（不区分大小写，多行匹配）
            text = re.sub(r'(<\s*code\s*>.*?</\s*code\s*>)', repl_code, text, flags=re.IGNORECASE | re.DOTALL)
            # 对 ```python ... ``` 块进行替换
            text = re.sub(r'(```python.*?```)', repl_code, text, flags=re.IGNORECASE | re.DOTALL)
            
            # 如果还需要保护其他块，也可以在这里加上类似处理
            return text, placeholder

        # 预处理：隐藏代码块内的换行符
        modified_text, placeholder = preprocess_text(output_text)
        
        # 分块：使用<separator>标签 或 连续空行分块
        split_pattern = r'(?:<\s*/\s*separator\s*>|<\s*separator\s*>|<\s*separator\s*/>|\n\s*\n\s*)'
        test_case_blocks = re.split(split_pattern, modified_text, flags=re.IGNORECASE)
        test_case_blocks = [b.strip() for b in test_case_blocks if b.strip()]
        
        # 还原各块内被隐藏的换行符
        test_case_blocks = [b.replace(placeholder, "\n") for b in test_case_blocks]

        # print(f"共分出 {len(test_case_blocks)} 个块")
        for idx, block in enumerate(test_case_blocks, 1):
            # 1. 提取 test_type
            test_type = None
            
            # Case 1：通过 <type>value</type>
            type_match = re.search(
                r'<\s*type\s*>(.*?)<\s*/\s*type\s*>', 
                block, 
                re.IGNORECASE | re.DOTALL
            )
            if type_match:
                test_type = type_match.group(1).strip()
            else:
                # Case 2：判断是否有其他非已知标签标记的类型
                known_tags = {'type', 'planning', 'code', 'reasoning', 'test_function', 'separator'}
                for tag_match in re.finditer(r'<\s*([^\s>/]+)\s*.*?>', block, re.IGNORECASE):
                    tag_name = tag_match.group(1).lower()
                    if tag_name not in known_tags:
                        test_type = tag_name
                        break  # 取第一个不在已知标签中的
                
            if not test_type:  # 若无 test_type 则跳过该块
                continue
            
            # 2. 提取 reasoning（支持 <planning> 和 <reasoning>）
            reasoning_match = re.search(
                r'<\s*(?:reasoning|planning)\s*>(.*?)<\s*/\s*(?:reasoning|planning)\s*>',
                block, 
                re.IGNORECASE | re.DOTALL
            )
            reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
            
            # 3. 提取 test_function（优先顺序：test_function 标签 > code 标签 > 独立代码块）
            test_func = None
            
            # 检查 <test_function> 标签
            test_func_match = re.search(
                r'<\s*test_function\s*>(.*?)<\s*/\s*test_function\s*>',
                block, 
                re.IGNORECASE | re.DOTALL
            )
            if test_func_match:
                content = test_func_match.group(1).strip()
                code_block = re.search(r'```python\s*(.*?)\s*```', content, re.DOTALL)
                test_func = code_block.group(1).strip() if code_block else content
            else:
                # 检查 <code> 标签
                code_match = re.search(
                    r'<\s*code\s*>(.*?)<\s*/\s*code\s*>',
                    block,
                    re.IGNORECASE | re.DOTALL
                )
                if code_match:
                    content = code_match.group(1).strip()
                    code_block = re.search(r'```python\s*(.*?)\s*```', content, re.DOTALL)
                    test_func = code_block.group(1).strip() if code_block else content
                else:
                    # 检查独立代码块 (```python ... ```)
                    code_block = re.search(r'```python\s*(.*?)\s*```', block, re.DOTALL)
                    if code_block:
                        test_func = code_block.group(1).strip()
            
            if test_type and test_func:
                test_cases[f'test_case_{idx}'] = {
                    'test_type': test_type,
                    'purpose': reasoning,
                    'test_function': test_func
                }
        
        if not test_cases:
            # 如果没有提取到测试用例，则返回 False
            return False

        return test_cases
    
    def get_test_cases(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)
        while retry_num <= max_retry:
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            test_cases = self.extract_test_cases(llm_output)
            if test_cases:
                break
            else:
                retry_num += 1
                # print(f"Failed to extract test cases, retrying ({retry_num})...")
                logging.warning(f"Failed to extract test cases, retrying ({retry_num})...")
        if test_cases is False:
            logging.error(f"Failed to extract test cases, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract test cases, current llm_output:\n", llm_output)
        return test_cases

    def _filter_test_cases(self, dataset):
        # print(dataset)
        runnable_entries = {}
        for code_id, attributes in dataset.items():
            test_code = attributes.get("test_function", "")
            try:
                # Attempt to compile the code string to check for syntax errors.
                compile(test_code, "<string>", "exec")
                # If no exception is raised, consider the code as runnable.
                runnable_entries[code_id] = attributes
            except Exception as error:
                # If an exception is raised, skip this entry.
                continue
        return runnable_entries

    async def get_test_cases_async(self, task_spec, task_description=None, use_example=True, bulk=True, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_test_prompt(task_description, task_spec, use_example, bulk)

        while retry_num <= max_retry:
            # debug
            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)
            test_cases = self.extract_test_cases(llm_output)
            if test_cases:
                # DEBUG
                test_cases = self._filter_test_cases(test_cases)
                if test_cases == {}:
                    logging.warning("No runnable test cases found, retrying...")
                    test_cases = False
                    retry_num += 1
                else:
                    break
            else:
                retry_num += 1
                # print(f"Failed to extract test cases, retrying ({retry_num})...")
                logging.warning(f"Failed to extract test cases, retrying ({retry_num})..., current llm_output:\n{llm_output}")
        if test_cases is False:
            logging.error(f"Failed to extract test cases, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract test cases, current llm_output:\n", llm_output)
        return test_cases
    
class LLMCG():
    def __init__(self, task_description, LLM_model):
        self.task_description = task_description
        self.LLM_model = LLM_model

    def create_code_generation_prompt(
        self,
        extracted_plan,
        user_feedback=None,
        task_description=None,
        test_cases=None,
        history=None,
        next_code_line=False,
        output_planning=False,
        use_example=False,
        use_task_description=False,
        use_system_prompt=True,
        more_comments=False,
        ):

        components = extracted_plan["components"]
        overall_plan = extracted_plan["overall_plan"]

        prompt_parts = []

        if user_feedback:
            system_prompt = "You are a code refinement specialist designed to improve existing implementations based on specific feedback. Analyze the provided feedback, identify areas for improvement, and modify the code while strictly maintaining the required input/output formats and component specifications."
        else:
            system_prompt = "You are a highly skilled coding assistant designed to generate clear, efficient, and correct code based on structured task descriptions and detailed plans provided by the user. Your responses must precisely follow the instructions, formats, and constraints given by the user, and you must strictly adhere to input-output formats, workflows, and specific guidelines outlined."

        # Add System Prompt if enabled
        if use_system_prompt:
            prompt_parts.append(f"=== Role ===\n{system_prompt}\n")

        # Add Task Description if enabled
        if not task_description:
            task_description = self.task_description

        if use_task_description:
            prompt_parts.append(f"=== Task Description ===\n{task_description}\n")

        # Add Components Section
        if components:
            prompt_parts.append("=== Components ===")
            for comp_name, comp_details in components.items():
                # Process Input Format
                input_fmt = comp_details["input_format"]
                input_lines = []
                for idx, (dtype, shape) in enumerate(input_fmt, 1):
                    shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
                    input_lines.append(f"Argument {idx}: {dtype} with {shape_str}")
                input_section = "Input Format:\n" + "\n".join([f"- {line}" for line in input_lines])

                # Process Output Format
                output_fmt = comp_details["output_format"]
                output_lines = []
                for idx, (dtype, shape) in enumerate(output_fmt, 1):
                    shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
                    output_lines.append(f"Output {idx}: {dtype} with {shape_str}")
                output_section = "Output Format:\n" + "\n".join([f"- {line}" for line in output_lines])

                # Build Component Details
                component_part = [
                    f"\n**Component: {comp_name}**",
                    f"Step Task Description: {comp_details['step_task_description']}",
                    input_section,
                    output_section,
                    "Workflow Steps:",
                    *[f"- {step}" for step in comp_details["work_flow"]],
                    "Test Case Generation Advice:",
                    *[f"- {advice}" for advice in comp_details["test_case_generation_advise"]],
                    "\n",
                ]
                prompt_parts.extend(component_part)

        # Add Overall Plan Section
        if overall_plan:
            prompt_parts.append("\n=== Overall Plan ===")
            # Process Input Format
            input_fmt = overall_plan["input_format"]
            input_lines = []
            for idx, (dtype, shape) in enumerate(input_fmt, 1):
                shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
                input_lines.append(f"Argument {idx}: {dtype} with {shape_str}")
            input_section = "Input Format:\n" + "\n".join([f"- {line}" for line in input_lines])

            # Process Output Format
            output_fmt = overall_plan["output_format"]
            output_lines = []
            for idx, (dtype, shape) in enumerate(output_fmt, 1):
                shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
                output_lines.append(f"Output {idx}: {dtype} with {shape_str}")
            output_section = "Output Format:\n" + "\n".join([f"- {line}" for line in output_lines])

            # Build Overall Plan Details
            plan_part = [
                input_section,
                output_section,
                f"Components Order: {', '.join(overall_plan['components'])}",
                "Plan Steps:",
                *[f"- {step}" for step in overall_plan["plan"]],
                "Overall Test Case Advice:",
                *[f"- {advice}" for advice in overall_plan["test_case_generation_advise"]],
                "\n",
            ]
            prompt_parts.extend(plan_part)

        if user_feedback:
            prompt_parts.append("\n=== User Feedback ===")
            prompt_parts.append(user_feedback)

        # Add Test Cases if enabled and available
        if use_example and test_cases:
            prompt_parts.append("\n=== Test Cases ===")
            example_num = 3
            for case_name, case_details in test_cases.items():
                case_part = [
                    f"\n**Test Case: {case_name}**",
                    f"Purpose: {case_details['purpose']}",
                    f"Type: {case_details['test_type']}",
                    f"Test Function:\n{case_details['test_function']}",
                    "\n",
                ]
                prompt_parts.extend(case_part)
                example_num -= 1
                if example_num == 0:
                    break

        # Add History if available
        if history:
            prompt_parts.append("\n=== Previous Generation Attempts ===")
            for gen_name, gen_details in history.items():
                history_part = [
                    f"\n**Generation: {gen_name}**",
                    f"Score: {gen_details['score']}",
                    "Generated Code:",
                    gen_details["generated_code"],
                    "Generation Plan:",
                    *[f"- {step}" for step in gen_details["generation_plan"]],
                    "\n",
                ]
                prompt_parts.extend(history_part)

        # Build Refinement Instructions
        if user_feedback:
            refine_instructions = ["\n=== Refinement Requirements ==="]
            refine_instructions.append("Generate a revised implementation that:")
            refine_instructions.append("- Addresses all identified issues from the feedback analysis")
            refine_instructions.append("- Maintains strict compliance with component specifications")
            refine_instructions.append("- Preserves existing functionality that passed validation")
            prompt_parts.append("\n".join(refine_instructions))

        # Build Instructions
        instructions = ["\n=== Instructions ==="]
        if next_code_line:
            instructions.append("Generate ONLY the next line or a small code snippet required to proceed.")
        else:
            instructions.append("Generate the COMPLETE code based on the components and plan above.")
        instructions.append("DO MAKE SURE the complete code is a runnable function, all components are correctly integrated with in this function.")
        instructions.append("The complete function should take the input arguments as specified in the overall plan and return the output as specified.")

        if more_comments:
            instructions.append("Please add as much comments as possible to your code to explain the logic and any critical steps.")

        if output_planning:
            instructions.append("Structure your response as follows:")
            instructions.append("<Code>")
            instructions.append("Your code here. DO make sure the output is a single function that integrates all components.")
            instructions.append("</Code>")
            instructions.append("<Planning>")
            if next_code_line:
                instructions.append("A concise summary of what this specific code part accomplishes.")
            else:
                instructions.append("A detailed step-by-step explanation of the code's workflow.")
            instructions.append("</Planning>")
            instructions.append("<Main Function Name>")
            instructions.append("The name of the main function that integrates all components.")
            instructions.append("</Main Function Name>")
            instructions.append("Provide the code with the same indicator and structure as shown in Instructions. DO NOT return any test cases or example usages in your code!")
        else:
            instructions.append("Structure your response as follows:")
            instructions.append("<Code>")
            instructions.append("Your code here")
            instructions.append("</Code>")
            instructions.append("Provide the code WITHOUT any additional explanations, and DO use the same indicator and structure as shown in Instructions.")

        prompt_parts.append("\n".join(instructions))

        return "\n".join(prompt_parts)
    
    def extract_code(self, llm_output):
        """Extracts code and planning sections from LLM output."""
        result = {"code": None, "plan": None, "main_function_name": None}
        
        # Extract code section
        code_match = re.search(r'<Code>(.*?)(?:</Code>|<End>)', llm_output, re.DOTALL)
        if code_match:
            result["code"] = code_match.group(1).strip()
        else:
            # If not found, try to extract from ```python ... ```
            code_block_match = re.search(r'```(?:python)?\s*(.*?)```', llm_output, re.DOTALL)
            if code_block_match:
                result["code"] = code_block_match.group(1).strip()
        
        # Extract planning section
        plan_match = re.search(r'<Planning>(.*?)(?:</Planning>|<End>)', llm_output, re.DOTALL)
        if plan_match:
            result["plan"] = plan_match.group(1).strip()

        # Extract main function name
        main_func_match = re.search(r'<Main Function Name>(.*?)(?:</Main Function Name>|<End>)', llm_output, re.DOTALL)
        if main_func_match:
            result["main_function_name"] = main_func_match.group(1).strip()
        
        return result
    
    def get_code(self, extracted_plan, task_description=None, test_cases=None, history=None, next_code_line=False, output_planning=True, use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_code_generation_prompt(extracted_plan, extracted_plan.get('user_feedback'), task_description, test_cases, history, next_code_line, output_planning, use_example, use_task_description, use_system_prompt, more_comments)
        while retry_num <= max_retry:
            llm_output = self.LLM_model.LLM_response(prompt, gen_kwargs)
            code_output = self.extract_code(llm_output)
            if code_output["code"] is None:
                retry_num += 1
                # print(f"Failed to extract code, retrying ({retry_num})...")
                # print(f"Current llm_output:\n{llm_output}")
                logging.warning(f"Failed to extract code, retrying ({retry_num})...")
                logging.warning(f"Current llm_output:\n{llm_output}")
            else:
                break
        if code_output["code"] is None:
            logging.error(f"Failed to extract code, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract code, current llm_output:\n", llm_output)
        return code_output
    
    def code_runnable_check(self, code_str):
        try:
            compile(code_str, "<string>", "exec")
            return True
        except Exception as error:
            return False

    async def get_code_async(self, extracted_plan, task_description=None, test_cases=None, history=None, next_code_line=False, output_planning=True, use_example=True, use_task_description=True, use_system_prompt=True, more_comments=True, gen_kwargs={}, max_retry=3):
        retry_num=0
        if task_description is None:
            task_description = self.task_description
        prompt = self.create_code_generation_prompt(extracted_plan, extracted_plan.get('user_feedback'), task_description, test_cases, history, next_code_line, output_planning, use_example, use_task_description, use_system_prompt, more_comments)
        while retry_num <= max_retry:
            llm_output = await self.LLM_model.LLM_response_async(prompt, gen_kwargs)
            code_output = self.extract_code(llm_output)
            if code_output["code"] is None:
                retry_num += 1
                # print(f"Failed to extract code, retrying ({retry_num})...")
                # print(f"Current llm_output:\n{llm_output}")
                logging.warning(f"Failed to extract code, retrying ({retry_num})...")
                logging.warning(f"Current llm_output:\n{llm_output}")
            else:
                # TODO
                code_str = code_output["code"]
                code_check = self.code_runnable_check(code_str)
                if not code_check:
                    retry_num += 1
                    # print(f"Code is not runnable, retrying ({retry_num})...")
                    logging.warning(f"Code is not runnable, retrying ({retry_num})...")
                    logging.warning(f"Current code_output:\n{code_str}")
                    code_output["code"] = None
                else:
                    break
        if code_output["code"] is None:
            logging.error(f"Failed to extract code, current llm_output:\n{llm_output}")
            raise ValueError("Failed to extract code, current llm_output:\n", llm_output)
        return code_output

class CodeRunner():
    def __init__(self, max_workers=5):
        self.max_workers = max_workers

    def run_test(self, func_obj, test_func):
        try:
            return test_func(func_obj)
        except Exception as e:
            return False
    
    def compile_code(self, code_str, main_function_name=None):
        try:
            local_vars = {}
            exec(code_str, local_vars)  # Use one dict for globals and locals
            if main_function_name is not None:
                func = local_vars.get(main_function_name)
                return func if callable(func) else None
            return next((obj for obj in local_vars.values() if callable(obj)), None)
        except Exception as e:
            print(f"Compilation Error: {str(e)}, code_str:\n {code_str}")
            return None
    
    def run_all_tests(self, functions, test_cases, max_workers=5, timeout=5):
        """
        Updated to handle new function structure with main_function_name
        """
        # 编译函数（处理带主函数名称的情况）
        compiled_functions = {
            fid: self.compile_code(
                code_info['code'],
                main_function_name=code_info.get('main_function_name')
            )
            for fid, code_info in functions.items()
        }
        
        # 编译测试用例（保持原有逻辑）
        compiled_tests = {
            tid: self.compile_code(code_info['test_function'])
            for tid, code_info in test_cases.items()
        }

        # 准备结果字典
        fun_results = {fid: {} for fid in functions}
        test_results = {tid: {} for tid in test_cases}

        total_tests = len(compiled_functions) * len(compiled_tests)
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {}
            pbar = tqdm(total=total_tests, desc="Running tests")
            
            # 提交测试任务
            for func_id, func_obj in compiled_functions.items():
                for test_id, test_func in compiled_tests.items():
                    # 处理编译失败的情况
                    if func_obj is None or test_func is None:
                        fun_results[func_id][test_id] = False
                        test_results[test_id][func_id] = False
                        pbar.update(1)
                        continue
                    
                    # 提交并发任务
                    future = executor.submit(self.run_test, func_obj, test_func)
                    futures[future] = (func_id, test_id)

            # 处理测试结果
            for future in concurrent.futures.as_completed(futures):
                func_id, test_id = futures[future]
                try:
                    result = future.result()
                except Exception:
                    result = False
                fun_results[func_id][test_id] = result
                test_results[test_id][func_id] = result
                pbar.update(1)
            
            pbar.close()
        
        return fun_results, test_results
    
import ast
import re
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATConv, GraphConv, global_max_pool, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, Subset
from torch_geometric.loader import DataLoader
import logging

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {device}')

class CodeGraphDataset(Dataset):
    def __init__(self, dataframe, scaler=None, node_type_vocab=None):
        self.invalid_count = 0
        self.dataframe = dataframe.reset_index(drop=True)
        self.scaler = scaler if scaler else MinMaxScaler()
        if scaler is None:  # 仅训练集拟合
            self.scaler.fit(self.dataframe['score'].values.reshape(-1, 1))
        logging.info('Score values scaled using MinMaxScaler.')
        # Build a vocabulary for AST node types
        if node_type_vocab is None:
            self.node_type_vocab = self.build_node_type_vocab()
        else:
            self.node_type_vocab = node_type_vocab
        logging.info(f'Built node type vocabulary with size: {len(self.node_type_vocab)}')

    def build_node_type_vocab(self):
        node_types = set()
        for idx, code in enumerate(self.dataframe['code']):
            try:
                tree = ast.parse(code)
                for node in ast.walk(tree):
                    node_types.add(type(node).__name__)
            except Exception as e:
                logging.warning(f"Error parsing code at index {idx}: {e}")
        node_type_to_id = {"UNK": 0}
        for idx, nt in enumerate(sorted(node_types), start=1):
            node_type_to_id[nt] = idx
        return node_type_to_id

    def ast_to_graph(self, code):
        try:
            tree = ast.parse(code)
        except Exception as e:
            logging.warning(f"Error parsing code: {e}")
            return None

        nodes = []
        edges = []
        node_features = []
        node_id = 0
        node_id_map = {}

        def traverse(node, parent_id=None):
            nonlocal node_id
            current_id = node_id
            node_id_map[id(node)] = current_id
            nodes.append(current_id)
            # Encode node type as integer
            node_type = type(node).__name__
            node_type_id = self.node_type_vocab.get(node_type, 0)  # Handle unknown types
            node_features.append([node_type_id])
            node_id += 1

            if parent_id is not None:
                edges.append((parent_id, current_id))

            for child in ast.iter_child_nodes(node):
                traverse(child, current_id)

        traverse(tree)

        if not nodes:
            return None

        # Convert edges to a tensor
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)

        # Convert node features to a tensor
        x = torch.tensor(node_features, dtype=torch.long)

        # Create a Data object
        data = Data(x=x, edge_index=edge_index)
        return data

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        code = row['code']
        score = row['score']

        graph = self.ast_to_graph(code)
        if graph is None:
            # Skip samples with parsing errors by raising an exception
            # Alternatively, implement a different handling strategy
            logging.debug(f"Skipping index {idx} due to parsing error.")
            raise ValueError(f"Parsing failed for code at index {idx}.")

        if graph is None:
            self.invalid_count += 1
            logging.debug(f"Skipping index {idx} due to parsing error.")
            return None

        # Normalize score using the scaler
        score_normalized = self.scaler.transform([[score]]).flatten()[0]

        graph.y = torch.tensor([score_normalized], dtype=torch.float)
        return graph
    
class GNNModel(nn.Module):
    def __init__(self, num_node_types, embed_dim=64, hidden_dim=128, scaler=None):
        super(GNNModel, self).__init__()
        self.embedding = nn.Embedding(num_node_types, embed_dim)
        self.conv1 = GATConv(embed_dim, hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  # 假设拼接了池化特征
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.scaler = scaler

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.embedding(x.squeeze())
        x = self.conv1(x, edge_index)
        x = self.dropout(F.relu(x))
        x = self.conv2(x, edge_index)
        x = self.dropout(F.relu(x))
        x = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x.squeeze()

class PassRatePredictor():
    def __init__(self, ini_data=None, model=None):
        self.model = model
        if ini_data is None:
            # 初始化数据集为空
            self.data = pd.DataFrame(columns=["code", "score"])
        self.scaler = MinMaxScaler()
        self.trained = False
        self.node_type_vocab = None

    def add_data(self, new_data, use_pass_rate=False):
        if isinstance(new_data, dict):
            new_data = pd.DataFrame.from_dict(new_data, orient='index').reset_index(drop=True)
            # 仅保留 'code' 和 'score' 列
            if use_pass_rate:
                new_data = new_data[['code', 'pass_rate']].rename(columns={'pass_rate': 'score'})
            else:
                new_data = new_data[['code', 'score']]

        if self.data is None:
            self.data = new_data
        else:
            # 过滤重复数据
            new_data = new_data[~new_data['code'].isin(self.data['code'])]
            self.data = pd.concat([self.data, new_data], ignore_index=True)

    def predict_score(self, new_code_samples, model=None, scaler=None):
        if model is None:
            model = self.model
        if scaler is None:
            scaler = self.scaler

        # 将新数据包装为DataFrame
        new_df = pd.DataFrame({
            "code": new_code_samples,
            "score": ["0s"] * len(new_code_samples)  # 占位值
        })
        
        df_clean, _ = self.clean_score_data(new_df)

        # 创建数据集
        dataset = CodeGraphDataset(df_clean, scaler=scaler, node_type_vocab=self.node_type_vocab)
        loader = DataLoader(
            [data for data in dataset if data is not None],
            batch_size=32
        )
        
        # 预测
        model.eval()
        preds = []
        with torch.no_grad():
            for batch in loader:
                pred = model(batch)
                preds.extend(pred.cpu().numpy())
        
        # 反归一化
        # print("Debug###########################################")
        # print(preds)
        # pred_score = scaler.inverse_transform(np.array(preds).reshape(-1, 1)).flatten()
        # print(pred_score)
        return preds
    
    def test_model(self, model, dataframe, train_scaler=None):
        # 使用训练集的scaler（假设已经通过train_model传递）
        df_clean, _ = self.clean_score_data(dataframe)
        if train_scaler is None:
            train_scaler = MinMaxScaler().fit(dataframe['score'].values.reshape(-1, 1))
        test_dataset = CodeGraphDataset(df_clean, scaler=train_scaler)
        test_loader = DataLoader(
            [data for data in test_dataset if data is not None],
            batch_size=32
        )
        
        criterion = torch.nn.MSELoss()
        model.eval()
        test_loss = []
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in test_loader:
                pred = model(batch)
                loss = criterion(pred, batch.y)
                test_loss.append(loss.item())
                all_preds.extend(pred.cpu().numpy())
                all_labels.extend(batch.y.cpu().numpy())
        
        # 反归一化预测值和真实值
        preds = test_dataset.scaler.inverse_transform(np.array(all_preds).reshape(-1, 1)).flatten()
        labels = test_dataset.scaler.inverse_transform(np.array(all_labels).reshape(-1, 1)).flatten()
        
        # 计算指标
        mae = np.mean(np.abs(preds - labels))
        rmse = np.sqrt(np.mean((preds - labels)**2))
        print(f"Test MAE: {mae:.4f}, Test RMSE: {rmse:.4f}")
        return {"mae": mae, "rmse": rmse}
    
    def train_model(self, dataframe=None, epochs=50, batch_size=32, lr=0.001):
        if dataframe is None:
            dataframe = self.data

        # 清洗数据
        df_preprocessed = self.preprocess_data(dataframe)
        if df_preprocessed.empty:
            raise ValueError("无有效数据可供训练")

        # 划分训练集和验证集
        train_df, val_df = train_test_split(df_preprocessed, test_size=0.2, random_state=42)
        
        # 初始化数据集和DataLoader（训练集拟合scaler）
        train_scaler = MinMaxScaler().fit(train_df['score'].values.reshape(-1, 1))
        self.scaler = train_scaler
        train_dataset = CodeGraphDataset(train_df, scaler=train_scaler)
        val_dataset = CodeGraphDataset(val_df, scaler=train_scaler)  # 使用训练集的scaler
        
        # 过滤无效样本并创建DataLoader
        train_loader = DataLoader(
            [data for data in train_dataset if data is not None],
            batch_size=batch_size,
            shuffle=True
        )
        val_loader = DataLoader(
            [data for data in val_dataset if data is not None],
            batch_size=batch_size
        )
        
        # 初始化模型和优化器
        model = GNNModel(
            num_node_types=len(train_dataset.node_type_vocab),
            embed_dim=64,
            hidden_dim=128,
            scaler=train_scaler
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = torch.nn.MSELoss()  # 均方误差损失
        
        # 训练循环
        best_val_loss = float('inf')
        for epoch in range(epochs):
            model.train()
            train_loss = []
            for batch in train_loader:
                optimizer.zero_grad()
                pred = model(batch)
                loss = criterion(pred, batch.y)
                loss.backward()
                optimizer.step()
                train_loss.append(loss.item())
            
            # 验证集评估
            model.eval()
            val_loss = []
            with torch.no_grad():
                for batch in val_loader:
                    pred = model(batch)
                    loss = criterion(pred, batch.y)
                    val_loss.append(loss.item())
            
            # 打印日志
            avg_train_loss = np.mean(train_loss)
            avg_val_loss = np.mean(val_loss)
            print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
            
            # 保存最佳模型
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_gnn_model.pth")
        
        self.model = model
        self.node_type_vocab = train_dataset.node_type_vocab
        return model
    
    def filter_invalid_ast(self, df):
        valid_indices = []
        invalid_indices = []
        
        for idx, code in enumerate(df['code']):
            try:
                ast.parse(code)
                valid_indices.append(idx)
            except Exception as e:
                logging.warning(f"索引 {idx} 的代码无法解析AST: {e}")
                invalid_indices.append(idx)
        
        # 保留有效样本
        df_valid = df.iloc[valid_indices].reset_index(drop=True)
        return df_valid, invalid_indices

    def clean_score_data(self, df):
        cleaned_scores = []
        invalid_indices = []
        
        for idx, row in df.iterrows():
            value = row['score']
            try:
                if isinstance(value, str):
                    # 移除空格，转换为小写
                    cleaned_str = value.strip().lower()
                    # 提取数值和单位（正则匹配数值部分）
                    num_match = re.match(r"^(\d+\.?\d*)\s*([a-z]*)?", cleaned_str)
                    if not num_match:
                        raise ValueError("无法提取数值")
                    num = float(num_match.group(1))
                    unit = num_match.group(2) or 's'  # 默认单位是秒
                    # 根据单位转换为秒
                    if unit in {'s', 'sec', 'second', ''}:
                        converted = num
                    elif unit in {'ms', 'msec', 'millisecond'}:
                        converted = num / 1000
                    elif unit in {'m', 'min', 'minute'}:
                        converted = num * 60
                    elif unit in {'h', 'hour'}:
                        converted = num * 3600
                    else:
                        logging.warning(f"索引 {idx} 的未知单位 '{unit}'，假设为秒")
                        converted = num
                    cleaned_scores.append(converted)
                else:
                    # 处理数值类型（int/float）
                    cleaned_scores.append(float(value))
            except Exception as e:
                logging.warning(f"索引 {idx} 的score值 '{value}' 处理失败: {e}")
                invalid_indices.append(idx)
                cleaned_scores.append(None)
        
        # 替换原列并删除无效行
        df_clean = df.copy()
        df_clean['score'] = cleaned_scores
        df_clean = df_clean.dropna(subset=['score']).reset_index(drop=True)
        return df_clean, invalid_indices

    def preprocess_data(self, df):
        # Step 1: 过滤无法解析AST的样本
        df_ast_valid, ast_invalid = self.filter_invalid_ast(df)
        logging.info(f"过滤 {len(ast_invalid)} 个无效AST样本")
        
        # Step 2: 清洗score字段
        df_clean, score_invalid = self.clean_score_data(df_ast_valid)
        logging.info(f"过滤 {len(score_invalid)} 个无效score样本")
        
        return df_clean
    
import subprocess
import re
import tempfile
import os
import json

def pylint_code_score(code):
    try:
        # 创建临时文件保存代码
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:
            tmp.write(code)
            tmp_path = tmp.name
        
        # 执行 Pylint 分析
        result = subprocess.run(
            ["pylint", "--output-format=text", tmp_path],
            capture_output=True,
            text=True,
            check=False
        )
        output = result.stdout
        # print(output)
        # 删除临时文件
        os.unlink(tmp_path)
        
        # 提取评分（如 "rated at 7.50/10"）
        match = re.search(r"rated at (\d+\.?\d*)/10", output)
        return float(match.group(1)) if match else -1
    
    except Exception as e:
        print(f"Pylint 分析失败: {e}")
        return -1

def radon_mi_code_score(code: str) -> float:
    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:
            tmp.write(code)
            tmp_path = tmp.name
        
        result = subprocess.run(
            ["radon", "mi", "--json", tmp_path],
            capture_output=True,
            text=True,
            check=False
        )
        data = json.loads(result.stdout)
        os.unlink(tmp_path)
        
        if data and isinstance(data, dict):
            file_key = list(data.keys())[0]  # 获取临时文件的键名
            return data[file_key]["mi"] / 10
        return -1
    except Exception as e:
        print(f"Radon 分析失败: {e}")
        return -1
    
import os
import re
import json
import copy
import tempfile
import subprocess
import concurrent.futures
from tqdm import tqdm

class Evaluator():
    def __init__(self, pass_rate_predictor=None):
        self.pass_rate_predictor = pass_rate_predictor

    def calculate_pass_rate_score(self, test_results, test_weights):
        total_weight = sum(test_weights.values())
        if total_weight == 0:
            return 0.0
        
        passed_weight = sum(weight for test_id, weight in test_weights.items() 
                           if test_results.get(test_id, False))
        return passed_weight / total_weight

    def calculate_batch_scores(self, code_data):
        items = list(code_data.items())
        code_ids = [k for k, _ in items]
        code_entries = [v for _, v in items]
        full_score_dict = {}

        # 计算pass_rate_score（快速计算，无需并行）
        pass_rate_scores = {
            code_id: self.calculate_pass_rate_score(entry["test_results"], entry["test_weights"])
            for code_id, entry in code_data.items()
        }

        # 批量预测score
        code_strs = [entry["code"] for entry in code_entries]
        prediction_scores = [0.0] * len(code_strs)
        if self.pass_rate_predictor is not None and self.pass_rate_predictor.model is not None:
            try:
                prediction_scores = self.pass_rate_predictor.predict_score(code_strs)
                print("###############################################################")
                print(f"Prediction scores: {prediction_scores}")
                print("###############################################################")
            except Exception as e:
                print(code_strs)
                raise e

        # 并行计算静态分析分数
        with concurrent.futures.ThreadPoolExecutor() as executor:
            static_scores = list(tqdm(
                executor.map(self._compute_static_scores, code_strs),
                total=len(code_strs),
                desc="Analyzing codes"
            ))

        # 组合最终分数
        final_scores = {}
        for i, code_id in enumerate(code_ids):
            final_scores[code_id] = (
                0.7 * pass_rate_scores[code_id] +
                0.1 * prediction_scores[i] +
                0.1 * static_scores[i][0] +
                0.1 * static_scores[i][1]
            )
            full_score_dict[code_id] = {
                "pass_rate_score": pass_rate_scores[code_id],
                "prediction_score": prediction_scores[i],
                "pylint_score": static_scores[i][0],
                "radon_score": static_scores[i][1]
            }

        return final_scores, full_score_dict

    def _compute_static_scores(self, code_str):
        return (
            self.pylint_code_score(code_str),
            self.radon_mi_code_score(code_str)
        )

    def pylint_code_score(self, code):
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:
                tmp.write(code)
                tmp_path = tmp.name
            
            result = subprocess.run(
                ["pylint", "--output-format=text", tmp_path],
                capture_output=True,
                text=True,
                check=False
            )
            os.unlink(tmp_path)
            
            match = re.search(r"rated at (\d+\.?\d*)/10", result.stdout)
            return float(match.group(1)) if match else -1
        
        except Exception as e:
            print(f"Pylint analysis failed: {e}")
            return -1

    def radon_mi_code_score(self, code):
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp:
                tmp.write(code)
                tmp_path = tmp.name
            
            result = subprocess.run(
                ["radon", "mi", "--json", tmp_path],
                capture_output=True,
                text=True,
                check=False
            )
            os.unlink(tmp_path)
            
            data = json.loads(result.stdout)
            if data and isinstance(data, dict):
                return list(data.values())[0]["mi"] / 10
            return -1
        except Exception as e:
            print(f"Radon analysis failed: {e}")
            return -1

class LCDP():
    def __init__(self, api_key, model="gpt-3.5-turbo"):
        self.llm_model = LLMModel(api_key, model)
        self.code_runner = CodeRunner()
        self.pass_rate_predictor = PassRatePredictor()
        self.evaluator = Evaluator(self.pass_rate_predictor)
        self.task_description = None
        self.current_plan = None
        self.test_weights = {}
        self.test_cases = {}
        self.test_timeout = None

    async def run(self, task_description, max_iterations=3, example_dataset=None,
                 num_plans=3, num_tests=5, num_codes=5, refine_rounds=3, use_pass_rate_for_train=False, test_timeout=None):
        self.test_timeout = test_timeout
        self.task_description = task_description
        
        # Initialize LLM Task Manager
        self.llmtm = LLMTM(task_description, self.llm_model)
        self.llmcg = LLMCG(task_description, self.llm_model)
        
        # Phase 1: Plan Generation and Refinement
        # print("########################################################################")
        # print("### Phase 1: Plan Generation and Refinement")
        logging.info("########################################################################")
        logging.info("### Phase 1: Plan Generation and Refinement")
        plan, plan_raw = self.llmtm.get_plan()
        self.current_plan = await self._plan_refinement_loop(self.llmtm, plan_raw, refine_rounds)
        self.current_plan = self._plan_format_refinement(self.current_plan)
        
        # Phase 2: Test Case Generation and Weighting
        # print("\n########################################################################")
        # print("### Phase 2: Test Case Generation and Weighting")
        logging.info("\n########################################################################")
        logging.info("### Phase 2: Test Case Generation and Weighting")
        # self.test_cases = await self._generate_tests(self.llmtm, num_tests)
        # debug
        self.test_cases = await self._generate_tests_async(self.llmtm, num_tests, use_example=False)
        self.test_cases = self._filter_test_cases(self.test_cases)

        # print("Calculating test weights...")
        logging.info("Calculating test weights...")
        self.test_weights = self._calculate_test_weights(self.test_cases, example_dataset)
        
        # Phase 3: Iterative Code Generation
        # print("\n########################################################################")
        # print("### Phase 3: Iterative Code Generation")
        logging.info("\n########################################################################")
        logging.info("### Phase 3: Iterative Code Generation")
        best_codes = {}
        for iteration in range(max_iterations):
            # print(f"\n=== Iteration {iteration+1}/{max_iterations} ===")
            logging.info(f"\n=== Iteration {iteration+1}/{max_iterations} ===")
            
            # Generate new codes
            # new_codes = await self._generate_codes(num_codes)
            new_codes = await self._generate_codes_async(num_codes)
            
            # Evaluate codes
            logging.info("Evaluating codes...")
            scored_codes, filtered_test_result = self._evaluate_codes(new_codes)
            # remove the test cases that are not in the filtered_test_result
            self.test_cases = {k: v for k, v in self.test_cases.items() if k in list(filtered_test_result.keys())}

            logging.info("training pass_rate_predictor...")
            self.pass_rate_predictor.add_data(scored_codes, use_pass_rate=use_pass_rate_for_train)
            self.pass_rate_predictor.train_model(epochs=50, batch_size=32, lr=0.001)
            
            # Update best codes
            best_codes.update(self._select_top_codes(scored_codes, top_k=3))
            
            # User feedback
            if not await self._get_user_feedback(best_codes):
                self.current_plan['user_feedback'] = "Based on previous outputs, please improve the code quality."
        
        return best_codes

    async def _plan_refinement_loop(self, llmtm, initial_plan_raw, max_rounds):
        current_plan_raw = initial_plan_raw
        current_plan = llmtm.extract_plan(current_plan_raw)
        for _ in range(max_rounds):
            # Show current plan
            # print("Current Plan:\n", self.plan_json_to_str(current_plan["overall_plan"]))
            logging.info("Current Plan:\n" + self.plan_json_to_str(current_plan["overall_plan"]))
            
            # Get user feedback
            if input("Refine plan? (y/n): ").lower() != 'y':
                logging.info("Skipping plan refinement.")
                break
            
            feedback = input("Enter refinement feedback: ")
            logging.info(f"User feedback: {feedback}")
            current_plan, current_plan_raw = llmtm.refine_plan(feedback, current_plan_raw)
        
        return llmtm.extract_plan(current_plan_raw)

    def plan_json_to_str(self, plan):
        # Process Input Format
        input_fmt = plan["input_format"]
        input_lines = []
        for idx, (dtype, shape) in enumerate(input_fmt, 1):
            shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
            input_lines.append(f"Argument {idx}: {dtype} with {shape_str}")
        input_section = "Input Format:\n" + "\n".join([f"- {line}" for line in input_lines])

        # Process Output Format
        output_fmt = plan["output_format"]
        output_lines = []
        for idx, (dtype, shape) in enumerate(output_fmt, 1):
            shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
            output_lines.append(f"Output {idx}: {dtype} with {shape_str}")
        output_section = "Output Format:\n" + "\n".join([f"- {line}" for line in output_lines])

        # Build Overall Plan Details
        plan_part = [
            "=== Current Plan ===",
            input_section,
            output_section,
            f"Components Order: {', '.join(plan['components'])}",
            "Plan Steps:",
            *[f"- {step}" for step in plan["plan"]],
            "Overall Test Case Advice:",
            *[f"- {advice}" for advice in plan["test_case_generation_advise"]],
            "\n",
        ]

        return "\n".join(plan_part)

    def _plan_format_refinement(self, plan_dict):
        """Refines the input and output formats in the plan to be lists of lists."""
        
        # Create a deep copy to avoid modifying the original input
        refined_plan = copy.deepcopy(plan_dict)
        
        def refine_format(formats):
            """Ensure each format field is a list of lists."""
            if isinstance(formats, list):
                # Check if all elements are lists
                if not all(isinstance(elem, list) for elem in formats):
                    return [formats]
            else:
                # If it's not a list, wrap it into a list (though input is expected to be a list)
                return [formats]
            return formats
        
        # Process each component in 'components'
        for component in refined_plan["components"].values():
            for key in ["input_format", "output_format"]:
                if key in component:
                    component[key] = refine_format(component[key])
        
        # Process 'overall_plan'
        overall_plan = refined_plan.get("overall_plan")
        if overall_plan:
            for key in ["input_format", "output_format"]:
                if key in overall_plan:
                    overall_plan[key] = refine_format(overall_plan[key])
        
        return refined_plan

    async def _generate_tests(self, llmtm, num_tests):
        test_cases = {}
        for _ in range(num_tests):
            test = llmtm.get_test_cases(self.current_plan['overall_plan'])
            test_cases.update(test)
        return test_cases
    
    async def _generate_tests_async(self, llmtm, num_tests, use_example=True):
        test_cases = {}
        task_list = [llmtm.get_test_cases_async(self.current_plan['overall_plan'], use_example=use_example) for _ in range(num_tests)]
        
        for task in tqdm_asyncio.as_completed(task_list, total=num_tests, desc="Generating async tests"):
            test = await task
            for key, value in test.items():
                # 生成唯一键名逻辑
                new_key = key
                suffix = 1
                while new_key in test_cases:
                    new_key = f"{key}_{suffix}"
                    suffix += 1
                test_cases[new_key] = value
                
        return test_cases

    def _filter_test_cases(self, dataset):
        # print(dataset)
        runnable_entries = {}
        for code_id, attributes in dataset.items():
            test_code = attributes.get("test_function", "")
            try:
                # Attempt to compile the code string to check for syntax errors.
                compile(test_code, "<string>", "exec")
                # If no exception is raised, consider the code as runnable.
                runnable_entries[code_id] = attributes
            except Exception as error:
                # If an exception is raised, skip this entry.
                continue
        return runnable_entries
            
    def _calculate_test_weights(self, test_cases, example_dataset):
        if not example_dataset:
            return {tid: 1.0 for tid in test_cases}
        
        # Run example dataset through tests
        _, test_results = self.code_runner.run_all_tests(example_dataset, test_cases)
        
        # Calculate weights
        weights = {}
        for tid, results in test_results.items():
            pass_rate = sum(results.values()) / len(results)
            weights[tid] = 1 - abs(pass_rate - 0.5)  # Weight tests that discriminate
        return weights

    async def _generate_codes(self, num_codes):
        codes = {}
        for _ in range(num_codes):
            code = self.llmcg.get_code(
                extracted_plan=self.current_plan,
                test_cases=self.test_cases,
            )
            codes[f"code_{len(codes)}"] = code
        return codes
    
    async def _generate_codes_async(self, num_codes):
        codes = {}
        task_list = [self.llmcg.get_code_async(extracted_plan=self.current_plan,
                                               test_cases=self.test_cases) for _ in range(num_codes)]
        for task in tqdm_asyncio.as_completed(task_list, total=num_codes, desc="Generating async codes"):
            code = await task
            codes[f"code_{len(codes)}"] = code
        return codes

    def transform_test_perspective(self, test_results):
        transformed = {}
        for test_case_id, code_results in test_results.items():
            for code_id, result in code_results.items():
                if code_id not in transformed:
                    transformed[code_id] = {}
                transformed[code_id][test_case_id] = result
        return transformed

    def _filter_test_cases_by_pass_rate(self, test_results, threshold=0.05):
        filtered_test_case_list = []
        filtered_test_results = {}
        test_case_length = len(test_results)
        for test_case_id, results in test_results.items():
            total = len(results)
            if total == 0:
                continue
            passed = sum(results.values())
            pass_rate = passed / total
            if pass_rate > threshold:
                filtered_test_case_list.append(test_case_id)
                filtered_test_results[test_case_id] = results
        self.test_cases = {k: v for k, v in self.test_cases.items() if k in filtered_test_case_list}
        logging.info(f"Filtered test cases: {len(self.test_cases)} out of {test_case_length}")

        filtered_fun_results = self.transform_test_perspective(filtered_test_results)

        return filtered_fun_results, filtered_test_results

    def _evaluate_codes(self, codes, timeout=None):
        if timeout is None:
            timeout = self.test_timeout
        fun_results, test_results = self.code_runner.run_all_tests(codes, self.test_cases, timeout=timeout)
        # print(fun_results)
        filtered_fun_results, filtered_test_results = self._filter_test_cases_by_pass_rate(test_results, threshold=0.05)

        input_data = {}
        for code_id, results in filtered_fun_results.items():
            input_data[code_id] = {
                'code': codes[code_id]['code'],
                'test_results': results,
                'test_weights': self.test_weights
            }
        # Calculate scores
        output_scores, full_score_dict = self.evaluator.calculate_batch_scores(input_data)
        # Combine scores with code data
        output_results = {
            code_id: {
                'code': codes[code_id]['code'],
                'plan': codes[code_id]['plan'],
                'main_function_name': codes[code_id]['main_function_name'],
                'score': output_scores[code_id],
                'pass_rate_score': full_score_dict[code_id]['pass_rate_score'],
                'prediction_score': full_score_dict[code_id]['prediction_score'],
                'pylint_score': full_score_dict[code_id]['pylint_score'],
                'radon_score': full_score_dict[code_id]['radon_score'],
                'test_case_results': filtered_fun_results[code_id],
            }
            for code_id in codes.keys()
        }
        return output_results, filtered_test_results

    def _select_top_codes(self, scored_codes, top_k=3):
        return dict(sorted(scored_codes.items(), 
                          key=lambda x: x[1]['score'], 
                          reverse=True)[:top_k])

    async def _get_user_feedback(self, top_codes):

        logging.info("\nTop Performing Codes:")
        for cid, data in top_codes.items():
            logging.info(f"{cid} [Score: {data['score']:.2f}]:")
            logging.info("Code workflow:")
            logging.info(data['plan'])
            logging.info("Partial Code:")
            logging.info(data['code'][:500] + "...\n")
        
        if input("Provide feedback? (y/n): ").lower() == 'y':
            feedback = input("Enter your feedback: ")
            logging.info(f"User feedback: {feedback}")
            # Store feedback for next generation cycle
            self.current_plan['user_feedback'] = feedback
            return True
        return False
