import requests
import json
import base64
import os
import time
import argparse
from tqdm import tqdm
from PIL import Image, ImageChops
import io
from openai import OpenAI
import concurrent.futures
import random
import threading
import shutil
import queue
from collections import defaultdict
import re

# 假设 oss_util.py 在同一目录下或在Python路径中
from oss_util import get_oss_image, image_upload_to_oss

class ImprovedTokenPool:
    """改进的Token池管理器，支持非阻塞获取和动态负载均衡"""
    def __init__(self, tokens: list, max_concurrent_per_token: int = 5):
        """
        Args:
            tokens: Token列表
            max_concurrent_per_token: 每个token的最大并发数
        """
        self.tokens = tokens
        self.max_concurrent_per_token = max_concurrent_per_token
        
        # 为每个token维护一个信号量和使用计数
        self.token_semaphores = {token: threading.Semaphore(max_concurrent_per_token) for token in tokens}
        self.token_active_count = {token: 0 for token in tokens}  # 当前活跃请求数
        
        # 统计信息
        self.token_usage_count = defaultdict(int)
        self.token_error_count = defaultdict(int)
        self.token_last_error_time = defaultdict(float)  # 记录最后一次错误时间
        self.lock = threading.Lock()
    
    def acquire_token(self, timeout=30):
        """
        获取一个可用的token（非阻塞，带超时）
        优先选择负载最低且近期错误少的token
        """
        start_time = time.time()
        
        while time.time() - start_time < timeout:
            with self.lock:
                # 计算每个token的得分（越低越好）
                token_scores = []
                current_time = time.time()
                
                for token in self.tokens:
                    # 如果该token的信号量可用
                    if self.token_semaphores[token]._value > 0:
                        # 计算得分：活跃请求数 + 错误惩罚
                        error_penalty = 0
                        if token in self.token_last_error_time:
                            time_since_error = current_time - self.token_last_error_time[token]
                            if time_since_error < 60:  # 1分钟内有错误，增加惩罚
                                error_penalty = 10 * (1 - time_since_error / 60)
                        
                        score = self.token_active_count[token] + error_penalty
                        token_scores.append((score, token))
                
                if token_scores:
                    # 选择得分最低的token
                    token_scores.sort()
                    selected_token = token_scores[0][1]
                    
                    # 尝试获取信号量（非阻塞）
                    if self.token_semaphores[selected_token].acquire(blocking=False):
                        self.token_active_count[selected_token] += 1
                        self.token_usage_count[selected_token] += 1
                        return selected_token
            
            # 如果没有可用token，短暂休眠后重试
            time.sleep(0.1)
        
        raise TimeoutError("无法在规定时间内获取可用token")
    
    def release_token(self, token, has_error=False):
        """释放token"""
        with self.lock:
            self.token_active_count[token] = max(0, self.token_active_count[token] - 1)
            if has_error:
                self.token_error_count[token] += 1
                self.token_last_error_time[token] = time.time()
        
        self.token_semaphores[token].release()
    
    def get_stats(self):
        """获取使用统计"""
        with self.lock:
            return {
                "usage": dict(self.token_usage_count),
                "errors": dict(self.token_error_count),
                "active": dict(self.token_active_count)
            }


# ==============================================================================
# 配置
# ==============================================================================

# Gemini API配置
GEMINI_API_BASE_URL = 'xxx'
GEMINI_API_TOKENS = []
GEMINI_MODEL_NAME = "gemini-2.5-pro"

# Qwen API配置
QWEN_API_BASE_URL = "xxxx"
QWEN_API_KEYS = []
QWEN_MODEL_NAME = "Qwen3-VL-235B-A22B-Instruct"

# Kimi LLM Client
llm_client = OpenAI(
    api_key="xxx",
    base_url="",
)

# Qwen3-VL-8B Client (新增)
mllm_client = OpenAI(
    api_key="EMPTY",
    base_url="http://127.0.0.1:18901/v1",
    timeout=3600
)

QWEN_MAX_CONCURRENT_PER_KEY = 10

qwen_token_pool = ImprovedTokenPool(QWEN_API_KEYS, max_concurrent_per_token=QWEN_MAX_CONCURRENT_PER_KEY)
QWEN_MODEL_NAME = "Qwen3-VL-235B-A22B-Instruct"

# 创建改进的Token池 (每个token允许5个并发)
token_pool = ImprovedTokenPool(GEMINI_API_TOKENS, max_concurrent_per_token=5)

crop_oss_prefix = "Data/safety_post_train/sa1b/crop_images"
MIN_RESOLUTION = 1536 * 1536
MAX_MASK_RATIO = 0.10


# ==============================================================================
# 辅助函数
# ==============================================================================

def extract_json_from_text(text):
    """从模型返回的字符串中提取 JSON 内容"""
    if not text:
        return []
    
    json_block_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
    if json_block_match:
        content = json_block_match.group(1).strip()
    else:
        start_idx = text.find('[')
        end_idx = text.rfind(']')
        if start_idx != -1 and end_idx != -1:
            content = text[start_idx:end_idx + 1]
        else:
            content = text.strip()

    try:
        data = json.loads(content)
        return data if isinstance(data, list) else []
    except Exception as e:
        print(f"JSON 解析失败: {e}")
        return []


def is_image_pair_valid(original_image: Image.Image, masked_image: Image.Image) -> (bool, str, tuple):
    w, h = original_image.size
    if w * h <= MIN_RESOLUTION: return False, f"skipped_filter_resolution (is {w*h}, need > {MIN_RESOLUTION})", None
    try:
        mask_w, mask_h = masked_image.size
        rgb_masked_image = masked_image.convert("RGB")
        inverted_image = ImageChops.invert(rgb_masked_image)
        bbox = inverted_image.getbbox()
        if bbox is None: return False, "skipped_filter_mask_ratio (no object)", None
        object_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
        area_ratio = object_area / (mask_w * mask_h)
        if area_ratio >= 0.3: return False, f"skipped_filter_mask_ratio (is {area_ratio:.2%})", bbox
        if area_ratio >= MAX_MASK_RATIO and random.random() > 0.5: return False, f"skipped_filter_mask_ratio (is {area_ratio:.2%})", bbox
        return True, "valid", bbox
    except Exception as e: return False, f"skipped_filter_mask_ratio (error: {e})", None


def encode_pil_image_to_base64(pil_image: Image.Image, image_format: str = "PNG") -> str:
    buffered = io.BytesIO()
    pil_image.save(buffered, format=image_format)
    return f"data:image/{image_format.lower()};base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"


def get_vqa_generation_prompt(seed_questions):
    sample_size = min(15, len(seed_questions))
    selected_examples = random.sample(seed_questions, sample_size)
    examples_str = "\n".join([f"- {q}" for q in selected_examples])

    return f"""You are an expert specialist in generating Visual Question Answering (VQA) datasets. Your task is to generate three high-quality, valid and challenging questions based on a provided **'Original Image'** and its corresponding **'Cropped Image'**.

**Reference Examples (use these for inspiration on questioning angles):**
{examples_str}

**Core Generation Rules:**
1. **Consistency of Answers:** The answer to each question must be identical, accurate, and concise for both the 'Original Image' and the 'Cropped Image'. It can be a **short, factual, and concrete string** (e.g., a number, a noun, or text).
2. **Content Relevance:** Questions must focus exclusively on the objects or content visible within the 'Cropped Image'. Question types include, but are not limited to:
    * **Object Identification:** Identify the exact sub-component or item. (e.g., 'What is the person holding in their hand?' Answer: 'Apple').
    * **OCR:** Recognizing text within the image.
3. **Strict No-Context Rule:** Do not ask questions that require background information or context found only in the 'Original Image' to answer.
4. **Diversity of Questions:** Provided that the consistency rule is met, aim for a diverse range of questions. This includes counting, spatial relationships, scene recognition, anomaly detection, shape, material, structure, color, etc.

**Please carefully observe the images, and generate the high-quality, valid and challenging questions in the following JSON format:**

```json
[
  {{"question": "Question 1"}},
  {{"question": "Question 2"}},
  {{"question": "Question 3"}}
]
```"""


# ==============================================================================
# Gemini API 调用（优化版）
# ==============================================================================

def call_gemini_api(messages, stream=True, max_retries=3000, base_retry_delay=1):
    """
    通用的Gemini API调用函数，支持自动重试和Token池
    
    Args:
        messages: 消息列表
        stream: 是否使用流式响应
        max_retries: 最大重试次数
        base_retry_delay: 基础重试延迟
    
    Returns:
        响应内容字符串，失败返回None
    """
    url = GEMINI_API_BASE_URL
    data = {
        "stream": stream,
        "model": GEMINI_MODEL_NAME,
        "messages": messages
    }
    
    for retry_count in range(max_retries):
        current_token = None
        try:
            # 获取token（带超时）
            current_token = token_pool.acquire_token(timeout=30)
            headers = {
                "Content-Type": "application/json", 
                "Authorization": current_token
            }
            
            response = requests.post(
                url, 
                data=json.dumps(data), 
                headers=headers, 
                timeout=120, 
                stream=stream
            )
            
            # 处理限流
            if response.status_code == 429:
                token_pool.release_token(current_token, has_error=True)
                current_token = None
                
                retry_delay = base_retry_delay
                print(f"⚠ 限流，等待 {retry_delay}s 后重试 ({retry_count + 1}/{max_retries})")
                time.sleep(retry_delay)
                continue
            
            if response.status_code != 200:
                token_pool.release_token(current_token, has_error=True)
                current_token = None
                print(f"API错误 {response.status_code}")
                time.sleep(base_retry_delay)
                continue
            
            # 处理响应
            if stream:
                full_content = ""
                for line in response.iter_lines():
                    if line:
                        decoded_line = line.decode('utf-8')
                        if decoded_line.startswith("data:"):
                            json_str = decoded_line[5:].strip()
                            if json_str == "[DONE]":
                                break
                            try:
                                chunk = json.loads(json_str)
                                content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
                                full_content += content
                            except json.JSONDecodeError:
                                continue
            else:
                response_text = response.text
                if response_text.startswith("data:"):
                    json_str = response_text[5:].strip()
                else:
                    json_str = response_text
                response_data = json.loads(json_str)
                full_content = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
            
            # 成功释放token
            if current_token:
                token_pool.release_token(current_token, has_error=False)
            assert full_content, "响应内容为空"
            return full_content if full_content.strip() else None
            
        except TimeoutError as e:
            print(f"⚠ 获取Token超时: {e}")
            if current_token:
                token_pool.release_token(current_token, has_error=True)
            time.sleep(base_retry_delay)
            
        except Exception as e:
            if current_token:
                token_pool.release_token(current_token, has_error=True)
            print(f"⚠ API调用异常 ({retry_count + 1}/{max_retries}): {e}")
            time.sleep(base_retry_delay)
    
    return None

def call_qwen_api(payload: dict, api_url: str, stream: bool = False,
                  max_retries: int = 3000, base_retry_delay: float = 1.0):
    """
    Qwen API 调用：多key并发 + 自动重试
    payload: 直接是 requests.post(json=payload) 的 body
    """
    for retry_count in range(max_retries):
        current_key = None
        try:
            current_key = qwen_token_pool.acquire_token(timeout=30)

            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {current_key}",
            }

            resp = requests.post(
                api_url,
                headers=headers,
                json=payload,
                timeout=120,
                stream=stream
            )

            # 限流/过载：标记该key近期错误，稍等重试
            if resp.status_code in (429, 503):
                qwen_token_pool.release_token(current_key, has_error=True)
                current_key = None
                time.sleep(base_retry_delay)
                continue

            # 其他非200
            if resp.status_code != 200:
                qwen_token_pool.release_token(current_key, has_error=True)
                current_key = None
                time.sleep(base_retry_delay)
                continue

            data = resp.json()

            # 成功释放key
            qwen_token_pool.release_token(current_key, has_error=False)
            current_key = None
            return data

        except TimeoutError as e:
            if current_key:
                qwen_token_pool.release_token(current_key, has_error=True)
            time.sleep(base_retry_delay)

        except Exception as e:
            if current_key:
                qwen_token_pool.release_token(current_key, has_error=True)
            time.sleep(base_retry_delay)

    return None



def generate_vqa_with_gemini(original_image, masked_image, seed_questions):
    """【生成器】使用 Gemini 生成 VQA 问题对"""
    generator_prompt = get_vqa_generation_prompt(seed_questions)
    base64_original_image = encode_pil_image_to_base64(original_image)
    base64_masked_image = encode_pil_image_to_base64(masked_image)
    
    messages = [
        {
            "role": "user", 
            "content": [
                {"type": "text", "text": "This is the Original Image:"},
                {"type": "image_url", "image_url": {"url": base64_original_image}},
                {"type": "text", "text": "This is the Cropped Image:"},
                {"type": "image_url", "image_url": {"url": base64_masked_image}},
                {"type": "text", "text": generator_prompt}
            ]
        }
    ]
    
    full_content = call_gemini_api(messages, stream=True, max_retries=3000)
    
    if not full_content:
        return None
    
    vqa_pairs = extract_json_from_text(full_content)
    return vqa_pairs if vqa_pairs else None


def validate_single_vqa_with_gemini(original_image, masked_image, question):
    """验证单个VQA问题（用于并发调用）"""
    base64_original_image = encode_pil_image_to_base64(original_image)
    base64_masked_image = encode_pil_image_to_base64(masked_image)
    
    validation_prompt = f"""
You are an expert at validating whether a question is appropriate for Visual Question Answering (VQA) under a crop-consistency setting.

You will be shown two images:
- Image 1: The original image (full context)
- Image 2: A cropped region taken from Image 1

Question: {question}

Your task: Determine whether the question is VALID according to ALL criteria below. If ANY criterion fails, the question is NOT valid.

CRITERIA (ALL must be satisfied):
1) Crop-answerable:
   - The question MUST be answerable using both two images.
   - If answering requires information outside the crop, mark INVALID.

2) Unique and unambiguous (in the original image):
   - In Image 1, there must be exactly ONE clearly correct answer.
   - If multiple instances/objects in the original image could produce different correct answers, mark INVALID.

3) Consistent with the original image:
   - The answer derived from Image 2 MUST match the answer that would be obtained from Image 1.
   - If Image 1 allows additional valid answers or changes the interpretation (e.g., there are multiple relevant objects in the full image), mark INVALID even if Image 2 looks unambiguous.

4) Clear question:
   - The question must specify the target unambiguously (which object/person/instance).
   - Avoid unclear references like "it", "this", "the object" when multiple candidates exist in Image 1 or Image 2.
   - The question should not rely on unspecified perspective ("left/right" without a clear frame is okay if it's the image frame), or vague quantifiers ("a lot", "some").

OUTPUT FORMAT (strict):
VALID: Yes/No
REASON: A brief explanation referencing which criterion/criteria passed or failed.

EXAMPLE (illustrating a common INVALID case):
- Question: "What is the number on the side of the boat hull?"
- Image 1 (original): There are TWO boats, each with a different hull number (e.g., 12 and 18).
- Image 2 (crop): The crop shows only ONE boat with number 12.
Decision: INVALID
Reason: Although Image 2 yields a single answer (12), the original image allows multiple valid answers (12 or 18), so criterion #3 (consistency with original) fails and the question is ambiguous in the full context.

Now please reason step by step, and then evaluate the given Question using the rules above.
"""
    
    messages = [
        {
            "role": "user", 
            "content": [
                {"type": "text", "text": "Original Image:"},
                {"type": "image_url", "image_url": {"url": base64_original_image}},
                {"type": "text", "text": "Cropped Image:"},
                {"type": "image_url", "image_url": {"url": base64_masked_image}},
                {"type": "text", "text": validation_prompt}
            ]
        }
    ]
    
    full_content = call_gemini_api(messages, stream=False, max_retries=3000)
    
    if not full_content:
        return {"is_valid": False, "reason": "API调用失败"}
    
    # 解析验证结果
    is_valid = False
    reason = "No reason provided."
    
    valid_match = re.search(r'VALID:\s*(Yes|No)', full_content, re.IGNORECASE)
    if valid_match:
        is_valid = valid_match.group(1).upper() == "YES"
    
    reason_match = re.search(r'REASON:\s*(.+?)(?:\n\n|\Z)', full_content, re.IGNORECASE | re.DOTALL)
    if reason_match:
        reason = reason_match.group(1).strip()
    
    return {"is_valid": is_valid, "reason": reason}


def validate_vqa_pairs_with_gemini(original_image, masked_image, vqa_pairs, max_validation_workers=6):
    """
    【验证器】并发验证多个 VQA 问题
    
    Args:
        original_image: 原始图像
        masked_image: 裁剪图像
        vqa_pairs: 问题列表
        max_validation_workers: 验证子任务的最大并发数
    
    Returns:
        验证结果列表
    """
    if not vqa_pairs:
        return []
    
    results = []
    
    # 使用线程池并发验证
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_validation_workers) as executor:
        future_to_idx = {}
        
        for idx, item in enumerate(vqa_pairs):
            question = item.get('question', '')
            if not question:
                results.append({"is_valid": False, "reason": "Empty question"})
                continue
            
            future = executor.submit(
                validate_single_vqa_with_gemini, 
                original_image, 
                masked_image, 
                question
            )
            future_to_idx[future] = idx
        
        # 按顺序收集结果
        temp_results = [None] * len(vqa_pairs)
        for future in concurrent.futures.as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                result = future.result()
                temp_results[idx] = result
            except Exception as e:
                print(f"验证问题 {idx} 时出错: {e}")
                temp_results[idx] = {"is_valid": False, "reason": f"验证异常: {str(e)}"}
        
        results = temp_results
    
    return results


# ==============================================================================
# Qwen API 调用（保持原有逻辑）
# ==============================================================================

def generate_vqa_with_qwen(original_b64_image, crop_b64_image, seed_questions, api_url):
    """使用 Qwen3-VL API 生成 VQA（多key并发版）"""
    prompt = get_vqa_generation_prompt(seed_questions)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "This is the Original Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{original_b64_image}"}},
                {"type": "text", "text": "This is the Cropped Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64_image}"}},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    payload = {
        "stream": False,
        "model": GENERATOR_MODEL_NAME_QWEN,
        "messages": messages,
        "temperature": 0.9,
        "top_p": 0.8
    }

    resp_json = call_qwen_api(payload, api_url=api_url, stream=False, max_retries=3000)
    if not resp_json:
        return None

    try:
        content = resp_json["choices"][0]["message"]["content"]
        return extract_json_from_text(content)
    except Exception:
        return None



def validate_single_vqa_with_qwen(original_b64_image, crop_b64_image, question, api_url):
    validation_prompt = f"""
You are an expert at validating whether a question is appropriate for Visual Question Answering (VQA) under a crop-consistency setting.

You will be shown two images:
- Image 1: The original image (full context)
- Image 2: A cropped region taken from Image 1

Question: {question}

Your task: Determine whether the question is VALID according to ALL criteria below. If ANY criterion fails, the question is NOT valid.

CRITERIA (ALL must be satisfied):
1) Crop-answerable:
   - The question MUST be answerable using both two images.
   - If answering requires information outside the crop, mark INVALID.

2) Unique and unambiguous (in the original image):
   - In Image 1, there must be exactly ONE clearly correct answer.
   - If multiple instances/objects in the original image could produce different correct answers, mark INVALID.

3) Consistent with the original image:
   - The answer derived from Image 2 MUST match the answer that would be obtained from Image 1.
   - If Image 1 allows additional valid answers or changes the interpretation (e.g., there are multiple relevant objects in the full image), mark INVALID even if Image 2 looks unambiguous.

4) Clear question:
   - The question must specify the target unambiguously (which object/person/instance).
   - Avoid unclear references like "it", "this", "the object" when multiple candidates exist in Image 1 or Image 2.
   - The question should not rely on unspecified perspective ("left/right" without a clear frame is okay if it's the image frame), or vague quantifiers ("a lot", "some").

OUTPUT FORMAT (strict):
VALID: Yes/No
REASON: A brief explanation referencing which criterion/criteria passed or failed.

EXAMPLE (illustrating a common INVALID case):
- Question: "What is the number on the side of the boat hull?"
- Image 1 (original): There are TWO boats, each with a different hull number (e.g., 12 and 18).
- Image 2 (crop): The crop shows only ONE boat with number 12.
Decision: INVALID
Reason: Although Image 2 yields a single answer (12), the original image allows multiple valid answers (12 or 18), so criterion #3 (consistency with original) fails and the question is ambiguous in the full context.

Now please reason step by step, and then evaluate the given Question using the rules above.
"""

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "This is the Original Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{original_b64_image}"}},
                {"type": "text", "text": "This is the Cropped Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64_image}"}},
                {"type": "text", "text": validation_prompt}
            ]
        }
    ]

    payload = {
        "stream": False,
        "model": GENERATOR_MODEL_NAME_QWEN,
        "messages": messages,
        "temperature": 0.3,
        "top_p": 0.9
    }

    resp_json = call_qwen_api(payload, api_url=api_url, stream=False, max_retries=3000)
    if not resp_json:
        return {"is_valid": False, "reason": "API调用失败"}

    try:
        full_content = resp_json["choices"][0]["message"]["content"]
    except Exception:
        return {"is_valid": False, "reason": "响应解析失败"}

    is_valid = False
    reason = "No reason provided."

    valid_match = re.search(r'VALID:\s*(Yes|No)', full_content, re.IGNORECASE)
    if valid_match:
        is_valid = valid_match.group(1).upper() == "YES"

    reason_match = re.search(r'REASON:\s*(.+?)(?:\n\n|\Z)', full_content, re.IGNORECASE | re.DOTALL)
    if reason_match:
        reason = reason_match.group(1).strip()

    return {"is_valid": is_valid, "reason": reason}

def validate_vqa_pairs_with_qwen(original_b64_image, crop_b64_image, vqa_pairs, api_url, max_validation_workers=6) -> list:
    if not vqa_pairs:
        return []

    temp_results = [None] * len(vqa_pairs)

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_validation_workers) as executor:
        future_to_idx = {}

        for idx, item in enumerate(vqa_pairs):
            question = item.get("question", "")
            if not question:
                temp_results[idx] = {"is_valid": False, "reason": "Empty question"}
                continue

            future = executor.submit(
                validate_single_vqa_with_qwen,
                original_b64_image,
                crop_b64_image,
                question,
                api_url
            )
            future_to_idx[future] = idx

        for future in concurrent.futures.as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                temp_results[idx] = future.result()
            except Exception as e:
                temp_results[idx] = {"is_valid": False, "reason": f"验证异常: {str(e)}"}

    return temp_results


# ==============================================================================
# 主处理流程
# ==============================================================================

def process_line_task(line_data: dict, args, seed_questions: list) -> dict or None:
    """处理单个任务"""
    original_image_path = line_data.get('oss_image_path')
    if not original_image_path: 
        return None

    # 获取原图
    original_pil_image = get_oss_image(original_image_path)
    if not original_pil_image: 
        return None
    original_b64_image = get_oss_image(original_image_path, image_type='base64')

    mask_image_paths = line_data.get('oss_mask_image_paths', [])
    generated_vqa_results = []
    
    for i, mask_path in enumerate(mask_image_paths):
        masked_pil_image = get_oss_image(mask_path)
        if not masked_pil_image: 
            continue

        is_valid, reason, bbox = is_image_pair_valid(original_pil_image, masked_image=masked_pil_image)
        if not is_valid:
            generated_vqa_results.append({"mask_path": mask_path, "status": reason, "vqa_pairs": []})
            continue

        crop_pil_image = original_pil_image.crop(bbox)
        filename = os.path.basename(mask_path)
        crop_oss_path = image_upload_to_oss(crop_pil_image, crop_oss_prefix, filename)
        crop_b64_image = get_oss_image(crop_oss_path, image_type='base64')

        # 步骤1: 生成 VQA 问题
        if args.generator_type == "qwen":
            candidate_vqa_pairs = generate_vqa_with_qwen(
                original_b64_image, crop_b64_image,
                seed_questions,
                args.generator_base_url
            )
        else:  # gemini
            candidate_vqa_pairs = generate_vqa_with_gemini(
                original_pil_image, crop_pil_image, 
                seed_questions
            )

        if not candidate_vqa_pairs or not isinstance(candidate_vqa_pairs, list):
            generated_vqa_results.append({
                "mask_path": mask_path, 
                "crop_path": crop_oss_path, 
                "status": "skipped_by_generator", 
                "vqa_pairs": []
            })
            continue

        # 步骤2: 验证 VQA 问题（使用并发验证）
        if args.validator_type == "qwen":
            validation_results = validate_vqa_pairs_with_qwen(
                original_b64_image, crop_b64_image,
                candidate_vqa_pairs,
                args.validator_base_url,
                max_validation_workers=args.validation_workers
            )
        else:  # gemini (并发验证)
            validation_results = validate_vqa_pairs_with_gemini(
                original_pil_image, crop_pil_image, 
                candidate_vqa_pairs,
                max_validation_workers=args.validation_workers  # 验证并发数
            )

        # 整合验证结果
        validated_pairs = []
        for j, pair in enumerate(candidate_vqa_pairs):
            if j < len(validation_results):
                result = validation_results[j]
                pair['validation_status'] = 'passed' if result.get("is_valid", False) else 'failed'
                pair['validation_reason'] = result.get("reason", "No reason.")
            else:
                pair['validation_status'] = 'failed'
                pair['validation_reason'] = 'No validation result'
            validated_pairs.append(pair)

        final_status = "success" if any(p['validation_status'] == 'passed' for p in validated_pairs) else "all_failed_validation"
        generated_vqa_results.append({
            "mask_path": mask_path, 
            "crop_path": crop_oss_path, 
            "status": final_status, 
            "vqa_pairs": validated_pairs
        })
    
    return {'oss_image_path': original_image_path, 'generated_vqa': generated_vqa_results}


# ==============================================================================
# 主函数
# ==============================================================================

def main(args):
    # --- 1. 断点恢复：确定已经处理过的 ID ---
    processed_ids = set()
    output_file_path = f"{args.output_file}_gen-{args.generator_type}_judge-{args.validator_type}.jsonl"
    
    if os.path.exists(output_file_path):
        with open(output_file_path, 'r', encoding='utf-8') as f_out_read:
            for line in f_out_read:
                try: 
                    processed_ids.add(json.loads(line)['oss_image_path'])
                except: 
                    continue
        print(f"✓ 从主输出文件恢复，已处理 {len(processed_ids)} 条数据。")
    else:
        print(f"✓ 主输出文件 {output_file_path} 不存在，将创建新文件。")

    # --- 2. 读取输入数据与种子问题 ---
    print(f"✓ 正在读取输入文件: {args.input_file}")
    with open(args.input_file, 'r', encoding='utf-8') as f_in:
        all_lines = f_in.readlines()

    # 特殊逻辑
    if os.path.basename(args.input_file) == "train-0000-of-0013.jsonl":
        all_lines = all_lines[-400000:]
    
    with open(args.seed_json, "r", encoding="utf-8") as f:
        seed_questions = json.load(f)

    # 过滤任务
    tasks = []
    for line in all_lines:
        try:
            data = json.loads(line)
            if data.get('oss_image_path') not in processed_ids:
                tasks.append(data)
        except: 
            continue
    
    if not tasks:
        print("✓ 没有需要处理的新任务。")
        return

    print(f"✓ 待处理任务数: {len(tasks)} (已跳过 {len(processed_ids)} 条)")
    
    # --- 3. 开始并发处理 ---
    processed_count_this_session = 0
    success_count = 0
    error_count = 0
    
    # 创建一个锁用于文件写入
    write_lock = threading.Lock()
    

    
    stats_interval = 50  # 每处理50个任务打印一次统计
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, \
         open(output_file_path, 'a', encoding='utf-8') as f_out:

        # 提交任务
        future_to_task = {
            executor.submit(process_line_task, task_data, args, seed_questions): task_data.get('oss_image_path') 
            for task_data in tasks
        }
        
        progress_bar = tqdm(
            concurrent.futures.as_completed(future_to_task), 
            total=len(tasks), 
            desc="Processing",
            ncols=100
        )
        
        for future in progress_bar:
            task_id = future_to_task.get(future, "Unknown")
            
            try:
                result = future.result(timeout=30000)  # 5分钟超时
                
                if result:
                    # 线程安全的文件写入
                    with write_lock:
                        f_out.write(json.dumps(result, ensure_ascii=False) + '\n')
                        f_out.flush()
                    
                    success_count += 1
                    processed_count_this_session += 1
                else:
                    error_count += 1
                    processed_count_this_session += 1
                
                
                
            except concurrent.futures.TimeoutError:
                error_count += 1
                print(f"\n⚠ 任务 {task_id} 超时")
                
            except Exception as exc:
                error_count += 1
                print(f"\n⚠ 任务 {task_id} 发生异常: {exc}")
    
    # 最终统计
    print("\n" + "="*60)
    print("处理完成！最终统计:")
    print(f"  总处理: {processed_count_this_session}")
    print(f"  成功: {success_count}")
    print(f"  失败: {error_count}")



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="VQA Batch Generation & Validation (优化版)")
    
    # 输入输出
    parser.add_argument('--input_file', type=str, 
                        default="train-0000-of-0013.jsonl")
    parser.add_argument('--output_file', type=str, default="sa1b_0000_finegained_hard")
    parser.add_argument("--seed_json", type=str, 
                        default="/mnt/nas/yanlong/code/new_benchmark/utils_new_pipeline/seed.json")
    parser.add_argument('--generator_type', type=str, choices=['qwen', 'gemini'], default='gemini',
                        help='生成器模型类型')
    parser.add_argument('--validator_type', type=str, choices=['qwen', 'gemini'], default='gemini',
                        help='验证器模型类型')
    # 并发控制
    parser.add_argument('--max_workers', type=int, default=15, 
                        help='主任务并发数（建议：Token数 × 每Token并发数）')
    parser.add_argument('--validation_workers', type=int, default=6,
                        help='每个任务内验证子任务的并发数')
    
    parser.add_argument('--generator_api_key', type=str, default=GENERATOR_API_KEY)
    parser.add_argument('--generator_base_url', type=str, default=GENERATOR_API_BASE_URL)
    parser.add_argument('--validator_api_key', type=str, default=GENERATOR_API_KEY)
    parser.add_argument('--validator_base_url', type=str, default=GENERATOR_API_BASE_URL)
    
    args = parser.parse_args()
    
    print(f"""
╔══════════════════════════════════════════════════════════╗
║           VQA 生成与验证系统 (优化版)                    ║
╟──────────────────────────────────────────────────────────╢
║  主任务并发数: {args.max_workers:2d}                                      ║
║  验证子任务并发数: {args.validation_workers:2d}                                 ║
║  生成器: {args.generator_type:6s}  |  验证器: {args.validator_type:6s}           ║
║  Token池: {len(GEMINI_API_TOKENS)} 个 Token × 5 并发/Token = 15 总并发  ║
╚══════════════════════════════════════════════════════════╝
""")
    
    main(args)

