import logging
import os
import re
import numpy as np
import torch
from vllm import LLM, SamplingParams
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.utils import hf_tokenizer
from verl import DataProto
from tensordict import TensorDict
from formal_reward import count_python_error_outputs, check_proved_in_output
from prompts import CV_PROMPT
import time
import subprocess
import signal
from contextlib import contextmanager

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))

# ========== 配置参数 ==========
TRAINING_MODEL_PATH = ''
MAX_LLM_CALLS = 4
WEIGHT = 3
VERIFIER_PASS_TAG = "Final Decision: Yes"
MAX_INPUT_TOKENS = 3072
CUSTOM_TAGS = ['<code>', '</code>', '<interpreter>', '</interpreter>', '<answer>', '</answer>']
END_TAG = "</answer>"

# ========== 新增：超时配置 ==========
TIMEOUT_SECONDS = 2.0  # 单个样本的最大处理时间（秒）
TIMEOUT_PENALTY = {
    "format_score": -2.0,
    "correct_score": -1.0 * WEIGHT,
}


# ========== 超时装饰器 ==========
class TimeoutError(Exception):
    pass

@contextmanager
def timeout(seconds):
    """上下文管理器：超时则抛出异常"""
    def timeout_handler(signum, frame):
        raise TimeoutError(f"Operation timed out after {seconds}s")
    
    # 设置信号处理
    old_handler = signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(seconds)
    
    try:
        yield
    finally:
        signal.alarm(0)  # 取消定时器
        signal.signal(signal.SIGALRM, old_handler)


# ========== 带超时保护的检测函数 ==========
def safe_is_correct_format(text: str, timeout_sec=TIMEOUT_SECONDS) -> tuple:
    """
    Returns:
        (result: bool, timeout_flag: bool)
    """
    try:
        with timeout(int(timeout_sec)):
            result = is_correct_format(text)
        return result, False
    except TimeoutError:
        logger.warning(f"⏰ is_correct_format timed out (len={len(text)})")
        return False, True  # 超时默认格式错误


def safe_is_repetitive(format_flag, text, timeout_sec=TIMEOUT_SECONDS) -> tuple:
    """
    Returns:
        (result: bool, timeout_flag: bool)
    """
    try:
        with timeout(int(timeout_sec)):
            result = is_repetitive(format_flag, text)
        return result, False
    except TimeoutError:
        logger.warning(f"⏰ is_repetitive timed out (len={len(text)})")
        return True, True  # 超时默认视为重复（病态）


def safe_count_undefined_tags(rep_flag, text, timeout_sec=TIMEOUT_SECONDS) -> tuple:
    """
    Returns:
        (count: int, timeout_flag: bool)
    """
    try:
        with timeout(int(timeout_sec)):
            count = count_undefined_tags(rep_flag, text)
        return count, False
    except TimeoutError:
        logger.warning(f"⏰ count_undefined_tags timed out (len={len(text)})")
        return 999, True  # 超时默认返回大量标签错误


def safe_count_python_error_outputs(text, timeout_sec=TIMEOUT_SECONDS) -> tuple:
    """
    Returns:
        ((error_count, total_count), timeout_flag: bool)
    """
    try:
        with timeout(int(timeout_sec)):
            result = count_python_error_outputs(text)
        return result, False
    except TimeoutError:
        logger.warning(f"⏰ count_python_error_outputs timed out (len={len(text)})")
        return (999, 999), True  # 超时默认大量错误


# ========== 原有辅助函数（保持不变）==========
def get_gpu_utilization():
    """获取GPU利用率"""
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'],
        capture_output=True, text=True
    )
    return int(result.stdout.strip().split('\n')[0])


def is_correct_format(text: str) -> bool:
    """检查括号配对（原逻辑不变）"""
    if '<answer>' not in text or '</answer>' not in text:
        return False
    if not text.strip().endswith('</answer>'):
        return False
    if text.count('<answer>') != 1 or text.count('</answer>') != 1:
        return False
    
    try:
        start_idx = text.index('<answer>')
        end_idx = text.rindex('</answer>')
        answer_content = text[start_idx:end_idx + len('</answer>')]
    except ValueError:
        return False
    
    boxed_count = answer_content.count('\\boxed{')
    # if boxed_count != 1: # too strict, model may repeat its answer, but in clear format # before 1222
    
    if boxed_count == 0 or boxed_count > 5: # modified 1222
        return False
    
    boxed_idx = answer_content.index('\\boxed{')
    after_boxed = answer_content[boxed_idx + len('\\boxed{'):]
    
    brace_count = 0
    for char in after_boxed:
        if char == '{':
            brace_count += 1
        elif char == '}':
            if brace_count == 0:
                return True
            brace_count -= 1
    
    return False


def is_repetitive(format_flag, text, min_repeat=3):
    """检查重复（原逻辑不变）"""
    if not text or not text.strip() or format_flag:
        return False
    
    # 1. 检查连续重复字符
    text_check_len = min(20, len(text))
    text_to_check = text[-text_check_len:]
    
    char_count = 1
    for i in range(1, len(text_to_check)):
        if text_to_check[i] == text_to_check[i-1]:
            char_count += 1
            if char_count >= min_repeat:
                return True
        else:
            char_count = 1
    
    # 2. 原有的token级别检测
    tokens = text.split()
    if len(tokens) < min_repeat:
        return False
    
    check_len = min(20, len(tokens))
    tokens_to_check = tokens[-check_len:]
    tokens = tokens_to_check
    
    for i in range(len(tokens) - min_repeat + 1):
        if all(tokens[i] == tokens[i + j] for j in range(min_repeat)):
            return True
    
    max_phrase_len = len(tokens) // min_repeat
    for phrase_len in range(1, max_phrase_len + 1):
        for start in range(len(tokens) - phrase_len * min_repeat + 1):
            phrase = tokens[start:start + phrase_len]
            expected = phrase * min_repeat
            actual = tokens[start:start + phrase_len * min_repeat]
            if actual == expected:
                return True
    
    return False


def count_undefined_tags(rep_flag, text, custom_tags=CUSTOM_TAGS):
    """统计未定义标签（原逻辑不变）"""
    if rep_flag:
        return 0
    
    all_tags = re.findall(r'<[^>]{1,15}>', text)
    custom_set = set(custom_tags)
    
    undefined_count = 0
    for tag in all_tags:
        if tag not in custom_set and len(tag) < 15:
            undefined_count += 1
    
    return undefined_count


# ========== 批量处理函数（新增超时保护）==========
def batch_process_responses(full_responses):
    """
    批量处理响应，带超时保护
    
    Returns:
        error_counts: List[Tuple[int, int]]  # (error_count, total_count)
        correct_format_flags: List[bool]
        repetitive_flags: List[bool]
        tag_counts: List[int]
        timeout_info: List[dict]  # 新增：超时信息
    """
    print(f"\n=== Processing {len(full_responses)} responses ===\n")
    
    correct_format_flags = []
    repetitive_flags = []
    tag_counts = []
    error_counts = []
    timeout_info = []  # 👈 新增：记录超时信息
    
    # 1. 格式检测（带超时保护）
    print("1️⃣ Format detection:")
    t_total = time.time()
    for i, resp in enumerate(full_responses):
        t0 = time.time()
        result, is_timeout = safe_is_correct_format(resp)
        correct_format_flags.append(result)
        
        elapsed = time.time() - t0
        if elapsed > 0.5 or is_timeout:
            print(f"  ⚠️ #{i} took {elapsed:.2f}s (timeout={is_timeout}, len={len(resp)})")
            if is_timeout:
                timeout_info.append({
                    "sample_idx": i,
                    "stage": "format_detection",
                    "response_len": len(resp)
                })
    
    print(f"  ✓ Total: {time.time()-t_total:.3f}s\n")
    
    # 2. 重复检测（带超时保护）
    print("2️⃣ Repetitive detection:")
    t_total = time.time()
    for i, (flag, resp) in enumerate(zip(correct_format_flags, full_responses)):
        t0 = time.time()
        result, is_timeout = safe_is_repetitive(flag, resp)
        repetitive_flags.append(result)
        
        elapsed = time.time() - t0
        if elapsed > 0.5 or is_timeout:
            print(f"  ⚠️ #{i} took {elapsed:.2f}s (timeout={is_timeout}, len={len(resp)})")
            if is_timeout:
                timeout_info.append({
                    "sample_idx": i,
                    "stage": "repetitive_detection",
                    "response_len": len(resp)
                })
    
    print(f"  ✓ Total: {time.time()-t_total:.3f}s\n")
    
    # 3. 标签检测（带超时保护）
    print("3️⃣ Tag detection:")
    t_total = time.time()
    for i, (rep_flag, resp) in enumerate(zip(repetitive_flags, full_responses)):
        t0 = time.time()
        count, is_timeout = safe_count_undefined_tags(rep_flag, resp)
        tag_counts.append(count)
        
        elapsed = time.time() - t0
        if elapsed > 0.5 or is_timeout:
            print(f"  ⚠️ #{i} took {elapsed:.2f}s (timeout={is_timeout}, len={len(resp)})")
            if is_timeout:
                timeout_info.append({
                    "sample_idx": i,
                    "stage": "tag_detection",
                    "response_len": len(resp)
                })
    
    print(f"  ✓ Total: {time.time()-t_total:.3f}s\n")
    
    # 4. 错误检测（带超时保护）
    print("4️⃣ Error detection:")
    t_total = time.time()
    for i, resp in enumerate(full_responses):
        t0 = time.time()
        result, is_timeout = safe_count_python_error_outputs(resp)
        error_counts.append(result)
        
        elapsed = time.time() - t0
        if elapsed > 0.5 or is_timeout:
            print(f"  ⚠️ #{i} took {elapsed:.2f}s (timeout={is_timeout}, len={len(resp)})")
            if is_timeout:
                timeout_info.append({
                    "sample_idx": i,
                    "stage": "error_detection",
                    "response_len": len(resp)
                })
    
    print(f"  ✓ Total: {time.time()-t_total:.3f}s\n")
    
    return error_counts, correct_format_flags, repetitive_flags, tag_counts, timeout_info


# ========== Worker 类（修改 compute_rm_score）==========
class RewardModelWorker(Worker):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.sampling_params = SamplingParams(temperature=0, max_tokens=2048)
    
    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def init_model(self):
        """初始化模型（保持不变）"""
        self.llm = LLM(
            model=self.config.model.path, 
            gpu_memory_utilization=0.4, # would report error if larger
            max_model_len=4096,
            enable_prefix_caching=True,
            dtype="bfloat16",
            max_num_seqs=256,
        )
        
        self.tokenizer = hf_tokenizer(
            TRAINING_MODEL_PATH,
            trust_remote_code=self.config.model.get("trust_remote_code", False)
        )
        self.llm.sleep(2)
        torch.cuda.empty_cache()
    
    def extract_response_text(self, data_item):
        """提取响应文本（保持不变）"""
        resp_ids = data_item.batch["responses"]
        response_length = resp_ids.size(-1)
        
        attn_mask = data_item.batch["attention_mask"]
        response_mask = attn_mask[-response_length:]
        
        valid_resp_len = int(response_mask.sum().item())
        valid_ids = resp_ids[:valid_resp_len]
        
        response_str = self.tokenizer.decode(valid_ids, skip_special_tokens=True)
        full_response_str = self.tokenizer.decode(valid_ids, skip_special_tokens=False)
        
        return response_str, valid_resp_len, full_response_str, ""
    
    def process_judgment(self, judgment_str: str) -> str:
        """处理判断结果（保持不变）"""
        boxed_matches = re.findall(r'boxed{([A-C])}', judgment_str)
        if boxed_matches:
            return boxed_matches[-1]
        
        if judgment_str in ["A", "B", "C"]:
            return judgment_str
        else:
            final_judgment_str = judgment_str.split("Final Judgment:")[-1]
            matches = re.findall(r'\(([A-C])\)*', final_judgment_str)
            if matches:
                return matches[-1]
            matches = re.findall(r'([A-C])', final_judgment_str)
            if matches:
                return matches[-1]
            return ""
    
    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def compute_rm_score(self, data: DataProto) -> DataProto:
        """计算奖励分数（新增超时处理）"""
        
        t_total_start = time.time()
        
        # ========== 1. 初始化 ==========
        t0 = time.time()
        torch.cuda.empty_cache()
        print(f"⏱️ [1. empty_cache] time: {time.time()-t0:.3f}s")
        
        t0 = time.time()
        self.llm.wake_up()
        print(f"⏱️ [2. wake_up] time: {time.time()-t0:.3f}s")
        
        # ========== 2. 数据提取 ==========
        t0 = time.time()
        sequence_strs = []
        ground_truths = []
        questions = []
        valid_response_lengths = []
        
        for i in range(len(data)):
            data_item = data[i]
            resp_str, resp_len, full_response, _ = self.extract_response_text(data_item)
            
            sequence_strs.append(resp_str)
            valid_response_lengths.append(resp_len)
            
            question = data_item.non_tensor_batch["extra_info"]["question"]
            ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
            questions.append(question)
            ground_truths.append(ground_truth)
        
        print(f"⏱️ [3. Data extraction] time: {time.time()-t0:.3f}s")
        
        # ========== 3. 提取 solution ==========
        t0 = time.time()
        solutions = [extract_solution(seq) for seq in sequence_strs]
        print(f"⏱️ [4. Extract solutions] time: {time.time()-t0:.3f}s")
        
        # ========== 4. 构建验证消息 ==========
        t0 = time.time()
        messages = []
        for q, gt, sol in zip(questions, ground_truths, solutions):
            message = CV_PROMPT.format(question=q, gold_answer=gt, llm_response=sol)
            truncated_message = truncate_string_by_tokens(self.tokenizer, message, MAX_INPUT_TOKENS)
            messages.append(truncated_message)
        
        print(f"⏱️ [5. Build messages] time: {time.time()-t0:.3f}s")
        
        # ========== 5. GPU 推理 ==========
        t0 = time.time()
        outputs = self.llm.generate(messages, self.sampling_params)
        print(f"⏱️ [6. GPU generate] time: {time.time()-t0:.3f}s")
        
        t0 = time.time()
        responses = [output.outputs[0].text.strip() for output in outputs]
        print(f"⏱️ [7. Extract responses] time: {time.time()-t0:.3f}s")
        
        # ========== 6. 奖励计算（新增超时处理）==========
        t0 = time.time()
        
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        reward_details_list = []
        
        full_responses = [seq.strip() for seq in sequence_strs]
        
        # 👉 调用带超时保护的批量处理
        error_counts, correct_format_flags, repetitive_flags, tag_counts, timeout_info = \
            batch_process_responses(full_responses)
        
        print(f"⏱️ [8. Batch detection] time: {time.time()-t0:.3f}s")
        
        # 👉 构建超时索引集合（快速查找）
        timeout_indices = {info["sample_idx"] for info in timeout_info}
        
        # 计算每个样本的奖励
        for i, (ground_truth, solution, verification, valid_response_length) in enumerate(
            zip(ground_truths, solutions, responses, valid_response_lengths)
        ):
            num_error_outputs, num_total_output = error_counts[i]
            is_rep = repetitive_flags[i]
            tag_undefined_count = tag_counts[i]
            full_response = sequence_strs[i].strip()
            correct_format_flag = correct_format_flags[i]
            
            details = dict()
            
            # 👉 新增：检查是否超时
            if i in timeout_indices:
                # 直接给最大惩罚
                format_score = TIMEOUT_PENALTY["format_score"]
                correct_score = TIMEOUT_PENALTY["correct_score"]
                
                # 标记超时信息
                details["timeout_flag"] = True
                details["timeout_stages"] = [
                    info["stage"] for info in timeout_info if info["sample_idx"] == i
                ]
                details["format_reward"] = format_score
                details["correct_answer"] = correct_score
                details["num_call"] = num_total_output
                
                logger.warning(
                    f"⏰ Sample {i} TIMEOUT at stages: {details['timeout_stages']}, "
                    f"给予最大惩罚 (score={format_score + correct_score})"
                )
                
            else:
                # 原有逻辑（保持不变）
                details["timeout_flag"] = False
                
                if solution is None:
                    solution = "No answer"
                
                if is_rep or \
                    full_response.count(END_TAG) > 1 or \
                    num_total_output > MAX_LLM_CALLS + 10:
                    format_score = -2
                    correct_score = -1 * WEIGHT
                    details["correct_answer"] = -1.0 * WEIGHT
                    
                elif len(self.tokenizer(solution).input_ids) >= 512 or \
                        solution == "No answer" or \
                        num_total_output > MAX_LLM_CALLS + 2 or \
                        not full_response.endswith(END_TAG):
                    format_score = -1
                    correct_score = -1 * WEIGHT
                    details["correct_answer"] = -1.0 * WEIGHT
                else:
                    if correct_format_flag:
                        format_score = 1.0
                    else:
                        format_score = 0.8
                    
                    details["undefined_tag_penalty"] = -min(tag_undefined_count * 0.005, 1.0) * 0.3
                    format_score += details["undefined_tag_penalty"]
                    
                    if num_total_output > MAX_LLM_CALLS:
                        format_score -= min((num_total_output - MAX_LLM_CALLS) * 0.2, 1.0) * 0.5
                    
                    if self.process_judgment(verification) == "A":
                        correct_score = 1.0 * WEIGHT
                        details["correct_answer"] = 1.0 * WEIGHT
                        
                        tokenized_solution = self.tokenizer.encode(solution)
                        tokenized_ground_truth = self.tokenizer.encode(ground_truth)
                        
                        difference = abs(len(tokenized_solution) - len(tokenized_ground_truth))
                        difference = min(difference, 10)
                        details["length_difference_penalty"] = -difference * 0.04
                        correct_score += details["length_difference_penalty"]
                    else:
                        correct_score = -1.0 * WEIGHT
                        details["correct_answer"] = -1.0 * WEIGHT
                
                details['format_reward'] = format_score
                details["num_call"] = num_total_output
            
            # 计算最终分数
            score = format_score + correct_score
            reward_details_list.append(details)
            reward_tensor[i, valid_response_length - 1] = score
        
        print(f"⏱️ [9. Reward computation] time: {time.time()-t0:.3f}s")
        
        # ========== 7. 打印示例 ==========
        for j in range(min(5, len(responses))):
            score = reward_tensor[j, valid_response_lengths[j] - 1].item()
            timeout_flag = reward_details_list[j].get("timeout_flag", False)
            print(f"Sample {j} - Score: {score:.2f}, Timeout: {timeout_flag}")
        
        # ========== 8. 汇总超时统计 ==========
        total_timeouts = len(timeout_indices)
        if total_timeouts > 0:
            print(f"\n⚠️ Total timeouts: {total_timeouts}/{len(full_responses)} samples")
            # 按阶段统计
            stage_counts = {}
            for info in timeout_info:
                stage = info["stage"]
                stage_counts[stage] = stage_counts.get(stage, 0) + 1
            print(f"   Timeout breakdown: {stage_counts}")
        
        # ========== 9. 准备输出 ==========
        batch = TensorDict({"rm_scores": reward_tensor}, batch_size=reward_tensor.shape[0])
        non_tensor_batch = {"reward_details": np.array(reward_details_list, dtype=object)}
        
        self.llm.sleep(2)
        torch.cuda.empty_cache()
        
        print(f"⏱️ ========== [TOTAL] time: {time.time()-t_total_start:.3f}s ==========\n")
        
        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)


# ========== 其他辅助函数（保持不变）==========
def truncate_string_by_tokens(tokenizer, text: str, max_tokens=3072) -> str:
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
        return tokenizer.decode(tokens, skip_special_tokens=True)
    return text


def extract_solution(solution_str: str) -> str:
    boxed_answer = extract_last_boxed(solution_str)
    if boxed_answer:
        return boxed_answer
    return extract_last_final_answer(solution_str)


# def extract_last_boxed(text: str) -> str:
#     pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
#     matches = list(re.finditer(pattern, text))
#     if matches:
#         return matches[-1].group(1)
#     return None


def extract_last_boxed(text: str) -> str:
    """优化版本：只遍历一次字符串"""
    last_boxed_content = None
    i = 0
    n = len(text)
    
    while i < n - 7:  # 至少需要 "\boxed{}" 7个字符
        # 查找下一个 \boxed{
        if text[i:i+7] == '\\boxed{':
            i += 7  # 跳过 \boxed{
            start = i
            count = 1
            
            # 找到匹配的右括号
            while i < n and count > 0:
                if text[i] == '\\':
                    i += 2  # 跳过转义序列
                    continue
                if text[i] == '{':
                    count += 1
                elif text[i] == '}':
                    count -= 1
                i += 1
            
            if count == 0:
                last_boxed_content = text[start:i-1]
        else:
            i += 1
    
    return last_boxed_content

def extract_last_final_answer(text: str) -> str:
    candidate_patterns = [
        r"Final Answer:\s*((?:[^<]|<[^<])*?)\n",
        r"Final Answer is:\s*((?:[^<]|<[^<])*?)\n",
        r"The answer is:\s*((?:[^<]|<[^<])*?)\n",
        r"Answer:\s*((?:[^<]|<[^<])*?)\n",
        r"Solution:\s*((?:[^<]|<[^<])*?)\n",
        r"The solution is:\s*((?:[^<]|<[^<])*?)\n",
    ]
    
    last_match = None
    last_position = -1
    for pattern in candidate_patterns:
        for match in re.finditer(pattern, text, flags=re.IGNORECASE):
            if match.start() > last_position:
                last_position = match.start()
                last_match = match.group(1).strip()
    
    stop_words = ["</s>", "<|im_end|>", "<|endoftext|>"]
    for stop_word in stop_words:
        if last_match and last_match.endswith(stop_word):
            last_match = last_match[:-len(stop_word)].strip()
    
    return last_match