import argparse
import sentencepiece as spm

# https://github.com/google/sentencepiece/blob/master/doc/options.md

LIBRI_LM = "/data/librispeech/librispeech-lm-norm.txt"
LIBRI_TRANS = "/data/librispeech/librispeech-lm-train-cleaned.txt"


def generate(args):
    spm.SentencePieceTrainer.Train(
        input=args.corpus,
        model_type=args.model_type,
        model_prefix=args.prefix,
        vocab_size=args.vocab_size,  # including <eos>, <unk>, <s>, and </s>.
        input_sentence_size=args.num_input_sentences,
        max_sentence_length=args.max_sentence_length,
        max_sentencepiece_length=args.max_sentencepiece_length,
        num_threads=args.num_threads,
        character_coverage=args.character_coverage,
        unk_id=3,
        unk_piece="<unk>",
        bos_id=1,
        bos_piece="<s>",
        eos_id=2,
        eos_piece="</s>",
        pad_id=0,  # blank and pad
        pad_piece="<b>"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--corpus", type=str, help="Input corpus path")
    parser.add_argument("--vocab_size", type=int, help="Vocabulary size")
    parser.add_argument("--prefix", type=str, help="Model prefix")
    parser.add_argument("--model_type", default="unigram", type=str, help="SentencePiece type")
    parser.add_argument("--character_coverage", default=1.0, type=float, help="Coverage to determine minimum symbols")
    parser.add_argument("--num_input_sentences", default=0, type=int, help="Number of input sentence to process, all=0")
    parser.add_argument("--max_sentence_length", default=16384, type=int, help="Max sentence length in bytes")
    parser.add_argument("--max_sentencepiece_length", default=16, type=int, help="Max sentence-piece length")
    parser.add_argument("--num_threads", default=8, type=int, help="Number of threads")
    cfg = parser.parse_args()

    generate(cfg)
