import json
import os
import fire
from tqdm import tqdm
import ast
from concurrent.futures import ProcessPoolExecutor, as_completed
import signal
import sys
import time
import hashlib
import tempfile
import shutil
from pathlib import Path
import multiprocessing as mp
import queue

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))
    return data

def save_jsonl(file_path, data):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def merge_jsonl_files(file_list, output_file):
    print(f"Merging {len(file_list)} files into {output_file}")
    total_items = 0
    with open(output_file, 'w', encoding='utf-8') as outf:
        for file_path in sorted(file_list):
            if os.path.exists(file_path):
                with open(file_path, 'r', encoding='utf-8') as inf:
                    for line in inf:
                        if line.strip():
                            outf.write(line)
                            total_items += 1
    print(f"Merged {total_items} items total")
    return total_items

def quick_safety_check(program: str) -> bool:
    dangerous_patterns = ["import os", "subprocess", "system(", "eval(", "exec(", "__import__"]
    program_lower = program.lower()
    return not any(pattern in program_lower for pattern in dangerous_patterns)

def validate_testcase_fast(testcase: str) -> tuple:
    try:
        parsed = ast.parse(testcase, mode='eval')
        if isinstance(parsed.body, ast.Call) and isinstance(parsed.body.func, ast.Name):
            return True, parsed.body.func.id
        return True, None
    except:
        return False, None

def execute_single_testcase(code: str, testcase: str, result_queue: mp.Queue):
    try:
        if not quick_safety_check(code):
            result_queue.put(('error', 'Dangerous code detected'))
            return
        is_valid, func_name = validate_testcase_fast(testcase)
        if not is_valid:
            result_queue.put(('error', 'Invalid testcase syntax'))
            return
        execution_context = {"__builtins__": __builtins__}
        exec(code, execution_context)
        if func_name and func_name not in execution_context:
            result_queue.put(('error', f'Function {func_name} not defined'))
            return
        exec(f"__TEMP_RESULT__ = {testcase}", execution_context)
        result = execution_context["__TEMP_RESULT__"]
        result_queue.put(('success', repr(result)))
    except Exception as e:
        result_queue.put(('error', str(e)))

def execute_with_process_isolation(code: str, testcase: str, timeout: float = 1.0) -> dict:
    try:
        result_queue = mp.Queue()
        process = mp.Process(target=execute_single_testcase, args=(code, testcase, result_queue))
        process.start()
        try:
            status, result = result_queue.get(timeout=timeout)
            process.join(timeout=0.1)
            if process.is_alive():
                process.terminate()
                process.join()
            if status == 'success':
                return {'expected_output': result, 'has_valid_testcase': True}
            else:
                return {'expected_output': f'Error: {result}', 'has_valid_testcase': False}
        except queue.Empty:
            if process.is_alive():
                process.terminate()
                process.join()
            return {'expected_output': 'Error: Timeout', 'has_valid_testcase': False}
    except Exception as e:
        return {'expected_output': f'Error: {str(e)}', 'has_valid_testcase': False}

def process_chunk_safe(chunk_data: list, chunk_id: int, output_dir: str, timeout: float = 0.5) -> dict:
    chunk_file = os.path.join(output_dir, f"chunk_{chunk_id:04d}.jsonl")
    print(f"Processing chunk {chunk_id} with {len(chunk_data)} items (subprocess mode)")
    results = []
    stats = {'valid_testcase': 0, 'invalid_testcase': 0, 'processed': 0}
    start_time = time.time()
    for item in tqdm(chunk_data, desc=f"Chunk {chunk_id}", leave=False):
        result_item = item.copy()
        result_item.pop("is_valid_python", None)
        result_item.pop("has_valid_testcase", None)
        code = result_item.get("solution", "")
        test_case = result_item.get("testcase_payload", "")
        if not code or not test_case:
            result_item["expected_output"] = "Error: Missing code or testcase"
            result_item["has_valid_testcase"] = False
            stats['invalid_testcase'] += 1
        else:
            execution_result = execute_with_process_isolation(code, test_case, timeout)
            result_item.update(execution_result)
            if result_item['has_valid_testcase']:
                stats['valid_testcase'] += 1
            else:
                stats['invalid_testcase'] += 1
        results.append(result_item)
        stats['processed'] += 1
    save_jsonl(chunk_file, results)
    elapsed = time.time() - start_time
    speed = len(chunk_data) / elapsed if elapsed > 0 else 0
    print(f"Chunk {chunk_id} completed: {stats['processed']} items in {elapsed:.1f}s ({speed:.1f} items/sec)")
    return {'chunk_id': chunk_id, 'chunk_file': chunk_file, 'stats': stats, 'elapsed': elapsed}

def generate_expected_outputs_distributed(
    data_file: str,
    output_file: str,
    max_execution_time: float = 0.5,
    num_processes: int = 8,
    max_items: int = None,
    chunk_size: int = None
):
    print(f"Processing file: {data_file}")
    print(f"Distributed processing with {num_processes} processes (subprocess mode)")
    if not os.path.exists(data_file):
        print(f"Error: File {data_file} does not exist")
        return
    try:
        data = load_jsonl(data_file)
        print(f"Loaded {len(data)} items")
    except Exception as e:
        print(f"Error loading data: {e}")
        return
    if max_items is not None:
        data = data[:max_items]
        print(f"Limited to first {len(data)} items")
    if chunk_size is None:
        chunk_size = max(100, len(data) // num_processes)
        print(f"Auto-calculated chunk size: {chunk_size}")
    temp_dir = tempfile.mkdtemp(prefix="distributed_processing_subprocess_")
    print(f"Using temp directory: {temp_dir}")
    try:
        chunks = []
        for i in range(0, len(data), chunk_size):
            chunk_data = data[i:i + chunk_size]
            chunk_id = len(chunks)
            chunks.append((chunk_data, chunk_id))
        print(f"Split data into {len(chunks)} chunks")
        estimated_time = len(data) * max_execution_time * 2 / num_processes / 60
        print(f"Estimated processing time: {estimated_time:.1f} minutes (subprocess mode)")
        chunk_results = []
        total_stats = {'valid_testcase': 0, 'invalid_testcase': 0, 'processed': 0}
        start_time = time.time()
        with ProcessPoolExecutor(max_workers=num_processes) as executor:
            future_to_chunk = {}
            for chunk_data, chunk_id in chunks:
                future = executor.submit(process_chunk_safe, chunk_data, chunk_id, temp_dir, max_execution_time)
                future_to_chunk[future] = chunk_id
            for future in tqdm(as_completed(future_to_chunk), total=len(chunks), desc="Processing chunks"):
                try:
                    result = future.result()
                    chunk_results.append(result)
                    for key, value in result['stats'].items():
                        total_stats[key] += value
                except Exception as e:
                    chunk_id = future_to_chunk[future]
                    print(f"Error processing chunk {chunk_id}: {e}")
        elapsed_total = time.time() - start_time
        print("Merging chunk files...")
        chunk_files = [result['chunk_file'] for result in chunk_results if os.path.exists(result['chunk_file'])]
        if chunk_files:
            merged_count = merge_jsonl_files(chunk_files, output_file)
            if merged_count != len(data):
                print(f"Warning: Data count mismatch! Expected {len(data)}, got {merged_count}")
            else:
                print("Data integrity verified: all items processed")
        else:
            print("Error: No chunk files found!")
            return
        print(f"\nFinal Statistics:")
        print(f"Total items: {total_stats['processed']}")
        print(f"Valid testcases: {total_stats['valid_testcase']} ({total_stats['valid_testcase']/total_stats['processed']*100:.1f}%)")
        print(f"Invalid testcases: {total_stats['invalid_testcase']} ({total_stats['invalid_testcase']/total_stats['processed']*100:.1f}%)")
        print(f"Total time: {elapsed_total/60:.1f} minutes")
        print(f"Average speed: {total_stats['processed']/elapsed_total:.1f} items/second")
        print(f"Processing mode: Subprocess isolation")
        print(f"Output saved to: {output_file}")
    finally:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
            print(f"Cleaned up temp directory: {temp_dir}")

def test_distributed(data_file: str, max_items: int = 10):
    output_file = f"/tmp/test_distributed_subprocess_{int(time.time())}.jsonl"
    print(f"Testing distributed processing (subprocess mode) with {max_items} items")
    generate_expected_outputs_distributed(
        data_file=data_file,
        output_file=output_file,
        max_items=max_items,
        num_processes=2
    )
    if os.path.exists(output_file):
        result_data = load_jsonl(output_file)
        print(f"Test completed: {len(result_data)} items in output file")
        print("\nSample results:")
        for i, item in enumerate(result_data[:3]):
            print(f"Item {i+1}:")
            print(f"  Expected output: {item.get('expected_output', 'N/A')}")
            print(f"  Valid testcase: {item.get('has_valid_testcase', 'N/A')}")
        os.remove(output_file)
    else:
        print("Test failed: no output file generated")

if __name__ == "__main__":
    fire.Fire({
        'generate': generate_expected_outputs_distributed,
        'test': test_distributed
    })
