import requests
import json
import base64
import os
import time
import argparse
from tqdm import tqdm
from PIL import Image
import io
from openai import OpenAI
import concurrent.futures
import threading
import re
from collections import Counter, defaultdict
from mathruler.grader import extract_boxed_content

# 假设 oss_util.py 在同一目录下
from oss_util import get_oss_image

# ==============================================================================
# API 配置
# ==============================================================================

# Gemini API配置
GEMINI_API_BASE_URL = 'xxx'
GEMINI_API_TOKENS = []
GEMINI_MODEL_NAME = "gemini-2.5-pro"

# Qwen API配置
QWEN_API_BASE_URL = "xxxx"
QWEN_API_KEYS = []
QWEN_MODEL_NAME = "Qwen3-VL-235B-A22B-Instruct"

# Kimi LLM Client
llm_client = OpenAI(
    api_key="xxx",
    base_url="",
)

# Qwen3-VL-8B Client (新增)
mllm_client = OpenAI(
    api_key="EMPTY",
    base_url="http://127.0.0.1:18901/v1",
    timeout=3600
)


# ==============================================================================
# Token池管理
# ==============================================================================

class ImprovedTokenPool:
    """改进的Token池管理器"""
    def __init__(self, tokens: list, max_concurrent_per_token: int = 5):
        self.tokens = tokens
        self.max_concurrent_per_token = max_concurrent_per_token
        self.token_semaphores = {token: threading.Semaphore(max_concurrent_per_token) for token in tokens}
        self.token_active_count = {token: 0 for token in tokens}
        self.token_usage_count = defaultdict(int)
        self.token_error_count = defaultdict(int)
        self.token_last_error_time = defaultdict(float)
        self.lock = threading.Lock()
    
    def acquire_token(self, timeout=30):
        start_time = time.time()
        while time.time() - start_time < timeout:
            with self.lock:
                token_scores = []
                current_time = time.time()
                
                for token in self.tokens:
                    if self.token_semaphores[token]._value > 0:
                        error_penalty = 0
                        if token in self.token_last_error_time:
                            time_since_error = current_time - self.token_last_error_time[token]
                            if time_since_error < 60:
                                error_penalty = 10 * (1 - time_since_error / 60)
                        
                        score = self.token_active_count[token] + error_penalty
                        token_scores.append((score, token))
                
                if token_scores:
                    token_scores.sort()
                    selected_token = token_scores[0][1]
                    
                    if self.token_semaphores[selected_token].acquire(blocking=False):
                        self.token_active_count[selected_token] += 1
                        self.token_usage_count[selected_token] += 1
                        return selected_token
            
            time.sleep(0.1)
        
        raise TimeoutError("无法在规定时间内获取可用token")
    
    def release_token(self, token, has_error=False):
        with self.lock:
            self.token_active_count[token] = max(0, self.token_active_count[token] - 1)
            if has_error:
                self.token_error_count[token] += 1
                self.token_last_error_time[token] = time.time()
        self.token_semaphores[token].release()

# 创建Token池
gemini_token_pool = ImprovedTokenPool(GEMINI_API_TOKENS, max_concurrent_per_token=5)
qwen_token_pool = ImprovedTokenPool(QWEN_API_KEYS, max_concurrent_per_token=10)

# ==============================================================================
# 辅助函数
# ==============================================================================

def encode_pil_image_to_base64(pil_image: Image.Image, image_format: str = "PNG") -> str:
    buffered = io.BytesIO()
    pil_image.save(buffered, format=image_format)
    return f"data:image/{image_format.lower()};base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"

# ==============================================================================
# API 调用函数
# ==============================================================================

def call_gemini_api(messages, stream=False, max_retries=300, base_retry_delay=1):
    """通用的Gemini API调用函数"""
    url = GEMINI_API_BASE_URL
    data = {
        "stream": stream,
        "model": GEMINI_MODEL_NAME,
        "messages": messages
    }
    
    for retry_count in range(max_retries):
        current_token = None
        try:
            current_token = gemini_token_pool.acquire_token(timeout=30)
            headers = {
                "Content-Type": "application/json", 
                "Authorization": current_token
            }
            
            response = requests.post(
                url, 
                data=json.dumps(data), 
                headers=headers, 
                timeout=120, 
                stream=stream
            )
            
            if response.status_code == 429:
                gemini_token_pool.release_token(current_token, has_error=True)
                current_token = None
                time.sleep(base_retry_delay)
                continue
            
            if response.status_code != 200:
                gemini_token_pool.release_token(current_token, has_error=True)
                current_token = None
                time.sleep(base_retry_delay)
                continue
            
            if stream:
                full_content = ""
                for line in response.iter_lines():
                    if line:
                        decoded_line = line.decode('utf-8')
                        if decoded_line.startswith("data:"):
                            json_str = decoded_line[5:].strip()
                            if json_str == "[DONE]":
                                break
                            try:
                                chunk = json.loads(json_str)
                                content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
                                full_content += content
                            except json.JSONDecodeError:
                                continue
            else:
                response_text = response.text
                if response_text.startswith("data:"):
                    json_str = response_text[5:].strip()
                else:
                    json_str = response_text
                response_data = json.loads(json_str)
                full_content = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
            
            if current_token:
                gemini_token_pool.release_token(current_token, has_error=False)
            assert full_content, "响应内容为空"
            return full_content if full_content.strip() else None
            
        except TimeoutError:
            if current_token:
                gemini_token_pool.release_token(current_token, has_error=True)
            time.sleep(base_retry_delay)
            
        except Exception as e:
            if current_token:
                gemini_token_pool.release_token(current_token, has_error=True)
            time.sleep(base_retry_delay)
    
    return None

def call_qwen_api(payload: dict, max_retries: int = 300, base_retry_delay: float = 1.0):
    """Qwen API 调用"""
    for retry_count in range(max_retries):
        current_key = None
        try:
            current_key = qwen_token_pool.acquire_token(timeout=30)
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {current_key}",
            }
            
            resp = requests.post(
                QWEN_API_BASE_URL,
                headers=headers,
                json=payload,
                timeout=120,
                stream=False
            )
            
            if resp.status_code in (429, 503):
                qwen_token_pool.release_token(current_key, has_error=True)
                current_key = None
                time.sleep(base_retry_delay)
                continue
            
            if resp.status_code != 200:
                qwen_token_pool.release_token(current_key, has_error=True)
                current_key = None
                time.sleep(base_retry_delay)
                continue
            
            data = resp.json()
            qwen_token_pool.release_token(current_key, has_error=False)
            return data
            
        except TimeoutError:
            if current_key:
                qwen_token_pool.release_token(current_key, has_error=True)
            time.sleep(base_retry_delay)
            
        except Exception as e:
            if current_key:
                qwen_token_pool.release_token(current_key, has_error=True)
            time.sleep(base_retry_delay)
    
    return None

# ==============================================================================
# VQA Answer Generation
# ==============================================================================

def answer_question_with_gemini(image, question, num_samples=4):
    """使用Gemini回答问题，采样多次"""
    base64_image = encode_pil_image_to_base64(image)
    
    answers = []
    for _ in range(num_samples):
        messages = [
            {
                "role": "user", 
                "content": [
                    {"type": "image_url", "image_url": {"url": base64_image}},
                    {"type": "text", "text": f"Question: {question}\n\nPlease reason step by step, and then put your answer within \\boxed{{}}."}
                ]
            }
        ]
        
        answer = call_gemini_api(messages, stream=False)
        answer = extract_boxed_content(answer) if 'boxed' in answer else answer
        if answer:
            answers.append(answer)
        else:
            answers.append("[ERROR: No response]")
    
    return answers

def answer_question_with_qwen(image_b64, question, num_samples=4):
    """使用Qwen回答问题，采样多次"""
    answers = []
    
    for _ in range(num_samples):
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
                    {"type": "text", "text": f"Question: {question}\n\nPlease reason step by step, and then put your answer within \\boxed{{}}."}
                ]
            }
        ]
        
        payload = {
            "stream": False,
            "model": QWEN_MODEL_NAME,
            "messages": messages,
            "temperature": 0.7,
            "top_p": 0.9
        }
        
        resp_json = call_qwen_api(payload)
        if resp_json:
            try:
                content = resp_json["choices"][0]["message"]["content"]
                content = extract_boxed_content(content) if 'boxed' in content else content
                answers.append(content.strip())
            except:
                answers.append("[ERROR: Parse failed]")
        else:
            answers.append("[ERROR: No response]")
    
    return answers

def answer_question_with_qwen8b(img_path, question, num_samples=4):
    """使用Qwen3-VL-8B回答问题，采样多次"""
    answers = []
    
    for _ in range(num_samples):
        try:
            image_encoded = get_oss_image(img_path, image_type='base64')

            if img_path.endswith(".mp4"):
                base64_image = f'data:video/mp4;base64,{image_encoded}'
                message_mm = {"type": "video_url", "video_url": {"url": base64_image}}
            else:
                base64_image = f"data:image;base64,{image_encoded}"
                message_mm = {"type": "image_url", "image_url": {"url": base64_image}}

            messages = [
                {
                    "role": "user",
                    "content": [
                        message_mm,
                        {"type": "text", "text": question + "\n\nPlease reason step by step, and put your final answer within \\boxed{}."}
                    ]
                }
            ]
            
            response = mllm_client.chat.completions.create(
                model="Qwen3-VL-8B-Instruct",
                messages=messages,
                temperature=1.0,
                max_tokens=20480
            )

            answer = response.choices[0].message.content
            if 'boxed' in answer:
                answer = extract_boxed_content(answer)
            elif '<|begin_of_box|>' in answer:
                answer = answer.split('<|begin_of_box|>')[1].split('<|end_of_box|>')[0]
            
            answers.append(answer.strip())
        
        except Exception as e:
            print(f"Error in Qwen8B answer generation: {str(e)}")
            time.sleep(1)
            answers.append("[ERROR: No response]")
    
    return answers





# ==============================================================================
# LLM Evaluation Functions
# ==============================================================================

def evaluate_consistency_with_llm(question, answers, ground_truth, client):
    """使用语言模型评估答案的一致性"""
    answers_text = "\n".join([f"Answer {i+1}: {ans}" for i, ans in enumerate(answers)])
    
    prompt = f"""You are an expert at evaluating the consistency of multiple answers to the same question.

Question: {question}

Ground Truth Answer: {ground_truth}

Here are 4 different answers to this question:

{answers_text}

Your task:
1. Carefully analyze all 4 answers
2. Group semantically similar or equivalent answers together
3. Identify the most common answer group (the mode)
4. Count how many answers belong to this most common group
5. Compare with the ground truth to see if the most common answer is correct

Consider them equivalent if:
- Different phrasings of the same answer (e.g., "5 people" vs "five people")
- Slight variations in description that refer to the same entity
- Minor differences in precision that don't change the core answer (such as two colors that are similar: e.g., brown and dark brown)

Do NOT consider them equivalent if:
- They refer to different objects, numbers, or concepts
- They represent different interpretations of the question

Please think step by step, then provide:
- The most common answer (normalized/standardized form)
- The count of how many answers are consistent with this most common answer (a number from 1 to 4)
- Whether the most common answer matches the ground truth (Yes/No)

Format your response as:
Most Common Answer: [your answer here]
Consistency Count: [number]
Matches Ground Truth: [Yes/No]
"""

    max_retries = 2000
    for attempt in range(max_retries):
        try:
            completion = client.chat.completions.create(
                model="Kimi-K2-Instruct-0905",
                messages=[
                    {'role': 'user', 'content': prompt},
                ],
                temperature=0.3,
                max_tokens=20480,
            )
            
            response = completion.choices[0].message.content
            
            # 解析响应
            most_common_answer = None
            consistency_count = 0
            matches_ground_truth = False
            
            for line in response.split('\n'):
                if 'Most Common Answer:' in line:
                    most_common_answer = line.split('Most Common Answer:')[1].strip()
                elif 'Consistency Count:' in line:
                    count_str = line.split('Consistency Count:')[1].strip()
                    numbers = re.findall(r'\d+', count_str)
                    if numbers:
                        consistency_count = int(numbers[0])
                elif 'Matches Ground Truth:' in line:
                    match_str = line.split('Matches Ground Truth:')[1].strip().lower()
                    matches_ground_truth = 'yes' in match_str
            
            return {
                "most_common_answer": most_common_answer,
                "consistency_count": consistency_count,
                "consistency_score": consistency_count / 4.0,
                "matches_ground_truth": matches_ground_truth,
                "llm_response": response,
                "success": True,
                "error": None
            }
        
        except Exception as e:
            if attempt == max_retries - 1:
                return {
                    "most_common_answer": None,
                    "consistency_count": 0,
                    "consistency_score": 0.0,
                    "matches_ground_truth": False,
                    "llm_response": None,
                    "success": False,
                    "error": str(e)
                }
            time.sleep(2)

def compare_original_vs_crop_answers(original_eval, crop_eval, client, question):
    """使用LLM比较原图和裁剪图的最常见答案是否一致"""
    if not original_eval or not crop_eval:
        return {
            "answers_match": None,
            "comparison_response": None,
            "success": False,
            "error": "Missing evaluation results"
        }
    
    original_answer = original_eval.get('most_common_answer')
    crop_answer = crop_eval.get('most_common_answer')
    
    if not original_answer or not crop_answer:
        return {
            "answers_match": None,
            "comparison_response": None,
            "success": False,
            "error": "Missing most common answer"
        }
    
    prompt = f"""You are an expert at comparing answers for semantic equivalence.

Question: {question}

Answer1: {original_answer}
Answer2: {crop_answer}

Your task:
Determine whether these two answers are semantically equivalent or refer to the same thing.
Consider them equivalent if:
- Different phrasings of the same answer (e.g., "5 people" vs "five people")
- Slight variations in description that refer to the same entity
- Minor differences in precision that don't change the core answer (such as two colors that are similar: e.g., brown and dark brown)

Do NOT consider them equivalent if:
- They refer to different objects, numbers, or concepts
- They represent different interpretations of the question

Please think step by step, then provide your judgment.

Format your response as:
Answers Match: Yes/No
Explanation: [Your reasoning here]
"""

    max_retries = 2000
    for attempt in range(max_retries):
        try:
            completion = client.chat.completions.create(
                model="Kimi-K2-Instruct-0905",
                messages=[
                    {'role': 'user', 'content': prompt},
                ],
                temperature=0.3,
                max_tokens=20480,
            )
            
            response = completion.choices[0].message.content
            
            # 解析响应
            answers_match = False
            explanation = ""
            
            for line in response.split('\n'):
                if 'Answers Match:' in line:
                    match_str = line.split('Answers Match:')[1].strip().lower()
                    answers_match = 'yes' in match_str
                elif 'Explanation:' in line:
                    explanation = line.split('Explanation:')[1].strip()
            
            return {
                "answers_match": answers_match,
                "original_answer": original_answer,
                "crop_answer": crop_answer,
                "explanation": explanation,
                "comparison_response": response,
                "success": True,
                "error": None
            }
        
        except Exception as e:
            if attempt == max_retries - 1:
                return {
                    "answers_match": None,
                    "original_answer": original_answer,
                    "crop_answer": crop_answer,
                    "explanation": None,
                    "comparison_response": None,
                    "success": False,
                    "error": str(e)
                }
            time.sleep(2)

def check_any_answer_matches_ground_truth(answers, ground_truth, client, question):
    """
    检查4个答案中是否有任何一个与ground truth匹配
    
    Args:
        answers: list of 4 answers
        ground_truth: the ground truth answer (a2)
        client: LLM client
        question: the original question
    
    Returns:
        bool: True if any answer matches ground truth
    """
    prompt = f"""You are an expert at comparing answers for semantic equivalence.

Question: {question}

Ground Truth Answer: {ground_truth}

Here are 4 candidate answers:
{chr(10).join([f"Answer {i+1}: {ans}" for i, ans in enumerate(answers)])}

Your task:
Check if ANY of these 4 answers is semantically equivalent to the ground truth answer.
Consider them equivalent if:
- Different phrasings of the same answer (e.g., "5 people" vs "five people")
- Slight variations in description that refer to the same entity
- Minor differences in precision that don't change the core answer (such as two colors that are similar: e.g., brown and dark brown)

Do NOT consider them equivalent if:
- They refer to different objects, numbers, or concepts
- They represent different interpretations of the question

Please think step by step, then provide your judgment.

Format your response as:
Has Matching Answer: Yes/No
Matching Answer Number: [1-4 or None]
Explanation: [Your reasoning here]
"""

    max_retries = 2000
    for attempt in range(max_retries):
        try:
            completion = client.chat.completions.create(
                model="Kimi-K2-Instruct-0905",
                messages=[
                    {'role': 'user', 'content': prompt},
                ],
                temperature=0.3,
                max_tokens=20480,
            )
            
            response = completion.choices[0].message.content
            
            # 解析响应
            has_match = False
            
            for line in response.split('\n'):
                if 'Has Matching Answer:' in line:
                    match_str = line.split('Has Matching Answer:')[1].strip().lower()
                    has_match = 'yes' in match_str
                    break
            
            return has_match
        
        except Exception as e:
            if attempt == max_retries - 1:
                # 如果所有重试都失败，保守地返回False
                return False
            time.sleep(2)
    
    return False

def check_qwen8b_correct_count(qwen8b_answers, ground_truth, client, question):
    """
    检查Qwen8B的4个答案中有多少个与ground truth匹配
    
    Args:
        qwen8b_answers: list of 4 answers from Qwen8B
        ground_truth: the ground truth answer (a2)
        client: LLM client
        question: the original question
    
    Returns:
        int: number of correct answers (0-4)
    """
    prompt = f"""You are an expert at comparing answers for semantic equivalence.

Question: {question}

Ground Truth Answer: {ground_truth}

Here are 4 candidate answers from a model:
{chr(10).join([f"Answer {i+1}: {ans}" for i, ans in enumerate(qwen8b_answers)])}

Your task:
Count how many of these 4 answers are semantically equivalent to the ground truth answer.
Consider them equivalent if:
- Different phrasings of the same answer (e.g., "5 people" vs "five people")
- Slight variations in description that refer to the same entity
- Minor differences in precision that don't change the core answer (such as two colors that are similar: e.g., brown and dark brown)

Do NOT consider them equivalent if:
- They refer to different objects, numbers, or concepts
- They represent different interpretations of the question

Please analyze each answer step by step, then provide your count.

Format your response as:
Correct Count: [0-4]
Correct Answer Numbers: [list like "1, 3" or "None"]
Explanation: [Your reasoning here]
"""

    max_retries = 2000
    for attempt in range(max_retries):
        try:
            completion = client.chat.completions.create(
                model="Kimi-K2-Instruct-0905",
                messages=[
                    {'role': 'user', 'content': prompt},
                ],
                temperature=0.3,
                max_tokens=20480,
            )
            
            response = completion.choices[0].message.content
            
            # 解析响应
            correct_count = 0
            
            for line in response.split('\n'):
                if 'Correct Count:' in line:
                    count_str = line.split('Correct Count:')[1].strip()
                    numbers = re.findall(r'\d+', count_str)
                    if numbers:
                        correct_count = int(numbers[0])
                    break
            
            return correct_count
        
        except Exception as e:
            if attempt == max_retries - 1:
                # 如果所有重试都失败，保守地返回0
                return 0
            time.sleep(2)
    
    return 0

# ==============================================================================
# 主处理流程
# ==============================================================================



def process_single_vqa(original_image, crop_image, original_b64, question, original_image_path):
    """
    处理单个VQA问题的完整流程
    
    返回: {
        "question": str,
        "status": "valid very easy vqa" | "valid easy vqa" | "valid middle vqa" | "invalid vqa" | "need annotation" | "hard vqa",
        "a2": str (crop answer - ground truth),
        "a1": str (qwen on original),
        "a3": str (gemini on original),
        "crop_eval": dict,
        "original_qwen_eval": dict,
        "original_gemini_eval": dict,
        "comparisons": dict
    }
    """
    result = {
        "question": question,
        "status": "processing",
        "a2": None,
        "a1": None,
        "a3": None,
        "crop_eval": None,
        "qwen8b_eval": None,  # 新增：先测试8B
        "original_qwen_eval": None,
        "original_gemini_eval": None,
        "original_gemini_raw_answers": None,
        "comparisons": {}
    }
    
    # Step 1: 用Gemini对crop图回答问题（采样4次）
    try:
        crop_answers = answer_question_with_gemini(crop_image, question, num_samples=4)
    except Exception as e:
        result["status"] = "invalid vqa"
        result["error"] = f"Crop answering failed: {str(e)}"
        return result
    
    # Step 2: 评估crop答案的一致性
    crop_eval = evaluate_consistency_with_llm(
        question, 
        crop_answers, 
        ground_truth="N/A",
        client=llm_client
    )
    result["crop_eval"] = crop_eval
    result["a2"] = crop_eval.get("most_common_answer")
    
    # Step 3: 检查一致性分数
    consistency_score = crop_eval.get("consistency_score", 0.0)
    
    if consistency_score < 0.75:
        result["status"] = "invalid vqa"
        result["reason"] = f"Low consistency: {consistency_score}"
        return result
    
    # Step 4: 一致性高，a2作为ground truth
    a2 = result["a2"]
    
    # Step 5: 新增 - 先用Qwen8B对原图测试4次
    try:
        qwen8b_answers = answer_question_with_qwen8b(original_image_path, question, num_samples=4)
        
        # 检查有多少个答案正确
        correct_count_8b = check_qwen8b_correct_count(
            qwen8b_answers,
            a2,
            llm_client,
            question
        )
        
        result["qwen8b_eval"] = {
            "answers": qwen8b_answers,
            "correct_count": correct_count_8b,
            "ground_truth": a2
        }
        
        # 如果8B正确次数>=2，直接标记为very easy并返回
        if correct_count_8b >= 2:
            result["status"] = "valid very easy vqa"
            result["a1"] = None  # 未测试235B
            return result
        
    except Exception as e:
        # 如果Qwen8B调用失败，记录错误但继续流程
        result["qwen8b_eval"] = {
            "error": str(e),
            "correct_count": 0
        }
        print(f"Warning: Qwen8B test failed for question '{question}': {str(e)}")
    
    # Step 6: 8B正确次数<2，用Qwen235B对原图回答
    try:
        original_qwen_answers = answer_question_with_qwen(original_b64, question, num_samples=4)
    except Exception as e:
        result["status"] = "error"
        result["error"] = f"Qwen235B answering failed: {str(e)}"
        return result
    
    # Step 7: 评估Qwen235B答案的一致性
    original_qwen_eval = evaluate_consistency_with_llm(
        question,
        original_qwen_answers,
        ground_truth=a2,
        client=llm_client
    )
    result["original_qwen_eval"] = original_qwen_eval
    result["a1"] = original_qwen_eval.get("most_common_answer")
    
    # Step 8: 比较a1和a2
    comparison_a1_a2 = compare_original_vs_crop_answers(
        original_qwen_eval,
        crop_eval,
        llm_client,
        question
    )
    result["comparisons"]["a1_vs_a2"] = comparison_a1_a2
    
    if comparison_a1_a2.get("answers_match"):
        # 235B答对了，标记为easy vqa
        result["status"] = "valid easy vqa"
        return result
    
    # Step 9: a1 != a2，用Gemini对原图回答
    try:
        original_gemini_answers = answer_question_with_gemini(original_image, question, num_samples=4)
        result["original_gemini_raw_answers"] = original_gemini_answers
    except Exception as e:
        result["status"] = "error"
        result["error"] = f"Gemini answering failed: {str(e)}"
        return result
    
    # Step 10: 评估Gemini答案的一致性
    original_gemini_eval = evaluate_consistency_with_llm(
        question,
        original_gemini_answers,
        ground_truth=a2,
        client=llm_client
    )
    result["original_gemini_eval"] = original_gemini_eval
    result["a3"] = original_gemini_eval.get("most_common_answer")
    
    # Step 11: 比较a3和a2
    comparison_a3_a2 = compare_original_vs_crop_answers(
        original_gemini_eval,
        crop_eval,
        llm_client,
        question
    )
    result["comparisons"]["a3_vs_a2"] = comparison_a3_a2
    
    if comparison_a3_a2.get("answers_match"):
        result["status"] = "valid middle vqa"
    else:
        has_matching_answer = check_any_answer_matches_ground_truth(
            original_gemini_answers,
            a2,
            llm_client,
            question
        )
        
        if has_matching_answer:
            result["status"] = "hard vqa"
            result["has_matching_individual_answer"] = True
        else:
            result["status"] = "need annotation"
            result["has_matching_individual_answer"] = False
    
    return result



def process_single_image_task(line_data):
    """处理单个图像的所有valid questions"""
    original_image_path = line_data.get('oss_image_path')
    if not original_image_path:
        return None
    
    # 加载原图
    original_image = get_oss_image(original_image_path)
    if not original_image:
        return None
    
    original_b64 = get_oss_image(original_image_path, image_type='base64')
    
    generated_vqa = line_data.get('generated_vqa', [])
    
    results = []
    # print('111')
    # print(generated_vqa)
    
    for vqa_item in generated_vqa:
        status = vqa_item.get('status')
        
        # 只处理成功生成的数据
        if status != 'success':
            continue
        
        crop_path = vqa_item.get('crop_path')
        if not crop_path:
            continue
        
        # 加载crop图
        crop_image = get_oss_image(crop_path)
        if not crop_image:
            continue
        
        vqa_pairs = vqa_item.get('vqa_pairs', [])
        
        # 处理每个valid question
        for pair in vqa_pairs:
            validation_status = pair.get('validation_status')
            if validation_status != 'passed':
                continue
            
            question = pair.get('question')
            if not question:
                continue
            
            # 处理这个valid question
            try:
                vqa_result = process_single_vqa(
                    original_image,
                    crop_image,
                    original_b64,
                    question,
                    original_image_path  # 传递原图路径
                )
                
                # 添加元数据
                vqa_result['original_image_path'] = original_image_path
                vqa_result['crop_path'] = crop_path
                vqa_result['mask_path'] = vqa_item.get('mask_path')
                
                results.append(vqa_result)
                
            except Exception as e:
                print(f"Error processing question '{question}': {e}")
                results.append({
                    "question": question,
                    "status": "error",
                    "error": str(e),
                    "original_image_path": original_image_path,
                    "crop_path": crop_path
                })
    
    if results:
        return {
            "oss_image_path": original_image_path,
            "vqa_results": results
        }
    
    return None


# ==============================================================================
# 主函数
# ==============================================================================


def main(args):
    # 读取输入文件
    print(f"✓ 正在读取输入文件: {args.input_file}")
    
    with open(args.input_file, 'r', encoding='utf-8') as f:
        all_lines = [json.loads(line) for line in f]
    
    print(f"✓ 共读取 {len(all_lines)} 条数据")
    
    # 断点恢复
    processed_paths = set()
    if os.path.exists(args.output_file):
        with open(args.output_file, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    processed_paths.add(data.get('oss_image_path'))
                except:
                    continue
        print(f"✓ 从输出文件恢复，已处理 {len(processed_paths)} 条数据")
    
    # 过滤任务
    tasks = [line for line in all_lines if line.get('oss_image_path') not in processed_paths]
    
    if not tasks:
        print("✓ 没有需要处理的新任务")
        return
    
    print(f"✓ 待处理任务数: {len(tasks)}")
    
    # 统计信息
    total_processed = 0
    total_valid_very_easy = 0  # 新增
    total_valid_easy = 0       # 新增
    total_valid_middle = 0     # 新增
    total_invalid = 0
    total_need_annotation = 0
    total_hard = 0
    
    write_lock = threading.Lock()
    
    # 并发处理
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, \
         open(args.output_file, 'a', encoding='utf-8') as f_out:
        
        future_to_task = {
            executor.submit(process_single_image_task, task): task.get('oss_image_path')
            for task in tasks
        }
        
        progress_bar = tqdm(
            concurrent.futures.as_completed(future_to_task),
            total=len(tasks),
            desc="Processing VQA Validation",
            ncols=140
        )
        
        for future in progress_bar:
            task_id = future_to_task[future]
            
            try:
                result = future.result(timeout=600)  # 10分钟超时
                
                if result:
                    # 统计
                    vqa_results = result.get('vqa_results', [])
                    for vqa in vqa_results:
                        total_processed += 1
                        status = vqa.get('status')
                        if status == 'valid very easy vqa':
                            total_valid_very_easy += 1
                        elif status == 'valid easy vqa':
                            total_valid_easy += 1
                        elif status == 'valid middle vqa':
                            total_valid_middle += 1
                        elif status == 'invalid vqa':
                            total_invalid += 1
                        elif status == 'need annotation':
                            total_need_annotation += 1
                        elif status == 'hard vqa':
                            total_hard += 1
                    
                    # 写入文件
                    with write_lock:
                        f_out.write(json.dumps(result, ensure_ascii=False) + '\n')
                        f_out.flush()
                    
                    # 更新进度条
                    progress_bar.set_postfix({
                        'v_easy': total_valid_very_easy,
                        'easy': total_valid_easy,
                        'mid': total_valid_middle,
                        'hard': total_hard,
                        'invalid': total_invalid,
                        'need_anno': total_need_annotation
                    })
                
            except concurrent.futures.TimeoutError:
                print(f"\n⚠ 任务 {task_id} 超时")
                
            except Exception as exc:
                print(f"\n⚠ 任务 {task_id} 发生异常: {exc}")
    
    # 最终统计
    total_valid = total_valid_very_easy + total_valid_easy + total_valid_middle
    print("\n" + "="*100)
    print("处理完成！最终统计:")
    print(f"  总处理VQA数: {total_processed}")
    print(f"  总Valid VQA: {total_valid} ({total_valid/total_processed*100:.2f}%)" if total_processed > 0 else "  总Valid VQA: 0")
    print(f"    ├─ Valid Very Easy VQA: {total_valid_very_easy} ({total_valid_very_easy/total_processed*100:.2f}%)" if total_processed > 0 else "    ├─ Valid Very Easy VQA: 0")
    print(f"    ├─ Valid Easy VQA: {total_valid_easy} ({total_valid_easy/total_processed*100:.2f}%)" if total_processed > 0 else "    ├─ Valid Easy VQA: 0")
    print(f"    └─ Valid Middle VQA: {total_valid_middle} ({total_valid_middle/total_processed*100:.2f}%)" if total_processed > 0 else "    └─ Valid Middle VQA: 0")
    print(f"  Hard VQA: {total_hard} ({total_hard/total_processed*100:.2f}%)" if total_processed > 0 else "  Hard VQA: 0")
    print(f"  Invalid VQA: {total_invalid} ({total_invalid/total_processed*100:.2f}%)" if total_processed > 0 else "  Invalid VQA: 0")
    print(f"  Need Annotation: {total_need_annotation} ({total_need_annotation/total_processed*100:.2f}%)" if total_processed > 0 else "  Need Annotation: 0")
    print("="*100)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="VQA Answer Validation Pipeline")
    
    parser.add_argument(
        '--input_file', 
        type=str, 
        default="/mnt/nas/yanlong/code/new_benchmark/utils_new_pipeline/sa1b_0000_finegained_gen-gemini_judge-gemini.jsonl",
        help='输入JSONL文件路径'
    )
    
    parser.add_argument(
        '--output_file', 
        type=str, 
        default="sa1b_0000_validated_vqa_gen-gemini_judge-gemini.jsonl",
        help='输出JSONL文件路径'
    )
    
    parser.add_argument(
        '--max_workers', 
        type=int, 
        default=12,
        help='并发处理的最大worker数'
    )
    
    args = parser.parse_args()
    
    print(f"""
╔════════════════════════════════════════════════════════════════════════════╗
║                      VQA Answer Validation Pipeline                        ║
╟────────────────────────────────────────────────────────────────────────────╢
║  输入文件: {args.input_file:58s} ║
║  输出文件: {args.output_file:58s} ║
║  并发数: {args.max_workers:2d}                                                              ║
╟────────────────────────────────────────────────────────────────────────────╢
║  处理流程:                                                                  ║
║  1. 用Gemini对crop图回答问题（采样4次）                                     ║
║  2. Kimi评估一致性，得到a2（ground truth）                                  ║
║  3. 如果一致性<0.75，标记为"invalid vqa"                                    ║
║  4. 先用Qwen3-VL-8B对原图测试4次                                            ║
║     - 如果正确次数≥2，标记为"valid very easy vqa"并结束                     ║
║  5. 如果8B正确次数<2，用Qwen3-VL-235B对原图回答，得到a1                     ║
║  6. Kimi比较a1和a2，如果相同，标记为"valid easy vqa"并结束                  ║
║  7. 如果a1≠a2，用Gemini对原图回答，得到a3                                   ║
║  8. Kimi比较a3和a2，如果相同，标记为"valid middle vqa"                      ║
║  9. 如果a3≠a2，检查4个gemini答案中是否有任一个与a2匹配                      ║
║     - 如果有匹配，标记为"hard vqa"                                          ║
║     - 否则标记为"need annotation"                                           ║
╚════════════════════════════════════════════════════════════════════════════╝
""")
    
    main(args)


