#!/usr/bin/env python3
import os
import json
import random
import re

# —— 配置 —— #
base_v5      = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v5'
train_json   = os.path.join(base_v5, 'train/train.json')
val_json     = os.path.join(base_v5, 'val/val.json')
val_ans_json = os.path.join(base_v5, 'val/val_ans.json')

# 固定随机种子以复现抽样
random.seed(42)
# 用于识别选项行的正则（以数字+点开头）
first_num_re = re.compile(r'^\s*\d+\.')

def subsample_prompt(text):
    """
    拆分 human prompt 为前缀（<image>…行）和所有选项（去掉编号）。
    返回 (prefix_lines, option_texts).
    """
    lines = text.strip().split('\n')
    prefix = []
    opts   = []
    for line in lines:
        if first_num_re.match(line):
            # 选项行，去除“数字. ”前缀
            parts = line.split('. ', 1)
            if len(parts) == 2:
                opts.append(parts[1].strip())
        else:
            prefix.append(line)
    return prefix, opts

def process_train():
    data = json.load(open(train_json, 'r', encoding='utf-8'))
    for entry in data:
        # 1) 更新 image 路径前缀
        entry['image'] = entry['image'].replace(
            'Scene_classification_v4/', 
            'Scene_classification_v5/'
        )

        # 2) 拆分 human prompt 与选项
        human = entry['conversations'][0]['value']
        correct_label = entry['conversations'][1]['value'].strip()  # 已无编号

        prefix, opts = subsample_prompt(human)

        # 3) 抽样：29个负 + 1个正，共30
        negatives = [o for o in opts if o != correct_label]
        sampled   = random.sample(negatives, 29) + [correct_label]
        random.shuffle(sampled)

        # 4) 重建 human prompt
        new_human = prefix + [f"{i+1}. {lbl}" for i, lbl in enumerate(sampled)]
        entry['conversations'][0]['value'] = '\n'.join(new_human)

        # 5) 重建 GPT 回答（仅内容，小写）
        entry['conversations'][1]['value'] = correct_label.lower()

    with open(train_json, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
    print("✅ train.json updated")

def process_val():
    val_data = json.load(open(val_json, 'r', encoding='utf-8'))
    ans_data = json.load(open(val_ans_json, 'r', encoding='utf-8'))
    for v, a in zip(val_data, ans_data):
        # 1) 更新 image 路径前缀
        v['image'] = v['image'].replace(
            'Scene_classification_v4/',
            'Scene_classification_v5/'
        )

        # 2) 拆分 prompt 与选项
        prefix, opts = subsample_prompt(v['text'])
        correct_label = a['text'].strip()  # 已无编号

        # 3) 抽样 29+1
        negatives = [o for o in opts if o != correct_label]
        sampled   = random.sample(negatives, 29) + [correct_label]
        random.shuffle(sampled)

        # 4) 重建 val.json 中的 text
        new_text = prefix + [f"{i+1}. {lbl}" for i, lbl in enumerate(sampled)]
        v['text'] = '\n'.join(new_text)

        # 5) 同步更新 val_ans.json
        a['prompt'] = v['text']
        a['text']   = correct_label.lower()

    with open(val_json, 'w', encoding='utf-8') as f:
        json.dump(val_data, f, indent=4, ensure_ascii=False)
    with open(val_ans_json, 'w', encoding='utf-8') as f:
        json.dump(ans_data, f, indent=4, ensure_ascii=False)
    print("✅ val.json and val_ans.json updated")

if __name__ == '__main__':
    process_train()
    process_val()
