import json

def convert_lines2train_file(lines_file, orig_train_file, out_train_file):
    lines = []
    with open(lines_file, 'r') as f:
        for line in f:
            lines.append(line.strip())

    print(len(lines))

    with open(orig_train_file, 'r') as f:
        ann = json.load(f)

    ann_new = []
    for i, a in enumerate(ann):
        a['caption'] = lines[i]
        ann_new.append(a)

    with open(out_train_file, 'w') as f:
        json.dump(ann_new, f, indent=4)

if __name__=="__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--lines_file", type=str, required=True)
    parser.add_argument("--orig_train_file", type=str, required=True)
    parser.add_argument("--out_train_file", type=str, required=True)
    args = parser.parse_args()
    convert_lines2train_file(args.lines_file, args.orig_train_file, args.out_train_file)