"""Makes a sentencepiece model for a math dataset."""
import io
import json
import os
from typing import List

from absl import app
from absl import flags
from absl import logging

import sentencepiece as spm
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import T5Tokenizer

from em.datasets import math_dataset
# from em.datasets import vocabs

FLAGS = flags.FLAGS

flags.DEFINE_string('output_path', None, 'Path to .model file where trained model will be saved.')
flags.DEFINE_string('dataset', None, '')

flags.DEFINE_integer("n_examples", None, '')
flags.DEFINE_integer("vocab_size", None, '')


def _is_example_true_false(x):
    return (x['answer'] == 'True') | (x['answer'] == 'False')


def get_examples() -> List[str]:
    assert FLAGS.dataset == 'og_true_false', 'TODO: Support other datasets'

    ds = tf.data.experimental.sample_from_datasets([
        tfds.load(f'math_dataset/{c}', split='train').filter(_is_example_true_false)
        for c in math_dataset.TRUE_FALSE_CONFIGS
    ])
    ds = ds.take(FLAGS.n_examples)
    return [
        tf.compat.as_str(x['question'])
        for x in ds.as_numpy_iterator()
    ]


def main(_):
    assert FLAGS.output_path.endswith('.model'), 'TODO: Fix this, simple assumption for making HF directory.'

    model_writer = io.BytesIO()

    spm.SentencePieceTrainer.train(
        sentence_iterator=iter(get_examples()),
        model_writer=model_writer,
        vocab_size=FLAGS.vocab_size,
        split_digits=True,
    )

    output_path = os.path.expanduser(FLAGS.output_path)
    with open(output_path, 'wb') as f:
        f.write(model_writer.getvalue())

    # We use the T5Tokenizer since it is essentially a wrapper around
    # a SentencePiece tokenizer.
    tokenizer = T5Tokenizer(output_path, extra_ids=0)
    hf_pretrained_dir = output_path[:-len('.model')]
    tokenizer.save_pretrained(hf_pretrained_dir)

    # Dear Huggingface, why do we have to create a dummy config.json
    # in order for AutoTokenizer.from_pretrained to work?
    with open(os.path.join(hf_pretrained_dir, 'config.json'), 'w') as f:
        json.dump({
            'model_type': 't5',
        }, f)


if __name__ == "__main__":
    app.run(main)
