import os
import json
import random
import shutil
import shortuuid
from tqdm import tqdm

# Paths
metadata_path = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/output/merged_data_v4_metadata.json'
src_root      = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset_v4'
train_json    = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/train.json'
train_img_dir = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/image'
val_json      = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val.json'
val_ans_json  = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val_ans.json'
val_img_dir   = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/image'

# Load metadata
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

# Build grouping by attribute and by dataset
by_attr = {}
by_dataset = {}
for entry in metadata:
    attr = entry['texture attribute']
    ds   = entry['base_dataset']
    by_attr.setdefault(attr, []).append(entry)
    by_dataset.setdefault(ds, set()).add(attr)

# Split 80/20 per attribute
train_entries = []
val_entries   = []
for attr, entries in by_attr.items():
    random.shuffle(entries)
    k = int(len(entries)*0.8)
    train_entries.extend(entries[:k])
    val_entries.extend(entries[k:])

# Ensure dirs
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)

# Build train.json
train_data = []
for entry in tqdm(train_entries, desc='Building train samples'):
    orig    = entry['image']
    attr    = entry['texture attribute']
    base_ds = entry['base_dataset']
    uid     = shortuuid.uuid()

    # options = all attributes of this original dataset
    options = list(by_dataset[base_ds])
    random.shuffle(options)

    # format prompt
    prompt_opts = ', '.join(f"{i+1}. {opt}" for i, opt in enumerate(options))
    human_value = "<image>\nWhat is the texture attribute of the image? Choose one from below. " + prompt_opts
    correct_idx = options.index(attr) + 1
    gpt_value   = f"{correct_idx}. {attr}"

    train_data.append({
        "id": uid,
        "image": f"Texture/train/image/{uid}.jpg",
        "conversations": [
            {"from": "human", "value": human_value},
            {"from": "gpt",   "value": gpt_value}
        ]
    })

    # copy image
    src = os.path.join(src_root, attr, orig)
    dst = os.path.join(train_img_dir, f"{uid}.jpg")
    shutil.copy2(src, dst)

# shuffle & write train.json
random.shuffle(train_data)
os.makedirs(os.path.dirname(train_json), exist_ok=True)
with open(train_json, 'w') as f:
    json.dump(train_data, f, indent=2)
print(f"Training samples: {len(train_data)}")

# Build val.json and val_ans.json
val_records = []
for entry in tqdm(val_entries, desc='Preparing val samples'):
    orig    = entry['image']
    attr    = entry['texture attribute']
    base_ds = entry['base_dataset']
    qid     = shortuuid.uuid()

    options = list(by_dataset[base_ds])
    random.shuffle(options)
    prompt_opts = ', '.join(f"{i+1}. {opt}" for i, opt in enumerate(options))
    human_text  = "What is the texture attribute of the image? Choose one from below. " + prompt_opts
    correct_idx = options.index(attr) + 1
    gpt_text    = f"{correct_idx}. {attr}"

    val_records.append({
        "qid":       qid,
        "orig":      orig,
        "attribute": attr,
        "prompt":    human_text,
        "answer":    gpt_text
    })

# shuffle val_records
random.shuffle(val_records)

# write val.json, val_ans.json and copy images
val_list = []
val_ans_list = []
for rec in val_records:
    val_list.append({
        "question_id": rec["qid"],
        "image":       rec["orig"],
        "category":    "default",
        "text":        rec["prompt"]
    })
    val_ans_list.append({
        "question_id": rec["qid"],
        "prompt":      rec["prompt"],
        "text":        rec["answer"],
        "answer_id":   None,
        "model_id":    None,
        "metadata":    {}
    })

    src = os.path.join(src_root, rec["attribute"], rec["orig"])
    dst = os.path.join(val_img_dir, rec["orig"])
    shutil.copy2(src, dst)

os.makedirs(os.path.dirname(val_json), exist_ok=True)
with open(val_json,    'w') as f: json.dump(val_list,    f, indent=2)
with open(val_ans_json,'w') as f: json.dump(val_ans_list,f, indent=2)
print(f"Validation samples: {len(val_list)}")
