import asyncio
import io
import traceback
import threading
from contextlib import redirect_stdout
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor, TimeoutError
from concurrent.futures.process import BrokenProcessPool
from z3 import *

_simplify = z3.Tactic("ctx-solver-simplify")

# Global executor for Z3 operations (CPU-bound, needs process isolation)
# Initialized lazily and managed carefully to prevent deadlocks/zombies
Z3_EXECUTOR = None
Z3_EXECUTOR_LOCK = threading.Lock() # Use a lock to protect Z3_EXECUTOR initialization

def get_z3_executor():
    """Get or create the Z3 executor, ensuring it's alive and properly managed."""
    global Z3_EXECUTOR
    with Z3_EXECUTOR_LOCK:
        if Z3_EXECUTOR is None or (hasattr(Z3_EXECUTOR, '_shutdown') and Z3_EXECUTOR._shutdown):
            if Z3_EXECUTOR:
                print("Detected Z3_EXECUTOR is dead or shutdown. Attempting aggressive cleanup...")
                try:
                    # Attempt to shut down any lingering processes. wait=False means don't block.
                    Z3_EXECUTOR.shutdown(wait=False, cancel_futures=True)
                except Exception as e:
                    print(f"Error during Z3_EXECUTOR previous shutdown attempt: {e}")
            
            print("Initializing new Z3 ProcessPoolExecutor...")
            Z3_EXECUTOR = ProcessPoolExecutor(max_workers=1)
    return Z3_EXECUTOR

def _canonical(expr: z3.ExprRef) -> str:
    """Return a canonical s‑expression: 1) push to NNF, 2) simplify, 3) sort."""
    nnf = z3.Tactic("nnf")(expr).as_expr()  # push Not inward
    simplified = _simplify(nnf).as_expr()
    # sort children of Or / And so that A∨B == B∨A in text
    def _sort(e):
        if e.num_args() == 0:
            return e.sexpr()
        op = e.decl().name()
        args = sorted(_sort(a) for a in e.children()) if op in {"and", "or"} else [_sort(a) for a in e.children()]
        return f"({op} {' '.join(args)})"
    return _sort(simplified)

def _extract_assertions(code: str, tag: str):
    """Execute *code* and return the assertion list of the last Solver found."""
    # Temporarily redirect stdout to suppress Z3 internal prints during exec
    with io.StringIO() as buf, redirect_stdout(buf):
        ns = {"__name__": f"z3_{tag}", "z3": z3, **vars(z3)}
        # Pre-compile for potentially faster execution if run multiple times, and better error reporting
        try:
            exec(compile(code, f"<{tag}>", "exec"), ns)
        except Exception as e:
            raise ValueError(f"Error compiling or executing Z3 code ({tag}): {e}\nCode:\n{code}")
    
    solvers = [v for v in ns.values() if isinstance(v, z3.Solver)]
    if not solvers:
        raise ValueError(f"No z3.Solver instance in {tag}")
    return list(solvers[-1].assertions())

def constraints_covered(src_code: str, tgt_code: str):
    """Return (covered: bool, missing: list[str], additions: list[str])."""
    Cs_src = {_canonical(c) for c in _extract_assertions(src_code, "src")}
    Cs_tgt = {_canonical(c) for c in _extract_assertions(tgt_code, "tgt")}

    missing = sorted(cs for cs in Cs_src if cs not in Cs_tgt)
    addition = sorted(cs for cs in Cs_tgt if cs not in Cs_src)
    return (len(missing) == 0 and len(addition) == 0), missing, addition

def _same_z3_worker(original_code: str, translated_code: str) -> bool:
    """Worker function for Z3 comparison, executed in a separate process."""
    try:
        ok, missing, addition = constraints_covered(original_code, translated_code)
        if ok:
            print("Z3 comparison: All constraints covered.")
            return True
        else:
            print(f"Z3 comparison: Missing constraints: {missing[:5]}...") # Log first few missing
            print(f"Z3 comparison: Additional constraints: {addition[:5]}...") # Log first few addition
            return False
    except Exception as e:
        print(f"Error in _same_z3_worker: {e}")
        traceback.print_exc()
        return False

def _z3_check_subprocess(original_code: str, translated_code: str) -> bool:
    """Run Z3 check in subprocess with a timeout."""
    executor = get_z3_executor()
    try:
        # Submit the task to the shared Z3 executor
        future = executor.submit(_same_z3_worker, original_code, translated_code)
        result = future.result(timeout=120) # Increased timeout for Z3
        print("Z3 comparison completed successfully.")
        return result
    except TimeoutError:
        print("Z3 comparison operation timed out (120 seconds).")
        # Attempt to cancel the future and clean up the executor
        future.cancel()
        return False
    except (BrokenProcessPool, BrokenPipeError) as e:
        print(f"Z3 worker pool crashed or broken pipe detected: {e}. Attempting to reinitialize and retry.")
        # Force a re-initialization of the executor
        global Z3_EXECUTOR
        with Z3_EXECUTOR_LOCK: # Protect global Z3_EXECUTOR modification
            if Z3_EXECUTOR:
                try:
                    Z3_EXECUTOR.shutdown(wait=True, cancel_futures=True) # Wait for shutdown
                except Exception as shutdown_err:
                    print(f"Error during Z3_EXECUTOR shutdown on crash: {shutdown_err}")
                finally:
                    Z3_EXECUTOR = None # Force reset
        
        # Retry once with a freshly initialized executor
        try:
            executor = get_z3_executor() # This will get a new executor
            future = executor.submit(_same_z3_worker, original_code, translated_code)
            result = future.result(timeout=120) # Retry with timeout
            print("Z3 comparison retry successful.")
            return result
        except Exception as retry_error:
            print(f"Z3 retry failed after crash: {retry_error}")
            traceback.print_exc()
            return False
    except Exception as e:
        print(f"Unexpected exception during Z3 subprocess execution: {e}")
        traceback.print_exc()
        return False

def _z3_format_check(z3_code: str) -> bool:
    """Check if the Z3 code snippet is syntactically valid and executable."""
    if not z3_code:
        print("Empty Z3 code for format check.")
        return False
    
    # Try to compile the code. This checks for basic syntax errors.
    try:
        compile(z3_code, '<z3_code_string>', 'exec')
        # We don't execute it here as it requires Z3 context and might be slow.
        # Execution is handled by _z3_check_subprocess which runs in a process pool.
        return True
    except SyntaxError as e:
        print(f"Z3 code syntax error: {e}")
        return False
    except Exception as e:
        print(f"Other Z3 code compilation error: {e}")
        return False

async def compute_single_score_async(solution_str: str, ground_truth: str) -> float:
    """Compute the score for a single solution asynchronously."""
    if not solution_str or not ground_truth:
        print("Empty solution or ground truth provided for scoring.")
        return 0.0

    try:
        if "Final Z3 Code:" not in solution_str:
            return 0.0

        back_code = solution_str.split("Final Z3 Code:")[-1].strip()
        z3_match = await asyncio.to_thread(_z3_check_subprocess, ground_truth, back_code)
        if not z3_match:
            print(f"Score: 0.0 (Z3 code mismatch).")
            return 0.1
        print(f"Z3 code matched.")
        return 1.0

    except Exception as e:
        print(f"Critical error in compute_single_score_async: {e}")
        traceback.print_exc()
        return 0.0

def process_solution_worker(index: int, solution_str: str, ground_truth: str) -> tuple[int, float]:
    """
    Worker function to process a single solution, including its async components.
    This runs in a ThreadPoolExecutor.
    """
    # Each thread needs its own asyncio event loop
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    try:
        score = loop.run_until_complete(compute_single_score_async(solution_str, ground_truth))
        return index, score
    except Exception as e:
        print(f"Error in worker process for solution {index}: {str(e)}")
        traceback.print_exc()
        return index, 0.1 # Return a minimal score on unhandled worker error
    finally:
        # 修复：在关闭loop之前确保所有任务都完成
        try:
            pending = asyncio.all_tasks(loop)
            if pending:
                loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
        except Exception as cleanup_error:
            print(f"Error during loop cleanup for solution {index}: {cleanup_error}")
        finally:
            loop.close()
        # print(f"Worker for solution {index} finished and cleaned up.")


def compute_score_batch(solution_strs: list[str], ground_truths: list[str]) -> list[float]:
    """
    Compute scores for a batch of solutions using a ThreadPoolExecutor.
    """
    # 修复：检查输入参数长度匹配
    if len(solution_strs) != len(ground_truths):
        raise ValueError(f"Mismatch in input lengths: {len(solution_strs)} solutions vs {len(ground_truths)} ground truths")
    
    results = [0.0] * len(solution_strs)
    
    valid_pairs = [(idx, solution_str, ground_truth) 
                  for idx, (solution_str, ground_truth) in enumerate(zip(solution_strs, ground_truths))
                  if solution_str and ground_truth]
    
    if not valid_pairs:
        print("No valid solution-ground truth pairs to process.")
        return results

    # Determine max_workers. A few tens are usually enough for I/O bound tasks
    # when combined with efficient async I/O.
    # Set a reasonable limit like 64 or 128, far less than 512.
    max_workers = min(64, len(valid_pairs)) 
    print(f"Using {max_workers} threads in ThreadPoolExecutor for processing {len(valid_pairs)} solutions.")
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_solution_worker, idx, solution_str, ground_truth) 
                   for idx, solution_str, ground_truth in valid_pairs]
        
        for future in as_completed(futures):
            try:
                idx, score = future.result()
                if idx is not None and 0 <= idx < len(results):
                    results[idx] = score
            except Exception as e:
                print(f"Error retrieving result from future: {e}")
                traceback.print_exc()
                # A specific future failed, but other results might still be valid
                
    # Final cleanup of the global Z3 executor after all tasks are done
    # This prevents the Z3 process from lingering after all work is complete.
    cleanup_z3_executor()
    
    return results

def cleanup_z3_executor():
    """Clean up the global Z3 executor."""
    global Z3_EXECUTOR
    with Z3_EXECUTOR_LOCK:
        if Z3_EXECUTOR is not None:
            print("Shutting down global Z3_EXECUTOR...")
            try:
                Z3_EXECUTOR.shutdown(wait=True, cancel_futures=True) # Wait for Z3 process to finish
            except Exception as e:
                print(f"Error during final Z3_EXECUTOR shutdown: {e}")
            finally:
                Z3_EXECUTOR = None