# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.cscan import conceptual_learning as cl
from conceptual_learning.cscan import tokenization


class TokenizationTest(parameterized.TestCase):

  def test_get_request_string_compact_left_unchanged(self):
    string_format = tokenization.ExampleStringFormat.COMPACT
    with self.subTest('rule_request'):
      self.assertEqual(
          '[x1 twice] = [x1] [x1]',
          tokenization.get_request_string('[x1 twice] = [x1] [x1]',
                                          string_format))
    with self.subTest('non_rule_request'):
      self.assertEqual(
          'jump twice',
          tokenization.get_request_string('jump twice', string_format))

  @parameterized.named_parameters(
      ('with_structure_tokens', tokenization.ExampleStringFormat.STANDARD),
      ('no_structure_tokens',
       tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS),
  )
  def test_get_request_string_standard(self, string_format):
    # Request string is not affected by structure tokens.
    with self.subTest('rule_request_has_spaces_added_around_punctuation'):
      self.assertEqual(
          '[ x1 twice ]   =   [ x1 ]   [ x1 ]',
          tokenization.get_request_string('[x1 twice] = [x1] [x1]',
                                          string_format))
    with self.subTest('non_rule_request_left_unchanged'):
      self.assertEqual(
          'jump twice',
          tokenization.get_request_string('jump twice', string_format))

  @parameterized.named_parameters(
      ('rule_true', '1', 'True'),
      ('rule_false', '0', 'False'),
      ('rule_unknown', '?', 'Unknown'),
      ('non_rule_left_unchanged', 'JUMP JUMP', 'JUMP JUMP'),
  )
  def test_get_reply_string_compact(self, reply, expected_result):
    self.assertEqual(
        expected_result,
        tokenization.get_reply_string(reply,
                                      tokenization.ExampleStringFormat.COMPACT))

  @parameterized.named_parameters(
      ('rule_true', '1', '1'),
      ('rule_false', '0', '0'),
      ('rule_unknown', '?', '?'),
      ('non_rule', 'JUMP JUMP', 'JUMP JUMP'),
  )
  def test_get_reply_string_standard_left_unchanged(self, reply,
                                                    expected_result):
    # Reply string is not affected by structure tokens.
    with self.subTest('with_structure_tokens'):
      self.assertEqual(
          expected_result,
          tokenization.get_reply_string(
              reply, tokenization.ExampleStringFormat.STANDARD))
    with self.subTest('no_structure_tokens'):
      self.assertEqual(
          expected_result,
          tokenization.get_reply_string(
              reply,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  @parameterized.named_parameters(
      ('D', cl.Qualifier.D, 'Defeasible'),
      ('M', cl.Qualifier.M, 'Monotonic'),
  )
  def test_get_qualifier_string_compact(self, qualifier, expected_result):
    self.assertEqual(
        expected_result,
        tokenization.get_qualifier_string(
            qualifier, tokenization.ExampleStringFormat.COMPACT))

  @parameterized.named_parameters(
      ('D', cl.Qualifier.D, 'D'),
      ('M', cl.Qualifier.M, 'M'),
  )
  def test_get_qualifier_string_standard(self, qualifier, expected_result):
    # Qualifier string is not affected by structure tokens.
    with self.subTest('with_structure_tokens'):
      self.assertEqual(
          expected_result,
          tokenization.get_qualifier_string(
              qualifier, tokenization.ExampleStringFormat.STANDARD))
    with self.subTest('no_structure_tokens'):
      self.assertEqual(
          expected_result,
          tokenization.get_qualifier_string(
              qualifier,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  @parameterized.named_parameters(
      ('non_rule_request', cl.Example(
          request='jump twice', reply='JUMP JUMP'), 'jump twice JUMP JUMP'),
      ('rule_request',
       cl.Example(request='[x1 twice] = [x1] [x1]',
                  reply=cl.RuleReply.TRUE), '[x1 twice] = [x1] [x1]'),
  )
  def test_get_context_example_string_compact(self, example, expected_result):
    self.assertEqual(
        expected_result,
        tokenization.get_context_example_string(
            example, tokenization.ExampleStringFormat.COMPACT))

  @parameterized.named_parameters(
      ('non_rule_request', cl.Example(request='jump twice', reply='JUMP JUMP'),
       '<  {  }  ,  jump twice ,  JUMP JUMP ,  M >'),
      ('rule_request',
       cl.Example(request='[x1 twice] = [x1] [x1]', reply=cl.RuleReply.TRUE),
       '<  {  }  ,   [ x1 twice ]   =   [ x1 ]   [ x1 ]  ,  1 ,  M >'),
  )
  def test_get_context_example_string_standard(self, example, expected_result):
    self.assertEqual(
        expected_result,
        tokenization.get_context_example_string(
            example, tokenization.ExampleStringFormat.STANDARD))

  @parameterized.named_parameters(
      # Note how the comma is omitted in each of the cases below.
      ('non_rule_request', cl.Example(request='jump twice', reply='JUMP JUMP'),
       '<  {  }    jump twice   JUMP JUMP   M >'),
      ('rule_request',
       cl.Example(request='[x1 twice] = [x1] [x1]', reply=cl.RuleReply.TRUE),
       '<  {  }     [ x1 twice ]   =   [ x1 ]   [ x1 ]    1   M >'),
  )
  def test_get_context_example_string_standard_no_structure_tokens(
      self, example, expected_result):
    self.assertEqual(
        expected_result,
        tokenization.get_context_example_string(
            example,
            tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  def test_get_context_string(self):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(request='jump twice', reply='JUMP JUMP'),
        cl.Example(request='[x1 twice] = [x1] [x1]', reply=cl.RuleReply.TRUE)
    ])

    with self.subTest('compact'):
      self.assertEqual(
          'jump twice JUMP JUMP\n'
          '[x1 twice] = [x1] [x1]',
          tokenization.get_context_string(
              context, tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('standard'):
      self.assertEqual(
          '{  <  {  }  ,  jump twice ,  JUMP JUMP ,  M > \n'
          '  <  {  }  ,   [ x1 twice ]   =   [ x1 ]   [ x1 ]  ,  1 ,  M >  }',
          tokenization.get_context_string(
              context, tokenization.ExampleStringFormat.STANDARD))

    with self.subTest('standard_no_structure_tokens'):
      self.assertEqual(
          '  <  {  }    jump twice   JUMP JUMP   M > \n'
          '  <  {  }     [ x1 twice ]   =   [ x1 ]   [ x1 ]    1   M >  ',
          tokenization.get_context_string(
              context,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  def test_get_nested_context_string(self):
    # For now, the only case we care about is where the nested context is empty.
    context = cl.FrozenExampleSet()

    with self.subTest('compact'):
      self.assertEqual(
          '',
          tokenization.get_nested_context_string(
              context, tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('standard'):
      self.assertEqual(
          '{  }',
          tokenization.get_nested_context_string(
              context, tokenization.ExampleStringFormat.STANDARD))

    with self.subTest('standard_no_structure_tokens'):
      self.assertEqual(
          '{  }',
          tokenization.get_nested_context_string(
              context,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  def test_get_input_string_from_context_and_request_strings(self):
    context_string = 'context'
    request_string = 'request'

    with self.subTest('compact'):
      self.assertEqual(
          'request\ncontext',
          tokenization.get_input_string_from_context_and_request_strings(
              context_string, request_string,
              tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('standard'):
      self.assertEqual(
          'request context',
          tokenization.get_input_string_from_context_and_request_strings(
              context_string, request_string,
              tokenization.ExampleStringFormat.STANDARD))

    with self.subTest('standard_no_structure_tokens'):
      self.assertEqual(
          'request context',
          tokenization.get_input_string_from_context_and_request_strings(
              context_string, request_string,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  def test_get_output_string_from_reply_and_qualifier_strings(self):
    reply_string = 'reply'
    qualifier_string = 'qualifier'

    with self.subTest('compact'):
      self.assertEqual(
          'reply (Reasoning: qualifier)',
          tokenization.get_output_string_from_reply_and_qualifier_strings(
              reply_string, qualifier_string,
              tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('standard'):
      self.assertEqual(
          'qualifier reply',
          tokenization.get_output_string_from_reply_and_qualifier_strings(
              reply_string, qualifier_string,
              tokenization.ExampleStringFormat.STANDARD))

    with self.subTest('standard_no_structure_tokens'):
      self.assertEqual(
          'qualifier reply',
          tokenization.get_output_string_from_reply_and_qualifier_strings(
              reply_string, qualifier_string,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

  def test_get_input_string(self):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(request='jump twice', reply='JUMP JUMP'),
        cl.Example(request='[x1 twice] = [x1] [x1]', reply=cl.RuleReply.TRUE)
    ])
    example = cl.Example(
        context=context,
        request='[x1 thrice] = [x1]',
        reply=cl.RuleReply.UNKNOWN)

    with self.subTest('compact'):
      self.assertEqual(
          '[x1 thrice] = [x1]\n'
          'jump twice JUMP JUMP\n'
          '[x1 twice] = [x1] [x1]',
          tokenization.get_input_string(
              example, tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('standard'):
      self.assertEqual(
          '[ x1 thrice ]   =   [ x1 ] '
          '{  <  {  }  ,  jump twice ,  JUMP JUMP ,  M > \n'
          '  <  {  }  ,   [ x1 twice ]   =   [ x1 ]   [ x1 ]  ,  1 ,  M >  }',
          tokenization.get_input_string(
              example, tokenization.ExampleStringFormat.STANDARD))

    with self.subTest('standard_no_structure_tokens'):
      self.assertEqual(
          '[ x1 thrice ]   =   [ x1 ] '
          '  <  {  }    jump twice   JUMP JUMP   M > \n'
          '  <  {  }     [ x1 twice ]   =   [ x1 ]   [ x1 ]    1   M >  ',
          tokenization.get_input_string(
              example,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

    with self.subTest('standard_input_length_uses_whitespace_tokenization'):
      # Each word and punctuation mark is counted as a separate token (thanks to
      # the extra spaces added around the punctuation marks). The whitespace
      # itself is ignored.
      self.assertEqual(
          42,
          tokenization.get_input_length(
              example, tokenization.ExampleStringFormat.STANDARD))

  def test_get_output_string(self):
    context = cl.FrozenExampleSet.from_examples([
        cl.Example(request='jump twice', reply='JUMP JUMP'),
        cl.Example(request='[x1 twice] = [x1] [x1]', reply=cl.RuleReply.TRUE)
    ])
    example = cl.Example(
        context=context,
        request='[x1 thrice] = [x1]',
        reply=cl.RuleReply.UNKNOWN)

    with self.subTest('compact'):
      self.assertEqual(
          'Unknown (Reasoning: Monotonic)',
          tokenization.get_output_string(
              example, tokenization.ExampleStringFormat.COMPACT))

    with self.subTest('standard'):
      self.assertEqual(
          'M ?',
          tokenization.get_output_string(
              example, tokenization.ExampleStringFormat.STANDARD))

    with self.subTest('standard_no_structure_tokens'):
      self.assertEqual(
          'M ?',
          tokenization.get_output_string(
              example,
              tokenization.ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS))

    with self.subTest('standard_output_length_uses_whitespace_tokenization'):
      self.assertEqual(
          2,
          tokenization.get_output_length(
              example, tokenization.ExampleStringFormat.STANDARD))


if __name__ == '__main__':
  absltest.main()
