#!/usr/bin/env python3
import os
import random
import json
import shutil
import shortuuid
from tqdm import tqdm

# —— 配置 —— #
src_root         = '/fs/scratch/PAS2099/Jiacheng/Places_merge/output/balanced_data_v4'
base_out         = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v3'
train_json       = os.path.join(base_out, 'train/train.json')
train_image_dir  = os.path.join(base_out, 'train/image')
val_json         = os.path.join(base_out, 'val/val.json')
val_ans_json     = os.path.join(base_out, 'val/val_ans.json')
val_image_dir    = os.path.join(base_out, 'val/image')

# 创建输出目录
os.makedirs(train_image_dir, exist_ok=True)
os.makedirs(val_image_dir,   exist_ok=True)

# 1. 收集所有标签（即子文件夹名）
labels = [
    d for d in os.listdir(src_root)
    if os.path.isdir(os.path.join(src_root, d))
]

# 2. 根据标签名是否包含大写字母区分 AID vs Places
def is_aid_label(label: str) -> bool:
    return any(c.isupper() for c in label)

aid_labels    = [l for l in labels if is_aid_label(l)]
places_labels = [l for l in labels if not is_aid_label(l)]

# 3. 收集所有图片路径，按类别分组
aid_entries    = []
places_entries = []
for label in labels:
    folder = os.path.join(src_root, label)
    for fn in os.listdir(folder):
        if not fn.lower().endswith(('.jpg', '.jpeg', '.png')):
            continue
        path = os.path.join(folder, fn)
        if is_aid_label(label):
            aid_entries.append((label, path))
        else:
            places_entries.append((label, path))

# 4. 划分训练/验证集（80% train / 20% val）
random.seed(42)
n_aid_train    = int(0.8 * len(aid_entries))
n_places_train = int(0.8 * len(places_entries))

train_aid    = random.sample(aid_entries, n_aid_train)
val_aid      = [e for e in aid_entries    if e not in train_aid]
train_places = random.sample(places_entries, n_places_train)
val_places   = [e for e in places_entries if e not in train_places]

train_entries = train_aid + train_places
val_entries   = val_aid   + val_places

random.shuffle(train_entries)
random.shuffle(val_entries)

print(f"Total training images:   {len(train_entries)}")
print(f"Total validation images: {len(val_entries)}")

# 5. 生成 train.json
train_data = []
for label, img_path in tqdm(train_entries, desc='Build train.json'):
    uid = shortuuid.uuid()
    # 复制并重命名图片
    dst_img = os.path.join(train_image_dir, uid + '.jpg')
    shutil.copy(img_path, dst_img)

    # 构建选项列表
    if is_aid_label(label):
        opts = aid_labels.copy()
    else:
        others = [l for l in places_labels if l != label]
        opts = random.sample(others, 199) + [label]
    random.shuffle(opts)

    # 构造 human 提问字符串
    opts_str = ' , '.join(f"{i+1}. {opt}" for i, opt in enumerate(opts))
    human_q  = (
        "<image>\n"
        "What is the scene class of the image? Choose one from below. "
        f"{opts_str}"
    )

    # GPT 回答保留 编号 + 内容
    correct_idx = opts.index(label) + 1
    gpt_a       = f"{correct_idx}. {label}"

    train_data.append({
        "id": uid,
        "image": f"Scene_classification_v3/train/image/{uid}.jpg",
        "conversations": [
            {"from": "human", "value": human_q},
            {"from": "gpt",   "value": gpt_a}
        ]
    })

with open(train_json, 'w', encoding='utf-8') as f:
    json.dump(train_data, f, indent=4, ensure_ascii=False)
print(f"✅ Train JSON written to {train_json}")

# 6. 生成 val.json & val_ans.json
val_pairs = []
for label, img_path in tqdm(val_entries, desc='Build val.json'):
    uid = shortuuid.uuid()
    orig_name = os.path.basename(img_path)
    # 保留原名复制到 val/image
    shutil.copy(img_path, os.path.join(val_image_dir, orig_name))

    # 构建选项列表
    if is_aid_label(label):
        opts = aid_labels.copy()
    else:
        others = [l for l in places_labels if l != label]
        opts = random.sample(others, 199) + [label]
    random.shuffle(opts)

    # 构造 prompt
    opts_str    = ' , '.join(f"{i+1}. {opt}" for i, opt in enumerate(opts))
    prompt_text = (
        "<image>\n"
        "What is the scene class of the image? Choose one from below. "
        f"{opts_str}"
    )
    # 验证集答案也保留 编号 + 内容
    answer_text = f"{opts.index(label) + 1}. {label}"

    category = "AID" if is_aid_label(label) else "Places"

    val_entry = {
        "question_id": uid,
        "image":       f"Scene_classification_v3/val/image/{orig_name}",
        "category":    category,
        "text":        prompt_text
    }
    val_ans_entry = {
        "question_id": uid,
        "prompt":      prompt_text,
        "text":        answer_text,
        "answer_id":   None,
        "model_id":    None,
        "metadata":    {}
    }
    val_pairs.append((val_entry, val_ans_entry))

# 打乱并写出
random.shuffle(val_pairs)
val_data, val_ans_data = zip(*val_pairs)

with open(val_json, 'w', encoding='utf-8') as f:
    json.dump(list(val_data), f, indent=4, ensure_ascii=False)
with open(val_ans_json, 'w', encoding='utf-8') as f:
    json.dump(list(val_ans_data), f, indent=4, ensure_ascii=False)

print(f"✅ Val JSON written to {val_json}")
print(f"✅ Val Answer JSON written to {val_ans_json}")
