"""Utility functions.
"""

import tensorflow.compat.v2 as tf

DEFAULT_SIZE_BINS = (1, 2, 4, 8, 16, 32, 64, 128, 256, 512)


def closest_size_bucket(text, size_bins):
  """Returns closest value in `size_bins` to the length in words of `text`."""
  words = tf.strings.split(text)
  num_words = tf.size(words)
  closest_bin_index = tf.math.argmin(tf.math.abs(size_bins-num_words))
  return tf.gather(size_bins, closest_bin_index)


def sample_random_subsequence(sequence, max_length, seeds):
  """Returns a randomly selected sub-sequence of seqence."""
  seq_len = len(sequence)

  def random_subsequence():
    sample_length = tf.random.stateless_uniform(
        [],
        minval=1,
        maxval=max_length,
        dtype=tf.dtypes.int32,
        seed=seeds[0],
        name='rand_sample_length')
    sample_start = tf.random.stateless_uniform(
        [],
        minval=0,
        maxval=seq_len - sample_length + 1,
        dtype=tf.dtypes.int32,
        seed=seeds[1],
        name='rand_sample_start')

    return sequence[sample_start:sample_start + sample_length]

  # For sequences of length 0 or 1, the only valid subsequence is empty list.
  return tf.cond(tf.less_equal(max_length, 1),
                 lambda: tf.constant([], dtype=tf.string),
                 random_subsequence)


def add_example_postprocess_fn(
    output,
    example,
    is_target=False):
  """Postprocessor for capturing the example and adding it to target for later use when computing metrics."""
  if is_target:
    return {
        'target': output,
        'size_bucket': example.get('size_bucket'),
        'sample': example.get('sample'),
    }
  else:
    prediction = {'prediction': output}

    return prediction


def wrap_metric(metric_fn):
  """Wraps a `metric_fn` to be used with `add_example_postprocess_fn`."""
  def _wrapped(targets, predictions):
    return metric_fn(
        [targets[i]['target'] for i in range(len(targets))],
        [predictions[i]['prediction'] for i in range(len(predictions))])

  return _wrapped


def length_accuracy(targets, predictions):
  """Metric that outputs fraction of generations with correct length.

  Args:
    targets: list of dictionaries containing the examples
    predictions: list of strings
  Returns:
    dict with `length_accuracy`: Accuracy of generating text whose length
      matches the input length bin size.
  """

  session = tf.compat.v1.Session()
  count_correct = 0
  for example, pred in zip(targets, predictions):
    target_length_bucket = example["size_bucket"]

    pred_length_bucket = closest_size_bucket(
        pred["prediction"], DEFAULT_SIZE_BINS).eval(session=session)

    count_correct += (target_length_bucket == pred_length_bucket)

  return {"length_accuracy": count_correct / len(targets)}


def sample_sequence_accuracy(targets, predictions):
  """Metric that outputs fraction of generates that correct use samlpe sequence.

  Args:
    targets: list of dictionaries containing the examples
    predictions: list of strings
  Returns:
    dict with `sample_sequence_accuracy`: Accuracy of generating text that
      contains the sample sequence.
  """

  count_correct = 0
  for example, pred in zip(targets, predictions):
    sample = example.get("sample", "")
    count_correct += (sample in pred)
  return {"sample_sequence_accuracy": count_correct / len(targets)}
