#!/usr/bin/env python

import os
import argparse
from huggingface_hub import hf_hub_download


def get(fname, dataset):
    if dataset == 'fineweb':
        local_dir = os.path.join(os.environ['MTP_ROOT'], 'data', 'fineweb10B')
        if not os.path.exists(os.path.join(local_dir, fname)):
            hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname,
                            repo_type="dataset", local_dir=local_dir)
    elif dataset == 'fineweb-edu':
        local_dir = os.path.join(os.environ['MTP_ROOT'], 'data', 'finewebedu10B')
        if not os.path.exists(os.path.join(local_dir, fname)):
            hf_hub_download(repo_id="kjj0/finewebedu10B-gpt2", filename=fname,
                            repo_type="dataset", local_dir=local_dir)
    else:
        raise ValueError('Unknown dataset: %s' % dataset)


# Download the GPT-2 tokens of Fineweb10B from huggingface. This
# saves about an hour of startup time compared to regenerating them.
if __name__ == "__main__":

    parser = argparse.ArgumentParser('Download datasets')
    parser.add_argument('--numchunks', type=int, default=103)
    parser.add_argument('--dataset', type=str,
                        choices=['fineweb', 'fineweb-edu'])

    args = parser.parse_args()

    if args.dataset == 'fineweb':
        get("fineweb_val_%06d.bin" % 0, args.dataset)
        for i in range(1, args.numchunks + 1):
            get("fineweb_train_%06d.bin" % i, args.dataset)
    elif args.dataset == 'fineweb-edu':
        get("finewebedu_val_%06d.bin" % 0, args.dataset)
        num_chunks = min(args.numchunks, 99)
        for i in range(1, num_chunks+1):
            get("finewebedu_train_%06d.bin" % i, args.dataset)
