import re
import asyncio
import io
import signal
import time
import os
import subprocess
import tempfile
from contextlib import redirect_stdout, contextmanager, asynccontextmanager
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor, TimeoutError
from concurrent.futures.process import BrokenProcessPool
import traceback
import sys
from z3 import *
from openai import AsyncOpenAI
from transformers import AutoTokenizer
import textwrap
import multiprocessing
import aiohttp # Import aiohttp for session management
import json
import random
import threading
import nltk
from nltk.translate.bleu_score import sentence_bleu
from collections import defaultdict
from typing import List, Dict  # 添加类型注解
from vllm import LLM, SamplingParams

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(16))

_simplify = z3.Tactic("ctx-solver-simplify")

# Global tokenizers to avoid reloading
tokenizer1 = None
tokenizer2 = None

MODEL_PATH = "models--deepseek-ai--DeepSeek-R1-Distill-Qwen-32B"
Solving_MODEL_PATH = "models--Qwen--Qwen3-30B-A3B"

# Global executor for Z3 operations (CPU-bound, needs process isolation)
# Initialized lazily and managed carefully to prevent deadlocks/zombies
Z3_EXECUTOR = None
Z3_EXECUTOR_LOCK = threading.Lock() # Use a lock to protect Z3_EXECUTOR initialization

def get_z3_executor():
    """Get or create the Z3 executor, ensuring it's alive and properly managed."""
    global Z3_EXECUTOR
    with Z3_EXECUTOR_LOCK:
        if Z3_EXECUTOR is None or (hasattr(Z3_EXECUTOR, '_shutdown') and Z3_EXECUTOR._shutdown):
            if Z3_EXECUTOR:
                print("Detected Z3_EXECUTOR is dead or shutdown. Attempting aggressive cleanup...")
                try:
                    # Attempt to shut down any lingering processes. wait=False means don't block.
                    Z3_EXECUTOR.shutdown(wait=False, cancel_futures=True)
                except Exception as e:
                    print(f"Error during Z3_EXECUTOR previous shutdown attempt: {e}")
            
            print("Initializing new Z3 ProcessPoolExecutor...")
            Z3_EXECUTOR = ProcessPoolExecutor(max_workers=1)
    return Z3_EXECUTOR

# --- Tokenizer Initialization ---
def initialize_tokenizers():
    """Initialize tokenizers once to avoid reloading, with error handling."""
    global tokenizer1, tokenizer2
    if tokenizer1 is None:
        try:
            # Added `trust_remote_code=True` for custom models if necessary
            tokenizer1 = AutoTokenizer.from_pretrained("models--Qwen--Qwen3-30B-A3B", trust_remote_code=True)
            print("Tokenizer1 initialized successfully.")
        except Exception as e:
            print(f"Error initializing tokenizer1: {e}")
            tokenizer1 = None
    
    if tokenizer2 is None:
        try:
            tokenizer2 = AutoTokenizer.from_pretrained("models--deepseek-ai--DeepSeek-R1", trust_remote_code=True)
            print("Tokenizer2 initialized successfully.")
        except Exception as e:
            print(f"Error initializing tokenizer2: {e}")
            tokenizer2 = None
    
    return tokenizer1, tokenizer2

def count_tokens(tokenizer, text):
    """Count tokens in text using the provided tokenizer"""
    try:
        if tokenizer is None:
            # print("Tokenizer is None, cannot count tokens") # Too verbose
            return 0
        tokens = tokenizer.encode(text)
        return len(tokens)
    except Exception as e:
        print(f"Error counting tokens: {e}")
        return 0

def _canonical(expr: z3.ExprRef) -> str:
    """Return a canonical s‑expression: 1) push to NNF, 2) simplify, 3) sort."""
    nnf = z3.Tactic("nnf")(expr).as_expr()  # push Not inward
    simplified = _simplify(nnf).as_expr()
    # sort children of Or / And so that A∨B == B∨A in text
    def _sort(e):
        if e.num_args() == 0:
            return e.sexpr()
        op = e.decl().name()
        args = sorted(_sort(a) for a in e.children()) if op in {"and", "or"} else [_sort(a) for a in e.children()]
        return f"({op} {' '.join(args)})"
    return _sort(simplified)

def _extract_assertions(code: str, tag: str):
    """Execute *code* and return the assertion list of the last Solver found."""
    # Temporarily redirect stdout to suppress Z3 internal prints during exec
    with io.StringIO() as buf, redirect_stdout(buf):
        ns = {"__name__": f"z3_{tag}", "z3": z3, **vars(z3)}
        # Pre-compile for potentially faster execution if run multiple times, and better error reporting
        try:
            exec(compile(code, f"<{tag}>", "exec"), ns)
        except Exception as e:
            raise ValueError(f"Error compiling or executing Z3 code ({tag}): {e}\nCode:\n{code}")
    
    solvers = [v for v in ns.values() if isinstance(v, z3.Solver)]
    if not solvers:
        raise ValueError(f"No z3.Solver instance in {tag}")
    return list(solvers[-1].assertions())

def constraints_covered(src_code: str, tgt_code: str):
    """Return (covered: bool, missing: list[str])."""
    Cs_src = {_canonical(c) for c in _extract_assertions(src_code, "src")}
    Cs_tgt = {_canonical(c) for c in _extract_assertions(tgt_code, "tgt")}

    missing = sorted(cs for cs in Cs_src if cs not in Cs_tgt)
    addition = sorted(cs for cs in Cs_tgt if cs not in Cs_src)
    return (len(missing) == 0 and len(addition) == 0), missing, addition

def _same_z3_worker(original_code: str, translated_code: str) -> bool:
    """Worker function for Z3 comparison, executed in a separate process."""
    try:
        ok, missing, addition = constraints_covered(original_code, translated_code)
        if ok:
            print("Z3 comparison: All constraints covered.")
            return True
        else:
            print(f"Z3 comparison: Missing constraints: {missing[:5]}...") # Log first few missing
            print(f"Z3 comparison: Additional constraints: {addition[:5]}...") # Log first few addition
            return False
    except Exception as e:
        print(f"Error in _same_z3_worker: {e}")
        traceback.print_exc()
        return False

def _z3_check_subprocess(original_code: str, translated_code: str) -> bool:
    """Run Z3 check in subprocess with a timeout."""
    executor = get_z3_executor()
    try:
        # Submit the task to the shared Z3 executor
        future = executor.submit(_same_z3_worker, original_code, translated_code)
        result = future.result(timeout=120) # Increased timeout for Z3
        print("Z3 comparison completed successfully.")
        return result
    except TimeoutError:
        print("Z3 comparison operation timed out (120 seconds).")
        # Attempt to cancel the future and clean up the executor
        future.cancel()
        return False
    except (BrokenProcessPool, BrokenPipeError) as e:
        print(f"Z3 worker pool crashed or broken pipe detected: {e}. Attempting to reinitialize and retry.")
        # Force a re-initialization of the executor
        global Z3_EXECUTOR
        with Z3_EXECUTOR_LOCK: # Protect global Z3_EXECUTOR modification
            if Z3_EXECUTOR:
                try:
                    Z3_EXECUTOR.shutdown(wait=True, cancel_futures=True) # Wait for shutdown
                except Exception as shutdown_err:
                    print(f"Error during Z3_EXECUTOR shutdown on crash: {shutdown_err}")
                finally:
                    Z3_EXECUTOR = None # Force reset
        
        # Retry once with a freshly initialized executor
        try:
            executor = get_z3_executor() # This will get a new executor
            future = executor.submit(_same_z3_worker, original_code, translated_code)
            result = future.result(timeout=120) # Retry with timeout
            print("Z3 comparison retry successful.")
            return result
        except Exception as retry_error:
            print(f"Z3 retry failed after crash: {retry_error}")
            traceback.print_exc()
            return False
    except Exception as e:
        print(f"Unexpected exception during Z3 subprocess execution: {e}")
        traceback.print_exc()
        return False


def _z3_format_check(z3_code: str) -> bool:
    """Check if the Z3 code snippet is syntactically valid and executable."""
    if not z3_code:
        print("Empty Z3 code for format check.")
        return False
    
    # Try to compile the code. This checks for basic syntax errors.
    try:
        compile(z3_code, '<z3_code_string>', 'exec')
        # We don't execute it here as it requires Z3 context and might be slow.
        # Execution is handled by _z3_check_subprocess which runs in a process pool.
        return True
    except SyntaxError as e:
        print(f"Z3 code syntax error: {e}")
        return False
    except Exception as e:
        print(f"Other Z3 code compilation error: {e}")
        return False

def check_symbol(content, variable_names):
    """Check if content contains too many variable names directly"""
    if not content or not variable_names:
        print(f"No content or No variable names")
        return True
        
    count = sum(1 for name in variable_names if name in content)
    print(f"Variable count in content: {count}")
    return count == 0

def process_solution_worker_stage_1(solution_strs: List[str], ground_truths: List[dict]) -> List[float]:
    """处理第一阶段的解决方案"""
        
    results = []
    prompts = []
    
    for solution_str in solution_strs:
        if "Natural Language Content:" not in solution_str or "Definitions:" not in solution_str:
            results.append(0.0)
            prompts.append("")  # 添加空提示以保持索引对齐
            continue
            
        content = solution_str.split("Natural Language Content:")[-1].split("Definitions:")[0].strip()
        definitions = solution_str.split("Definitions:")[-1].strip()
        
        prompt = f'''Natural Language Content:
{content}

Definitions:
{definitions}
        
Based on the Definitions, translate the Natural Language Content into Z3 code. Each constraint consists of a forbidden combination of assignments for two variables.
Conclude your response with "Final Z3 Code:". Then present the generated code directly, do not enclose it in quotation marks or code blocks.

For example:
Final Z3 Code:
from z3 import *

# Create solver instance
solver = Solver()

# Create boolean variables
A1, A2, A3, B1, B2, B3 = Bools('A1 A2 A3 B1 B2 B3')

# Add constraints
solver.add(Not(And(Not(A2), Not(B1))))
solver.add(Not(And(A2, Not(B3))))
solver.add(Not(And(Not(B3), Not(B2))))
solver.add(Not(And(Not(A3), Not(A1))))
solver.add(Not(And(Not(B1), B3)))
solver.add(Not(And(B3, Not(A2))))
solver.add(Not(And(B1, A1)))
solver.add(Not(And(Not(A1), B2)))'''

        prompts.append(prompt)

    # 过滤掉空提示
    valid_prompts = [p for p in prompts if p]
    if not valid_prompts:
        return results

    try:
        llm = LLM(
            model=MODEL_PATH, 
            tensor_parallel_size=8,
            data_parallel_size=1,
            max_model_len=32768, 
            max_num_seqs=256,
            gpu_memory_utilization=0.5
        )
        params = SamplingParams(temperature=0, max_tokens=32768)
        outputs = llm.generate(valid_prompts, params)

        output_idx = 0
        for i, (solution_str, ground_truth) in enumerate(zip(solution_strs, ground_truths)):
            if prompts[i] == "":  # 对应之前添加的空提示
                continue
                
            response = outputs[output_idx].outputs[0].text
            output_idx += 1
            
            print(response)
            if "Final Z3 Code:" not in response:
                print("No Final Z3 Code label")
                results.append(0.0)
                continue
                
            back_code = response.split("Final Z3 Code:")[-1].strip()
            original_code = ground_truth["code"]

            z3_match = _z3_check_subprocess(original_code, back_code)
            if not z3_match:
                print(f"Score: 0.0 (Z3 code mismatch).")
                results.append(0.0)
            else:
                results.append(0.1)
            
    except Exception as e:
        print(f"Error in process_solution_worker_stage_1: {e}")
        traceback.print_exc()
        # 确保results长度与输入长度匹配
        while len(results) < len(solution_strs):
            results.append(0.0)

    return results


def process_solutions_stage_2(solution_strs: List[str], ground_truths: List[dict]) -> List[float]:
    """处理第二阶段的解决方案"""
    results = []
    prompts = []
    
    for solution_str in solution_strs:
        content = solution_str.split("Natural Language Content:")[-1].split("Definitions:")[0].strip()
        definitions = solution_str.split("Definitions:")[-1].strip()
        
        prompt = f'''Content:
{content}

Definitions:
{definitions}

Based on the Content and Definitions, determine the truth value (True or False) for each variable mentioned.
Respond with your final answer using the label "Final Answer". Format each line as: "[Variable name]: [True/False]". Each variable name appears at the start of its corresponding definition in the Definitions.

Example:
Final Answer:
A1: True
B2: False'''

        # 每个prompt重复16次
        for _ in range(16):
            prompts.append(prompt)

    try:
        llm = LLM(
            model=Solving_MODEL_PATH, 
            tensor_parallel_size=8,
            data_parallel_size=1,
            max_model_len=32768, 
            max_num_seqs=256,
            gpu_memory_utilization=0.5
        )
        params = SamplingParams(temperature=0, max_tokens=32768)
        outputs = llm.generate(prompts, params)

        # 正确分组批次输出
        for i, ground_truth in enumerate(ground_truths):
            batch_start = i * 16
            batch_end = batch_start + 16
            output_batch = outputs[batch_start:batch_end]
            
            pass_rate = 0
            gold = ground_truth["answer"]

            for output in output_batch:
                response = output.outputs[0].text
                print(response)
                if "Final Answer:" not in response:
                    print("No Final Answer label")
                    continue

                answer = response.split("Final Answer:")[-1].strip()
                    
                assignment = {}
                for line in answer.splitlines():
                    if ":" not in line:
                        continue
                    k, v = [x.strip() for x in line.split(":", 1)]
                    if v.lower() == "true":
                        assignment[k] = True
                    elif v.lower() == "false":
                        assignment[k] = False
                
                if assignment == gold:
                    print(f'''Correct answer - Assignment: {assignment}''')
                    pass_rate += 1
                else:
                    print(f'''Wrong answer - Assignment: {assignment}, Gold: {gold}''')

            if pass_rate > 0 and pass_rate < 8:
                results.append(1.0)
                try:
                    output_dir = "puzzle_collection"
                    os.makedirs(output_dir, exist_ok=True)
                    output_file = os.path.join(output_dir, "sat_problem_3-3_qwen3-30b_qualified.jsonl")
                    
                    solution_str = solution_strs[i]  # 获取对应的solution_str
                    with open(output_file, "a", encoding='utf-8') as f:
                        f.write(json.dumps({
                            "content": solution_str.split("Natural Language Content:")[-1].split("Definitions:")[0].strip(),
                            "definitions": solution_str.split("Definitions:")[-1].strip(), 
                            "code": ground_truth["code"], 
                            "answer": ground_truth["answer"]
                        }, ensure_ascii=False))
                        f.write("\n")
                        f.flush()
                    print(f"Successfully saved qualified problem to {output_file}.")
                except Exception as e:
                    print(f"Error saving qualified problem to file: {e}")
                    traceback.print_exc()
            else:
                results.append(0.0)

    except Exception as e:
        print(f"Error in process_solutions_stage_2: {e}")
        traceback.print_exc()
        results = [0.0] * len(solution_strs)  # 修正：使用0而不是0.0以保持一致性

    return results


def compute_score_batch(solution_strs: list[str], ground_truths: list[dict]) -> list[float]:
    """
    Compute scores for a batch of solutions using a ThreadPoolExecutor.
    Each thread will manage its own async event loop and API clients.
    """
    results = [0.0] * len(solution_strs)
    
    valid_pairs = [(idx, solution_str, ground_truth) 
                  for idx, (solution_str, ground_truth) in enumerate(zip(solution_strs, ground_truths))
                  if solution_str and ground_truth]
    
    if not valid_pairs:
        print("No valid solution-ground truth pairs to process.")
        return results

    print(f"Processing {len(valid_pairs)} solutions in stage 1.")
    
    # 第一阶段：使用函数直接处理，不使用线程池
    solution_strs_stage_1 = [pair[1] for pair in valid_pairs]
    ground_truths_stage_1 = [pair[2] for pair in valid_pairs]
    
    stage_1_scores = process_solution_worker_stage_1(solution_strs_stage_1, ground_truths_stage_1)
    
    # 更新结果
    for i, (idx, _, _) in enumerate(valid_pairs):
        if i < len(stage_1_scores):  # 确保索引有效
            results[idx] = stage_1_scores[i]

    # 第二阶段：处理通过第一阶段的解决方案
    solution_strs_stage_2 = []
    ground_truths_stage_2 = []
    stage_2_indices = []
    
    for i, (score, solution_str, ground_truth) in enumerate(zip(results, solution_strs, ground_truths)):
        if score == 0.1:
            solution_strs_stage_2.append(solution_str)
            ground_truths_stage_2.append(ground_truth)
            stage_2_indices.append(i)

    if solution_strs_stage_2:
        print(f"Processing {len(solution_strs_stage_2)} solutions in stage 2.")
        stage_2_scores = process_solutions_stage_2(solution_strs_stage_2, ground_truths_stage_2)
        
        # 更新结果
        for i, stage_2_score in enumerate(stage_2_scores):
            if i < len(stage_2_indices):  # 确保索引有效
                original_idx = stage_2_indices[i]
                results[original_idx] = results[original_idx] + stage_2_score

    # 第三阶段：符号检查
    for i, (score, solution_str, ground_truth) in enumerate(zip(results, solution_strs, ground_truths)):
        if score == 1.1:  # 0.1 + 1
            content = solution_str.split("Natural Language Content:")[-1].split("Definitions:")[0].strip()
            # 获取变量名
            gold = ground_truth["answer"]
            if check_symbol(content, gold.keys()):
                results[i] = results[i] + 1.0
                
    # Final cleanup of the global Z3 executor after all tasks are done
    # This prevents the Z3 process from lingering after all work is complete.
    global Z3_EXECUTOR
    with Z3_EXECUTOR_LOCK:
        if Z3_EXECUTOR is not None:
            print("Shutting down global Z3_EXECUTOR...")
            try:
                Z3_EXECUTOR.shutdown(wait=True, cancel_futures=True) # Wait for Z3 process to finish
            except Exception as e:
                print(f"Error during final Z3_EXECUTOR shutdown: {e}")
            finally:
                Z3_EXECUTOR = None
    
    return results