# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for generating and tokenizing cSCAN Example string representations.

Note that, in principle, each solution is free to choose its own tokenizer and
to perform arbitrary manipulations of the example format in its input pipeline.
In this library, we implement several variations on the logic for generating
and tokenizing cSCAN Example string representations, which are used in the
baselines that we are publishing together with the cSCAN benchmark.

Other solution implementations are free to either reuse the logic implemented
here or to implement their own custom logic.
"""

import enum
import functools
import re
import seqio
import t5.data
import tensorflow as tf

from conceptual_learning.cscan import conceptual_learning as cl


@enum.unique
class ExampleStringFormat(str, enum.Enum):
  """Format in which examples are represented as strings."""
  # COMPACT format is currently used for the T5X baselines. It is considerably
  # more compact than STANDARD, with most structural punctuation removed, and
  # reads a little more like natural language. It is designed for use with T5X's
  # default SentencePieceVocabulary tokenizer.
  #
  COMPACT: str = 'COMPACT'
  # STANDARD format is based on the standard string representation defined in
  # the conceptual_learning library, with just some minor adjustments around
  # punctuation to work with the whitespace tokenizer.
  STANDARD: str = 'STANDARD'
  # STANDARD_NO_STRUCTURE_TOKENS is the same as STANDARD format, but with commas
  # and outer curly braces omitted. This slightly more compact format was
  # motivated by the etc_relative architecture, in which structural information
  # is captured by relational attention with global tokens, rather than via
  # punctuation.
  STANDARD_NO_STRUCTURE_TOKENS: str = 'STANDARD_NO_STRUCTURE_TOKENS'


def _format_punctuation(string, use_structure_tokens):
  # Add spaces around punctuation.
  string = re.sub(r"([{}<>,\[\]=\(\)+?'])", r' \1 ', string).strip()
  # But still treat '->' as a single token.
  string = re.sub(r'(- >)', '->', string)
  if not use_structure_tokens:
    # Remove comma.
    string = re.sub('([,])', '', string)
  return string


@functools.lru_cache()
def _get_t5x_tokenizer():
  """Returns a cached T5X tokenizer for use with the COMPACT representation."""
  # It's important to cache this tokenizer, as each call to the below function
  # takes about a second.
  return t5.data.get_default_vocabulary()


def get_request_string(request, string_format):
  """Returns the string representation for the given request."""
  if string_format == ExampleStringFormat.COMPACT:
    return request
  elif string_format == ExampleStringFormat.STANDARD:
    return _format_punctuation(request, True)
  elif string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS:
    return _format_punctuation(request, False)
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_reply_string(reply, string_format):
  """Returns the string representation for the given reply."""
  if string_format == ExampleStringFormat.COMPACT:
    if reply == cl.RuleReply.TRUE:
      return 'True'
    elif reply == cl.RuleReply.FALSE:
      return 'False'
    elif reply == cl.RuleReply.UNKNOWN:
      return 'Unknown'
    else:
      return reply
  elif string_format == ExampleStringFormat.STANDARD:
    return _format_punctuation(reply, True)
  elif string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS:
    return _format_punctuation(reply, False)
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_qualifier_string(qualifier,
                         string_format):
  """Returns the string representation for the given qualifier."""
  if string_format == ExampleStringFormat.COMPACT:
    if qualifier == cl.Qualifier.D:
      return 'Defeasible'
    else:
      return 'Monotonic'
  elif string_format == ExampleStringFormat.STANDARD:
    return _format_punctuation(qualifier, True)
  elif string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS:
    return _format_punctuation(qualifier, False)
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_context_example_string(context_example,
                               string_format):
  """Returns the string representation for a given context example."""
  if string_format == ExampleStringFormat.COMPACT:
    if context_example.get_request_type() == cl.RequestType.NON_RULE:
      # In this case the request and reply are already in the format of
      # question and answer.
      return f'{context_example.request} {context_example.reply}'
    else:
      # In this case the request is an assertion of a rule.
      return f'{context_example.request}'
  elif string_format == ExampleStringFormat.STANDARD:
    return _format_punctuation(str(context_example), True)
  elif string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS:
    return _format_punctuation(str(context_example), False)
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_context_string(context,
                       string_format):
  """Returns the string representation for a given top-level example context."""
  if string_format == ExampleStringFormat.COMPACT:
    context_lines = []
    for context_example in context:
      context_lines.append(
          get_context_example_string(context_example, string_format))
    return '\n'.join(context_lines)
  elif string_format == ExampleStringFormat.STANDARD:
    return _format_punctuation(context.to_string(), True)
  elif string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS:
    context_string = _format_punctuation(context.to_string(), False)
    # Remove surrounding brackets.
    return context_string[1:-1]
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_nested_context_string(context,
                              string_format):
  """Returns the string representation for a given context example context."""
  if string_format == ExampleStringFormat.COMPACT:
    # In COMPACT format, nested contexts are omitted. Note that this is not a
    # problem for cSCAN, as nested contexts are currently always empty.
    return ''
  elif string_format == ExampleStringFormat.STANDARD:
    return _format_punctuation(context.to_string(), True)
  elif string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS:
    # Note that unlike for top-level contexts, here we don't remove the
    # surrounding brackets.
    return _format_punctuation(context.to_string(), False)
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_input_tensor_from_context_and_request_tensors(
    context_string, request_string,
    string_format):
  """Returns an input string built from the context and request string.

  Args:
    context_string: Context string represented as a tensor. Assumed to have been
      generated in the same format as `string_format`.
    request_string: Request string represented as a tensor. Assumed to have been
      generated in the same format as `string_format`.
    string_format: Format in which to generate the input string representation.

  Returns:
    Input string in the specified format, represented as a tensor.
  """
  if string_format == ExampleStringFormat.COMPACT:
    return tf.strings.join([request_string, '\n', context_string])
  elif (string_format == ExampleStringFormat.STANDARD or
        string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS):
    return tf.strings.join([request_string, context_string], separator=' ')
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_output_tensor_from_reply_and_qualifier_tensors(
    reply_string, qualifier_string,
    string_format):
  """Returns an output string built from the reply and qualifier strings.

  Args:
    reply_string: Reply string represented as a tensor. Assumed to have been
      generated in the same format as `string_format`.
    qualifier_string: Qualifier string represented as a tensor. Assumed to have
      been generated in the same format as `string_format`.
    string_format: Format in which to generate the output string representation.

  Returns:
    Output string in the specified format, represented as a tensor.
  """
  if string_format == ExampleStringFormat.COMPACT:
    return tf.strings.join(
        [reply_string, ' (Reasoning: ', qualifier_string, ')'])
  elif (string_format == ExampleStringFormat.STANDARD or
        string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS):
    return tf.strings.join([qualifier_string, reply_string], separator=' ')
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_input_string_from_context_and_request_strings(
    context_string, request_string,
    string_format):
  """Returns an input string built from the context and request strings."""
  result_tensor = get_input_tensor_from_context_and_request_tensors(
      tf.constant(context_string), tf.constant(request_string), string_format)
  return result_tensor.numpy().decode('utf-8')


def get_output_string_from_reply_and_qualifier_strings(
    reply_string, qualifier_string,
    string_format):
  """Returns an output string built from the reply and qualifier strings."""
  result_tensor = get_output_tensor_from_reply_and_qualifier_tensors(
      tf.constant(reply_string), tf.constant(qualifier_string), string_format)
  return result_tensor.numpy().decode('utf-8')


def get_input_string(example,
                     string_format):
  """Returns the input string for a given top-level example."""
  context_string = get_context_string(example.context, string_format)
  request_string = get_request_string(example.request, string_format)
  return get_input_string_from_context_and_request_strings(
      context_string, request_string, string_format)


def get_output_string(example,
                      string_format):
  """Returns the output string for a given top-level example."""
  reply_string = get_reply_string(example.reply, string_format)
  qualifier_string = get_qualifier_string(example.qualifier, string_format)
  return get_output_string_from_reply_and_qualifier_strings(
      reply_string, qualifier_string, string_format)


def get_tokenized_length(string,
                         string_format):
  """Returns the number of tokens in the given string."""
  if string_format == ExampleStringFormat.COMPACT:
    # COMPACT representation uses T5X's default SentencePiece tokenization.
    return len(_get_t5x_tokenizer().encode(string))
  elif (string_format == ExampleStringFormat.STANDARD or
        string_format == ExampleStringFormat.STANDARD_NO_STRUCTURE_TOKENS):
    # STANDARD representation uses whitespace tokenization. We are assuming that
    # all relevant preprocessing of the string for whitespace tokenization
    # (e.g., adding extra spaces around punctuation marks) has already been
    # performed. If the string was generated  by one of the "get_*_string"
    # functions of this library, then this is guaranteed to be the case.
    return len(string.split())
  else:
    raise ValueError(f'Unknown example string format: {string_format}')


def get_input_length(example,
                     string_format):
  """Returns the number of input string tokens for a top-level example."""
  input_string = get_input_string(example, string_format)
  return get_tokenized_length(input_string, string_format)


def get_output_length(example,
                      string_format):
  """Returns the number of output string tokens for a top-level example."""
  output_string = get_output_string(example, string_format)
  return get_tokenized_length(output_string, string_format)
