import json
import os
import argparse
import random

from tqdm import tqdm
import pdb


def create_wikisquad(args):
    dpr_wiki = json.load(open(args.input_file))
    data = {'data': []}

    for title, passages in tqdm(dpr_wiki.items()):
        article = {'title': title, 'paragraphs': []}
        for passage in passages:
            article['paragraphs'].append({'context': passage['text']})
            
        data['data'].append(article)
    
    # Shuffle for equal bins
    # import pdb; pdb.set_trace()
    random.shuffle(data['data'])

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    for start_idx in range(0, len(data['data']), args.docs_per_file):
        to_path = os.path.join(args.output_dir, str(int(start_idx / args.docs_per_file)).zfill(4))
        cur_data = {'data': data['data'][start_idx:start_idx + args.docs_per_file]}
        with open(to_path, 'w') as fp:
            json.dump(cur_data, fp)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file')
    parser.add_argument('--output_dir')
    parser.add_argument('--docs_per_file', default=1000, type=int)

    return parser.parse_args()


def main():
    args = get_args()
    create_wikisquad(args)

if __name__ == '__main__':
    main()
