import json
import numpy as np
from typing import List, Union, Dict, Optional
import itertools
from collections import defaultdict
import multiprocessing
import contextlib
import io
import signal
import tempfile
import os
import ast
import argparse
class TimeoutException(Exception):
    pass

@contextlib.contextmanager
def time_limit(seconds: float):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    signal.setitimer(signal.ITIMER_REAL, seconds)
    signal.signal(signal.SIGALRM, signal_handler)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0)

def unsafe_execute(code: str, test_case: str, result: list, timeout: float):
    """
    Executes code and test case in a restricted environment.
    """
    try:
        with time_limit(timeout):
            # Execute the code and test case in a restricted environment
            local_dict = {}
            exec(code, local_dict)
            exec(test_case, local_dict)
            result.append(("passed", None))
    except TimeoutException:
        result.append(("timed out", "Execution took too long"))
    except Exception as e:
        result.append(("failed", str(e)))

def check_correctness(code: str, test_case: str, timeout: float = 3.0) -> Dict:
    """
    Safely evaluates a single test case for the given code.
    """
    manager = multiprocessing.Manager()
    result = manager.list()

    p = multiprocessing.Process(
        target=unsafe_execute,
        args=(code, test_case, result, timeout)
    )
    p.start()
    p.join(timeout=timeout + 1)
    
    if p.is_alive():
        p.kill()
        result.append(("timed out", "Process killed after timeout"))
    
    if not result:  # If result is empty
        result.append(("timed out", "No result returned"))
        
    status, error = result[0]
    return {
        "passed": status == "passed",
        "result": status,
        "error": error
    }

def clean_code(output: str) -> str:
    """
    Extracts clean code from the output field of a JSON object.
    
    Args:
        output (str): The output string containing markdown-formatted code
        
    Returns:
        str: The extracted clean code
        
    Example:
        Input: "```python\ndef func():\n    return True\n```"
        Output: "def func():\n    return True"
    """
    # Handle empty or invalid input
    if not output or not isinstance(output, str):
        raise ValueError("Output must be a non-empty string")
    
    # Find code block between triple backticks
    start_marker = "```python"
    end_marker = "```"
    
    try:
        # Find the code block
        start_idx = output.index(start_marker) + len(start_marker)
        end_idx = output.index(end_marker, start_idx)
        
        # Extract and clean the code
        code = output[start_idx:end_idx].strip()
        
        return code
        
    except ValueError:
        # If no code block is found, try to extract code directly
        if "def " in output:
            # Find the first function definition
            start_idx = output.index("def ")
            # Take everything from 'def' onwards
            return output[start_idx:].strip()
        else:
            raise ValueError("No valid Python code block found in output")

def evaluate_solution(json_obj: Dict) -> tuple[int, int]:
    """
    Evaluates a single solution against its test cases.
    """
    code = clean_code(json_obj["output"])
    test_cases = json_obj["test_list"]
    
    print(f"\nEvaluating task {json_obj['task_id']}:")
    print("Code:")
    print(code)
    print("\nTest Results:")
    
    total_tests = len(test_cases)
    passed_tests = 0
    
    for i, test in enumerate(test_cases):
        result = check_correctness(code, test)
        passed = result["passed"]
        if passed:
            passed_tests += 1
            print(f"Test {i+1}: ✓ Passed")
        else:
            print(f"Test {i+1}: ✗ Failed")
            print(f"  Test case: {test}")
            print(f"  Error: {result['error']}")
    
    print(f"\nPassed {passed_tests}/{total_tests} tests")
    return total_tests, passed_tests

def estimate_pass_at_k(
    num_samples: Union[int, List[int], np.ndarray],
    num_correct: Union[List[int], np.ndarray],
    k: int
) -> np.ndarray:
    """
    Estimates pass@k of each problem and returns them in an array.
    """
    def estimator(n: int, c: int, k: int) -> float:
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    if isinstance(num_samples, int):
        num_samples_it = itertools.repeat(num_samples, len(num_correct))
    else:
        assert len(num_samples) == len(num_correct)
        num_samples_it = iter(num_samples)

    return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str)
    args = parser.parse_args()
    
    results = defaultdict(list)
    
    # Read JSON objects and group by task_id
    with open(args.path, 'r') as f:
        for line in f:
            json_obj = json.loads(line)
            task_id = json_obj['task_id']
            total, passed = evaluate_solution(json_obj)
            results[task_id].append(passed == total)
    
    # Calculate pass@k
    total, correct = [], []
    for task_results in results.values():
        total.append(len(task_results))
        correct.append(sum(task_results))
    
    total = np.array(total)
    correct = np.array(correct)
    
    # Calculate pass@k for different k values
    ks = [1]
    pass_at_k = {
        f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
        for k in ks if (total >= k).all()
    }
    
    # Print results
    print("\nOverall Results:")
    for k, value in pass_at_k.items():
        print(f"{k}: {value:.3f}")
    # log path 
    log_path = os.path.join(os.path.dirname(args.path), 'passk.txt')
    with open(log_path, 'a') as f:
        f.write(f"{args.path}: {pass_at_k}\n")

if __name__ == "__main__":
    main() 