# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for tokenization logic with external dependency on T5X tokenizer.

This test file should be kept minimal, as tests with this external dependency
are non-hermetic and take longer to run.
"""

from absl.testing import absltest

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import tokenization


class TokenizationWithT5XTokenizerTest(absltest.TestCase):

  def test_get_input_output_length_compact(self):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(request='jump twice', reply='JUMP JUMP'),
        cl.Example(request='[x1 twice] = [x1] [x1]', reply=cl.RuleReply.TRUE)
    ])
    example = cl.Example(
        context=context,
        request='[x1 thrice] = [x1]',
        reply=cl.RuleReply.UNKNOWN)

    with self.subTest('compact_input_string'):
      # This method is already tested in "tokenization_test.py", but we are
      # repeating it here so as to make it clearer why the compact input length
      # becomes what it does.
      self.assertEqual(
          '[x1 thrice] = [x1]\n'
          'jump twice JUMP JUMP\n'
          '[x1 twice] = [x1] [x1]',
          tokenization.get_input_string(
              example, tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('compact_input_length_uses_t5x_tokenization'):
      # We provide a range of possible values, as we don't want to make too many
      # assumptions about how fine-grained T5X's SentencePiece tokenizer is.
      # If the input length is over 40, however, it could be a sign that this
      # method was accidentally tokenizing the STANDARD string representation
      # rather than the shorter COMPACT one.
      self.assertBetween(
          tokenization.get_input_length(
              example, tokenization.ExampleStringFormat.COMPACT), 20, 40)

    with self.subTest('compact_output_string'):
      # This method is already tested in "tokenization_test.py", but we are
      # repeating it here so as to make it clearer why the compact output length
      # becomes what it does.
      self.assertEqual(
          'Unknown (Reasoning: Monotonic)',
          tokenization.get_output_string(
              example, tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('compact_output_length_uses_t5x_tokenization'):
      # We provide a range of possible values, as we don't want to make too many
      # assumptions about how fine-grained T5X's SentencePiece tokenizer is.
      # If the output length is under 6, however, it could be a sign that this
      # method was accidentally tokenizing the STANDARD string representation
      # rather than the COMPACT one (which in this case is actually longer than
      # the STANDARD representation due to the extra "(Reasoning: ...)" tokens.
      self.assertBetween(
          tokenization.get_output_length(
              example, tokenization.ExampleStringFormat.COMPACT), 6, 18)

  def test_calling_get_tokenized_length_repeatedly_does_not_time_out(self):
    # This verifies that the T5X tokenizer is getting cached properly.
    # (Otherwise, every call to get_input_length or get_output_length would
    # load a new tokenizer, which is quite time consuming.)
    sum_of_lengths = 0
    num_iterations = 1000
    for i in range(num_iterations):
      sum_of_lengths += tokenization.get_tokenized_length(
          str(i), tokenization.ExampleStringFormat.COMPACT)
    # The below assertion basically just ensures that the above code was
    # actually run, as opposed to being optimized away by a clever compiler.
    # All we really care about is that the test completes without timing out.
    self.assertGreaterEqual(sum_of_lengths, num_iterations)


if __name__ == '__main__':
  absltest.main()
