import json
import os
import fire
from tqdm import tqdm
import ast
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
import tempfile
import shutil
import multiprocessing as mp
import queue

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))
    return data

def save_jsonl(file_path, data):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def append_jsonl(file_path, data):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'a', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def merge_jsonl_files(file_list, output_file):
    print(f"Merging {len(file_list)} files into {output_file}")
    total_items = 0
    with open(output_file, 'w', encoding='utf-8') as outf:
        for file_path in sorted(file_list):
            if os.path.exists(file_path):
                with open(file_path, 'r', encoding='utf-8') as inf:
                    for line in inf:
                        if line.strip():
                            outf.write(line)
                            total_items += 1
    print(f"Merged {total_items} items total")
    return total_items

def quick_safety_check(program: str) -> bool:
    dangerous_patterns = ["import os", "subprocess", "system(", "eval(", "exec(", "__import__"]
    program_lower = program.lower()
    return not any(pattern in program_lower for pattern in dangerous_patterns)

def validate_testcase_fast(testcase: str) -> tuple:
    try:
        parsed = ast.parse(testcase, mode='eval')
        if isinstance(parsed.body, ast.Call) and isinstance(parsed.body.func, ast.Name):
            return True, parsed.body.func.id
        return True, None
    except:
        return False, None

def execute_code_detection(code: str, testcase: str, result_queue: mp.Queue):
    try:
        if not quick_safety_check(code):
            result_queue.put(('error', 'Dangerous code detected'))
            return
        is_valid, func_name = validate_testcase_fast(testcase)
        if not is_valid:
            result_queue.put(('error', 'Invalid testcase syntax'))
            return
        execution_context = {"__builtins__": __builtins__}
        exec(code, execution_context)
        if func_name and func_name not in execution_context:
            result_queue.put(('error', f'Function {func_name} not defined'))
            return
        exec(f"__TEMP_RESULT__ = {testcase}", execution_context)
        result = execution_context["__TEMP_RESULT__"]
        result_queue.put(('success', repr(result)))
    except Exception as e:
        result_queue.put(('error', str(e)))

def execute_with_process_isolation(code: str, testcase: str, timeout: float = 1.0) -> str:
    try:
        result_queue = mp.Queue()
        process = mp.Process(target=execute_code_detection, args=(code, testcase, result_queue))
        process.start()
        try:
            status, result = result_queue.get(timeout=timeout)
            process.join(timeout=0.1)
            if process.is_alive():
                process.terminate()
                process.join()
            if status == 'success':
                return result
            else:
                return f'Error: {result}'
        except queue.Empty:
            if process.is_alive():
                process.terminate()
                process.join()
            return 'Error: Timeout'
    except Exception as e:
        return f'Error: {str(e)}'

def normalize_output(output_str: str) -> str:
    if not isinstance(output_str, str):
        return str(output_str)
    normalized = ' '.join(output_str.split())
    return normalized.strip()

def compare_outputs(actual: str, expected: str) -> bool:
    if actual == expected:
        return True
    actual_norm = normalize_output(actual)
    expected_norm = normalize_output(expected)
    if actual_norm == expected_norm:
        return True
    if actual.startswith("Error:") and expected.startswith("Error:"):
        return True
    return False

def detect_code_problem(item: dict, timeout: float = 1.0) -> dict:
    result_item = item.copy()
    if result_item.get('is_invalid_testcase', False):
        result_item["has_detected"] = False
        result_item["actual_output"] = 'N/A'
        result_item["detection_reason"] = "Invalid testcase: failed on ground truth"
        return result_item
    extracted_code = result_item.get("extracted_code", "")
    testcase_payload = result_item.get("testcase_payload", "")
    expected_output = result_item.get("expected_output", "")
    if not extracted_code or not testcase_payload or not expected_output:
        result_item["has_detected"] = False
        result_item["detection_reason"] = "Missing required fields"
        return result_item
    try:
        actual_output = execute_with_process_isolation(extracted_code, testcase_payload, timeout)
        outputs_match = compare_outputs(actual_output, expected_output)
        result_item["has_detected"] = not outputs_match
        result_item["actual_output"] = actual_output
        if not outputs_match:
            result_item["raw_extracted_output"] = actual_output
            result_item["detection_reason"] = f"Output mismatch: expected '{expected_output}', got '{actual_output}'"
        else:
            result_item["detection_reason"] = "Outputs match"
    except Exception as e:
        result_item["has_detected"] = True
        result_item["actual_output"] = f"Execution failed: {str(e)}"
        result_item["detection_reason"] = f"Execution error: {str(e)}"
    return result_item

def process_chunk_detection(chunk_data: list, chunk_id: int, output_dir: str, timeout: float = 1.0) -> dict:
    chunk_file = os.path.join(output_dir, f"chunk_{chunk_id:04d}.jsonl")
    print(f"Processing chunk {chunk_id} with {len(chunk_data)} items (subprocess mode)")
    results = []
    stats = {'detected': 0, 'not_detected': 0, 'processed': 0, 'invalid_testcases': 0}
    start_time = time.time()
    for item in tqdm(chunk_data, desc=f"Chunk {chunk_id}", leave=False):
        result = detect_code_problem(item, timeout)
        results.append(result)
        if result.get('is_invalid_testcase', False):
            stats['invalid_testcases'] += 1
        elif result['has_detected']:
            stats['detected'] += 1
        else:
            stats['not_detected'] += 1
        stats['processed'] += 1
    save_jsonl(chunk_file, results)
    elapsed = time.time() - start_time
    speed = len(chunk_data) / elapsed if elapsed > 0 else 0
    print(f"Chunk {chunk_id} completed: {stats['processed']} items in {elapsed:.1f}s ({speed:.1f} items/sec)")
    return {'chunk_id': chunk_id, 'chunk_file': chunk_file, 'stats': stats, 'elapsed': elapsed}

def detect_code_problems(
    data_file: str,
    output_file: str,
    max_execution_time: float = 1.0,
    num_processes: int = 8,
    max_items: int = None,
    chunk_size: int = None
):
    print(f"Processing file: {data_file}")
    print(f"Detecting code problems with {num_processes} processes (subprocess mode)")
    if not os.path.exists(data_file):
        print(f"Error: File {data_file} does not exist")
        return
    try:
        data = load_jsonl(data_file)
        print(f"Loaded {len(data)} items")
    except Exception as e:
        print(f"Error loading data: {e}")
        return
    num_invalids = 0
    for item in data:
        if not item.get('has_valid_testcase', False):
            item['is_invalid_testcase'] = True
            num_invalids += 1
        else:
            item['is_invalid_testcase'] = False
    valid_data = data
    print(f"Filtered to {len(valid_data)} items with valid testcases")
    if len(valid_data) == 0:
        print("No valid testcases found!")
        return
    if max_items is not None:
        valid_data = valid_data[:max_items]
        print(f"Limited to first {len(valid_data)} items")
    if chunk_size is None:
        chunk_size = max(10, len(valid_data) // num_processes)
        print(f"Auto-calculated chunk size: {chunk_size}")
    temp_dir = tempfile.mkdtemp(prefix="code_detection_subprocess_")
    print(f"Using temp directory: {temp_dir}")
    try:
        chunks = []
        for i in range(0, len(valid_data), chunk_size):
            chunk_data = valid_data[i:i + chunk_size]
            chunk_id = len(chunks)
            chunks.append((chunk_data, chunk_id))
        print(f"Split data into {len(chunks)} chunks")
        estimated_time = len(valid_data) * max_execution_time * 2 / num_processes / 60
        print(f"Estimated processing time: {estimated_time:.1f} minutes (subprocess mode)")
        chunk_results = []
        total_stats = {'detected': 0, 'not_detected': 0, 'processed': 0, 'invalid_testcases': 0}
        start_time = time.time()
        with ProcessPoolExecutor(max_workers=num_processes) as executor:
            future_to_chunk = {}
            for chunk_data, chunk_id in chunks:
                future = executor.submit(process_chunk_detection, chunk_data, chunk_id, temp_dir, max_execution_time)
                future_to_chunk[future] = chunk_id
            for future in tqdm(as_completed(future_to_chunk), total=len(chunks), desc="Processing chunks"):
                try:
                    result = future.result()
                    chunk_results.append(result)
                    for key, value in result['stats'].items():
                        total_stats[key] += value
                except Exception as e:
                    chunk_id = future_to_chunk[future]
                    print(f"Error processing chunk {chunk_id}: {e}")
        elapsed_total = time.time() - start_time
        print("Merging chunk files...")
        chunk_files = [result['chunk_file'] for result in chunk_results if os.path.exists(result['chunk_file'])]
        if chunk_files:
            merged_count = merge_jsonl_files(chunk_files, output_file)
            if merged_count != len(valid_data):
                print(f"Warning: Data count mismatch! Expected {len(valid_data)}, got {merged_count}")
            else:
                print("Data integrity verified: all items processed")
        else:
            print("Error: No chunk files found!")
            return
        print(f"\nFinal Statistics:")
        print(f"Total items processed: {total_stats['processed']}")
        print(f"Code problems detected: {total_stats['detected']} ({total_stats['detected']/total_stats['processed']*100:.1f}%)")
        print(f"No problems detected: {total_stats['not_detected']} ({total_stats['not_detected']/total_stats['processed']*100:.1f}%)")
        print(f"Invalid testcases: {total_stats['invalid_testcases']} ({total_stats['invalid_testcases']/total_stats['processed']*100:.1f}%)")
        print(f"Total time: {elapsed_total/60:.1f} minutes")
        print(f"Average speed: {total_stats['processed']/elapsed_total:.1f} items/second")
        print(f"Processing mode: Subprocess isolation")
        print(f"Output saved to: {output_file}")
    finally:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
            print(f"Cleaned up temp directory: {temp_dir}")

def test_detection(data_file: str, item_index: int = 0):
    data = load_jsonl(data_file)
    valid_items = [item for item in data if item.get('has_valid_testcase', False)]
    if not valid_items:
        print("No valid testcases found in the file!")
        return
    if item_index >= len(valid_items):
        print(f"Index {item_index} out of range, using index 0")
        item_index = 0
    item = valid_items[item_index]
    print(f"Testing detection on item {item_index} (subprocess mode)")
    print(f"Extracted code: {item.get('extracted_code', '')[:100]}...")
    print(f"Testcase: {item.get('testcase_payload', '')}")
    print(f"Expected output: {item.get('expected_output', '')}")
    result = detect_code_problem(item, timeout=2.0)
    print(f"\nDetection result:")
    print(f"Has detected problem: {result['has_detected']}")
    print(f"Actual output: {result.get('actual_output', 'N/A')}")
    print(f"Detection reason: {result.get('detection_reason', 'N/A')}")

if __name__ == "__main__":
    fire.Fire({'detect': detect_code_problems, 'test': test_detection})
