import concurrent.futures
import multiprocessing
import threading
import time
import sys
import argparse
import json
from collections import namedtuple

TestTask = namedtuple('TestTask', ['func_name', 'test_name', 'func_code', 'main_func_name', 'test_code'])

class TestRunner:
    def __init__(self, max_workers=10):
        self.max_workers = max_workers
    
    def compile_code(self, code_str, main_function_name=None):
        try:
            local_vars = {}
            exec(code_str, local_vars)
            if main_function_name is not None:
                func = local_vars.get(main_function_name)
                return func if callable(func) else None
            return next((obj for obj in local_vars.values() if callable(obj)), None)
        except Exception as e:
            print(f"Compilation Error: {str(e)}, code_str:\n {code_str}")
            return None

    def _run_all_tests(self, functions, test_cases, timeout=None):
        tasks = []
        for func_name, func_info in functions.items():
            for test_name, test_case in test_cases.items():
                tasks.append(TestTask(
                    func_name=func_name,
                    test_name=test_name,
                    func_code=func_info['code'],
                    main_func_name=func_info['main_function_name'],
                    test_code=test_case['test_function']
                ))
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_task = {}
            for task in tasks:
                future = executor.submit(
                    self._run_single_test,
                    task.func_code,
                    task.main_func_name,
                    task.test_code,
                    timeout
                )
                future_to_task[future] = task
            
            function_results = {fn: {} for fn in functions}
            test_results = {tn: {} for tn in test_cases}
            total_tasks = len(tasks)
            completed = 0
            last_progress = -1
            
            for future in concurrent.futures.as_completed(future_to_task.keys()):
                task = future_to_task[future]
                result = future.result()
                function_results[task.func_name][task.test_name] = result
                test_results[task.test_name][task.func_name] = result
                
                completed += 1
                print(f"PROGRESS_TASK: {completed}/{total_tasks}", flush=True)
            
            return function_results, test_results
    
    def _run_single_test(self, func_code, main_func_name, test_code, timeout):
        manager = multiprocessing.Manager()
        queue = manager.Queue()
        process = multiprocessing.Process(
            target=self._worker_process,
            args=(func_code, main_func_name, test_code, timeout, queue)
        )
        process.start()
        process.join(timeout + 1)
        
        if process.is_alive():
            process.terminate()
            process.join()
            return False
        return queue.get() if not queue.empty() else False
    
    def _worker_process(self, func_code, main_func_name, test_code, timeout, queue):
        try:
            func = self.compile_code(func_code, main_func_name)
            test_func = self.compile_code(test_code)
            if not func or not test_func:
                queue.put(False)
                return
            
            result_container = []
            event = threading.Event()
            
            def worker():
                try:
                    result_container.append(test_func(func))
                except Exception as e:
                    result_container.append(False)
                finally:
                    event.set()
            
            thread = threading.Thread(target=worker)
            thread.daemon = True
            thread.start()
            
            event.wait(timeout)
            queue.put(bool(result_container[0]) if result_container else False)
        except Exception as e:
            queue.put(False)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--functions_file", type=str, required=True)
    parser.add_argument("--test_cases_file", type=str, required=True)
    parser.add_argument("--max_workers", type=int, default=10)
    parser.add_argument("--timeout", type=float, default=5)
    
    args = parser.parse_args()
    runner = TestRunner(args.max_workers)

    # 从文件读取数据
    with open(args.functions_file, 'r') as f:
        functions = json.load(f)
    
    with open(args.test_cases_file, 'r') as f:
        test_cases = json.load(f)

    start_time = time.time()
    func_results, test_results = runner._run_all_tests(functions, test_cases, args.timeout)
    elapsed_time = time.time() - start_time
    
    print(f"FUNCTION_RESULTS:{json.dumps(func_results)}")
    print(f"TEST_RESULTS:{json.dumps(test_results)}")
    print(f"EXECUTION_TIME:{elapsed_time}")

if __name__ == "__main__":
    main()