import requests
import json
import base64
import os
import time
import argparse
from tqdm import tqdm
from PIL import Image, ImageChops
import io
from openai import OpenAI
import concurrent.futures
import random
import threading
import shutil
import queue
from collections import defaultdict
import re

# 假设 oss_util.py 在同一目录下或在Python路径中
from oss_util import get_oss_image, image_upload_to_oss

class ImprovedTokenPool:
    """改进的Token池管理器，支持非阻塞获取和动态负载均衡"""
    def __init__(self, tokens: list, max_concurrent_per_token: int = 5):
        """
        Args:
            tokens: Token列表
            max_concurrent_per_token: 每个token的最大并发数
        """
        self.tokens = tokens
        self.max_concurrent_per_token = max_concurrent_per_token
        
        # 为每个token维护一个信号量和使用计数
        self.token_semaphores = {token: threading.Semaphore(max_concurrent_per_token) for token in tokens}
        self.token_active_count = {token: 0 for token in tokens}  # 当前活跃请求数
        
        # 统计信息
        self.token_usage_count = defaultdict(int)
        self.token_error_count = defaultdict(int)
        self.token_last_error_time = defaultdict(float)  # 记录最后一次错误时间
        self.lock = threading.Lock()
    
    def acquire_token(self, timeout=30):
        """
        获取一个可用的token（非阻塞，带超时）
        优先选择负载最低且近期错误少的token
        """
        start_time = time.time()
        
        while time.time() - start_time < timeout:
            with self.lock:
                # 计算每个token的得分（越低越好）
                token_scores = []
                current_time = time.time()
                
                for token in self.tokens:
                    # 如果该token的信号量可用
                    if self.token_semaphores[token]._value > 0:
                        # 计算得分：活跃请求数 + 错误惩罚
                        error_penalty = 0
                        if token in self.token_last_error_time:
                            time_since_error = current_time - self.token_last_error_time[token]
                            if time_since_error < 60:  # 1分钟内有错误，增加惩罚
                                error_penalty = 10 * (1 - time_since_error / 60)
                        
                        score = self.token_active_count[token] + error_penalty
                        token_scores.append((score, token))
                
                if token_scores:
                    # 选择得分最低的token
                    token_scores.sort()
                    selected_token = token_scores[0][1]
                    
                    # 尝试获取信号量（非阻塞）
                    if self.token_semaphores[selected_token].acquire(blocking=False):
                        self.token_active_count[selected_token] += 1
                        self.token_usage_count[selected_token] += 1
                        return selected_token
            
            # 如果没有可用token，短暂休眠后重试
            time.sleep(0.1)
        
        raise TimeoutError("无法在规定时间内获取可用token")
    
    def release_token(self, token, has_error=False):
        """释放token"""
        with self.lock:
            self.token_active_count[token] = max(0, self.token_active_count[token] - 1)
            if has_error:
                self.token_error_count[token] += 1
                self.token_last_error_time[token] = time.time()
        
        self.token_semaphores[token].release()
    
    def get_stats(self):
        """获取使用统计"""
        with self.lock:
            return {
                "usage": dict(self.token_usage_count),
                "errors": dict(self.token_error_count),
                "active": dict(self.token_active_count)
            }


# ==============================================================================
# 配置
# ==============================================================================

# Gemini API配置
GEMINI_API_BASE_URL = 'xxx'
GEMINI_API_TOKENS = []
GEMINI_MODEL_NAME = "gemini-2.5-pro"

# Qwen API配置
QWEN_API_BASE_URL = "xxxx"
QWEN_API_KEYS = []
QWEN_MODEL_NAME = "Qwen3-VL-235B-A22B-Instruct"

# Kimi LLM Client
llm_client = OpenAI(
    api_key="xxx",
    base_url="",
)

# Qwen3-VL-8B Client (新增)
mllm_client = OpenAI(
    api_key="EMPTY",
    base_url="http://127.0.0.1:18901/v1",
    timeout=3600
)
QWEN_MAX_CONCURRENT_PER_KEY = 10

qwen_token_pool = ImprovedTokenPool(QWEN_API_KEYS, max_concurrent_per_token=QWEN_MAX_CONCURRENT_PER_KEY)
QWEN_MODEL_NAME = "Qwen3-VL-235B-A22B-Instruct"

# 创建改进的Token池 (每个token允许5个并发)
token_pool = ImprovedTokenPool(GEMINI_API_TOKENS, max_concurrent_per_token=5)

crop_oss_prefix = "Data/safety_post_train/other_crop_images/"
MIN_RESOLUTION = 1536 * 1536
MAX_MASK_RATIO = 0.10


# ==============================================================================
# 新增：解析 bbox 字符串的辅助函数
# ==============================================================================

def parse_bbox_string(bbox_str):
    """
    解析 bbox 字符串，支持多种格式：
    - "(100,200,300,400)"
    - "100,200,300,400"
    - "[100, 200, 300, 400]"
    
    Returns:
        tuple: (x_min, y_min, x_max, y_max) 或 None（解析失败时）
    """
    if not bbox_str:
        return None
    
    # 移除空格和各种括号
    cleaned = str(bbox_str).strip().replace('(', '').replace(')', '').replace('[', '').replace(']', '')
    
    try:
        coords = [int(float(x.strip())) for x in cleaned.split(',')]
        if len(coords) == 4:
            return tuple(coords)
    except:
        pass
    
    return None


# ==============================================================================
# 辅助函数
# ==============================================================================

def extract_json_from_text(text):
    """从模型返回的字符串中提取 JSON 内容"""
    if not text:
        return []
    
    json_block_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
    if json_block_match:
        content = json_block_match.group(1).strip()
    else:
        start_idx = text.find('[')
        end_idx = text.rfind(']')
        if start_idx != -1 and end_idx != -1:
            content = text[start_idx:end_idx + 1]
        else:
            content = text.strip()

    try:
        data = json.loads(content)
        return data if isinstance(data, list) else []
    except Exception as e:
        print(f"JSON 解析失败: {e}")
        return []


def is_image_pair_valid(original_image: Image.Image, masked_image: Image.Image = None, bbox: tuple = None) -> (bool, str, tuple):
    """
    验证图片对是否有效
    
    Args:
        original_image: 原始图片
        masked_image: 掩码图片（可选，与bbox二选一）
        bbox: 边界框坐标 (x_min, y_min, x_max, y_max)（可选，与masked_image二选一）
    
    Returns:
        (is_valid, reason, bbox_tuple)
    """
    w, h = original_image.size
    if w * h <= MIN_RESOLUTION: 
        return False, f"skipped_filter_resolution (is {w*h}, need > {MIN_RESOLUTION})", None
    
    try:
        # 如果提供了bbox，直接使用
        if bbox is not None:
            if len(bbox) != 4:
                return False, "skipped_filter_mask_ratio (invalid bbox format)", None
            
            x_min, y_min, x_max, y_max = bbox
            
            # 验证坐标合法性
            if x_min >= x_max or y_min >= y_max:
                return False, "skipped_filter_mask_ratio (invalid bbox coordinates)", None
            
            if x_min < 0 or y_min < 0 or x_max > w or y_max > h:
                return False, "skipped_filter_mask_ratio (bbox out of image bounds)", None
            
            object_area = (x_max - x_min) * (y_max - y_min)
            area_ratio = object_area / (w * h)
            
        # 否则使用masked_image计算bbox
        elif masked_image is not None:
            mask_w, mask_h = masked_image.size
            rgb_masked_image = masked_image.convert("RGB")
            inverted_image = ImageChops.invert(rgb_masked_image)
            bbox = inverted_image.getbbox()
            
            if bbox is None: 
                return False, "skipped_filter_mask_ratio (no object)", None
            
            object_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
            area_ratio = object_area / (mask_w * mask_h)
        else:
            return False, "skipped_filter_mask_ratio (no bbox or masked_image provided)", None
        
        # 检查区域比例
        if area_ratio >= 0.3: 
            return False, f"skipped_filter_mask_ratio (is {area_ratio:.2%})", bbox
        
        if area_ratio >= MAX_MASK_RATIO and random.random() > 0.5: 
            return False, f"skipped_filter_mask_ratio (is {area_ratio:.2%})", bbox
        
        return True, "valid", bbox
        
    except Exception as e: 
        return False, f"skipped_filter_mask_ratio (error: {e})", None


def encode_pil_image_to_base64(pil_image: Image.Image, image_format: str = "PNG") -> str:
    buffered = io.BytesIO()
    pil_image.save(buffered, format=image_format)
    return f"data:image/{image_format.lower()};base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"


def get_vqa_generation_prompt(seed_questions):
    sample_size = min(15, len(seed_questions))
    selected_examples = random.sample(seed_questions, sample_size)
    examples_str = "\n".join([f"- {q}" for q in selected_examples])

    return f"""You are an expert specialist in generating Visual Question Answering (VQA) datasets. Your task is to generate three high-quality, valid and challenging questions based on a provided **'Original Image'** and its corresponding **'Cropped Image'**.

**Reference Examples (use these for inspiration on questioning angles):**
{examples_str}

**Core Generation Rules:**
1. **Consistency of Answers:** The answer to each question must be identical, accurate, and concise for both the 'Original Image' and the 'Cropped Image'. It can be a **short, factual, and concrete string** (e.g., a number, a noun, or text).
2. **Content Relevance:** Questions must focus exclusively on the objects or content visible within the 'Cropped Image'. Question types include, but are not limited to:
    * **Object Identification:** Identify the exact sub-component or item. (e.g., 'What is the person holding in their hand?' Answer: 'Apple').
    * **OCR:** Recognizing text within the image.
3. **Strict No-Context Rule:** Do not ask questions that require background information or context found only in the 'Original Image' to answer.
4. **Diversity of Questions:** Provided that the consistency rule is met, aim for a diverse range of questions. This includes counting, spatial relationships, scene recognition, anomaly detection, material, structure, etc. Do ask questions related to colors.

**Please carefully observe the images, and generate the high-quality, valid and challenging questions in the following JSON format:**

```json
[
  {{"question": "Question 1"}},
  {{"question": "Question 2"}},
  {{"question": "Question 3"}}
]
```"""


# ==============================================================================
# Gemini API 调用（优化版）
# ==============================================================================

def call_gemini_api(messages, stream=True, max_retries=100, base_retry_delay=1):
    """
    通用的Gemini API调用函数，支持自动重试和Token池
    
    Args:
        messages: 消息列表
        stream: 是否使用流式响应
        max_retries: 最大重试次数
        base_retry_delay: 基础重试延迟
    
    Returns:
        响应内容字符串，失败返回None
    """
    url = GEMINI_API_BASE_URL
    data = {
        "stream": stream,
        "model": GEMINI_MODEL_NAME,
        "messages": messages
    }
    
    for retry_count in range(max_retries):
        current_token = None
        try:
            # 获取token（带超时）
            current_token = token_pool.acquire_token(timeout=30)
            headers = {
                "Content-Type": "application/json", 
                "Authorization": current_token
            }
            
            response = requests.post(
                url, 
                data=json.dumps(data), 
                headers=headers, 
                timeout=120, 
                stream=stream
            )
            
            # 处理限流
            if response.status_code == 429:
                token_pool.release_token(current_token, has_error=True)
                current_token = None
                
                retry_delay = base_retry_delay
                print(f"⚠ 限流，等待 {retry_delay}s 后重试 ({retry_count + 1}/{max_retries})")
                time.sleep(retry_delay)
                
                continue
            
            if response.status_code != 200:
                token_pool.release_token(current_token, has_error=True)
                current_token = None
                print(f"API错误 {response.status_code}")
                time.sleep(base_retry_delay)
                continue
            
            # 处理响应
            if stream:
                full_content = ""
                for line in response.iter_lines():
                    if line:
                        decoded_line = line.decode('utf-8')
                        if decoded_line.startswith("data:"):
                            json_str = decoded_line[5:].strip()
                            if json_str == "[DONE]":
                                break
                            try:
                                chunk = json.loads(json_str)
                                content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
                                full_content += content
                            except json.JSONDecodeError:
                                continue
            else:
                response_text = response.text
                if response_text.startswith("data:"):
                    json_str = response_text[5:].strip()
                else:
                    json_str = response_text
                response_data = json.loads(json_str)
                full_content = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
            
            # 成功释放token
            if current_token:
                token_pool.release_token(current_token, has_error=False)
            assert full_content, "响应内容为空"
            return full_content if full_content.strip() else None
            
        except TimeoutError as e:
            print(f"⚠ 获取Token超时: {e}")
            if current_token:
                token_pool.release_token(current_token, has_error=True)
            time.sleep(base_retry_delay)
            
        except Exception as e:
            if current_token:
                token_pool.release_token(current_token, has_error=True)
            print(f"⚠ API调用异常 ({retry_count + 1}/{max_retries}): {e}")
            time.sleep(base_retry_delay)
    
    return None

def call_qwen_api(payload: dict, api_url: str, stream: bool = False,
                  max_retries: int = 100, base_retry_delay: float = 1.0):
    """
    Qwen API 调用：多key并发 + 自动重试
    payload: 直接是 requests.post(json=payload) 的 body
    """
    for retry_count in range(max_retries):
        current_key = None
        try:
            current_key = qwen_token_pool.acquire_token(timeout=30)

            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {current_key}",
            }

            resp = requests.post(
                api_url,
                headers=headers,
                json=payload,
                timeout=120,
                stream=stream
            )

            # 限流/过载：标记该key近期错误，稍等重试
            if resp.status_code in (429, 503):
                qwen_token_pool.release_token(current_key, has_error=True)
                current_key = None
                time.sleep(base_retry_delay)
                continue

            # 其他非200
            if resp.status_code != 200:
                qwen_token_pool.release_token(current_key, has_error=True)
                current_key = None
                time.sleep(base_retry_delay)
                continue

            data = resp.json()

            # 成功释放key
            qwen_token_pool.release_token(current_key, has_error=False)
            current_key = None
            return data

        except TimeoutError as e:
            if current_key:
                qwen_token_pool.release_token(current_key, has_error=True)
            time.sleep(base_retry_delay)

        except Exception as e:
            if current_key:
                qwen_token_pool.release_token(current_key, has_error=True)
            time.sleep(base_retry_delay)

    return None



def generate_vqa_with_gemini(original_image, masked_image, seed_questions):
    """【生成器】使用 Gemini 生成 VQA 问题对"""
    generator_prompt = get_vqa_generation_prompt(seed_questions)
    base64_original_image = encode_pil_image_to_base64(original_image)
    base64_masked_image = encode_pil_image_to_base64(masked_image)
    
    messages = [
        {
            "role": "user", 
            "content": [
                {"type": "text", "text": "This is the Original Image:"},
                {"type": "image_url", "image_url": {"url": base64_original_image}},
                {"type": "text", "text": "This is the Cropped Image:"},
                {"type": "image_url", "image_url": {"url": base64_masked_image}},
                {"type": "text", "text": generator_prompt}
            ]
        }
    ]
    
    full_content = call_gemini_api(messages, stream=True, max_retries=100)
    
    if not full_content:
        return None
    
    vqa_pairs = extract_json_from_text(full_content)
    return vqa_pairs if vqa_pairs else None


def validate_single_vqa_with_gemini(original_image, masked_image, question):
    """验证单个VQA问题（用于并发调用）"""
    base64_original_image = encode_pil_image_to_base64(original_image)
    base64_masked_image = encode_pil_image_to_base64(masked_image)
    
    validation_prompt = f"""
You are an expert at validating whether a question is appropriate for Visual Question Answering (VQA) under a crop-consistency setting.

You will be shown two images:
- Image 1: The original image (full context)
- Image 2: A cropped region taken from Image 1

Question: {question}

Your task: Determine whether the question is VALID according to ALL criteria below. If ANY criterion fails, the question is NOT valid.

CRITERIA (ALL must be satisfied):
1) Crop-answerable:
   - The question MUST be answerable using both two images.
   - If answering requires information outside the crop, mark INVALID.

2) Unique and unambiguous (in the original image):
   - In Image 1, there must be exactly ONE clearly correct answer.
   - If multiple instances/objects in the original image could produce different correct answers, mark INVALID.

3) Consistent with the original image:
   - The answer derived from Image 2 MUST match the answer that would be obtained from Image 1.
   - If Image 1 allows additional valid answers or changes the interpretation (e.g., there are multiple relevant objects in the full image), mark INVALID even if Image 2 looks unambiguous.

4) Clear question:
   - The question must specify the target unambiguously (which object/person/instance).
   - Avoid unclear references like "it", "this", "the object" when multiple candidates exist in Image 1 or Image 2.
   - The question should not rely on unspecified perspective ("left/right" without a clear frame is okay if it's the image frame), or vague quantifiers ("a lot", "some").

OUTPUT FORMAT (strict):
VALID: Yes/No
REASON: A brief explanation referencing which criterion/criteria passed or failed.

EXAMPLE (illustrating a common INVALID case):
- Question: "What is the number on the side of the boat hull?"
- Image 1 (original): There are TWO boats, each with a different hull number (e.g., 12 and 18).
- Image 2 (crop): The crop shows only ONE boat with number 12.
Decision: INVALID
Reason: Although Image 2 yields a single answer (12), the original image allows multiple valid answers (12 or 18), so criterion #3 (consistency with original) fails and the question is ambiguous in the full context.

Now please reason step by step, and then evaluate the given Question using the rules above.
"""
    
    messages = [
        {
            "role": "user", 
            "content": [
                {"type": "text", "text": "Original Image:"},
                {"type": "image_url", "image_url": {"url": base64_original_image}},
                {"type": "text", "text": "Cropped Image:"},
                {"type": "image_url", "image_url": {"url": base64_masked_image}},
                {"type": "text", "text": validation_prompt}
            ]
        }
    ]
    
    full_content = call_gemini_api(messages, stream=False, max_retries=100)
    
    if not full_content:
        return {"is_valid": False, "reason": "API调用失败"}
    
    # 解析验证结果
    is_valid = False
    reason = "No reason provided."
    
    valid_match = re.search(r'VALID:\s*(Yes|No)', full_content, re.IGNORECASE)
    if valid_match:
        is_valid = valid_match.group(1).upper() == "YES"
    
    reason_match = re.search(r'REASON:\s*(.+?)(?:\n\n|\Z)', full_content, re.IGNORECASE | re.DOTALL)
    if reason_match:
        reason = reason_match.group(1).strip()
    
    return {"is_valid": is_valid, "reason": reason}


def validate_vqa_pairs_with_gemini(original_image, masked_image, vqa_pairs, max_validation_workers=6):
    """
    【验证器】并发验证多个 VQA 问题
    
    Args:
        original_image: 原始图像
        masked_image: 裁剪图像
        vqa_pairs: 问题列表
        max_validation_workers: 验证子任务的最大并发数
    
    Returns:
        验证结果列表
    """
    if not vqa_pairs:
        return []
    
    results = []
    
    # 使用线程池并发验证
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_validation_workers) as executor:
        future_to_idx = {}
        
        for idx, item in enumerate(vqa_pairs):
            question = item.get('question', '')
            if not question:
                results.append({"is_valid": False, "reason": "Empty question"})
                continue
            
            future = executor.submit(
                validate_single_vqa_with_gemini, 
                original_image, 
                masked_image, 
                question
            )
            future_to_idx[future] = idx
        
        # 按顺序收集结果
        temp_results = [None] * len(vqa_pairs)
        for future in concurrent.futures.as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                result = future.result()
                temp_results[idx] = result
            except Exception as e:
                print(f"验证问题 {idx} 时出错: {e}")
                temp_results[idx] = {"is_valid": False, "reason": f"验证异常: {str(e)}"}
        
        results = temp_results
    
    return results


# ==============================================================================
# Qwen API 调用（保持原有逻辑）
# ==============================================================================

def generate_vqa_with_qwen(original_b64_image, crop_b64_image, seed_questions, api_url):
    """使用 Qwen3-VL API 生成 VQA（多key并发版）"""
    prompt = get_vqa_generation_prompt(seed_questions)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "This is the Original Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{original_b64_image}"}},
                {"type": "text", "text": "This is the Cropped Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64_image}"}},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    payload = {
        "stream": False,
        "model": GENERATOR_MODEL_NAME_QWEN,
        "messages": messages,
        "temperature": 0.9,
        "top_p": 0.8
    }

    resp_json = call_qwen_api(payload, api_url=api_url, stream=False, max_retries=100)
    if not resp_json:
        return None

    try:
        content = resp_json["choices"][0]["message"]["content"]
        return extract_json_from_text(content)
    except Exception:
        return None



def validate_single_vqa_with_qwen(original_b64_image, crop_b64_image, question, api_url):
    validation_prompt = f"""
You are an expert at validating whether a question is appropriate for Visual Question Answering (VQA) under a crop-consistency setting.

You will be shown two images:
- Image 1: The original image (full context)
- Image 2: A cropped region taken from Image 1

Question: {question}

Your task: Determine whether the question is VALID according to ALL criteria below. If ANY criterion fails, the question is NOT valid.

CRITERIA (ALL must be satisfied):
1) Crop-answerable:
   - The question MUST be answerable using both two images.
   - If answering requires information outside the crop, mark INVALID.

2) Unique and unambiguous (in the original image):
   - In Image 1, there must be exactly ONE clearly correct answer.
   - If multiple instances/objects in the original image could produce different correct answers, mark INVALID.

3) Consistent with the original image:
   - The answer derived from Image 2 MUST match the answer that would be obtained from Image 1.
   - If Image 1 allows additional valid answers or changes the interpretation (e.g., there are multiple relevant objects in the full image), mark INVALID even if Image 2 looks unambiguous.

4) Clear question:
   - The question must specify the target unambiguously (which object/person/instance).
   - Avoid unclear references like "it", "this", "the object" when multiple candidates exist in Image 1 or Image 2.
   - The question should not rely on unspecified perspective ("left/right" without a clear frame is okay if it's the image frame), or vague quantifiers ("a lot", "some").

OUTPUT FORMAT (strict):
VALID: Yes/No
REASON: A brief explanation referencing which criterion/criteria passed or failed.

EXAMPLE (illustrating a common INVALID case):
- Question: "What is the number on the side of the boat hull?"
- Image 1 (original): There are TWO boats, each with a different hull number (e.g., 12 and 18).
- Image 2 (crop): The crop shows only ONE boat with number 12.
Decision: INVALID
Reason: Although Image 2 yields a single answer (12), the original image allows multiple valid answers (12 or 18), so criterion #3 (consistency with original) fails and the question is ambiguous in the full context.

Now please reason step by step, and then evaluate the given Question using the rules above.
"""

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "This is the Original Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{original_b64_image}"}},
                {"type": "text", "text": "This is the Cropped Image:"},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64_image}"}},
                {"type": "text", "text": validation_prompt}
            ]
        }
    ]

    payload = {
        "stream": False,
        "model": GENERATOR_MODEL_NAME_QWEN,
        "messages": messages,
        "temperature": 0.3,
        "top_p": 0.9
    }

    resp_json = call_qwen_api(payload, api_url=api_url, stream=False, max_retries=100)
    if not resp_json:
        return {"is_valid": False, "reason": "API调用失败"}

    try:
        full_content = resp_json["choices"][0]["message"]["content"]
    except Exception:
        return {"is_valid": False, "reason": "响应解析失败"}

    is_valid = False
    reason = "No reason provided."

    valid_match = re.search(r'VALID:\s*(Yes|No)', full_content, re.IGNORECASE)
    if valid_match:
        is_valid = valid_match.group(1).upper() == "YES"

    reason_match = re.search(r'REASON:\s*(.+?)(?:\n\n|\Z)', full_content, re.IGNORECASE | re.DOTALL)
    if reason_match:
        reason = reason_match.group(1).strip()

    return {"is_valid": is_valid, "reason": reason}

def validate_vqa_pairs_with_qwen(original_b64_image, crop_b64_image, vqa_pairs, api_url, max_validation_workers=6) -> list:
    if not vqa_pairs:
        return []

    temp_results = [None] * len(vqa_pairs)

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_validation_workers) as executor:
        future_to_idx = {}

        for idx, item in enumerate(vqa_pairs):
            question = item.get("question", "")
            if not question:
                temp_results[idx] = {"is_valid": False, "reason": "Empty question"}
                continue

            future = executor.submit(
                validate_single_vqa_with_qwen,
                original_b64_image,
                crop_b64_image,
                question,
                api_url
            )
            future_to_idx[future] = idx

        for future in concurrent.futures.as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                temp_results[idx] = future.result()
            except Exception as e:
                temp_results[idx] = {"is_valid": False, "reason": f"验证异常: {str(e)}"}

    return temp_results


# ==============================================================================
# 主处理流程（新增：支持 bbox JSON 输入）
# ==============================================================================

def process_bbox_task(bbox_item: dict, args, seed_questions: list) -> dict or None:
    """
    处理来自 bbox JSON 文件的单个任务
    
    Args:
        bbox_item: 包含 image_path, bbox 等信息的字典
        args: 命令行参数
        seed_questions: 种子问题列表
    
    Returns:
        处理结果字典或 None
    """
    image_path = bbox_item.get('image_path')
    bbox_str = bbox_item.get('bbox')
    
    if not image_path or not bbox_str:
        return None
    
    # 解析 bbox 坐标
    bbox = parse_bbox_string(bbox_str)
    if bbox is None:
        return {
            'image_path': image_path,
            'bbox': bbox_str,
            'status': 'skipped_invalid_bbox',
            'generated_vqa': []
        }
    
    # 读取原始图片
    try:
        if image_path.startswith('oss://') or image_path.startswith('http'):
            # OSS 路径
            original_pil_image = get_oss_image(image_path)
            original_b64_image = get_oss_image(image_path, image_type='base64')
        else:
            # 本地文件路径
            original_pil_image = Image.open(image_path).convert('RGB')
            buffered = io.BytesIO()
            original_pil_image.save(buffered, format='JPEG')
            original_b64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
    except Exception as e:
        return {
            'image_path': image_path,
            'bbox': bbox_str,
            'status': f'skipped_image_load_error: {str(e)}',
            'generated_vqa': []
        }
    
    if not original_pil_image:
        return {
            'image_path': image_path,
            'bbox': bbox_str,
            'status': 'skipped_image_load_failed',
            'generated_vqa': []
        }
    
    # 验证图片和 bbox 的有效性
    is_valid, reason, validated_bbox = is_image_pair_valid(
        original_pil_image, 
        bbox=bbox
    )
    
    if not is_valid:
        return {
            'image_path': image_path,
            'bbox': bbox_str,
            'status': reason,
            'generated_vqa': []
        }
    
    # 裁剪图片
    try:
        crop_pil_image = original_pil_image.crop(validated_bbox)
        
        # 上传裁剪图片到 OSS（如果需要）
        if args.upload_crop_to_oss:
            filename = f"{os.path.basename(image_path).split('.')[0]}_crop_{validated_bbox[0]}_{validated_bbox[1]}.jpg"
            crop_oss_path = image_upload_to_oss(crop_pil_image, crop_oss_prefix, filename)
        else:
            crop_oss_path = None
        
        # 转换为 base64
        buffered = io.BytesIO()
        crop_pil_image.save(buffered, format='JPEG')
        crop_b64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
        
    except Exception as e:
        return {
            'image_path': image_path,
            'bbox': bbox_str,
            'status': f'skipped_crop_error: {str(e)}',
            'generated_vqa': []
        }
    
    # 步骤1: 生成 VQA 问题
    if args.generator_type == "qwen":
        candidate_vqa_pairs = generate_vqa_with_qwen(
            original_b64_image, crop_b64_image,
            seed_questions,
            args.generator_base_url
        )
    else:  # gemini
        candidate_vqa_pairs = generate_vqa_with_gemini(
            original_pil_image, crop_pil_image, 
            seed_questions
        )
    
    if not candidate_vqa_pairs or not isinstance(candidate_vqa_pairs, list):
        return {
            'image_path': image_path,
            'bbox': bbox_str,
            'crop_path': crop_oss_path,
            'status': 'skipped_by_generator',
            'generated_vqa': []
        }
    
    # 步骤2: 验证 VQA 问题
    if args.validator_type == "qwen":
        validation_results = validate_vqa_pairs_with_qwen(
            original_b64_image, crop_b64_image,
            candidate_vqa_pairs,
            args.validator_base_url,
            max_validation_workers=args.validation_workers
        )
    else:  # gemini
        validation_results = validate_vqa_pairs_with_gemini(
            original_pil_image, crop_pil_image, 
            candidate_vqa_pairs,
            max_validation_workers=args.validation_workers
        )
    
    # 整合验证结果
    validated_pairs = []
    for j, pair in enumerate(candidate_vqa_pairs):
        if j < len(validation_results):
            result = validation_results[j]
            pair['validation_status'] = 'passed' if result.get("is_valid", False) else 'failed'
            pair['validation_reason'] = result.get("reason", "No reason.")
        else:
            pair['validation_status'] = 'failed'
            pair['validation_reason'] = 'No validation result'
        validated_pairs.append(pair)
    
    final_status = "success" if any(p['validation_status'] == 'passed' for p in validated_pairs) else "all_failed_validation"
    
    return {
        'image_path': image_path,
        'bbox': bbox_str,
        'crop_path': crop_oss_path,
        'status': final_status,
        'generated_vqa': validated_pairs
    }


def process_line_task(line_data: dict, args, seed_questions: list) -> dict or None:
    """处理单个任务（原有逻辑，保持兼容）"""
    original_image_path = line_data.get('oss_image_path')
    if not original_image_path: 
        return None

    # 获取原图
    original_pil_image = get_oss_image(original_image_path)
    if not original_pil_image: 
        return None
    original_b64_image = get_oss_image(original_image_path, image_type='base64')

    mask_image_paths = line_data.get('oss_mask_image_paths', [])
    generated_vqa_results = []
    
    for i, mask_path in enumerate(mask_image_paths):
        masked_pil_image = get_oss_image(mask_path)
        if not masked_pil_image: 
            continue

        is_valid, reason, bbox = is_image_pair_valid(original_pil_image, masked_image=masked_pil_image)
        if not is_valid:
            generated_vqa_results.append({"mask_path": mask_path, "status": reason, "vqa_pairs": []})
            continue

        crop_pil_image = original_pil_image.crop(bbox)
        filename = os.path.basename(mask_path)
        
        if args.upload_crop_to_oss:
            crop_oss_path = image_upload_to_oss(crop_pil_image, crop_oss_prefix, filename)
        else:
            crop_oss_path = None
            
        crop_b64_image = get_oss_image(crop_oss_path, image_type='base64') if crop_oss_path else None
        
        if not crop_b64_image:
            buffered = io.BytesIO()
            crop_pil_image.save(buffered, format='JPEG')
            crop_b64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')

        # 步骤1: 生成 VQA 问题
        if args.generator_type == "qwen":
            candidate_vqa_pairs = generate_vqa_with_qwen(
                original_b64_image, crop_b64_image,
                seed_questions,
                args.generator_base_url
            )
        else:  # gemini
            candidate_vqa_pairs = generate_vqa_with_gemini(
                original_pil_image, crop_pil_image, 
                seed_questions
            )

        if not candidate_vqa_pairs or not isinstance(candidate_vqa_pairs, list):
            generated_vqa_results.append({
                "mask_path": mask_path, 
                "crop_path": crop_oss_path, 
                "status": "skipped_by_generator", 
                "vqa_pairs": []
            })
            continue

        # 步骤2: 验证 VQA 问题（使用并发验证）
        if args.validator_type == "qwen":
            validation_results = validate_vqa_pairs_with_qwen(
                original_b64_image, crop_b64_image,
                candidate_vqa_pairs,
                args.validator_base_url,
                max_validation_workers=args.validation_workers
            )
        else:  # gemini (并发验证)
            validation_results = validate_vqa_pairs_with_gemini(
                original_pil_image, crop_pil_image, 
                candidate_vqa_pairs,
                max_validation_workers=args.validation_workers
            )

        # 整合验证结果
        validated_pairs = []
        for j, pair in enumerate(candidate_vqa_pairs):
            if j < len(validation_results):
                result = validation_results[j]
                pair['validation_status'] = 'passed' if result.get("is_valid", False) else 'failed'
                pair['validation_reason'] = result.get("reason", "No reason.")
            else:
                pair['validation_status'] = 'failed'
                pair['validation_reason'] = 'No validation result'
            validated_pairs.append(pair)

        final_status = "success" if any(p['validation_status'] == 'passed' for p in validated_pairs) else "all_failed_validation"
        generated_vqa_results.append({
            "mask_path": mask_path, 
            "crop_path": crop_oss_path, 
            "status": final_status, 
            "vqa_pairs": validated_pairs
        })
    
    return {'oss_image_path': original_image_path, 'generated_vqa': generated_vqa_results}


# ==============================================================================
# 主函数（新增：自动检测输入文件格式）
# ==============================================================================

def detect_input_format(input_file):
    """
    检测输入文件格式
    
    Returns:
        'bbox' - 来自 bbox 生成器的 JSON 文件（标准JSON数组）
        'original' - 原有的 JSONL 格式
        'unknown' - 无法识别
    """
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            content = f.read().strip()
        
        # 尝试作为完整 JSON 数组解析
        if content.startswith('['):
            try:
                data_list = json.loads(content)
                if isinstance(data_list, list) and len(data_list) > 0:
                    first_item = data_list[0]
                    # 检查是否为 bbox 格式
                    if 'image_path' in first_item and 'bbox' in first_item:
                        return 'bbox'
            except json.JSONDecodeError:
                pass
        
        # 尝试作为 JSONL 格式解析
        with open(input_file, 'r', encoding='utf-8') as f:
            first_line = f.readline().strip()
            
        if not first_line:
            return 'unknown'
        
        data = json.loads(first_line)
        
        # 检查是否为原有格式
        if 'oss_image_path' in data:
            return 'original'
        
        return 'unknown'
        
    except Exception as e:
        print(f"检测输入格式时出错: {e}")
        return 'unknown'


def main(args):
    # --- 0. 检测输入文件格式 ---
    input_format = detect_input_format(args.input_file)
    
    if input_format == 'unknown':
        print(f"❌ 无法识别输入文件格式: {args.input_file}")
        print("支持的格式:")
        print("  1. Bbox JSON 格式 (标准JSON数组，包含 image_path, bbox 字段)")
        print("  2. 原有 JSONL 格式 (每行一个JSON对象，包含 oss_image_path 字段)")
        return
    
    print(f"✓ 检测到输入格式: {input_format.upper()}")
    
    # --- 1. 断点恢复：确定已经处理过的 ID ---
    processed_ids = set()
    
    # 根据输入格式设置输出文件名
    if input_format == 'bbox':
        output_file_path = f"{args.output_file}_bbox_gen-{args.generator_type}_judge-{args.validator_type}.jsonl"
        id_key = 'image_path'
    else:
        output_file_path = f"{args.output_file}_gen-{args.generator_type}_judge-{args.validator_type}.jsonl"
        id_key = 'oss_image_path'
    
    if os.path.exists(output_file_path):
        with open(output_file_path, 'r', encoding='utf-8') as f_out_read:
            for line in f_out_read:
                try: 
                    data = json.loads(line)
                    if id_key in data:
                        processed_ids.add(data[id_key])
                except: 
                    continue
        print(f"✓ 从主输出文件恢复，已处理 {len(processed_ids)} 条数据。")
    else:
        print(f"✓ 主输出文件 {output_file_path} 不存在，将创建新文件。")

    # --- 2. 读取输入数据 ---
    print(f"✓ 正在读取输入文件: {args.input_file}")
    
    # 🔧 修复：根据格式采用不同的读取方式
    if input_format == 'bbox':
        # bbox 格式：完整的 JSON 数组
        with open(args.input_file, 'r', encoding='utf-8') as f_in:
            all_data = json.load(f_in)
    else:
        # original 格式：JSONL（每行一个JSON）
        with open(args.input_file, 'r', encoding='utf-8') as f_in:
            all_data = [json.loads(line) for line in f_in if line.strip()]
    
    # 读取种子问题
    with open(args.seed_json, "r", encoding="utf-8") as f:
        seed_questions = json.load(f)

    # --- 3. 过滤任务 ---
    tasks = []
    skipped_failed = 0
    skipped_processed = 0
    
    for data in all_data:
        try:
            # bbox 格式特殊处理：跳过失败或缺少 bbox 的记录
            if input_format == 'bbox':
                # 🔧 修复：检查 'success' 字段（如果存在）
                if 'success' in data and not data.get('success', False):
                    skipped_failed += 1
                    continue
                if not data.get('bbox'):
                    skipped_failed += 1
                    continue
            
            # 跳过已处理的
            task_id = data.get(id_key)
            if task_id in processed_ids:
                skipped_processed += 1
                continue
            
            tasks.append(data)
            
        except Exception as e:
            print(f"⚠ 解析数据失败: {e}")
            continue
    
    if not tasks:
        print("✓ 没有需要处理的新任务。")
        return

    print(f"✓ 待处理任务数: {len(tasks)}")
    if input_format == 'bbox':
        print(f"  - 跳过失败/无效记录: {skipped_failed}")
    print(f"  - 跳过已处理: {skipped_processed}")
    print("="*60)
    
    # --- 3. 开始并发处理 ---
    processed_count_this_session = 0
    success_count = 0
    error_count = 0
    
    # 创建一个锁用于文件写入
    write_lock = threading.Lock()
    stats_interval = 50  # 每处理50个任务打印一次统计
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, \
         open(output_file_path, 'a', encoding='utf-8') as f_out:

        # 根据输入格式选择处理函数
        if input_format == 'bbox':
            process_func = process_bbox_task
            print(f"✓ 使用 bbox 处理模式")
        else:
            process_func = process_line_task
            print(f"✓ 使用原有处理模式")
        
        # 提交任务
        future_to_task = {
            executor.submit(process_func, task_data, args, seed_questions): task_data.get(id_key) 
            for task_data in tasks
        }
        
        progress_bar = tqdm(
            concurrent.futures.as_completed(future_to_task), 
            total=len(tasks), 
            desc="Processing",
            ncols=100
        )
        
        for future in progress_bar:
            task_id = future_to_task.get(future, "Unknown")
            
            try:
                result = future.result(timeout=100)  # 5分钟超时
                
                if result:
                    # 线程安全的文件写入
                    with write_lock:
                        f_out.write(json.dumps(result, ensure_ascii=False) + '\n')
                        f_out.flush()
                    
                    success_count += 1
                    processed_count_this_session += 1
                    
                    # 更新进度条显示
                    progress_bar.set_postfix({
                        'success': success_count,
                        'error': error_count,
                        'rate': f"{success_count/(success_count+error_count)*100:.1f}%" if (success_count+error_count) > 0 else "0%"
                    })
                else:
                    error_count += 1
                    processed_count_this_session += 1
                
                # 定期打印统计信息
                if processed_count_this_session % stats_interval == 0:
                    gemini_stats = token_pool.get_stats()
                    qwen_stats = qwen_token_pool.get_stats()
                    
                    print(f"\n{'='*60}")
                    print(f"中间统计 (已处理 {processed_count_this_session} 个任务):")
                    print(f"  成功: {success_count} | 失败: {error_count}")
                    print(f"  成功率: {success_count/(success_count+error_count)*100:.2f}%")
                    
                    if args.generator_type == 'gemini' or args.validator_type == 'gemini':
                        print(f"\nGemini Token 使用情况:")
                        for token_idx, (token, count) in enumerate(gemini_stats['usage'].items(), 1):
                            errors = gemini_stats['errors'].get(token, 0)
                            active = gemini_stats['active'].get(token, 0)
                            print(f"  Token{token_idx}: 使用{count}次 | 错误{errors}次 | 活跃{active}")
                    
                    if args.generator_type == 'qwen' or args.validator_type == 'qwen':
                        print(f"\nQwen Key 使用情况:")
                        for key_idx, (key, count) in enumerate(qwen_stats['usage'].items(), 1):
                            errors = qwen_stats['errors'].get(key, 0)
                            active = qwen_stats['active'].get(key, 0)
                            print(f"  Key{key_idx}: 使用{count}次 | 错误{errors}次 | 活跃{active}")
                    
                    print(f"{'='*60}\n")
                
            except concurrent.futures.TimeoutError:
                error_count += 1
                print(f"\n⚠ 任务 {task_id} 超时")
                
            except Exception as exc:
                error_count += 1
                print(f"\n⚠ 任务 {task_id} 发生异常: {exc}")
    
    # 最终统计
    print("\n" + "="*60)
    print("处理完成！最终统计:")
    print(f"  总处理: {processed_count_this_session}")
    print(f"  成功: {success_count}")
    print(f"  失败: {error_count}")
    if processed_count_this_session > 0:
        print(f"  成功率: {success_count/processed_count_this_session*100:.2f}%")
    
    # Token 使用统计
    if args.generator_type == 'gemini' or args.validator_type == 'gemini':
        gemini_stats = token_pool.get_stats()
        print(f"\nGemini Token 最终统计:")
        total_usage = sum(gemini_stats['usage'].values())
        total_errors = sum(gemini_stats['errors'].values())
        print(f"  总使用: {total_usage} 次")
        print(f"  总错误: {total_errors} 次")
        for token_idx, (token, count) in enumerate(gemini_stats['usage'].items(), 1):
            errors = gemini_stats['errors'].get(token, 0)
            print(f"  Token{token_idx}: {count}次 ({count/total_usage*100:.1f}%) | 错误: {errors}")
    
    if args.generator_type == 'qwen' or args.validator_type == 'qwen':
        qwen_stats = qwen_token_pool.get_stats()
        print(f"\nQwen Key 最终统计:")
        total_usage = sum(qwen_stats['usage'].values())
        total_errors = sum(qwen_stats['errors'].values())
        print(f"  总使用: {total_usage} 次")
        print(f"  总错误: {total_errors} 次")
        for key_idx, (key, count) in enumerate(qwen_stats['usage'].items(), 1):
            errors = qwen_stats['errors'].get(key, 0)
            print(f"  Key{key_idx}: {count}次 ({count/total_usage*100:.1f}%) | 错误: {errors}")
    
    print(f"\n输出文件: {output_file_path}")
    print("="*60)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="VQA Batch Generation & Validation (支持多种输入格式)")
    
    # 输入输出
    parser.add_argument('--input_file', type=str, required=True,
                        help='输入文件路径 (支持 bbox JSON 或原有 JSONL 格式)')
    parser.add_argument('--output_file', type=str, default="vqa_output",
                        help='输出文件前缀（会自动添加后缀）')
    parser.add_argument("--seed_json", type=str, 
                        default="./new_benchmark/utils_new_pipeline/seed.json",
                        help='种子问题 JSON 文件路径')
    
    # 并发控制
    parser.add_argument('--max_workers', type=int, default=15, 
                        help='主任务并发数（建议：Token数 × 每Token并发数）')
    parser.add_argument('--validation_workers', type=int, default=6,
                        help='每个任务内验证子任务的并发数')
    
    # 模型选择
    parser.add_argument('--generator_type', type=str, choices=['qwen', 'gemini'], default='gemini',
                        help='生成器模型类型')
    parser.add_argument('--validator_type', type=str, choices=['qwen', 'gemini'], default='gemini',
                        help='验证器模型类型')
    
    # API配置
    parser.add_argument('--generator_api_key', type=str, default=GENERATOR_API_KEY)
    parser.add_argument('--generator_base_url', type=str, default=GENERATOR_API_BASE_URL)
    parser.add_argument('--validator_api_key', type=str, default=GENERATOR_API_KEY)
    parser.add_argument('--validator_base_url', type=str, default=GENERATOR_API_BASE_URL)
    
    # 新增：是否上传裁剪图片到 OSS
    parser.add_argument('--upload_crop_to_oss', action='store_true',
                        help='是否将裁剪后的图片上传到 OSS')
    
    args = parser.parse_args()
    
    
    main(args)


