import concurrent.futures
import multiprocessing
import threading
import time
import sys
import argparse
import json
from collections import namedtuple

class TestTask:
    def __init__(self, func_name, test_name, func_code, main_func_name, test_code):
        self.func_name = func_name
        self.test_name = test_name
        self.func_code = func_code
        self.main_func_name = main_func_name
        self.test_code = test_code

class TestRunner:
    def __init__(self, max_workers=10):
        self.max_workers = max_workers

    def compile_code(self, code_str, main_function_name=None):
        try:
            local_vars = {}
            exec(code_str, local_vars)
            if main_function_name is not None:
                func = local_vars.get(main_function_name)
                if callable(func):
                    return func, None
                else:
                    return None, f"Main function '{main_function_name}' not found or not callable"
            else:
                for name, obj in local_vars.items():
                    if callable(obj):
                        return obj, None
                return None, "No callable found in code"
        except Exception as e:
            return None, f"Compilation Error: {str(e)}"

    def _run_all_tests(self, functions, test_cases, timeout=None):
        tasks = []
        for func_name, func_info in functions.items():
            for test_name, test_case in test_cases.items():
                tasks.append(TestTask(
                    func_name=func_name,
                    test_name=test_name,
                    func_code=func_info['code'],
                    main_func_name=func_info['main_function_name'],
                    test_code=test_case['test_function']
                ))

        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_task = {}
            for task in tasks:
                future = executor.submit(
                    self._run_single_test,
                    task.func_code,
                    task.main_func_name,
                    task.test_code,
                    timeout
                )
                future_to_task[future] = task

            function_results = {fn: {} for fn in functions}
            test_results = {tn: {} for tn in test_cases}
            total_tasks = len(tasks)
            completed = 0

            for future in concurrent.futures.as_completed(future_to_task):
                task = future_to_task[future]
                result = future.result()
                function_results[task.func_name][task.test_name] = result
                test_results[task.test_name][task.func_name] = result
                completed += 1
                print(f"PROGRESS_TASK: {completed}/{total_tasks}", flush=True)

            return function_results, test_results

    def _run_single_test(self, func_code, main_func_name, test_code, timeout):
        manager = multiprocessing.Manager()
        queue = manager.Queue()
        process = multiprocessing.Process(
            target=self._worker_process,
            args=(func_code, main_func_name, test_code, timeout, queue)
        )
        process.start()
        process.join(timeout + 1 if timeout is not None else None)

        if process.is_alive():
            process.terminate()
            process.join()
            return {
                'success': False,
                'reason': 'timeout',
                'message': 'Process did not finish within the allowed time'
            }
        else:
            return queue.get() if not queue.empty() else {
                'success': False,
                'reason': 'unknown',
                'message': 'No result from process'
            }

    def _worker_process(self, func_code, main_func_name, test_code, timeout, queue):
        try:
            func, func_error = self.compile_code(func_code, main_func_name)
            if func is None:
                queue.put({
                    'success': False,
                    'reason': 'main_func_compile_error',
                    'message': func_error
                })
                return

            test_func, test_error = self.compile_code(test_code)
            if test_func is None:
                queue.put({
                    'success': False,
                    'reason': 'test_func_compile_error',
                    'message': test_error
                })
                return

            result_container = []
            event = threading.Event()

            def worker():
                try:
                    test_result, test_message = test_func(func)
                    if test_result:
                        result_container.append({
                            'success': True,
                            'reason': None,
                            'message': test_message
                        })
                    else:
                        result_container.append({
                            'success': False,
                            'reason': 'test_failed',
                            'message': test_message
                        })
                except Exception as e:
                    result_container.append({
                        'success': False,
                        'reason': 'exception',
                        'message': str(e)
                    })
                finally:
                    event.set()

            thread = threading.Thread(target=worker)
            thread.daemon = True
            thread.start()

            event_occurred = event.wait(timeout)
            if not event_occurred:
                queue.put({
                    'success': False,
                    'reason': 'timeout',
                    'message': f'Test execution exceeded {timeout} seconds'
                })
            else:
                if result_container:
                    queue.put(result_container[0])
                else:
                    queue.put({
                        'success': False,
                        'reason': 'unknown',
                        'message': 'No result captured after event was set'
                    })
        except Exception as e:
            queue.put({
                'success': False,
                'reason': 'worker_process_error',
                'message': str(e)
            })

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--functions_file", type=str, required=True)
    parser.add_argument("--test_cases_file", type=str, required=True)
    parser.add_argument("--max_workers", type=int, default=10)
    parser.add_argument("--timeout", type=float, default=5)
    
    args = parser.parse_args()
    runner = TestRunner(args.max_workers)

    # 从文件读取数据
    with open(args.functions_file, 'r') as f:
        functions = json.load(f)
    
    with open(args.test_cases_file, 'r') as f:
        test_cases = json.load(f)
    # print("debug#")
    start_time = time.time()
    func_results, test_results = runner._run_all_tests(functions, test_cases, args.timeout)
    elapsed_time = time.time() - start_time
    
    print(f"FUNCTION_RESULTS:{json.dumps(func_results)}")
    print(f"TEST_RESULTS:{json.dumps(test_results)}")
    print(f"EXECUTION_TIME:{elapsed_time}")

if __name__ == "__main__":
    main()