import os
import json
import random
import shutil
import shortuuid
from tqdm import tqdm

# 路径配置
metadata_path = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/output/merged_data_v4_metadata.json'
src_root      = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset_v4'
train_json    = '/fs/scratch/PAS2099/Jiacheng/Texture/train/train.json'
train_img_dir = '/fs/scratch/PAS2099/Jiacheng/Texture/train/image'
val_json      = '/fs/scratch/PAS2099/Jiacheng/Texture/val/val.json'
val_ans_json  = '/fs/scratch/PAS2099/Jiacheng/Texture/val/val_ans.json'
val_img_dir   = '/fs/scratch/PAS2099/Jiacheng/Texture/val/image'

# 加载所有元数据
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

# 按 attribute 分组，并收集所有 attribute
by_attr = {}
all_attrs = set()
for entry in metadata:
    attr = entry['texture attribute']
    by_attr.setdefault(attr, []).append(entry)
    all_attrs.add(attr)
all_attrs = list(all_attrs)

# 按每个 attribute 做 80/20 划分
train_entries = []
val_entries   = []
for attr, entries in by_attr.items():
    total = len(entries)
    k = int(total * 0.8)
    random.shuffle(entries)
    train_entries.extend(entries[:k])
    val_entries.extend(entries[k:])

# 创建输出目录
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)

# 构建 train.json
train_data = []
for entry in tqdm(train_entries, desc='构建训练样本'):
    orig = entry['image']
    attr = entry['texture attribute']
    uid  = shortuuid.uuid()

    # 随机选 9 个干扰项 + 正确项
    distractors = random.sample([a for a in all_attrs if a != attr], 9)
    options = distractors + [attr]
    random.shuffle(options)
    prompt_opts = ', '.join(f"{i+1}. {opt}" for i, opt in enumerate(options))
    human_value  = "<image>\nWhat is the texture attribute of the image? Choose one from below. " + prompt_opts
    correct_idx  = options.index(attr) + 1
    gpt_value    = f"{correct_idx}. {attr}"

    train_data.append({
        "id": uid,
        "image": f"Texture/train/image/{uid}.jpg",
        "conversations": [
            {"from": "human", "value": human_value},
            {"from": "gpt",   "value": gpt_value}
        ]
    })

    # 复制图片到训练目录
    src_path = os.path.join(src_root, attr, orig)
    dst_path = os.path.join(train_img_dir, f"{uid}.jpg")
    shutil.copy2(src_path, dst_path)

# 打乱并写入 train.json
random.shuffle(train_data)
os.makedirs(os.path.dirname(train_json), exist_ok=True)
with open(train_json, 'w') as f:
    json.dump(train_data, f, indent=2)
print(f"训练集样本数: {len(train_data)}")

# 构建 val.json 和 val_ans.json
val_records = []
for entry in tqdm(val_entries, desc='准备验证样本'):
    orig = entry['image']
    attr = entry['texture attribute']
    qid  = shortuuid.uuid()

    distractors = random.sample([a for a in all_attrs if a != attr], 9)
    options     = distractors + [attr]
    random.shuffle(options)
    prompt_opts = ', '.join(f"{i+1}. {opt}" for i, opt in enumerate(options))
    human_text  = "What is the texture attribute of the image? Choose one from below. " + prompt_opts
    correct_idx = options.index(attr) + 1
    gpt_text    = f"{correct_idx}. {attr}"

    val_records.append({
        "qid":       qid,
        "orig":      orig,
        "attribute": attr,
        "prompt":    human_text,
        "answer":    gpt_text
    })

# 打乱
random.shuffle(val_records)

# 写入 val.json 和 val_ans.json，并复制图片
val_list = []
val_ans_list = []
for rec in val_records:
    val_list.append({
        "question_id": rec["qid"],
        "image":       rec["orig"],
        "category":    "default",
        "text":        rec["prompt"]
    })
    val_ans_list.append({
        "question_id": rec["qid"],
        "prompt":       rec["prompt"],
        "text":         rec["answer"],
        "answer_id":    None,
        "model_id":     None,
        "metadata":     {}
    })
    # 复制验证集图片到 val/image，保留原名
    src_path = os.path.join(src_root, rec["attribute"], rec["orig"])
    dst_path = os.path.join(val_img_dir, rec["orig"])
    shutil.copy2(src_path, dst_path)

os.makedirs(os.path.dirname(val_json), exist_ok=True)
with open(val_json, 'w') as f:
    json.dump(val_list, f, indent=2)
with open(val_ans_json, 'w') as f:
    json.dump(val_ans_list, f, indent=2)
print(f"验证集样本数: {len(val_list)}")
