#!/usr/bin/env python3
import os
import random
import json
import shutil
import shortuuid
from tqdm import tqdm

# —— 配置 —— #
src_root         = '/fs/scratch/PAS2099/Jiacheng/Places_merge/output/balanced_data_v3'
train_json       = '/fs/scratch/PAS2099/Jiacheng/Scene_classification/train/train.json'
train_image_dir  = '/fs/scratch/PAS2099/Jiacheng/Scene_classification/train/image'
val_json         = '/fs/scratch/PAS2099/Jiacheng/Scene_classification/val/val.json'
val_ans_json     = '/fs/scratch/PAS2099/Jiacheng/Scene_classification/val/val_ans.json'
val_image_dir    = '/fs/scratch/PAS2099/Jiacheng/Scene_classification/val/image'

# 创建目录
os.makedirs(train_image_dir, exist_ok=True)
os.makedirs(val_image_dir,   exist_ok=True)

# 1. 收集标签
labels             = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
remote_labels      = [l for l in labels if l.startswith('remote sensing')]
non_remote_labels  = [l for l in labels if not l.startswith('remote sensing')]

# 2. 收集所有图片路径
remote_entries     = []
non_remote_entries = []
for label in labels:
    for fn in os.listdir(os.path.join(src_root, label)):
        if fn.lower().endswith(('.jpg', '.jpeg', '.png')):
            full_path = os.path.join(src_root, label, fn)
            if label.startswith('remote sensing'):
                remote_entries.append((label, full_path))
            else:
                non_remote_entries.append((label, full_path))

# 3. 划分训练/测试集 (80% / 20%)
random.seed(42)
n_rt = int(0.8 * len(remote_entries))
n_nt = int(0.8 * len(non_remote_entries))

train_remote      = random.sample(remote_entries, n_rt)
test_remote       = [e for e in remote_entries     if e not in train_remote]
train_non_remote  = random.sample(non_remote_entries, n_nt)
test_non_remote   = [e for e in non_remote_entries if e not in train_non_remote]

train_entries = train_remote + train_non_remote
test_entries  = test_remote  + test_non_remote

# 混合顺序
random.shuffle(train_entries)
random.shuffle(test_entries)

print(f"Total training images: {len(train_entries)}")
print(f"Total validation images: {len(test_entries)}")

# 4. 生成 train.json
train_data = []
for label, img_path in tqdm(train_entries, desc='Building train.json'):
    uid = shortuuid.uuid()
    # 复制并重命名图片
    dst_img = os.path.join(train_image_dir, uid + '.jpg')
    shutil.copy(img_path, dst_img)

    # 选项
    if label.startswith('remote sensing'):
        opts = list(remote_labels)
    else:
        others = [l for l in non_remote_labels if l != label]
        opts = random.sample(others, 199) + [label]
    random.shuffle(opts)
    opts_str = ' , '.join(f"{i+1}. {opt}" for i, opt in enumerate(opts))

    human_q = (
        "<image>\n"
        "What is the scene class of the image? Choose one from below. "
        f"{opts_str}"
    )
    correct_idx = opts.index(label) + 1
    gpt_a = f"{correct_idx}. {label}"

    train_data.append({
        "id": uid,
        "image": f"Scene_classification/train/image/{uid}.jpg",
        "conversations": [
            {"from": "human", "value": human_q},
            {"from": "gpt",   "value": gpt_a}
        ]
    })

with open(train_json, 'w', encoding='utf-8') as f:
    json.dump(train_data, f, indent=4, ensure_ascii=False)
print(f"✅ Train JSON written to {train_json}")
# 5. 生成 val.json & val_ans.json（修改版）
val_pairs = []
for label, img_path in tqdm(test_entries, desc='Building val.json'):
    uid = shortuuid.uuid()
    orig_name = os.path.basename(img_path)
    # 复制图片到 val 文件夹，保留原始文件名
    dst_img = os.path.join(val_image_dir, orig_name)
    shutil.copy(img_path, dst_img)

    # 选项逻辑同 train
    if label.startswith('remote sensing'):
        opts = list(remote_labels)
    else:
        others = [l for l in non_remote_labels if l != label]
        opts = random.sample(others, 199) + [label]
    random.shuffle(opts)
    opts_str = ' , '.join(f"{i+1}. {opt}" for i, opt in enumerate(opts))

    prompt_text = (
        "<image>\n"
        "What is the scene class of the image? Choose one from below. "
        f"{opts_str}"
    )
    answer_text = f"{opts.index(label) + 1}. {label}"

    # category 根据标签前缀决定
    category = "AID" if label.startswith('remote sensing') else "Places"

    val_entry = {
        "question_id": uid,
        "image": f"Scene_classification/val/image/{orig_name}",
        "category": category,
        "text": prompt_text
    }
    val_ans_entry = {
        "question_id": uid,
        "prompt": prompt_text,
        "text": answer_text,
        "answer_id": None,
        "model_id": None,
        "metadata": {}
    }
    val_pairs.append((val_entry, val_ans_entry))

# 打乱并拆分
random.shuffle(val_pairs)
val_data, val_ans_data = zip(*val_pairs)

with open(val_json, 'w', encoding='utf-8') as f:
    json.dump(list(val_data), f, indent=4, ensure_ascii=False)
with open(val_ans_json, 'w', encoding='utf-8') as f:
    json.dump(list(val_ans_data), f, indent=4, ensure_ascii=False)

print(f"✅ Val JSON written to {val_json}")
print(f"✅ Val Answer JSON written to {val_ans_json}")

