import json

def convert_train_file(ann_file, out_file):
    with open(ann_file, 'r') as f:
        ann = json.load(f)

    img_id2ann = {}
    for i, a in enumerate(ann):
        if i % 1000 == 0:
            print(f"Processed {i} images")
        
        image_id = a['image_id']
        if image_id not in img_id2ann:
            img_id2ann[image_id] = []
        img_id2ann[image_id].append(a)
    
    caps_top_1 = [
        anns[0] for img_id, anns in img_id2ann.items()
    ]
    caps_top_2 = [
        ann
        for img_id, anns in img_id2ann.items()
        for ann in anns[:2] 
    ]
    caps_top_3 = [
        ann
        for img_id, anns in img_id2ann.items()
        for ann in anns[:3] 
    ]
    caps_top_4 = [
        ann
        for img_id, anns in img_id2ann.items()
        for ann in anns[:4] 
    ]

    out_file_top_1 = out_file.replace(".json", "_caps=1.json")
    with open(out_file_top_1, 'w') as f:
        json.dump(caps_top_1, f)

    out_file_top_2 = out_file.replace(".json", "_caps=2.json")
    with open(out_file_top_2, 'w') as f:
        json.dump(caps_top_2, f)

    out_file_top_3 = out_file.replace(".json", "_caps=3.json")
    with open(out_file_top_3, 'w') as f:
        json.dump(caps_top_3, f)

    out_file_top_4 = out_file.replace(".json", "_caps=4.json")
    with open(out_file_top_4, 'w') as f:
        json.dump(caps_top_4, f)


if __name__=="__main__":
    # train_file = "/data/dataset/dataset_json/data/flickr30k_train.json"
    # out_file = "/data/dataset/dataset_json/data_rewrite/source_flickr30k_train.json"
    # convert_train_file(train_file, out_file)
    
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ann_file", type=str, required=True)
    parser.add_argument("--out_file", type=str, required=True)
    args = parser.parse_args()

    if os.path.exists(args.out_file):
        print("Output file already exists. Exiting...")
    else:
        convert_train_file(args.ann_file, args.out_file)