import ast
import re
import inspect
import uuid
import traceback
import random
import json
import tempfile
import subprocess

class LLMCG():
    def __init__(self, LLM_model):
        self.LLM_model = LLM_model

    def refine_codes(self, generated_codes, code_results, test_cases, error_test_num):
        print("INFO: Refining codes based on evaluation results...")
        # Mock refinement: just adds a comment
        for code_id, code_info in generated_codes.items():
            code_info["content"] += "\n  # refined"
        return generated_codes

    
    def _validate_and_create_function(self, code_string):
        """
        Validates the generated Python code string for syntax and correctness.
        If valid, it returns the compiled function object. Otherwise, returns None.
        """
        try:
            # Safely compile the code to check for syntax errors
            compile(code_string, '<string>', 'exec')
            
            # Execute the code in a temporary namespace to define the function
            local_namespace = {}
            exec(code_string, {}, local_namespace)
            
            # Ensure 'test_function' was defined and is a callable function
            if 'test_function' in local_namespace and callable(local_namespace['test_function']):
                return local_namespace['test_function']
            else:
                print(f"Warning: 'test_function' not found in executed code.\nCode: {code_string}")
                return None
        except Exception as e:
            print(f"Warning: Failed to validate or create function due to {e}.\nCode: {code_string}")
            return None

    def generate_test_cases_from_raw(self, raw_llm_outputs, test_type):

        # 1. Normalize the input to always be a list of strings
        if isinstance(raw_llm_outputs, str):
            outputs_list = [raw_llm_outputs]
        else:
            outputs_list = raw_llm_outputs

        if not outputs_list:
            return []

        # 2. Generate an extraction prompt for each raw output
        prompts = [self._get_extraction_prompt(raw_output) for raw_output in outputs_list]

        # 3. Call the LLM asynchronously to format all prompts
        llm_results, all_successful = self.LLM_model.LLM_response_async(prompts)

        # if not all_successful:
        #     print("Warning: At least one LLM call failed during test formatting.")

        # 4. Process each LLM result
        all_test_cases = []
        for i, (success, formatted_response, _original_prompt) in enumerate(llm_results):
            if not success:
                print(f"Warning: Failed to get formatted response. Error: {formatted_response}")
                continue

            # Retrieve the original description for this test
            original_raw_output = outputs_list[i]

            # Parse the callable test functions from the formatted response
            test_functions, code_strs = self._extract_and_parse_tests(formatted_response)

            # 5. Create a dictionary for each successfully parsed function
            for test_func, code_str in zip(test_functions, code_strs):
                all_test_cases.append({
                    "test_function": test_func,
                    "test_function_string": code_str,
                    "description": original_raw_output,
                    "type": test_type,
                    "weight": 1.0
                })

        return all_test_cases

    def adding_new_test_case(self, test_cases, new_test_case):
        next_id = max(test_cases.keys() or [-1]) + 1

        # Normalize input to a list to handle both single and multiple additions
        if isinstance(new_test_case, dict):
            cases_to_add = [new_test_case]
        elif isinstance(new_test_case, list):
            cases_to_add = new_test_case
        else:
            raise TypeError("new_test_case must be a dictionary or a list of dictionaries.")

        # Iterate through the cases to add them to the main dictionary
        for case in cases_to_add:
            if isinstance(case, dict):
                test_cases[next_id] = case
                next_id += 1 # Increment the ID for the next case
            else:
                # Handle cases where a list contains non-dictionary items
                print(f"Warning: Skipping non-dictionary item in list: {case}")

        return test_cases

    def generate_tests(self, num, task_description, plan=None, original_test_cases=None, debug=False):
        if original_test_cases is None:
            # Ensure we start with a mutable list
            test_cases = {}
        else:
            # Create a copy to avoid modifying the original list in place
            test_cases = original_test_cases

        # get test case generation prompts for different purposes
        prompt_in_content = self.get_test_prompt_in_content(task_description, plan)
        prompt_correctness = self.get_test_prompt_correctness(task_description, plan)
        prompt_runnable_check = self.get_test_prompt_runnable(task_description, plan)

        # use LLM to generate test cases
        test_in_contents_raw = self.LLM_model.LLM_response(prompt_in_content)
        test_runnable_checks_raw = self.LLM_model.LLM_response(prompt_runnable_check)
        test_correctness_raw_list_generation, success_result = self.LLM_model.LLM_response_async([prompt_correctness]*num, max_workers=5)

        if not success_result:
            print("Warning: Some correctness test case generations failed.")
        test_correctness_raw_list = ["" for _ in range(num)]
        for i, temp_list in enumerate(test_correctness_raw_list_generation):
            if temp_list[0]:
                test_correctness_raw_list[i] = temp_list[1]

        # --- Extract test cases by using LLM calls ---
        # print(task_description)
        # print("#############################################################")
        # print(test_in_contents_raw)
        # print("#############################################################")
        # print(test_runnable_checks_raw)
        # print("#############################################################")
        # print(test_correctness_raw_list)
        # raise NotImplementedError("debug")
    
        # 1. Extract from task description
        test_cases = self.adding_new_test_case(test_cases, self.generate_test_cases_from_raw(test_in_contents_raw, 'in_content'))
        
        # 2. Extract from runnable checks
        # print(test_runnable_checks_raw)
        test_cases = self.adding_new_test_case(test_cases, self.generate_test_cases_from_raw(test_runnable_checks_raw, 'runnable'))

        # 3. Extract from correctness checks
        test_cases = self.adding_new_test_case(test_cases, self.generate_test_cases_from_raw(test_correctness_raw_list, 'correctness'))

        if debug:
            debug_info = {"prompt_in_content": prompt_in_content,
                          "prompt_correctness": prompt_correctness,
                          "prompt_runnable_check": prompt_runnable_check,
                          "raw_in_content": test_in_contents_raw,
                          "raw_runnable_check": test_runnable_checks_raw,
                          "raw_correctness": test_correctness_raw_list,
                          "test_cases": test_cases}
            return test_cases, debug_info

        return test_cases

    def _get_extraction_prompt(self, raw_llm_output):
        prompt = f"""
You are an expert programmer specializing in writing Python unit tests.
Your task is to extract all test cases from the provided raw text and convert them into executable Python test functions.

**Instructions:**

1.  **Function Signature**: Each test function must have the exact signature `def test_case(func_to_test):`. The `func_to_test` parameter is the function that will be under test.
2.  **Return Value**: Each test function must return a tuple `(bool, str)`.
    * The first element is a boolean: `True` if the test passes, `False` otherwise.
    * The second element is a string message. If the test fails, this message must be informative, explaining the reason for failure (e.g., "Expected: <expected_value>, Got: <actual_value>").
3.  **Error Handling**: Wrap the call to `func_to_test` in a `try...except` block to catch any exceptions during its execution. If an exception occurs, the test should fail and the message should include the exception details.
4.  **Multiple Tests**: If the raw text contains multiple test cases, generate a separate `test_case` function for each.
5.  **Formatting**:
    * Enclose all the generated Python code within a single markdown code block (e.g., ```python ... ```).
    * Use the exact separator `---TEST-CASE-SEPARATOR---` on its own line between each distinct `test_case` function.

**Example:**

* **Raw Text:**
    "The function must handle edge cases like an empty list. For `[]`, it should return `0`. Also, check for a list with a single element like `[5]`, which should return `5`."

* **Expected Output:**
    ```python
    import traceback

    def test_case(func_to_test):
        try:
            input_val = []
            expected_output = 0
            actual_output = func_to_test(input_val)
            if actual_output == expected_output:
                return True, "Test passed for empty list."
            else:
                return False, f"Test failed for input {{input_val}}. Expected: {{expected_output}}, Got: {{actual_output}}"
        except Exception as e:
            return False, f"Test failed for input {{input_val}} with exception: {{e}}\\n{{traceback.format_exc()}}"

    ---TEST-CASE-SEPARATOR---

    import traceback

    def test_case(func_to_test):
        try:
            input_val = [5]
            expected_output = 5
            actual_output = func_to_test(input_val)
            if actual_output == expected_output:
                return True, "Test passed for single-element list."
            else:
                return False, f"Test failed for input {{input_val}}. Expected: {{expected_output}}, Got: {{actual_output}}"
        except Exception as e:
            return False, f"Test failed for input {{input_val}} with exception: {{e}}\\n{{traceback.format_exc()}}"
    ```

Now, please process the following raw text and generate the corresponding Python test functions.

**Raw Text to Process:**

---
{raw_llm_output}
---
"""
        return prompt.strip()

    def _extract_and_parse_tests(self, llm_output):

        # 1. Find the Python code block using a regular expression
        code_block_match = re.search(r"```python\n(.*?)```", llm_output, re.DOTALL)
        if not code_block_match:
            print("Warning: Could not find a python code block in the LLM output.")
            print(f"LLM Output:\n{llm_output}")
            return [], []

        code_block = code_block_match.group(1).strip()
        
        # 2. Split the code block into individual test case strings
        test_case_strings = code_block.split("---TEST-CASE-SEPARATOR---")

        test_functions = []
        code_strs = []
        for i, code_str in enumerate(test_case_strings):
            code_str = code_str.strip()
            code_strs.append(code_str)
            if not code_str:
                continue

            try:
                # 3. Execute the code in an isolated namespace to define the function
                local_namespace = {}
                exec(code_str, {}, local_namespace)

                # 4. Retrieve the compiled function from the namespace
                if 'test_case' in local_namespace and callable(local_namespace['test_case']):
                    test_functions.append(local_namespace['test_case'])
                else:
                    print(f"Warning: 'test_case' function not found in code snippet #{i+1}.")

            except SyntaxError as e:
                print(f"Warning: Syntax error parsing test case #{i+1}: {e}")
                print(f"Problematic code:\n---\n{code_str}\n---")
            except Exception as e:
                print(f"Warning: An unexpected error occurred while parsing test case #{i+1}: {e}")
                print(f"Problematic code:\n---\n{code_str}\n---")

        return test_functions, code_strs
    
    def get_test_prompt_in_content(self, task_description, plan=None):
        # Conditionally define the text and section for the plan
        if plan:
            context_description = "task description and plan"
            plan_section = f"""**Plan:**
---
{plan}
---"""
        else:
            context_description = "task description"
            plan_section = ""

        # Construct the prompt, inserting the plan section only if it exists
        prompt_temp = f"""
You are a meticulous assistant. Your task is to carefully read the following {context_description} and extract any explicit examples or test cases mentioned within it.

**Task Description:**
---
{task_description}
---
{plan_section}
**Your Instructions:**
1.  Read the {context_description} carefully.
2.  Identify every example that demonstrates the function's behavior, usually in the form of an input and its corresponding expected output.
3.  List each input-output pair you find.
4.  **Crucially**, if you find absolutely no examples or test cases in the text, you must return the single word: **None**.

Do not invent any new test cases. Only extract what is explicitly written.
"""
        return prompt_temp.strip()

    def get_test_prompt_correctness(self, task_description, plan=None):
        if plan:
            context_description = "task description and plan"
            plan_section = f"""**Plan:**
---
{plan}
---"""
        else:
            context_description = "task description"
            plan_section = ""

        prompt_temp = f"""
You are a senior software quality engineer. Your goal is to design a single, high-quality test case to verify the correctness of a function based on the {context_description} provided. This test case must be **novel** and **not** one of the examples already mentioned in the description.

**Task Description:**
---
{task_description}
---
{plan_section}
**Your Thought Process (Follow these steps meticulously):**

1.  **Analyze the Task:**
    * What is the primary goal of the function?
    * What are the specified inputs and their data types?
    * What is the specified output and its data type?
    * What are the constraints and edge cases (e.g., empty inputs, negative numbers, large values, specific formats)?

2.  **Identify a Novel Test Scenario:**
    * Review the examples in the description (if any) and consciously choose a different category of input. Consider edge cases, corner cases, or common failure points for this type of problem. For example, if the description tests positive numbers, consider testing negative numbers, zero, or a mix.

3.  **Generate the Test Case (Show your work):**
    * **Step A: Propose the Input.** State the exact input you will use for your test.
    * **Step B: Reason Step-by-Step.** Walk through the logic required by the task description, applying it to your chosen input. Explain how you arrive at the final output. This reasoning is the most important part.
    * **Step C: State the Expected Output.** Clearly state the final, correct output based on your reasoning.

Provide your analysis and the final test case (Input, Reasoning, and Output).
"""
        return prompt_temp.strip()
    
    def get_test_prompt_runnable(self, task_description, plan=None):
        if plan:
            context_description = "task description and plan"
            plan_section = f"""**Plan:**
---
{plan}
---"""
        else:
            context_description = "task description"
            plan_section = ""

        prompt_temp = f"""
You are a build engineer creating a "smoke test". The goal is **not** to check for correctness, but simply to ensure a function is runnable (i.e., it can be called without crashing due to syntax errors or basic type mismatches).

**Task Description:**
---
{task_description}
---
{plan_section}
**Your Instructions:**

1.  **Analyze the Function Signature:** Based on the {context_description}, determine the expected data type and structure of the input arguments. For example, does it take a single integer, a list of strings, two arguments?

2.  **Propose a Mock Input:** Create the simplest possible, valid input that conforms to the signature you identified.

3.  **Describe the Test Logic:** Explain that the test involves calling the function with this mock input. The test passes if the function executes and returns *anything* without raising an exception. The actual return value does not matter for this specific test. Describe the necessary components for a test that would run the function and catch any potential errors.
"""
        return prompt_temp.strip()
    
    def _sample_test_case(self, test_cases, num_test):
        if not test_cases or num_test <= 0:
            return []

        all_test_strings = [
            data['test_function_string']
            for data in test_cases.values()
            if 'test_function_string' in data
        ]

        if not all_test_strings:
            return []

        # Ensure we don't try to sample more items than exist
        num_to_sample = min(num_test, len(all_test_strings))

        return random.sample(all_test_strings, num_to_sample)

    def _construct_prompt(self, plan, task_description, sampled_tests, use_task_description=True):

        prompt_parts = [
            "You are an expert Python algorithm engineer. Your task is to generate a complete and runnable Python script based on the provided plan and context.",
            "Follow these instructions carefully:",
            "1. **Reasoning First**: Before writing any code, provide a step-by-step reasoning of your approach. Explain the chosen algorithms, data structures, and the logic for the main function. This thought process is critical.",
            "2. **Code Generation**: After your reasoning, provide the complete Python code in a single block. The code must be fully functional and self-contained.",
            "3. **Main Function**: The script MUST include a `main` function that serves as the entry point. This function must accept inputs and return outputs exactly as described in the plan, as it will be used for automated evaluation.",
            "4. **No Type Hints**: Do not use type hints from the `typing` module in your code.",
            "\n---\n"
        ]

        # Add the core plan
        prompt_parts.append("## Plan to Implement\n" + plan)

        # Conditionally add the task description
        if use_task_description:
            prompt_parts.append("## Task Description\n" + task_description)

        # Conditionally add sampled test cases for context
        if sampled_tests:
            test_cases_str = "\n".join(sampled_tests)
            prompt_parts.append(
                "## Example Test Cases\n"
                "Here are some example test cases to help you understand the required input/output format. Your solution should be able to pass these.\n"
                + test_cases_str
            )
        
        prompt_parts.append("\n---\n")
        prompt_parts.append("Now, begin with your reasoning, followed by the complete code.")

        return "\n".join(prompt_parts)

    def generate_code_with_reasoning(self, task_description, plans, num_codes, test_cases=None, num_test=0, use_task_description=True):
        sampled_tests = []
        if test_cases and num_test > 0:
            sampled_tests = self._sample_test_case(test_cases, num_test)

        prompt = self._construct_prompt(
            plans,
            task_description,
            use_task_description,
            sampled_tests
        )

        if num_codes <= 0:
            return []

        if num_codes == 1:
            try:
                # Use synchronous call for a single request
                response = self.LLM_model.LLM_response(prompt)
                return [response] if response else []
            except Exception as e:
                print(f"An error occurred during single LLM call: {e}")
                return []
        else:
            # Use asynchronous call for multiple requests for efficiency
            prompts_list = [prompt] * num_codes
            try:
                results, all_successful = self.LLM_model.LLM_response_async(prompts_list)
                if not all_successful:
                    print("Warning: Not all asynchronous LLM calls were successful.")
                
                # Extract successful responses
                successful_responses = [res[1] for res in results if res[0]]
                return successful_responses
            except Exception as e:
                print(f"An error occurred during asynchronous LLM calls: {e}")
                return []

    def _code_extract_prompt(self, raw_llm_output):
        prompt = f"""
You are a precise code parsing tool. Your task is to extract the complete Python code block and the name of the main function from the text provided below. The text contains reasoning followed by the code.

Respond ONLY with a JSON object in the following format:
{{
  "code_str": "...",
  "main_func_name": "..."
}}

- The value for "code_str" should be the entire, clean Python code as a single string. This includes all necessary imports and functions.
- The value for "main_func_name" should be the name of the main entry point function as a string.
- If you cannot find a valid Python code block or a main function, return a JSON object with empty strings for both values. Do not add any explanation.

--- TEXT TO PARSE ---
{raw_llm_output}
--- END OF TEXT ---
"""
        return prompt

    def code_extraction(self, stage_one_outputs, plan_id):
        if not stage_one_outputs:
            return []

        # Create a list of prompts for the async call
        prompts = [self._code_extract_prompt(output) for output in stage_one_outputs]
        
        try:
            results, all_successful = self.LLM_model.LLM_response_async(prompts)
        except Exception as e:
            print(f"An error occurred during asynchronous LLM calls for extraction: {e}")
            # On catastrophic failure, return a list of failure records
            return [
                {"code_str": "", "main_func_name": "", "reasoning": original_output, "plan_id": plan_id}
                for original_output in stage_one_outputs
            ]

        extracted_data = []
        
        # Correlate original outputs with their corresponding results
        output_result_pairs = zip(stage_one_outputs, results)

        for original_output, result_tuple in output_result_pairs:
            success, response, _ = result_tuple

            if not success:
                # Handle failure of the LLM call itself
                extracted_data.append({
                    "code_str": "",
                    "main_func_name": "",
                    "reasoning": original_output,
                    "plan_id": plan_id
                })
                continue

            try:
                # The primary success path: parsing the JSON response
                parsed_json = json.loads(response)
                
                # Use .get() for safe dictionary access, providing "" as a default
                code_str = parsed_json.get("code_str", "")
                main_func_name = parsed_json.get("main_func_name", "")

                if not isinstance(code_str, str) or not isinstance(main_func_name, str):
                    # Handle case where JSON is valid but types are wrong
                    raise TypeError("JSON values are not strings.")

                extracted_data.append({
                    "code_str": code_str,
                    "main_func_name": main_func_name,
                    "reasoning": original_output,
                    "plan_id": plan_id
                })

            except (json.JSONDecodeError, TypeError) as e:
                # Handle cases where the LLM response is not valid JSON or has wrong types
                print(f"Failed to parse LLM response for extraction: {e}")
                extracted_data.append({
                    "code_str": "",
                    "main_func_name": "",
                    "reasoning": original_output,
                    "plan_id": plan_id
                })

        return extracted_data

    def generate_codes(self, num_codes, task_description, plan_id, plan, test_cases=None, num_test_cases=1, use_task_description=True):

        raw_outputs = self.stage_one_generator.generate_code_with_reasoning(
            task_description=task_description,
            plans=plan,
            num_codes=num_codes,
            test_cases=test_cases,
            num_test=num_test_cases,
            use_task_description=use_task_description
        )

        # If the first stage failed to produce any output, terminate and return.
        if not raw_outputs:
            print("Code Generation Warning: Stage One did not produce any output.")
            return []

        # Pass the raw outputs from stage one to the second stage for parsing and structuring.
        extracted_codes = self.code_extraction(stage_one_outputs=raw_outputs, plan_id=plan_id)

        # restructured as a dictionary with unique IDs
        extracted_codes = {str(uuid.uuid4()): code_info for code_info in extracted_codes}

        # Return the final list of structured code objects.
        return extracted_codes
    
    def preprocess_codes(self, codes):
        updated_codes = codes.copy()
        
        # --- Step 1: Identify codes that need preprocessing ---
        ids_to_embed = []
        strings_to_embed = []

        for code_id, info in updated_codes.items():
            # Process AST if not present
            if 'ast' not in info:
                try:
                    code_str = info.get('code_str', '')
                    # Parse the code into an AST object
                    parsed_ast = ast.parse(code_str)
                    # Dump the AST into a serializable string format
                    info['ast'] = ast.dump(parsed_ast)
                except SyntaxError:
                    info['ast'] = "Error: Invalid Python syntax"
            
            # Identify codes that need an embedding
            if 'embedding' not in info:
                ids_to_embed.append(code_id)
                strings_to_embed.append(info.get('code_str', ''))

        # --- Step 2: Batch process embeddings if any are needed ---
        if strings_to_embed:
            # Call the embedding API once for all unprocessed codes
            embeddings = self.LLM_model.Embedding_response(strings_to_embed)
            
            # --- Step 3: Update the records with the new embeddings ---
            for code_id, embedding_vector in zip(ids_to_embed, embeddings):
                updated_codes[code_id]['embedding'] = embedding_vector

        return updated_codes
    
class CodeRunner:
    def __init__(self, max_workers=5):
        self.max_workers = max_workers

    def run_code_runner_in_subprocess(self, functions, test_cases, max_workers=None, timeout=30, script_path=r"E:\python_project_new\AI4SLCDP\src\Test_runner.py"):
        if max_workers is None:
            max_workers = self.max_workers
        
        # 创建临时文件保存数据
        with tempfile.NamedTemporaryFile(mode='w', delete=False) as func_file:
            json.dump(functions, func_file)
            func_file_path = func_file.name
        with tempfile.NamedTemporaryFile(mode='w', delete=False) as test_file:
            json.dump(test_cases, test_file)
            test_file_path = test_file.name
        
        # print("debug01")

        cmd = [
            "python", script_path,
            "--functions_file", func_file_path,
            "--test_cases_file", test_file_path,
            "--max_workers", str(max_workers),
            "--timeout", str(timeout)
        ]

        process = subprocess.Popen(
            cmd, 
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1
        )
        # print("debug02")
        pbar = None
        output = []
        total_tasks = None
        while True:
            line = process.stdout.readline()
            if not line:
                break
            line = line.strip()
            output.append(line)
            
            if line.startswith("PROGRESS_TASK: "):
                # 解析进度数据
                progress_part = line.split("PROGRESS_TASK: ")[1]
                try:
                    current, total = map(int, progress_part.split('/'))
                except ValueError:
                    continue
                
                # 初始化进度条
                if not pbar:
                    total_tasks = total
                    pbar = tqdm(
                        total=total,
                        desc="Testing Progress",
                        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}",
                        ascii=True
                    )
                
                # 更新进度（确保不越界）
                current = max(0, min(current, total_tasks))
                if pbar.n != current:
                    pbar.n = current
                    pbar.refresh()
        
        # 收尾工作
        if pbar:
            pbar.n = total_tasks
            pbar.refresh()
            pbar.close()
        
        process.wait()
        
        # 解析最终结果
        func_results = {}
        test_results = {}
        for line in output:
            if line.startswith("FUNCTION_RESULTS:"):
                func_results = json.loads(line[len("FUNCTION_RESULTS:"):])
            elif line.startswith("TEST_RESULTS:"):
                test_results = json.loads(line[len("TEST_RESULTS:"):])
        
        # 清理临时文件
        try:
            os.unlink(func_file_path)
            os.unlink(test_file_path)
        except Exception as e:
            print(f"Error cleaning temp files: {e}")

        return func_results, test_results, "\n".join(output)

    def run_all_tests(self, functions, test_cases, max_workers=None, timeout=5):
        if max_workers is None:
            max_workers = self.max_workers
        # print(functions)
        # print(test_cases)
        fr, tr, _ = self.run_code_runner_in_subprocess(functions, test_cases, max_workers, timeout)
        # print(_)
        return fr, tr