# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data utilities for working w/ text data (StackOverflow and Wikipedia)."""

import tensorflow as tf

SEQ_LENGTH = 100
DEFAULT_BATCH_SIZE = 8
SHUFFLE_BUFFER = 10000

# Vocabulary of ASCII chars (these happen to occur in Shakespeare's works).
VOCAB = list(
    'dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAE'
    + 'IMQUY]!%)-159\r')


def preprocess_text_dataset(text_ds,
                            batch_size = DEFAULT_BATCH_SIZE,
                            shuffle = True,
                            num_epochs = 1):
  """Returns a preprocessed dataset.

  Args:
    text_ds: Raw text dataset, to be processed.
    batch_size: Batch size of output dataset. If None, don't batch.
    shuffle: If True, shuffle the dataset.
    num_epochs: The number of epochs to repeat the raw dataset in the processed
      dataset.

  Returns:
    A preprocessed, batched, and possibly shuffled/repeated dataset of sequences
    of characters.
  """

  # Construct a lookup table to map string chars to indexes,
  # using the vocab loaded above:
  table = tf.lookup.StaticHashTable(
      tf.lookup.KeyValueTensorInitializer(
          keys=VOCAB, values=tf.constant(list(range(len(VOCAB))),
                                         dtype=tf.int64)),
      default_value=0)

  def _to_ids(x):
    s = tf.reshape(x, shape=[1])
    chars = tf.strings.bytes_split(s).values
    ids = table.lookup(chars)
    return ids

  def _split_input_target(chunk):
    input_text = tf.map_fn(lambda x: x[:-1], chunk)
    target_text = tf.map_fn(lambda x: x[1:], chunk)
    return (input_text, target_text)

  text_ds = (
      # Map ASCII chars to int64 indexes using the vocab
      text_ds.map(_to_ids)
      # Split into individual chars
      .unbatch()
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1, drop_remainder=True))
  if shuffle:
    # Shuffle and form minibatches
    text_ds = text_ds.shuffle(SHUFFLE_BUFFER)
  text_ds = text_ds.batch(batch_size, drop_remainder=False)
  # And finally split into (input, target) tuples,
  # each of length SEQ_LENGTH.
  text_ds = text_ds.map(_split_input_target)

  # Shuffle and batch come before repeat so we shuffle and batch within each
  # epoch, but process complete epochs before repeating.
  text_ds = text_ds.repeat(num_epochs)

  # Note: .prefetch is an optimization which will begin preparing later elements
  # while current elements are being processed. It consumes more memory but
  # should save time. See documentation at:
  # https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch
  return text_ds.prefetch(tf.data.experimental.AUTOTUNE)
