import re
import os
import json
import shutil
import unicodedata
import sys
import torch
from typing import Any, Dict, Iterable, List, Set
import math
from transformers import AutoConfig

# ------- InternVL 官方推荐的图像预处理（README Quick Start 里的那一套） -------
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def _build_internvl_transform(input_size: int):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose(
        [
            T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD),
        ]
    )
    return transform


def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def _dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    target_aspect_ratio = _find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def _internvl_load_image(image_file: str, input_size=448, max_num=6) -> torch.Tensor:
    image = Image.open(image_file).convert("RGB")
    transform = _build_internvl_transform(input_size=input_size)
    images = _dynamic_preprocess(
        image,
        image_size=input_size,
        use_thumbnail=True,
        max_num=max_num,
    )
    pixel_values = [transform(im) for im in images]
    pixel_values = torch.stack(pixel_values)  # [T, 3, H, W]
    return pixel_values

# ========================================
# SeedBench 数据相关
# ========================================

_LETTER = r"[A-D]"   # SeedBench Plus 限定 A–D

# 正则匹配模式
AFTER = r'(?=$|[\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜·\-–—~〜～*`]|项)'
SEP_CHARS = r'\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜=·\-–—~〜～/\\\*`'
RE_MAIN_PAREN = re.compile(rf'(?i)the\s+answer\s+is\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?{AFTER}')
RE_VARIANTS = [
    # 1) 英文强锚点
    re.compile(rf'(?i)\b(?:final|correct)\s+answer\s*(?:is\s*[:：]?)?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?{AFTER}'),  # 允许没有 is 与冒号
    re.compile(rf'(?i)\bthe\s+answer\s+is\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)\banswer\s+is\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)\banswer\s*[:：]\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),               # 冒号写法
    re.compile(rf'(?i)\b(?:correct|final)\s*choice\s*(?:is|:)?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)\bthe\s+answer\s+is\s*[:：]?\s*(?:option\s*)?[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),

    # 2) 中文常见
    re.compile(rf'(?i)(?:正确答案|答案)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)(?:故选)\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)(?:选择(?:答案)?|选(?:择)?)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?' + AFTER),
    re.compile(rf'(?i)(?:选项|选)\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*项?'+AFTER),

    # 英文动作动词（避免误命中枚举行）
    re.compile(rf'(?i)\b(?:choose|chosen|pick|picked|select|selected)\s+(?:option\s*)?[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?' + AFTER),

    # 3) MarkDown格式
    re.compile(rf'(?i)\b(?:final|correct)?\s*answer\s*[:：]?\s*\**\s*[\[\{{\(\（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*\**'+AFTER),
    # re.compile(r'(?i)\*\*\s*({_LETTER})\s*\*\*' + AFTER),

    # 4) LaTeX
    # re.compile(r'(?i)\\boxed\{\s*\(?\s*({_LETTER})\s*\)?\s*\}'),
    # re.compile(r'(?i)\\boxed\{\s*(?:\\(?:textbf|mathbf|mathrm|mathsf)\{\s*({_LETTER})\s*\}|({_LETTER}))\s*\}'),
    # re.compile(r'(?i)\\\(\s*({_LETTER})\s*\\\)'),                                # 允许括号包字母
    # （可选再加：\(\s*({_LETTER})\s*\)，用于 \(D\) 这种）
]

# 整行锚定，一般不会错杀
RE_LASTLINE_FORMAT_FALLBACKS = [
    # **A** / **b**
    re.compile(rf'^\s*\*\*\s*({_LETTER})\s*\*\*\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
    # **(A)**
    re.compile(rf'^\s*\*\*\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*\*\*\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
    # \boxed{A} / \boxed{\textbf{A}} / \boxed{\mathrm{a}} …
    re.compile(
        rf'^\s*\\boxed\{{\s*(?:\\(?:textbf|mathbf|mathrm|mathsf|bfseries|boldsymbol)\{{\s*({_LETTER})\s*\}}|({_LETTER}))\s*\}}\s*[\.\!\?\u3002\uFF1F]*\s*$',
        re.I
    ),
    # \(A\)
    re.compile(rf'^\s*\\\(\s*({_LETTER})\s*\\\)\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
    # (A) / [A] / {A} 以及全角括号
    re.compile(rf'^\s*[\(\[\{{（【]\s*({_LETTER})\s*[\)\]\}}）】]\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
]

### 增加 incomplete 截断功能 ###
_SENT_END_RE = re.compile(r'(?:[。！？!?]|\.{3,}|…|\.(?=(?:\s|\u3000|$|[”"\'’）\)\]】》])))[\s\u3000]*[”"\'’）\)\]】》]*[\s\u3000]*')
ANSWER_PHRASES = [
    # === 英文通用（收敛） ===
    r"\b(?:so|thus|therefore)\b.*?\b(?:final\s+|correct\s+)?(?:answer|result|solution|choice)\s*(?:is|are|:|：)",  # 有因果/收束词
    r"\b(?:the\s+)?(?:final\s+|correct\s+)?answer\s*(?:is|are|:|：)",   # the/final/correct answer ...
    r"\banswer\s*(?:is|:|：)",                                          # answer is / answer:
    r"\b(?:final|correct)\s+choice\s*(?:is|:|：)",                      # final/correct choice ...
    r"\bthe\s+answer\s+is\s*(?:option\s*)?",                            # the answer is (option) ...
    r"\b(?:result|solution)\s*[:：]",                                   # 只在冒号形式放宽 result/solution

    # === 中文通用（无需前缀空格） ===
    r"(?:因此|所以|综上(?:所述)?|故|可知)\s*(?:最终\s*)?(?:答案|结果|选项|选择)\s*(?:是|为|:|：)",
    r"(?:最终\s*)?(?:答案|结果)\s*(?:是|为|:|：)",

    # === 简洁形式（无动词） ===
    r"(?:answer|答案)\s*[:：]",

    # === 选择题强语义结尾（收紧） ===
    # 仅保留"故选/应选"，或严格限制为 A–D/全角Ａ–Ｄ/数字/中文数字 后缀
    r"(?:故\s*选|应\s*选)\s*[A-DＡ-Ｄ](?:\s*项|(?:\s*选项)?)?",
    r"(?:故\s*选|应\s*选)\s*(?:[0-9]+|[一二三四五六七八九十])(?:\s*项|(?:\s*选项)?)?",
]
ANSWER_PHRASES_RE = [re.compile(ANSWER_PHRASES[0], re.I | re.S)] + \
                    [re.compile(p, re.I) for p in ANSWER_PHRASES[1:]]

# 如果尾巴有完整答案，则不截断（理论上发生几率较小）
def _tail_has_complete_answer(tail: str) -> bool:
    if RE_MAIN_PAREN.search(tail):  # 严格主形态
        return True
    for pat in RE_VARIANTS:         # 变体里只要有非空捕获
        m = pat.search(tail)
        if m and next((g for g in m.groups() if g), None):
            return True
    return False

def _contains_answer_phrase(text: str) -> bool:
    return any(p.search(text) for p in ANSWER_PHRASES_RE)

# ========================================
# 答案解析 / 文本处理
# ========================================

# 去掉不完整的结尾句子
def trim_incomplete_sentence(ans: str) -> str:
    if not ans:
        return ans
    matches = list(_SENT_END_RE.finditer(ans))
    if not matches:
        return ans
    last_end = matches[-1].end()

    # print(ans[:last_end][:50], ans[last_end:])

    # 只有在最后一句不完整的话中包含可能会被匹配的候选项并且前面的内容中也出现了候选项时才截断
    if _contains_answer_phrase(ans[last_end:]) and _contains_answer_phrase(ans[:last_end]):
        if _tail_has_complete_answer(ans[last_end:]):
            return ans
        return ans[:last_end].strip()
    return ans

# 暂时不考虑 few-shot
def format_prompt(example):
    prompt = "Question:\n"
    question = example["question"]
    prompt += question + "\n"
    prompt += "Options:\n"

    # 对于 SeedBench Plus，使用 choice_A/B/C/D 字段
    option_mapping = [
        ("choice_A", "A"),
        ("choice_B", "B"),
        ("choice_C", "C"),
        ("choice_D", "D")
    ]

    for field_name, choice_letter in option_mapping:
        if field_name in example and example[field_name]:
            prompt += "{}. {}\n".format(choice_letter, example[field_name])

    return prompt

# question, choice_A/B/C/D, answer
def parse_input(item):
    prompt = "The following is a multiple choice question. Think step by step and then finish your answer with \"the answer is (X)\" where X is one of the options (e.g. one of ABCD).\n"
    context = prompt + format_prompt(item)
    return context

# 将字符串转换为规范形式，去除空白，压缩空白，保留换行
def _normalize(s: str) -> str:
    s = s or ""
    s = unicodedata.normalize('NFKC', s)
    s = s.replace("（", "(").replace("）", ")")
    for z in ("\u200b", "\xa0", "\u202f", "\ufeff", "\u200d", "\u2060", "\t"):
        s = s.replace(z, " ")
    # 仅压缩空白但保留换行：避免破坏“最后一行”逻辑
    s = re.sub(r'[^\S\r\n]+', ' ', s)
    return s

def extract_answer(text: str, option_answer_text: str):
    """首选匹配：the answer is (X) —— 不区分大小写，取最后一次出现。"""
    if not text:
        return None
    t = _normalize(text)

    # 展开常见文本/样式宏为纯文本，防止误杀 \text{A} / \boxed{A} 等
    t = re.sub(r'\\(?:text|mathrm|mathbf|mathsf|textbf|operatorname|textrm|rm)\s*\{([^{}]*)\}', r'\1', t)
    matches = list(RE_MAIN_PAREN.finditer(t))
    if matches:
        return matches[-1].group(1).upper()
    return extract_again(t, option_answer_text)

# 变体匹配，优先选择强锚点
def extract_again(text: str, option_answer_text: str):
    best_end = -1
    best_val = None
    for pat in RE_VARIANTS:
        for m in pat.finditer(text):
            # 兼容多捕获组，并按顺序（锚点强弱）取第一个非空的组内容
            g = next((x for x in m.groups() if x), None)
            if not g:
                continue
            if m.end() > best_end:
                best_end, best_val = m.end(), g.upper()
    return best_val if best_val is not None else extract_final(text, option_answer_text)

def _strip_latex_noise_oneline(s: str) -> str:
    # 去掉行内/陈列数学环境定界符
    s = re.sub(r'\\\[|\\\]|\\\(|\\\)|\$\$?', ' ', s)
    # 去掉常见 \frac / \dfrac / \tfrac 形式
    s = re.sub(r'\\[dt]?frac\s*\{[^{}]*\}\s*\{[^{}]*\}', ' ', s)
    # 去掉通用 \cmd{...}（保留 \boxed{...} 等已在上面兜底过的）
    s = re.sub(r'\\(?!boxed\b)[A-Za-z]+\s*\{([^{}]*)\}', r' \1 ', s)
    return s

# 兜底：看最后三行
def extract_final(text: str, option_answer_text: str):
    """
    兜底：看最后三行
    先试 the answer is (X)，再试纯字母行 '(D)' / 'D'，最后取最后一个 A–D。
    """
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    if not lines:
        return None

    # 检查末尾3行（如果长度允许的话）
    for idx, ln in enumerate(reversed(lines[max(-len(lines),-3):])):
        last = _normalize(ln)

        # 末行快速重试若干锚点（含无括号）
        # 再跑一遍强/中锚点
        for pat in RE_VARIANTS:
            m = pat.search(last)
            if m:
                g = next((x for x in m.groups() if x), None)
                if g:
                    return g.upper()

        for pat in RE_LASTLINE_FORMAT_FALLBACKS:
            m = pat.search(last)
            if m:
                g = next((x for x in m.groups() if x), None)
                if g:
                    return g.upper()

        if option_answer_text:
            cand = extract_latex_non_letter(last)
            if cand:
                if _normalize_for_compare(cand) == _normalize_for_compare(option_answer_text):
                    # 提取到的“非字母选项内容”与正确选项文本精确匹配
                    return cand

        last_sanit = _strip_latex_noise_oneline(last)

        # 整行只包含一个字母（可带括号和标点）
        m = re.match(rf'^\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*[\.\!\?\u3002\uFF1F]*\s*$', last_sanit, re.I)
        if m:
            return m.group(1).upper()

        # 在行内查找被分隔符包围的独立字母，取最后一个
        SEP_CHARS_RIGHT = SEP_CHARS
        SEP_CHARS_LEFT  = r'\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜=·\-–—~〜～*`'  # 去掉了 / 和 \
        pattern = rf'(?i)(?:^|[{SEP_CHARS_LEFT}])\s*[\(\[\{{（【]?\b({_LETTER})\b\s*[\)\]\}}）】]?(?!\s*[/\\=^\(])(?=$|[{SEP_CHARS_RIGHT}]|项)'
        matches = list(re.finditer(pattern, last_sanit, re.I))
        if matches:
            return matches[0].group(1).upper()

    return None

# 标准化字符串，去除空白，压缩空白，保留换行，方便答案比较
def _normalize_for_compare(s: str) -> str:
    s = s or ""
    s = unicodedata.normalize("NFKC", s)
    s = re.sub(r"\s+", " ", s) # 压缩连续空白符
    return s.strip()

def _latex_clean_payload(s: str) -> str:
    if s is None:
        return ""
    cur = s

    # —— 分数直接删掉（含 \frac/\tfrac/\dfrac 与 纯文本 a/b 形式）——
    # 1) 去掉 \frac{...}{...} / \tfrac / \dfrac
    cur = re.sub(r"\\[td]?frac\s*\{\s*[^{}]*?\s*\}\s*\{\s*[^{}]*?\s*\}", " ", cur)
    # 2) 去掉纯文本分数 a/b（含可选正负号与空格）
    cur = re.sub(r"[-+−]?\s*\d+\s*/\s*\d+", " ", cur)

    # 统一 Unicode 负号
    cur = cur.replace("\u2212", "-")

    # 递归剥外层常见文本/样式宏壳
    prev = None
    while prev != cur:
        prev = cur
        cur = re.sub(
            r"\\(?:text|mathrm|mathbf|mathsf|textbf|operatorname|textrm|rm)\s*\{([^{}]*)\}",
            r"\1", cur
        )

    # 去掉数学定界符与剩余裸命令
    cur = re.sub(r"(\\\[|\\\]|\$\$|\\\(|\\\))", " ", cur)
    cur = re.sub(r"\\[A-Za-z]+", " ", cur)   # \left \right 等
    cur = cur.replace("{", " ").replace("}", " ")

    # 规整空白
    cur = _normalize_for_compare(cur)

    # 去掉引导语
    cur = re.sub(r"(?i)\b(?:the\s+answer\s+is|答案(?:为|是)?)\b\s*[:：]?\s*", "", cur).strip()

    # 若整体仅被一对括号包裹，剥掉
    if re.fullmatch(r"[\(\[\{（【〔〈《]\s*[^()\[\]{}（）【】〔〕〈〉《》]*\s*[\)\]\}）】〕〉》]", cur):
        cur = cur[1:-1].strip()
    return cur

# 抽“非字母选项内容”。采用“后出现优先”：候选收集到列表后，从后往前清洗与判定；命中即返回，其余候选不会传播。
def extract_latex_non_letter(text: str) -> str | None:
    """
    抽“非字母选项内容”。采用“后出现优先”：
    候选收集到列表后，从后往前清洗与判定；命中即返回，其余候选不会传播。
    """
    if not text:
        return None

    t = text
    candidates: list[str] = []

    # 1) \boxed{...}（可能出现多次，先收集，后出现优先）
    last_boxed = None
    for m in re.finditer(r"\\boxed\s*\{\s*([^{}]+?)\s*\}", t):
        last_boxed = m.group(1)
    if last_boxed:
        candidates.append(last_boxed)

    # 2) 常见文本/样式宏
    last_text_like = None
    for macro in ("text", "textbf", "mathrm", "mathbf", "mathsf", "operatorname", "textrm", "rm"):
        for m in re.finditer(rf"\\{macro}\s*\{{\s*([^{{}}]+?)\s*\}}", t):
            last_text_like = m.group(1)
    if last_text_like:
        candidates.append(last_text_like)

    # 3) “the answer is ...” + 数字（非分数；分数已被清洗逻辑去掉）
    m_num = re.search(r"(?i)the\s+answer\s+is\s*[:：]?\s*\(?\s*([-+−]?\d+(?:\.\d+)?)", t)
    if m_num:
        candidates.append(m_num.group(1))

    # 4) 数学环境 \( ... \)、\[ ... \]、$$ ... $$
    for m in re.finditer(r"(?:\\\(|\\\[|\$\$)(.*?)(?:\\\)|\\\]|\$\$)", t, re.S):
        inner = m.group(1)
        # math 内部 \boxed{...}
        for m2 in re.finditer(r"\\boxed\s*\{\s*([^{}]+?)\s*\}", inner):
            candidates.append(m2.group(1))
        # 其次取最后出现的“数字”（非分数）
        nums = re.findall(r"[-+−]?\d+(?:\.\d+)?", inner)
        if nums:
            candidates.append(nums[-1])

    # 过滤：单个 A–D 交给字母流程
    try:
        LETTER = _LETTER  # 如果你的项目里已有
    except NameError:
        LETTER = r"[A-D]"

    for raw in reversed(candidates):  # “后出现优先”
        s = _latex_clean_payload(raw)
        if not s:
            continue
        if re.fullmatch(rf"\(?\s*{LETTER}\s*\)?", s, re.I):
            continue
        return s  # 命中非字母选项
    return None

# 去除字符串末尾的标点符号和空白字符
def _strip_trailing_punct(s: str) -> str:
    # 去掉常见句尾标点（英文/中文）
    return re.sub(r'[\.。!,，；;：:\s]+$', '', s or '').strip()

def seedbench_parse(pred, true, hit_limit, option_answer_text):
    """
    SeedBench Plus 答案解析函数，提取答案并判断正确性
    """
    extracted = extract_answer(pred, option_answer_text)
    if not extracted or extracted.strip() == "":
        if hit_limit is True:
            return "Incomplete", False
        return None, False

    return extracted, true.strip().upper() == extracted.upper() or _normalize_for_compare(_strip_trailing_punct(extracted.lower())) == _normalize_for_compare(_strip_trailing_punct((option_answer_text or "").lower()))


# ========================================
# SeedBench 数据相关
# ========================================

# 清空当前目录下的所有文件和子目录
def clear_dir(dir_path: str):
    if not os.path.isdir(dir_path):
        return
    for filename in os.listdir(dir_path):
        file_path = os.path.join(dir_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.remove(file_path)       # 删除文件或符号链接
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)   # 删除子目录
        except Exception as e:
            print(f"删除失败: {file_path}, 原因: {e}")

def get_item_id(item: Dict[str, Any]) -> str:
    # SeedBench Plus 使用 question_id 作为唯一标识
    val = item.get("question_id", None)
    if val is None or val == "":
        val = item.get("data_id") # 备用：data_id
        if val is None or val == "":
            val = item.get("unique_id", None) # 其他数据集的 unique_id
            if val is None or val == "":
                val = item.get("id") # BBH 使用 id
    if val is None:
        return ""
    return str(val).strip()  # "0" 也会被当作有效 ID

def get_options_list(item: Dict[str, Any]) -> List[str]:
    """获取选项列表，适配新的字段结构"""
    options = []
    option_mapping = [
        ("choice_A", "A"),
        ("choice_B", "B"),
        ("choice_C", "C"),
        ("choice_D", "D")
    ]

    for field_name, _ in option_mapping:
        if field_name in item and item[field_name]:
            options.append(item[field_name])

    return options

def get_option_answer_text(item: Dict[str, Any]) -> str:
    """获取正确答案对应的选项文本"""
    answer = item.get("answer", "")
    if not answer:
        return ""

    # 对于 SeedBench Plus，答案是字母（A/B/C/D）
    answer = answer.upper().strip()

    # 根据答案字母获取对应的选项文本
    if answer == "A" and "choice_A" in item:
        return item["choice_A"]
    elif answer == "B" and "choice_B" in item:
        return item["choice_B"]
    elif answer == "C" and "choice_C" in item:
        return item["choice_C"]
    elif answer == "D" and "choice_D" in item:
        return item["choice_D"]

    return ""

def get_image_path(item: Dict[str, Any]) -> str:

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

    """构建完整的图片路径"""
    image_path = item.get("image_path", "")
    if not image_path:
        return ""

    # 构建完整路径：项目根目录 + new_benchmark + seedbench_plus2 + 图片相对路径
    full_path = os.path.join(PROJECT_ROOT, "new_benchmark", "seedbench_plus2", image_path)
    return full_path

# ========================================
# 运行管理
# ========================================

# eos 规范化
def _normalize_to_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple, set)):
        return list(x)
    return [x]

class CheckpointManager:
    """
    将进度以“已处理样本 ID 集合”的方式持久化。
    load_existing=False 时，会忽略并覆盖既有 checkpoint，实现“从头跑但仍写断点”的模式。
    """
    def __init__(self, path: str, load_existing: bool = True):
        self.path = path
        self.processed_ids: Set[str] = set()
        self.completed: bool = False
        self.total_samples: int = 0
        if load_existing:
            self._load()
        else:
            # 显式选择不加载旧 ckpt：如有旧文件，将在第一次 _atomic_save 时被覆盖
            pass

    def _load(self) -> None:
        if os.path.exists(self.path):
            try:
                with open(self.path, "r", encoding="utf-8") as f:
                    data = json.load(f)
                self.processed_ids = set(map(str, data.get("processed_ids", [])))
                self.completed = bool(data.get("completed", False))
                self.total_samples = int(data.get("total_samples", 0))
            except Exception as e:
                print(f"[ckpt] 读取失败，忽略旧文件：{self.path} ({e})")

    def reset(self) -> None:
        """清空内存态的进度；下次 _atomic_save 将覆盖磁盘文件。"""
        self.processed_ids.clear()
        self.completed = False
        # total_samples 保留，后续 set_total 会更新
        # 不立即落盘，避免误覆盖，等到首次写入再覆盖

    def set_total(self, n: int) -> None:
        if self.total_samples != n:
            self.total_samples = int(n)
            self._atomic_save()

    def mark_batch_processed(self, ids: Iterable[str]) -> None:
        for _id in ids:
            sid = str(_id).strip() if _id is not None else ""
            if sid:
                self.processed_ids.add(sid)
        if self.total_samples > 0 and len(self.processed_ids) >= self.total_samples:
            self.completed = True
        self._atomic_save()

    def is_processed(self, _id: Any) -> bool:
        sid = str(_id).strip() if _id is not None else ""
        return bool(sid) and (sid in self.processed_ids)

    def _atomic_save(self) -> None:
        tmp = self.path + ".tmp"
        data = {
            "processed_ids": sorted(self.processed_ids),
            "completed": self.completed,
            "total_samples": self.total_samples,
        }
        os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
        with open(tmp, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False)
        os.replace(tmp, self.path)


def split_model_internvl(model_path):
    model_name = None

    intern_name = ['InternVL2_5-1B', 'InternVL2_5-2B', 'InternVL2_5-4B', 'InternVL2_5-8B',
        'InternVL2_5-26B', 'InternVL2_5-38B', 'InternVL2_5-78B', 'InternVL2-1B',
        'InternVL2-2B', 'InternVL2-4B', 'InternVL2-8B', 'InternVL2-26B',
        'InternVL2-40B', 'InternVL2-Llama3-76B', 'Mini-InternVL-2B-V1-5', 'Mini-InternVL-4B-V1-5', 'InternVL-Chat-V1-5']

    for name in intern_name:
        if name.lower() in model_path.lower():
            model_name = name
            break

    device_map = {}
    world_size = torch.cuda.device_count()
    num_layers = {
        'InternVL2_5-1B': 24, 'InternVL2_5-2B': 24, 'InternVL2_5-4B': 36, 'InternVL2_5-8B': 32,
        'InternVL2_5-26B': 48, 'InternVL2_5-38B': 64, 'InternVL2_5-78B': 80, 'InternVL2-1B': 24,
        'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32, 'InternVL2-26B': 48,
        'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80, 'Mini-InternVL-2B-V1-5': 24, 'Mini-InternVL-4B-V1-5': 32, 'InternVL-Chat-V1-5': 48}[model_name]
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

def split_model_internvl3(model_path):
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

def _extend_prefill_inputs(batch_size, inputs, extra_tokens: int, fill_token_id: int):
    if extra_tokens <= 0:
        return inputs
    if "input_ids" not in inputs:
        return inputs

    input_ids = inputs["input_ids"]

    extra_ids = torch.full(
        (batch_size, extra_tokens),
        int(fill_token_id),
        dtype=input_ids.dtype,
        device=input_ids.device,
    )
    inputs["input_ids"] = torch.cat([input_ids, extra_ids], dim=1)

    if "attention_mask" in inputs and inputs["attention_mask"] is not None:
        am = inputs["attention_mask"]
        extra_mask = torch.ones((batch_size, extra_tokens), dtype=am.dtype, device=am.device)
        inputs["attention_mask"] = torch.cat([am, extra_mask], dim=1)

    return inputs

def apply_prefill_extra_tokens(
    batch_size,
    inputs,
    prefill_extra_tokens: int,
    tokenizer,
    prefill_token_id=None,
):
    """
    基于 tokenizer/prefill_token_id 选择 fill_token_id，
    然后复用现有 _extend_prefill_inputs 做右侧拼接。
    """
    if not prefill_extra_tokens:
        return inputs
    prefill_extra_tokens = int(prefill_extra_tokens)
    if prefill_extra_tokens <= 0:
        return inputs

    # 选 tok_id：prefill_token_id -> tokenizer.eos_token_id -> 0
    tok_id = prefill_token_id
    if isinstance(tok_id, (list, tuple)):
        tok_id = tok_id[0] if tok_id else None

    if tok_id is None:
        tok_id = getattr(tokenizer, "eos_token_id", None)
        if isinstance(tok_id, (list, tuple)):
            tok_id = tok_id[0] if tok_id else None

    if tok_id is None:
        tok_id = 0

    return _extend_prefill_inputs(batch_size, inputs, prefill_extra_tokens, int(tok_id))
