import asyncio
import json
import os
import logging
import time
from typing import List, Dict, Any, Tuple
from LCDP import LCDP
from tqdm import tqdm

# # Configure logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# MBPP Problem Structure (as loaded by Hugging Face datasets)
# {
# 'task_id': int,
# 'text': str, (problem description)
# 'code': str, (reference solution, not used for generation but for info)
# 'test_list': list[str], (list of assert statements)
# 'test_setup_code': str, (setup code needed before tests, e.g., imports)
# 'entry_point': str (name of the function to be implemented)
# }

def load_mbpp_dataset(split="test", num_samples=None) -> List[Dict[str, Any]]:
    """
    Loads the MBPP dataset from Hugging Face datasets hub.
    You might need to be logged in to Hugging Face CLI for some datasets.
    Alternatively, load from a local JSONL file if you have it.
    """
    try:
        from datasets import load_dataset
        logging.info(f"Loading MBPP dataset, split: {split}...")
        # Using 'sanitized' version as it's often preferred
        mbpp_dataset = load_dataset("mbpp", "sanitized", split=split, trust_remote_code=True)
        problems = list(mbpp_dataset)
        if num_samples:
            problems = problems[:num_samples]
        logging.info(f"Loaded {len(problems)} problems from MBPP {split} split.")
        return problems
    except Exception as e:
        logging.error(f"Failed to load MBPP dataset: {e}")
        logging.info("Please ensure 'datasets' library is installed ('pip install datasets') and you can access the dataset.")
        logging.info("As a fallback, returning a dummy problem set.")
        # Fallback dummy data if dataset loading fails
        return [
            {
                'task_id': 1,
                'text': "Write a Python function to add two numbers.",
                'code': "def add(a, b):\n  return a + b",
                'test_list': ["assert candidate(1, 2) == 3", "assert candidate(-1, 1) == 0"],
                'test_setup_code': "",
                'entry_point': 'candidate' # Assuming 'candidate' is the common entry point
            },
            {
                'task_id': 2,
                'text': "Write a Python function to find the largest number in a list.",
                'code': "def find_largest(lst):\n  return max(lst)",
                'test_list': ["assert candidate([1, 2, 3]) == 3", "assert candidate([-1, -5, 0]) == 0"],
                'test_setup_code': "",
                'entry_point': 'candidate'
            }
        ]

def execute_code_and_run_tests(
    code_str: str,
    entry_point: str, # Expected name of the function defined in code_str
    test_setup_code: str,
    test_list: List[str],
    timeout_seconds: int = 10 # Timeout for executing all tests for this one code snippet
) -> Tuple[bool, List[Dict[str, Any]]]:
    """
    Executes the provided code string and runs test cases against it.
    The generated code_str is expected to define a function with the name `entry_point`.
    The test cases in `test_list` will typically call `candidate(...)`.
    We will rename the function in the exec scope if entry_point is not 'candidate'.
    """
    execution_results = []
    all_tests_passed = True
    
    # Create a new local scope for each execution to avoid interference
    # The 'candidate' function, as used in MBPP tests, will be defined here.
    # Other helper functions or classes from code_str will also be in this scope.
    local_scope = {}

    try:
        # Execute the setup code (e.g., imports)
        if test_setup_code:
            exec(test_setup_code, local_scope)
        
        # Execute the generated code string. This defines the function(s).
        exec(code_str, local_scope)

        # If the entry_point from MBPP is not 'candidate',
        # but the tests use 'candidate', create an alias.
        # This assumes generated code uses the MBPP entry_point name.
        if entry_point != "candidate" and entry_point in local_scope:
            local_scope["candidate"] = local_scope[entry_point]
        elif "candidate" not in local_scope and entry_point in local_scope:
            # If the generated code defined 'entry_point' but not 'candidate', make it available as 'candidate'
            # This is a common pattern for MBPP where tests expect 'candidate'
             local_scope["candidate"] = local_scope[entry_point]
        elif "candidate" not in local_scope and entry_point not in local_scope:
            # Try to find *any* function defined if specific ones are not found.
            # This is a fallback and might not be robust.
            defined_funcs = [name for name, obj in local_scope.items() if callable(obj) and not isinstance(obj, type)]
            if len(defined_funcs) == 1: # If only one function is defined in code_str
                 local_scope["candidate"] = local_scope[defined_funcs[0]]
            else: # If no clear candidate function, this will likely fail tests with NameError
                logging.warning(f"Could not find function '{entry_point}' or 'candidate' in generated code. Defined functions: {defined_funcs}")


        # Run test cases
        # Note: A more robust solution for timeouts per test or overall would use multiprocessing
        # or a library designed for sandboxed execution with timeouts.
        # The `exec` calls themselves are blocking.
        start_time = time.time()
        for test_str in test_list:
            if time.time() - start_time > timeout_seconds:
                logging.warning(f"Test execution for a code snippet exceeded timeout of {timeout_seconds}s.")
                all_tests_passed = False
                execution_results.append({"test": test_str, "passed": False, "error": "Timeout"})
                break # Stop further tests for this code snippet

            try:
                exec(test_str, local_scope)
                execution_results.append({"test": test_str, "passed": True, "error": None})
            except AssertionError as e:
                all_tests_passed = False
                execution_results.append({"test": test_str, "passed": False, "error": f"AssertionError: {e}"})
            except Exception as e: # Catch other errors like NameError, TypeError during test execution
                all_tests_passed = False
                execution_results.append({"test": test_str, "passed": False, "error": f"ExecutionError: {type(e).__name__}: {e}"})
                break # Stop if a test causes a critical error

    except Exception as e: # Catch errors from exec(code_str) or exec(test_setup_code)
        all_tests_passed = False
        # This error applies to the code snippet as a whole, not a specific test
        execution_results.append({"test": "Code or Setup Execution", "passed": False, "error": f"Error during code/setup: {type(e).__name__}: {e}"})

    return all_tests_passed, execution_results


async def evaluate_lcdp_on_mbpp(
    lcdp_instance: LCDP,
    mbpp_problems: List[Dict[str, Any]],
    lcdp_params: Dict[str, Any],
    max_generated_codes_to_check: int # Corresponds to num_codes in lcdp_params
):
    """
    Main evaluation function.
    """
    results_summary = {
        "total_problems": len(mbpp_problems),
        "problems_solved_at_least_one_code": 0,
        "pass_at_k": {k: 0 for k in range(1, max_generated_codes_to_check + 1)},
        "detailed_results": []
    }

    for i, problem in tqdm(enumerate(mbpp_problems), desc="Evaluating MBPP Problems", total=len(mbpp_problems)):
        task_id = problem['task_id']
        problem_text = problem['text']
        entry_point = problem['entry_point'] # The function name LCDP should generate
        
        # Construct a more detailed task description for LCDP if needed
        # For MBPP, the 'text' usually suffices, but you might want to guide the entry point name.
        lcdp_task_description = f"{problem_text}\n\nWrite a Python function named '{entry_point}'. Your solution should only include the function definition and necessary imports."

        logging.info(f"Processing MBPP problem {i+1}/{len(mbpp_problems)} (ID: {task_id}) - Entry Point: {entry_point}")

        problem_result = {
            "task_id": task_id,
            "problem_text": problem_text,
            "entry_point": entry_point,
            "solved": False,
            "passed_k_value": -1, # Which k-th code passed, -1 if none
            "generated_code_evaluations": []
        }

        try:
            # Get generated codes from LCDP
            # best_codes is a dict: {code_id: {"code": str, "scores": dict}}
            # Python dicts (3.7+) maintain insertion order. If LCDP populates this
            # in ranked order, values() will give codes in that order.
            # Otherwise, you might need to sort based on 'scores'.
            best_codes_dict = await lcdp_instance.run(
                task_description=lcdp_task_description,
                **lcdp_params
            )
            
            # Assuming best_codes_dict.values() gives codes in ranked order
            # Or, if you need to sort by a specific score:
            # ranked_codes = sorted(best_codes_dict.values(), key=lambda x: x['scores']['some_score_key'], reverse=True)
            ranked_code_items = list(best_codes_dict.values())


            for k_minus_1, code_info in enumerate(ranked_code_items):
                code_str = code_info["code"]
                code_id = [cid for cid, cval in best_codes_dict.items() if cval == code_info][0] # Get key for this code_info
                
                logging.debug(f"  Evaluating generated code {k_minus_1+1}/{len(ranked_code_items)} (ID: {code_id}) for problem {task_id}")
                
                passed_all_tests, test_details = execute_code_and_run_tests(
                    code_str=code_str,
                    entry_point=entry_point,
                    test_setup_code=problem.get('test_setup_code', ''),
                    test_list=problem['test_list'],
                    timeout_seconds=10 # Timeout for this specific code snippet's tests
                )

                code_eval_info = {
                    "code_id": code_id,
                    "code_generated": code_str,
                    "lcdp_scores": code_info["score"],
                    "passed_all_mbpp_tests": passed_all_tests,
                    "mbpp_test_details": test_details
                }
                problem_result["generated_code_evaluations"].append(code_eval_info)

                if passed_all_tests:
                    if not problem_result["solved"]: # First solution that passed
                        problem_result["solved"] = True
                        problem_result["passed_k_value"] = k_minus_1 + 1
                        results_summary["problems_solved_at_least_one_code"] += 1
                    
                    # Increment pass@k for all k >= current k_minus_1 + 1
                    for k_val_check in range(k_minus_1 + 1, max_generated_codes_to_check + 1):
                        if results_summary["pass_at_k"].get(k_val_check) is not None:
                             # Check if this problem has already been counted for this k_val_check
                             # This logic ensures we count a problem once if any of its top k codes pass
                             # This is tricky if we iterate. Simpler: mark problem as "passed_at_k_X"
                             pass # Handled below by checking problem_result["solved"] and passed_k_value

                    logging.info(f"  Problem {task_id} SOLVED with code {k_minus_1+1} (ID: {code_id}).")
                    # Once a solution passes, we can break if we only care about "did it solve?"
                    # or "what's the minimum k?". For full Pass@k, we might need to check all.
                    # For this implementation, we mark it solved and continue to gather data on other codes if any.
                    # The pass_at_k calculation below will use the smallest k that passed.

            # After checking all generated codes for this problem:
            if problem_result["solved"]:
                min_k_passed = problem_result["passed_k_value"]
                for k_val_check in range(min_k_passed, max_generated_codes_to_check + 1):
                     results_summary["pass_at_k"][k_val_check] +=1


        except Exception as e:
            logging.error(f"Error processing problem {task_id}: {e}", exc_info=True)
            problem_result["error"] = str(e)
        
        results_summary["detailed_results"].append(problem_result)
        # Optional: Save intermediate results
        # with open(f"mbpp_eval_results_problem_{task_id}.json", "w") as f:
        #    json.dump(problem_result, f, indent=2)

    return results_summary


async def main():
    # --- Configuration ---
    lcdp_api_key = os.getenv("LCDP_API_KEY", "sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg")

    lcdp_model = "gpt-4o-mini" # Or your desired model
    lcdp_max_workers = 10 # Adjust as per your LCDP setup
    ignore_user_advice = True # Set to False if you want to see user advice in LCDP responses

    # Parameters for lcdp.run()
    # num_codes will determine the N for Pass@N
    lcdp_run_params = {
        "max_iterations": 3,
        "num_plans": 3,
        "num_tests": 5,      # LCDP's internal test generation count
        "num_codes": 3,      # Number of code solutions to generate per problem (for Pass@3)
        "refine_rounds": 2,
        "use_pass_rate_for_train": True,
        "test_timeout": 10,  # Timeout for LCDP's internal tests
    }
    
    # Evaluation settings
    # Set to a small number for quick testing, None for full dataset
    num_mbpp_samples_to_test = 5 # e.g., 10 problems for a quick run, or None for all
    mbpp_split = "test" # "test", "train", "validation"

    # --- Initialization ---
    # Initialize your LCDP framework
    # If using the mock, it's fine. If using your actual LCDP, ensure it's imported correctly.
    lcdp_client = LCDP(api_key=lcdp_api_key, model=lcdp_model, max_workers=lcdp_max_workers, ignore_advice=ignore_user_advice)
    
    # Load MBPP dataset
    # Using a small subset of the 'test' split for demonstration
    # For full evaluation, use the full 'test' split and potentially 'validation'
    mbpp_problems = load_mbpp_dataset(split=mbpp_split, num_samples=num_mbpp_samples_to_test)

    if not mbpp_problems:
        logging.error("No MBPP problems loaded. Exiting.")
        return

    # --- Run Evaluation ---
    start_time = time.time()
    evaluation_results = await evaluate_lcdp_on_mbpp(
        lcdp_instance=lcdp_client,
        mbpp_problems=mbpp_problems,
        lcdp_params=lcdp_run_params,
        max_generated_codes_to_check=lcdp_run_params["num_codes"]
    )
    end_time = time.time()

    # --- Report Results ---
    logging.info("\n--- MBPP Evaluation Results ---")
    logging.info(f"Total problems processed: {evaluation_results['total_problems']}")
    logging.info(f"Problems solved (at least one code passed): {evaluation_results['problems_solved_at_least_one_code']}")
    
    success_rate = (evaluation_results['problems_solved_at_least_one_code'] / evaluation_results['total_problems'] * 100) \
        if evaluation_results['total_problems'] > 0 else 0
    logging.info(f"Overall Success Rate: {success_rate:.2f}%")

    logging.info("Pass@k rates:")
    for k, count in evaluation_results['pass_at_k'].items():
        pass_k_rate = (count / evaluation_results['total_problems'] * 100) \
            if evaluation_results['total_problems'] > 0 else 0
        logging.info(f"  Pass@{k}: {count}/{evaluation_results['total_problems']} = {pass_k_rate:.2f}%")
    
    logging.info(f"Evaluation completed in {end_time - start_time:.2f} seconds.")

    # Save detailed results to a JSON file
    results_filename = "mbpp_evaluation_detailed_results.json"
    with open(results_filename, "w") as f:
        json.dump(evaluation_results, f, indent=2)
    logging.info(f"Detailed results saved to {results_filename}")

    # Example of accessing a specific problem's details:
    # if evaluation_results['detailed_results']:
    #     logging.info("\nExample detail for first problem:")
    #     logging.info(json.dumps(evaluation_results['detailed_results'][0], indent=2))


if __name__ == "__main__":
    # Ensure you have your LCDP module and its dependencies,
    # and the 'datasets' library installed.
    # Set LCDP_API_KEY environment variable if using the actual LCDP.
    asyncio.run(main())

