import asyncio
import json
import argparse
import os
import sys
import tempfile
import time
import re # For function name extraction and LCDP placeholder
from LCDP import LCDP
import random

import re

# def get_function_name_from_code(code_string):
#     match = re.search(r"def\s+([a-zA-Z_]\w*)\s*\(", code_string)
#     if match:
#         # The first captured group is the function name
#         return match.group(1)
#     return None

# def extract_mbpp_function_name(prompt, code, test_list):
#     # The most reliable way to get the function name is from the 'code' string.
#     if code:
#         function_name = get_function_name_from_code(code)
#         if function_name:
#             return function_name
#     return None

def get_all_function_names_from_code(code_string):
    matches = re.findall(r"def\s+([a-zA-Z_]\w*)\s*\(", code_string)
    return matches

def extract_mbpp_function_name(prompt, code, test_list):
    if not code:
        return None # Cannot determine function name without code

    defined_function_names = get_all_function_names_from_code(code)

    if not defined_function_names:
        return None

    if len(defined_function_names) == 1:
        return defined_function_names[0]
    else:
        # Multiple functions are defined in the 'code'.
        # We need to check the 'test_list' to see which one is being called.
        defined_funcs_set = set(defined_function_names)
        for test_entry in test_list:
            potential_calls_in_test = re.findall(r"\b([a-zA-Z_]\w*)\s*\(", test_entry)
            for func_candidate in potential_calls_in_test:
                if func_candidate in defined_funcs_set:
                    # This function, called in a test, is also defined in the code block.
                    return func_candidate
        return None

def load_mbpp_dataset(file_path):
    """
    Loads the MBPP dataset from a JSONL file.
    """
    dataset = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    dataset.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"Warning: Skipping invalid JSON line in {file_path}: {e}")
        if len(dataset) == 1 and isinstance(dataset[0], list): # Handle if a single line contains the whole list of problems
            dataset = dataset[0]
        print(f"Successfully loaded {len(dataset)} problems from {file_path}")
    except FileNotFoundError:
        print(f"Error: MBPP dataset file not found at {file_path}")
        sys.exit(1)
    return dataset

async def execute_code_with_tests(generated_code_str, test_imports, test_list, timeout_seconds=15):
    """
    Executes the generated code against the provided test cases in an isolated subprocess.
    """
    script_content = ""
    for imp in test_imports:
        script_content += f"{imp}\n"
    script_content += "\n"

    script_content += generated_code_str + "\n\n"

    # Corrected line: removed leading space before "def"
    script_content += "def __run_mbpp_tests__():\n"
    script_content += "    __test_results__ = []\n"
    script_content += "    __all_passed__ = True\n"  # Flag to track overall success
    for i, test_case_str in enumerate(test_list):
        script_content += f"    try:\n"
        script_content += f"        {test_case_str}\n"  # MBPP tests are often just the expression
        script_content += f"        __test_results__.append(True)\n"
        script_content += f"    except AssertionError:\n"
        script_content += f"        __test_results__.append(False)\n"
        script_content += f"        __all_passed__ = False\n"
        script_content += f"    except Exception as e:\n"
        # For debugging subprocess:
        # script_content += f"        print(f'Test case {{i+1}} ({test_case_str}) raised an exception: {{e}}')\n"
        script_content += f"        __test_results__.append(False)\n"
        script_content += f"        __all_passed__ = False\n"

    script_content += "    if __all_passed__:\n"
    script_content += "        exit(0)\n"
    script_content += "    else:\n"
    # For debugging subprocess:
    # script_content += "        # print(f'Detailed test results (inside subprocess): {__test_results__}')\n"
    script_content += "        exit(1)\n"
    script_content += "\n__run_mbpp_tests__()\n"

    # print(f"Executing script:\n{script_content}")  # For debugging

    tmp_script_name = None
    process = None

    try:
        # delete=False is kept as the file needs to exist for the subprocess to open it by name.
        # We will manually delete it in the finally block.
        with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding='utf-8') as tmpf:
            tmpf.write(script_content)
            tmp_script_name = tmpf.name

        python_executable = sys.executable
        process = await asyncio.create_subprocess_exec(
            python_executable, tmp_script_name,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE
        )

        try:
            stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout_seconds)
        except asyncio.TimeoutError:
            print(f"Execution timed out after {timeout_seconds} seconds for script {tmp_script_name}.")
            if process.returncode is None: # Process still running
                process.kill() # Terminate the process
                await process.wait() # Wait for termination
            # print(f"Subprocess STDERR (timeout):\n{stderr.decode(errors='replace') if stderr else 'N/A'}") # Optional: Log stderr on timeout
            return False # Indicate failure due to timeout

        # # Debugging: Uncomment to see subprocess output
        # if stdout:
        #     print(f"[Subprocess STDOUT - {os.path.basename(tmp_script_name)}]:\n{stdout.decode(errors='replace')}")
        # if stderr:
        #     print(f"[Subprocess STDERR - {os.path.basename(tmp_script_name)}]:\n{stderr.decode(errors='replace')}")

        return process.returncode == 0

    finally:
        if tmp_script_name and os.path.exists(tmp_script_name):
            os.remove(tmp_script_name)


async def evaluate_single_problem(lcdp_instance, problem_data, lcdp_run_params, problem_idx, total_problems):
    """
    Evaluates a single MBPP problem using the LCDP framework.
    The prompt to LCDP is enhanced with function name and example tests.
    """
    task_id = problem_data.get('task_id', f"task_idx_{problem_idx}")
    original_prompt = problem_data.get('prompt', "")
    test_list = problem_data.get('test_list', [])
    test_imports = problem_data.get('test_imports', [])
    gt_code = problem_data.get('code', "")

    # --- Create a default or base enhanced_task_description for fallback ---
    # _extracted_func_name_default = extract_function_name(original_prompt, test_list)
    _extracted_func_name_default = extract_mbpp_function_name(original_prompt, gt_code, test_list)

    enhanced_task_description_fallback = f"{original_prompt}\n\nYour goal is to implement the Python function described above. The function should be named `{_extracted_func_name_default}`."
    if not original_prompt:
        enhanced_task_description_fallback = "Error: Original prompt was empty or missing for this task."

    print(f"Processing problem {problem_idx+1}/{total_problems} (Task ID: {task_id})...")

    if not test_list:
        print(f"Warning: No test cases found for Task ID {task_id}. Skipping.")
        return task_id, False, "No test cases provided", enhanced_task_description_fallback

    # --- Enhance the prompt for LCDP ---
    try:
        # extracted_func_name = extract_function_name(original_prompt, test_list)
        extracted_func_name = extract_mbpp_function_name(original_prompt, gt_code, test_list)
        enhanced_task_description = f"{original_prompt}\n\n"
        enhanced_task_description += f"Your goal is to implement the Python function described above. "
        enhanced_task_description += f"The function should be named `{extracted_func_name}`.\n\n"

        if test_list:
            enhanced_task_description += "Here are some example assertions the function should satisfy:\n"
            for i, test_case_str in enumerate(test_list[:3]):
                enhanced_task_description += f"- `{test_case_str}`\n" # MBPP tests are often expressions
        enhanced_task_description += "\nPlease provide the complete Python code for this function."
    except Exception as prompt_exc:
        print(f"Task ID {task_id}: Error during prompt enhancement: {prompt_exc}")
        import traceback
        traceback.print_exc()
        return task_id, False, f"Prompt enhancement error: {prompt_exc}", enhanced_task_description_fallback
    # --- End of prompt enhancement ---

    try:
        best_codes_dict = await lcdp_instance.run(
            task_description=enhanced_task_description, # Use the enhanced prompt
            **lcdp_run_params
        )

        if not best_codes_dict or not isinstance(best_codes_dict, dict) or not best_codes_dict:
            print(f"Task ID {task_id}: LCDP returned no or invalid codes.")
            return task_id, False, "LCDP returned no/invalid codes", enhanced_task_description

        selected_code_str = None
        try:
            first_code_id = next(iter(best_codes_dict))
            selected_code_str = best_codes_dict[first_code_id]['code']
        except (StopIteration, KeyError, TypeError) as e:
            print(f"Task ID {task_id}: Error extracting code from LCDP output: {e}. Output: {best_codes_dict}")
            return task_id, False, f"Error extracting code: {e}", enhanced_task_description

        if not selected_code_str:
            print(f"Task ID {task_id}: No code string found in LCDP output.")
            return task_id, False, "No code string from LCDP", enhanced_task_description

        passed_all_tests = await execute_code_with_tests(
            selected_code_str,
            test_imports,
            test_list, # Pass the full test list for execution
            timeout_seconds=lcdp_run_params.get('test_timeout', 10) + 5
        )

        status_msg = "Passed" if passed_all_tests else "Failed"
        print(f"Task ID {task_id}: {status_msg} tests.")
        # if not passed_all_tests: # For debugging failed cases
        #     print(f"--- Failing code for Task ID {task_id} ---\n{selected_code_str}\n---------------------------------")

        return task_id, passed_all_tests, selected_code_str if not passed_all_tests else "Passed", enhanced_task_description

    except Exception as e:
        print(f"Task ID {task_id}: Error during LCDP run or code evaluation: {e}")
        import traceback
        traceback.print_exc()
        return task_id, False, f"Evaluation error: {e}", enhanced_task_description


async def main(args):
    mbpp_dataset = load_mbpp_dataset(args.mbpp_dataset_path)
    if not mbpp_dataset: return

    if args.limit_problems is not None and args.limit_problems > 0:
        # mbpp_dataset = mbpp_dataset[:args.limit_problems]
        # print(f"Limiting evaluation to the first {len(mbpp_dataset)} problems.")
        k = min(args.limit_problems, len(mbpp_dataset))
        mbpp_dataset = random.sample(mbpp_dataset, k)
        print(f"Limiting evaluation to {len(mbpp_dataset)} randomly selected problems.")

    lcdp_instance = LCDP(
        api_key=args.api_key,
        model=args.model,
        max_workers=args.lcdp_max_workers,
        ignore_advice=True,
        use_pr_predictor=False,
    )

    lcdp_run_params = {
        "max_iterations": args.max_iterations,
        "num_plans": args.num_plans,
        "num_tests": args.num_tests, # This is for LCDP's internal test generation
        "num_codes": args.num_codes,
        "refine_rounds": args.refine_rounds,
        "use_pass_rate_for_train": args.use_pass_rate_for_train,
        "test_timeout": args.test_timeout, # LCDP's internal test execution timeout
        "use_async_generation": False,
        "best_only": args.best_only,
    }

    evaluation_outcomes = []
    start_time = time.time()

    # Process problems sequentially
    for i, problem_data in enumerate(mbpp_dataset):
        try:
            # evaluate_single_problem is designed to catch its internal errors and return a tuple.
            # This outer try-except is for unforeseen errors in the calling sequence
            # or if evaluate_single_problem itself has an unhandled exception.
            outcome = await evaluate_single_problem(
                lcdp_instance,
                problem_data,
                lcdp_run_params,
                i,
                len(mbpp_dataset)
            )
            evaluation_outcomes.append(outcome)
        except Exception as e:
            # This block will be hit if evaluate_single_problem itself raises an uncaught exception.
            task_id_context = problem_data.get('task_id', f"unknown_task_at_idx_{i}")
            print(f"Task ID {task_id_context}: A critical unexpected exception occurred during its sequential processing: {e}")
            import traceback
            traceback.print_exc()
            # Append the exception itself to be handled by the results processing logic,
            # mimicking asyncio.gather(return_exceptions=True)
            evaluation_outcomes.append(e)

    end_time = time.time()

    num_passed = 0
    num_total = len(mbpp_dataset)
    detailed_results = []

    for i, outcome in enumerate(evaluation_outcomes):
        # Get task_id for context, even if outcome is an exception
        # Guarded by `if not mbpp_dataset: return` and loop over mbpp_dataset earlier.
        task_id_context = mbpp_dataset[i]['task_id'] if i < len(mbpp_dataset) and 'task_id' in mbpp_dataset[i] else f"unknown_task_at_idx_{i}"

        if isinstance(outcome, Exception):
            print(f"Task ID {task_id_context}: An unexpected exception was caught and recorded for this problem: {outcome}")
            detailed_results.append({
                "task_id": task_id_context,
                "passed": False,
                "error": str(outcome),
                "prompt_to_lcdp": "Error during problem evaluation; prompt details may be missing or incomplete."
            })
        elif isinstance(outcome, tuple) and len(outcome) == 4:
            # Expected outcome from evaluate_single_problem: (task_id, bool_passed, code_or_msg, enhanced_prompt)
            _task_id, passed, code_or_msg, enhanced_prompt_val = outcome
            detailed_results.append({
                "task_id": _task_id,
                "passed": passed,
                "details": code_or_msg,
                "prompt_to_lcdp": enhanced_prompt_val
            })
            if passed:
                num_passed += 1
        else:
            # Should not happen if evaluate_single_problem behaves as expected or exception is caught.
            print(f"Task ID {task_id_context}: Received an unexpected outcome format: {outcome}")
            detailed_results.append({
                "task_id": task_id_context,
                "passed": False,
                "error": "Unexpected outcome format from evaluation.",
                "details": str(outcome),
                "prompt_to_lcdp": "Unknown due to unexpected outcome."
            })

    pass_at_1 = (num_passed / num_total) * 100 if num_total > 0 else 0
    total_time_taken = end_time - start_time

    print("\n--- Evaluation Summary ---")
    print(f"Total problems evaluated: {num_total}")
    print(f"Problems passed: {num_passed}")
    print(f"Pass@1 Score: {pass_at_1:.2f}%")
    print(f"Total evaluation time: {total_time_taken:.2f} seconds")
    if num_total > 0:
        print(f"Average time per problem: {total_time_taken / num_total:.2f} seconds")


    if args.results_output_path:
        try:
            with open(args.results_output_path, 'w', encoding='utf-8') as f_out:
                json.dump({
                    "summary": {
                        "total_problems": num_total,
                        "passed_problems": num_passed,
                        "pass_at_1_percentage": pass_at_1,
                        "total_time_seconds": total_time_taken,
                        "average_time_per_problem_seconds": total_time_taken / num_total if num_total > 0 else 0,
                        # "max_concurrent_problems_setting": args.max_concurrent_problems # Removed
                    },
                    "lcdp_params": lcdp_run_params,
                    "detailed_results": detailed_results
                }, f_out, indent=4)
            print(f"Detailed results saved to {args.results_output_path}")
        except Exception as e:
            print(f"Error saving results to {args.results_output_path}: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MBPP Evaluation Pipeline for LCDP Framework (Enhanced - Sequential Execution)")

    parser.add_argument("--mbpp_dataset_path", type=str, required=True, help="Path to the MBPP dataset JSONL file.")
    parser.add_argument("--results_output_path", type=str, default=None, help="Optional path to save detailed results in JSON format.")
    parser.add_argument("--limit_problems", type=int, default=None, help="Limit evaluation to the first N problems from the dataset.")
    # parser.add_argument("--max_concurrent_problems", type=int, default=5, help="Maximum number of MBPP problems to process concurrently.") # Removed

    parser.add_argument("--api_key", type=str, default=os.environ.get("LCDP_API_KEY", os.environ.get("OPENAI_API_KEY")), help="API key for LCDP. Defaults to LCDP_API_KEY or OPENAI_API_KEY env var.")
    parser.add_argument("--model", type=str, default="gpt-4o-mini", help="Model name for LCDP.")
    parser.add_argument("--lcdp_max_workers", type=int, default=20, help="Max workers for LCDP's internal operations (if applicable).")

    parser.add_argument("--max_iterations", type=int, default=3, help="max_iterations for LCDP run.")
    parser.add_argument("--num_plans", type=int, default=3, help="num_plans for LCDP run.")
    parser.add_argument("--num_tests", type=int, default=5, help="num_tests for LCDP run (internal test generation).")
    parser.add_argument("--num_codes", type=int, default=5, help="num_codes for LCDP run.")
    parser.add_argument("--refine_rounds", type=int, default=3, help="refine_rounds for LCDP run.")
    parser.add_argument("--use_pass_rate_for_train", action='store_true', help="Set use_pass_rate_for_train to True for LCDP run.")
    parser.add_argument("--test_timeout", type=int, default=10, help="test_timeout (seconds) for LCDP's internal tests.")
    parser.add_argument("--random_seed", type=int, default=None, help="Random seed for reproducibility.")
    parser.add_argument("--best_only", type=bool, default=False, help="Only use best code for code refinement.")
    parser.add_argument("--error_test_num", type=int, default=False, help="number of failed test sampled in code refinment prompt.")
    
    args = parser.parse_args()

    if hasattr(args, 'random_seed') and args.random_seed is not None:
        random.seed(args.random_seed)
        print(f"Using random seed: {args.random_seed}")

    if not args.api_key:
        print("Warning: API key not provided via --api_key argument or relevant environment variables (LCDP_API_KEY, OPENAI_API_KEY). LCDP might require it.")

    asyncio.run(main(args))