import json
import os

original_path = "/data/dataset/dataset_json/data/flickr30k_train.json"
out_path = "/data/dataset/dataset_json/data_idx/flickr30k_train_idx{}.json"

os.makedirs(os.path.dirname(out_path), exist_ok=True)

with open(original_path, 'r') as f:
    ann = json.load(f)


image_id2anns = {}
for a in ann:
    image_id = a['image_id']
    if image_id not in image_id2anns:
        image_id2anns[image_id] = []
    image_id2anns[image_id].append(a)


idx_0_anns = []
idx_1_anns = []
idx_2_anns = []
idx_3_anns = []
idx_4_anns = []

for image_id, anns in image_id2anns.items():
    
    idx_0_anns.append(anns[0])
    idx_1_anns.append(anns[1])
    idx_2_anns.append(anns[2])
    idx_3_anns.append(anns[3])
    idx_4_anns.append(anns[4])

print(len(idx_0_anns))
print(len(idx_1_anns))
print(len(idx_2_anns))
print(len(idx_3_anns))
print(len(idx_4_anns))

save_path = out_path.format(0)
with open(save_path, 'w') as f:
    json.dump(idx_0_anns, f, indent=4)

save_path = out_path.format(1)
with open(save_path, 'w') as f:
    json.dump(idx_1_anns, f, indent=4)

save_path = out_path.format(2)
with open(save_path, 'w') as f:
    json.dump(idx_2_anns, f, indent=4)

save_path = out_path.format(3)
with open(save_path, 'w') as f:
    json.dump(idx_3_anns, f, indent=4)

save_path = out_path.format(4)
with open(save_path, 'w') as f:
    json.dump(idx_4_anns, f, indent=4)