import asyncio
import json
import argparse
import os
import sys
import tempfile
import time
import re # For function name extraction and LCDP placeholder
from LCDP import LCDP
from tqdm.asyncio import tqdm as async_tqdm

def extract_function_name(prompt_text, test_list):
    """
    Extracts the Python function name from the problem's prompt or test list.
    """
    # Attempt 1: Look for "def function_name(" or "def function_name ("
    match = re.search(r"def\s+([a-zA-Z_]\w*)\s*\(", prompt_text)
    if match:
        return match.group(1)

    # Attempt 2: Look for "function named `function_name`" or "'function_name'" or ""function_name""
    # Covers: "Write a Python function named 'solve_this'", "function `my_func`"
    match = re.search(r"function\s+(?:named\s+)?(?:`|'|\")([a-zA-Z_]\w*)(?:`|'|\")", prompt_text)
    if match:
        return match.group(1)
    
    # Attempt 3: A common pattern "Write a function `func_name`"
    match = re.search(r"(?:Write|Create|Define)\s+(?:a\s+)?(?:Python\s+)?function\s+`([a-zA-Z_]\w*)`", prompt_text)
    if match:
        return match.group(1)

    # Attempt 4: Infer from the first test case (e.g., "assert func_name(...) ...")
    if test_list and isinstance(test_list, list) and len(test_list) > 0:
        first_test = test_list[0]
        # Matches 'func_name(' at the start of an assertion or as a standalone call within the assertion
        # e.g. "assert my_func(1, 2) == 3" or "assert is_valid(my_func(x))"
        # This regex tries to find a word followed by an opening parenthesis.
        match = re.search(r"([a-zA-Z_]\w*)\s*\(", first_test)
        if match:
            # Avoid matching common keywords like 'str', 'list', 'int' if they appear before a parenthesis
            # This is a simple heuristic; a more robust solution might involve checking against a list of keywords.
            potential_name = match.group(1)
            if potential_name not in ['list', 'dict', 'tuple', 'set', 'str', 'int', 'float', 'bool', 'len', 'print', 'range']:
                 # Check if it's part of an assertion like "assert func_name("
                assert_match = re.search(r"assert\s+([a-zA-Z_]\w*)\s*\(", first_test)
                if assert_match:
                    return assert_match.group(1)
                return potential_name # Return the first plausible function call found

    return "solution_function" # Default if no name can be reliably extracted


def load_mbpp_dataset(file_path):
    """
    Loads the MBPP dataset from a JSONL file.
    """
    dataset = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    dataset.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"Warning: Skipping invalid JSON line in {file_path}: {e}")
        if len(dataset) == 1:
            dataset = dataset[0]
        print(f"Successfully loaded {len(dataset)} problems from {file_path}")
    except FileNotFoundError:
        print(f"Error: MBPP dataset file not found at {file_path}")
        sys.exit(1)
    return dataset

async def execute_code_with_tests(generated_code_str, test_imports, test_list, timeout_seconds=15):
    """
    Executes the generated code against the provided test cases in an isolated subprocess.
    """
    script_content = ""
    for imp in test_imports:
        script_content += f"{imp}\n"
    script_content += "\n"
    
    script_content += generated_code_str + "\n\n"
    
    script_content += "def __run_mbpp_tests__():\n"
    script_content += "    __test_results__ = []\n"
    script_content += "    __all_passed__ = True\n" # Flag to track overall success
    for i, test_case_str in enumerate(test_list):
        script_content += f"    try:\n"
        script_content += f"        assert {test_case_str}\n" # MBPP tests are often just the expression
        script_content += f"        __test_results__.append(True)\n"
        script_content += f"    except AssertionError:\n"
        script_content += f"        __test_results__.append(False)\n"
        script_content += f"        __all_passed__ = False\n"
        script_content += f"    except Exception as e:\n"
        script_content += f"        # print(f'Test case {{i+1}} ({test_case_str}) raised an exception: {{e}}') # For debugging subprocess\n"
        script_content += f"        __test_results__.append(False)\n"
        script_content += f"        __all_passed__ = False\n"
    
    script_content += "    if __all_passed__:\n"
    script_content += "        exit(0)\n"
    script_content += "    else:\n"
    script_content += "        # print(f'Detailed test results (inside subprocess): {__test_results__}') # For debugging subprocess\n"
    script_content += "        exit(1)\n"
    script_content += "\n__run_mbpp_tests__()\n"

    tmp_script_name = None
    process = None

    try:
        with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding='utf-8') as tmpf:
            tmpf.write(script_content)
            tmp_script_name = tmpf.name
        
        python_executable = sys.executable
        process = await asyncio.create_subprocess_exec(
            python_executable, tmp_script_name,
            stdout=asyncio.subprocess.PIPE, 
            stderr=asyncio.subprocess.PIPE
        )
        
        stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout_seconds)
        
        # # Debugging: Uncomment to see subprocess output
        # if stdout:
        #     print(f"[Subprocess STDOUT - {os.path.basename(tmp_script_name)}]:\n{stdout.decode(errors='replace')}")
        # if stderr:
        #     print(f"[Subprocess STDERR - {os.path.basename(tmp_script_name)}]:\n{stderr.decode(errors='replace')}")

        return process.returncode == 0

    except asyncio.TimeoutError:
        # print(f"Test execution timed out for script {tmp_script_name or 'N/A'}")
        if process and process.returncode is None:
            try:
                process.terminate()
                await process.wait()
            except ProcessLookupError: pass
            except Exception as term_exc: print(f"Error during process termination: {term_exc}")
        return False
    except Exception as e:
        # print(f"Error running test script ({tmp_script_name or 'N/A'}): {e}")
        return False
    finally:
        if tmp_script_name and os.path.exists(tmp_script_name):
            os.remove(tmp_script_name)


async def evaluate_single_problem(lcdp_instance, problem_data, lcdp_run_params, problem_idx, total_problems):
    """
    Evaluates a single MBPP problem using the LCDP framework.
    The prompt to LCDP is enhanced with function name and example tests.
    """
    task_id = problem_data['task_id']
    original_prompt = problem_data['prompt']
    test_imports = problem_data.get('test_imports', [])
    test_list = problem_data.get('test_list', [])

    print(f"Processing problem {problem_idx+1}/{total_problems} (Task ID: {task_id})...")

    if not test_list:
        print(f"Warning: No test cases found for Task ID {task_id}. Skipping.")
        return task_id, False, "No test cases provided", None

    # --- Enhance the prompt for LCDP ---
    extracted_func_name = extract_function_name(original_prompt, test_list)
    
    enhanced_task_description = f"{original_prompt}\n\n"
    enhanced_task_description += f"Your goal is to implement the Python function described above. "
    enhanced_task_description += f"The function should be named `{extracted_func_name}`.\n\n"
    
    if test_list:
        enhanced_task_description += "Here are some example assertions the function should satisfy:\n"
        # Show up to 3 example tests
        for i, test_case_str in enumerate(test_list[:3]):
            enhanced_task_description += f"- `assert {test_case_str}`\n" # MBPP tests are often expressions
    enhanced_task_description += "\nPlease provide the complete Python code for this function."
    # --- End of prompt enhancement ---

    try:
        best_codes_dict = await lcdp_instance.run(
            task_description=enhanced_task_description, # Use the enhanced prompt
            **lcdp_run_params 
        )

        if not best_codes_dict or not isinstance(best_codes_dict, dict) or not best_codes_dict:
            print(f"Task ID {task_id}: LCDP returned no or invalid codes.")
            return task_id, False, "LCDP returned no/invalid codes", enhanced_task_description

        # For pass@1, select the top-scoring code.
        # Assumes the first entry is the highest-scoring or desired one.
        selected_code_str = None
        try:
            first_code_id = next(iter(best_codes_dict))
            selected_code_str = best_codes_dict[first_code_id]['code']
        except (StopIteration, KeyError, TypeError) as e:
            print(f"Task ID {task_id}: Error extracting code from LCDP output: {e}. Output: {best_codes_dict}")
            return task_id, False, f"Error extracting code: {e}", enhanced_task_description
        
        if not selected_code_str:
            print(f"Task ID {task_id}: No code string found in LCDP output.")
            return task_id, False, "No code string from LCDP", enhanced_task_description

        passed_all_tests = await execute_code_with_tests(
            selected_code_str,
            test_imports,
            test_list, # Pass the full test list for execution
            timeout_seconds=lcdp_run_params.get('test_timeout', 10) + 5 
        )
        
        status_msg = "Passed" if passed_all_tests else "Failed"
        print(f"Task ID {task_id}: {status_msg} tests.")
        # if not passed_all_tests: # For debugging failed cases
        #     print(f"--- Failing code for Task ID {task_id} ---\n{selected_code_str}\n---------------------------------")

        return task_id, passed_all_tests, selected_code_str if not passed_all_tests else "Passed", enhanced_task_description

    except Exception as e:
        print(f"Task ID {task_id}: Error during evaluation: {e}")
        import traceback
        traceback.print_exc()
        return task_id, False, f"Evaluation error: {e}", enhanced_task_description


async def main(args):
    mbpp_dataset = load_mbpp_dataset(args.mbpp_dataset_path)
    if not mbpp_dataset: return

    if args.limit_problems is not None and args.limit_problems > 0:
        mbpp_dataset = mbpp_dataset[:args.limit_problems]
        print(f"Limiting evaluation to the first {len(mbpp_dataset)} problems.")

    lcdp_instance = LCDP(
        api_key=args.api_key,
        model=args.model,
        max_workers=args.lcdp_max_workers,
        ignore_advice=True,
    )

    lcdp_run_params = {
        "max_iterations": args.max_iterations,
        "num_plans": args.num_plans,
        "num_tests": args.num_tests, # This is for LCDP's internal test generation
        "num_codes": args.num_codes,
        "refine_rounds": args.refine_rounds,
        "use_pass_rate_for_train": args.use_pass_rate_for_train,
        "test_timeout": args.test_timeout, # LCDP's internal test execution timeout
        "use_async_generation": False,
    }

    tasks = []
    # Semaphore to control concurrency of problem processing (LCDP runs + local test executions)
    problem_processing_semaphore = asyncio.Semaphore(args.max_concurrent_problems) 

    async def run_with_semaphore(semaphore, coro):
        async with semaphore:
            # print(f"Acquired semaphore for a task...") # Debugging
            result = await coro
            # print(f"Released semaphore for a task...") # Debugging
            return result

    for i, problem_data in enumerate(mbpp_dataset):
        eval_coroutine = evaluate_single_problem(
            lcdp_instance, 
            problem_data, 
            lcdp_run_params, 
            i, 
            len(mbpp_dataset)
        )
        tasks.append(run_with_semaphore(problem_processing_semaphore, eval_coroutine))
    
    start_time = time.time()
    evaluation_outcomes = await asyncio.gather(*tasks, return_exceptions=True)
    end_time = time.time()

    num_passed = 0
    num_total = len(mbpp_dataset)
    detailed_results = []

    for i, outcome in enumerate(evaluation_outcomes):
        # Get task_id for context, even if outcome is an exception
        # Note: If mbpp_dataset was empty, this would error. Guarded by `if not mbpp_dataset: return`
        task_id_context = mbpp_dataset[i]['task_id'] 
        
        if isinstance(outcome, Exception):
            print(f"Task ID {task_id_context}: An unexpected exception occurred during its processing: {outcome}")
            detailed_results.append({
                "task_id": task_id_context, 
                "passed": False, 
                "error": str(outcome),
                "prompt_to_lcdp": "Error before prompt generation or prompt not available"
            })
        else:
            # outcome is (task_id, bool_passed, code_or_msg, enhanced_prompt)
            _task_id, passed, code_or_msg, enhanced_prompt_val = outcome
            detailed_results.append({
                "task_id": _task_id, 
                "passed": passed, 
                "details": code_or_msg,
                "prompt_to_lcdp": enhanced_prompt_val
            })
            if passed:
                num_passed += 1
    
    pass_at_1 = (num_passed / num_total) * 100 if num_total > 0 else 0
    total_time_taken = end_time - start_time

    print("\n--- Evaluation Summary ---")
    print(f"Total problems evaluated: {num_total}")
    print(f"Problems passed: {num_passed}")
    print(f"Pass@1 Score: {pass_at_1:.2f}%")
    print(f"Total evaluation time: {total_time_taken:.2f} seconds")
    if num_total > 0:
        print(f"Average time per problem: {total_time_taken / num_total:.2f} seconds")


    if args.results_output_path:
        try:
            with open(args.results_output_path, 'w', encoding='utf-8') as f_out:
                json.dump({
                    "summary": {
                        "total_problems": num_total,
                        "passed_problems": num_passed,
                        "pass_at_1_percentage": pass_at_1,
                        "total_time_seconds": total_time_taken,
                        "average_time_per_problem_seconds": total_time_taken / num_total if num_total > 0 else 0,
                        "max_concurrent_problems_setting": args.max_concurrent_problems
                    },
                    "lcdp_params": lcdp_run_params,
                    "detailed_results": detailed_results
                }, f_out, indent=4)
            print(f"Detailed results saved to {args.results_output_path}")
        except Exception as e:
            print(f"Error saving results to {args.results_output_path}: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MBPP Evaluation Pipeline for LCDP Framework (Enhanced)")
    
    parser.add_argument("--mbpp_dataset_path", type=str, required=True, help="Path to the MBPP dataset JSONL file.")
    parser.add_argument("--results_output_path", type=str, default=None, help="Optional path to save detailed results in JSON format.")
    parser.add_argument("--limit_problems", type=int, default=None, help="Limit evaluation to the first N problems from the dataset.")
    parser.add_argument("--max_concurrent_problems", type=int, default=5, help="Maximum number of MBPP problems to process concurrently.")

    parser.add_argument("--api_key", type=str, default=os.environ.get("LCDP_API_KEY", os.environ.get("OPENAI_API_KEY")), help="API key for LCDP. Defaults to LCDP_API_KEY or OPENAI_API_KEY env var.")
    parser.add_argument("--model", type=str, default="gpt-4o-mini", help="Model name for LCDP.")
    parser.add_argument("--lcdp_max_workers", type=int, default=20, help="Max workers for LCDP's internal operations (if applicable).")

    parser.add_argument("--max_iterations", type=int, default=3, help="max_iterations for LCDP run.")
    parser.add_argument("--num_plans", type=int, default=3, help="num_plans for LCDP run.")
    parser.add_argument("--num_tests", type=int, default=5, help="num_tests for LCDP run (internal test generation).")
    parser.add_argument("--num_codes", type=int, default=5, help="num_codes for LCDP run.")
    parser.add_argument("--refine_rounds", type=int, default=3, help="refine_rounds for LCDP run.")
    parser.add_argument("--use_pass_rate_for_train", action='store_true', help="Set use_pass_rate_for_train to True for LCDP run.")
    parser.add_argument("--test_timeout", type=int, default=10, help="test_timeout (seconds) for LCDP's internal tests.")
    
    args = parser.parse_args()

    if not args.api_key:
        print("Warning: API key not provided via --api_key argument or relevant environment variables (LCDP_API_KEY, OPENAI_API_KEY). LCDP might require it.")

    asyncio.run(main(args))
