import json

def convert_train_file(ann_file, out_file):
    with open(ann_file, 'r') as f:
        ann = json.load(f)

    n = 0
    image_id2n = {}
    for i, a in enumerate(ann):
        image_id = a['image_id']
        if image_id2n.get(image_id) is None:
            image_id2n[image_id] = 1
        else:
            image_id2n[image_id] += 1
        if image_id2n[image_id] > args.caps_k:
            continue

        if len(a["caption"]) == 0:
            print(a)

        # Important
        caption = a["caption"].replace("\n", "")

        n += 1
        with open(out_file, 'a') as f:
            f.write(f"{caption}\n")
        # print(n, caption)
        # if n == 50:
        #     afdgdaf
        
        if i % 1000 == 0:
            print(f"Processed {i} images")

    print(f"Total number of captions: {n}")

if __name__=="__main__":
    # train_file = "/data/dataset/dataset_json/data/flickr30k_train.json"
    # out_file = "/data/dataset/dataset_json/data_rewrite/source_flickr30k_train.json"
    # convert_train_file(train_file, out_file)
    
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ann_file", type=str, required=True)
    parser.add_argument("--out_file", type=str, required=True)
    parser.add_argument("--caps_k", type=int, default=5)
    # parser.add_argument("--overwrite", action="store_true")
    args = parser.parse_args()

    if os.path.exists(args.out_file):
        print("Output file already exists. Exiting...")
    else:
        convert_train_file(args.ann_file, args.out_file)