import asyncio
import io
import traceback
import threading
from contextlib import redirect_stdout
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor, TimeoutError
from concurrent.futures.process import BrokenProcessPool
from typing import List
from z3 import *
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os

# 全局变量定义
Z3_EXECUTOR = None
Z3_EXECUTOR_LOCK = threading.Lock()
ground_truth = None  # 需要在使用前初始化

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(16))

# 模型路径
MODEL_PATH = "models--deepseek-ai--DeepSeek-R1-Distill-Qwen-32B"

# 全局tokenizer，延迟加载
_tokenizer = None

def get_tokenizer():
    """获取tokenizer实例，使用单例模式"""
    global _tokenizer
    if _tokenizer is None:
        try:
            _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
            print(f"Loaded tokenizer from {MODEL_PATH}")
        except Exception as e:
            print(f"Failed to load tokenizer from {MODEL_PATH}: {e}")
            print("Falling back to simple regex tokenizer")
            _tokenizer = None
    return _tokenizer

def _z3_format_check(z3_code: str) -> bool:
    """Check if the Z3 code snippet is syntactically valid and executable."""
    if not z3_code:
        print("Empty Z3 code for format check.")
        return False
    
    # Try to compile the code. This checks for basic syntax errors.
    try:
        compile(z3_code, '<z3_code_string>', 'exec')
        # We don't execute it here as it requires Z3 context and might be slow.
        # Execution is handled by _z3_check_subprocess which runs in a process pool.
        return True
    except SyntaxError as e:
        print(f"Z3 code syntax error: {e}")
        return False
    except Exception as e:
        print(f"Other Z3 code compilation error: {e}")
        return False

def bleu(candidate: str, reference: str = None, n_gram: int = 4) -> float:
    """
    计算BLEU分数，使用指定模型的tokenizer进行分词
    
    Args:
        candidate: 候选文本（生成的文本）
        reference: 参考文本（如果为None，需要在调用时提供）
        n_gram: n-gram的最大值，默认为4
    
    Returns:
        BLEU分数 (0-1之间的浮点数)
    """
    import math
    from collections import Counter
    import re
    
    def tokenize_with_model(text: str) -> list:
        """使用模型tokenizer进行分词"""
        tokenizer = get_tokenizer()
        if tokenizer is not None:
            try:
                # 使用模型的tokenizer进行编码，然后解码为token字符串
                token_ids = tokenizer.encode(text, add_special_tokens=False)
                tokens = [tokenizer.decode([token_id]) for token_id in token_ids]
                # 过滤掉空字符串和只包含空格的token
                tokens = [token.strip() for token in tokens if token.strip()]
                return tokens
            except Exception as e:
                print(f"Error using model tokenizer: {e}, falling back to regex tokenizer")
        
        # 备用的简单分词方法
        text = text.lower()
        tokens = re.findall(r'\b\w+\b', text)
        return tokens
    
    def tokenize_simple(text: str) -> list:
        """简单的分词函数（备用）"""
        text = text.lower()
        tokens = re.findall(r'\b\w+\b', text)
        return tokens
    
    def get_ngrams(tokens: list, n: int) -> list:
        """获取n-gram"""
        if len(tokens) < n:
            return []
        return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
    
    def calculate_precision(candidate_tokens: list, reference_tokens: list, n: int) -> float:
        """计算n-gram精确度"""
        if len(candidate_tokens) < n:
            return 0.0
            
        candidate_ngrams = get_ngrams(candidate_tokens, n)
        reference_ngrams = get_ngrams(reference_tokens, n)
        
        if not candidate_ngrams:
            return 0.0
            
        candidate_counter = Counter(candidate_ngrams)
        reference_counter = Counter(reference_ngrams)
        
        # 计算重叠的n-gram数量
        overlap = 0
        for ngram, count in candidate_counter.items():
            overlap += min(count, reference_counter.get(ngram, 0))
        
        return overlap / len(candidate_ngrams)
    
    def brevity_penalty(candidate_len: int, reference_len: int) -> float:
        """计算简洁性惩罚"""
        if candidate_len > reference_len:
            return 1.0
        elif candidate_len == 0:
            return 0.0
        else:
            return math.exp(1 - reference_len / candidate_len)
    
    # 如果没有提供reference，使用一个默认的参考文本
    if reference is None:
        # 这里应该使用实际的ground truth，暂时使用candidate作为参考
        reference = candidate
    
    # 使用模型tokenizer进行分词
    try:
        candidate_tokens = tokenize_with_model(candidate)
        reference_tokens = tokenize_with_model(reference)
    except Exception as e:
        print(f"Error in tokenization: {e}")
        return 0.0
    
    if not candidate_tokens or not reference_tokens:
        return 0.0
    
    # 计算各个n-gram的精确度
    precisions = []
    for n in range(1, n_gram + 1):
        precision = calculate_precision(candidate_tokens, reference_tokens, n)
        if precision == 0:
            # 如果任何n-gram精确度为0，BLEU分数为0
            return 0.0
        precisions.append(precision)
    
    # 计算几何平均数
    if not precisions:
        return 0.0
    
    # 使用对数来避免数值下溢
    log_precisions = [math.log(p) for p in precisions]
    geometric_mean = math.exp(sum(log_precisions) / len(log_precisions))
    
    # 计算简洁性惩罚
    bp = brevity_penalty(len(candidate_tokens), len(reference_tokens))
    
    # 最终BLEU分数
    bleu_score = bp * geometric_mean
    
    return min(1.0, max(0.0, bleu_score))  # 确保分数在0-1之间

async def compute_single_score_async(solution_str: str) -> float:
    """Compute the score for a single solution asynchronously."""
    if not solution_str or not ground_truth:
        print("Empty solution or ground truth provided for scoring.")
        return 0.0

    try:
        code = solution_str.split("Final Z3 Code:")[-1].strip()
        
        # 检查Z3代码格式
        if not await asyncio.to_thread(_z3_format_check, code):
            print(f"Score: 0.0 (Z3 code can not be compiled).")
            return 0.0
        
        print(f"Z3 code can be compiled.")
        return 0.1

    except Exception as e:
        print(f"Critical error in compute_single_score_async: {e}")
        traceback.print_exc()
        return 0.0

def process_solutions_stage_2(solution_strs: List[str], ground_truths: List[dict]) -> List[float]:
    """处理第二阶段的解决方案"""
    results = []
    prompts = []
    
    for solution_str in solution_strs:
        code = solution_str.split("Final Z3 Code:")[-1].strip()
        
        # 假设ground_truths是字典列表，每个包含"definitions"键
        definitions = ground_truths[0].get("definitions", "") if ground_truths else ""
        
        prompt = f'''Z3 Code:
{code}

Definitions:
{definitions}

Integrate all information from the Z3 code into the Background to generate a challenging natural language content. Do not refer to or quote the code directly, and do not use symbolic identifiers (e.g., "A1", "C5") in the narrative.
                
Conclude your response with following format:
Natural Language Content:
[content]'''

        prompts.append(prompt)

    try:
        llm = LLM(
            model=MODEL_PATH, 
            tensor_parallel_size=8,
            data_parallel_size=1,
            max_model_len=32768, 
            max_num_seqs=256,
            gpu_memory_utilization=0.5
        )
        params = SamplingParams(temperature=0, max_tokens=32768)
        outputs = llm.generate(prompts, params)

        for output, ground_truth in zip(outputs, ground_truths):
            response = output.outputs[0].text
            print(response)
            if "Natural Language Content:" not in response:
                print("No Natural Language Content label")
                results.append(0.0)
                continue
                
            back_content = response.split("Natural Language Content:")[-1].strip()
            original_content = ground_truth.get("content", "") if isinstance(ground_truth, dict) else ""

            # 计算生成内容与原始内容的BLEU分数
            bleu_score = bleu(back_content, original_content)
            print(f"bleu score: {bleu_score}")
            results.append(bleu_score)
            
    except Exception as e:
        print(f"Error in process_solutions_stage_2: {e}")
        traceback.print_exc()
        results = [0.0] * len(solution_strs)

    return results

def process_solution_worker_stage_1(idx: int, solution_str: str) -> tuple:
    """处理单个解决方案的工作函数"""
    try:
        if "Final Z3 Code:" not in solution_str:
            return idx, 0.0

        code = solution_str.split("Final Z3 Code:")[-1].strip()
        if not _z3_format_check(code):
            print(f"Score: 0.0 (Z3 code can not be compiled).")
            return idx, 0.0
        
        print(f"Z3 code can be compiled.")
        return idx, 0.1

    except Exception as e:
        print(f"Critical error in process_solution_worker_stage_1: {e}")
        traceback.print_exc()
        return idx, 0.0

def compute_score_batch(solution_strs: List[str], ground_truths: List[dict], max_workers: int = 1) -> List[float]:
    """
    Compute scores for a batch of solutions using a ThreadPoolExecutor.
    """
    # 检查输入参数长度匹配
    if len(solution_strs) != len(ground_truths):
        raise ValueError(f"Mismatch in input lengths: {len(solution_strs)} solutions vs {len(ground_truths)} ground truths")
    
    results = [0.0] * len(solution_strs)

    # 第一阶段：检查Z3代码格式
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_solution_worker_stage_1, idx, solution_str) 
                   for idx, solution_str in enumerate(solution_strs)]
        
        for future in as_completed(futures):
            try:
                idx, score = future.result()
                if idx is not None and 0 <= idx < len(results):
                    results[idx] = score
            except Exception as e:
                print(f"Error retrieving result from future: {e}")
                traceback.print_exc()
                # A specific future failed, but other results might still be valid
                
    # Final cleanup of the global Z3 executor after all tasks are done
    cleanup_z3_executor()

    # 第二阶段：处理通过第一阶段的解决方案
    solution_strs_stage_2 = []
    ground_truths_stage_2 = []
    stage_2_indices = []
    
    for i, (score, solution_str, ground_truth) in enumerate(zip(results, solution_strs, ground_truths)):
        if score == 0.1:
            solution_strs_stage_2.append(solution_str)
            ground_truths_stage_2.append(ground_truth)
            stage_2_indices.append(i)

    if solution_strs_stage_2:
        stage_2_scores = process_solutions_stage_2(solution_strs_stage_2, ground_truths_stage_2)
        
        # 更新结果
        for i, stage_2_score in enumerate(stage_2_scores):
            original_idx = stage_2_indices[i]
            results[original_idx] = results[original_idx] + stage_2_score
    
    return results

def cleanup_z3_executor():
    """Clean up the global Z3 executor."""
    global Z3_EXECUTOR
    with Z3_EXECUTOR_LOCK:
        if Z3_EXECUTOR is not None:
            print("Shutting down global Z3_EXECUTOR...")
            try:
                Z3_EXECUTOR.shutdown(wait=True, cancel_futures=True)  # Wait for Z3 process to finish
            except Exception as e:
                print(f"Error during final Z3_EXECUTOR shutdown: {e}")
            finally:
                Z3_EXECUTOR = None