# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for data_utils."""

import tensorflow as tf

from mixed_fl.experiments.next_char_prediction import data_utils


def _get_text_sequences():
  return tf.constant([
      'A good plan, violently executed now, is better than a perfect plan ' +
      'next week. Never tell people how to do things. Tell them what to do ' +
      'and they will surprise you with their ingenuity.'])


def _get_raw_dataset():
  return tf.data.Dataset.from_tensor_slices(_get_text_sequences())


class DataUtilsTest(tf.test.TestCase):

  def test_preprocess_text_dataset_outputs_expected_shapes_and_types(self):
    raw_dataset = _get_raw_dataset()

    processed_dataset = data_utils.preprocess_text_dataset(raw_dataset)
    self.assertLen(list(processed_dataset.as_numpy_iterator()), 1)

    input_id_sequence, target_id_sequence = next(iter(processed_dataset))
    # Check that shape and type of output sequences is correct.
    self.assertDTypeEqual(input_id_sequence, tf.int64)
    self.assertEqual(input_id_sequence.shape, [1, data_utils.SEQ_LENGTH])
    self.assertDTypeEqual(target_id_sequence, tf.int64)
    self.assertEqual(target_id_sequence.shape, [1, data_utils.SEQ_LENGTH])

  def test_preprocess_text_dataset_output_sequences_match(self):
    raw_dataset = _get_raw_dataset()

    processed_dataset = data_utils.preprocess_text_dataset(raw_dataset)

    input_id_sequence, target_id_sequence = next(iter(processed_dataset))
    # Check that the last n-1 elements of the input sequence match the first n-1
    # elements of the target (label) sequence.
    input_id_sequence = input_id_sequence.numpy()[0].tolist()
    target_id_sequence = target_id_sequence.numpy()[0].tolist()
    self.assertSequenceEqual(input_id_sequence[1:], target_id_sequence[:-1])

if __name__ == '__main__':
  tf.test.main()
