# %%
import json
import random

# === File paths ===
TRAIN_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/train/train.json"
VAL_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/val/val.json"
VAL_ANS_JSON = "/fs/scratch/PAS2099/Lemeng/DatasetResult/Spatial_bbox/val/val_ans.json"

SEED = 42
random.seed(SEED)

# === Shuffle train.json ===
with open(TRAIN_JSON, "r") as f:
    train_data = json.load(f)

random.shuffle(train_data)

with open(TRAIN_JSON, "w") as f:
    json.dump(train_data, f, indent=2)

print(f"✅ Shuffled and saved: train.json ({len(train_data)} entries)")

# === Shuffle val.json and val_ans.json together ===
with open(VAL_JSON, "r") as f:
    val_data = json.load(f)

with open(VAL_ANS_JSON, "r") as f:
    val_ans_data = json.load(f)

# Pair them by question_id
val_map = {entry["question_id"]: entry for entry in val_data}
val_ans_map = {entry["question_id"]: entry for entry in val_ans_data}

# Check consistency
assert set(val_map.keys()) == set(val_ans_map.keys()), "❌ Mismatched question_ids!"

# Shuffle as pairs
paired = list(val_map.items())
random.shuffle(paired)

shuffled_val = [v for _, v in paired]
shuffled_val_ans = [val_ans_map[qid] for qid, _ in paired]

# Save back
with open(VAL_JSON, "w") as f:
    json.dump(shuffled_val, f, indent=2)

with open(VAL_ANS_JSON, "w") as f:
    json.dump(shuffled_val_ans, f, indent=2)

print(f"✅ Shuffled and saved: val.json + val_ans.json ({len(shuffled_val)} pairs)")
