#!/usr/bin/env python3
import os
import json
import re

# 文件路径
base = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v4'
train_json   = os.path.join(base, 'train/train.json')
val_json     = os.path.join(base, 'val/val.json')
val_ans_json = os.path.join(base, 'val/val_ans.json')

# 正则：定位第一项序号，用于分割 prefix 和 options
first_opt_re = re.compile(r'(\d+\.\s*)')

def transform_human(value: str) -> str:
    # 将所有选项项目前加 "\n"，并转小写
    # 分割前缀与选项
    m = first_opt_re.search(value)
    if not m:
        return value
    prefix = value[:m.start()].rstrip(' .')  # 去掉末尾空格和句点
    # 强制以冒号结尾
    if not prefix.endswith(':'):
        prefix += ':'
    opts_str = value[m.start():]
    # 提取所有选项编号与文本
    # 匹配 "num. label" （label 到逗号或行尾）
    items = re.findall(r'(\d+)\.\s*([^,]+)', opts_str)
    # 重建选项，每项前加 "\n"，内容小写
    opts = ''.join(f"\n{num}. {label.strip().lower()}" for num, label in items)
    return prefix + opts

def transform_gpt(value: str) -> str:
    # 去掉编号与点，仅保留内容，并转小写
    # value 形如 "107. loft"
    parts = value.split('.', 1)
    if len(parts) == 2:
        return parts[1].strip().lower()
    return value.strip().lower()

def process_train():
    data = json.load(open(train_json, encoding='utf-8'))
    for entry in data:
        # 更新 image 路径前缀
        img = entry.get('image', '')
        entry['image'] = img.replace('Scene_classification_v3/', 'Scene_classification_v4/')
        # 转换 conversations
        for msg in entry.get('conversations', []):
            if msg['from'] == 'human':
                msg['value'] = transform_human(msg['value'])
            elif msg['from'] == 'gpt':
                msg['value'] = transform_gpt(msg['value'])
    with open(train_json, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
    print("✅ Updated train.json")

def process_val():
    data = json.load(open(val_json, encoding='utf-8'))
    for entry in data:
        # 更新 image 路径前缀
        img = entry.get('image', '')
        entry['image'] = img.replace('Scene_classification_v3/', 'Scene_classification_v4/')
        # 转换 text 字段（human 提问）
        entry['text'] = transform_human(entry['text'])
    with open(val_json, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
    print("✅ Updated val.json")

def process_val_ans():
    data = json.load(open(val_ans_json, encoding='utf-8'))
    for entry in data:
        # prompt 也做 human 格式转换
        entry['prompt'] = transform_human(entry['prompt'])
        # text 做 gpt 转换
        entry['text'] = transform_gpt(entry['text'])
    with open(val_ans_json, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
    print("✅ Updated val_ans.json")

if __name__ == '__main__':
    process_train()
    process_val()
    process_val_ans()
