"""Compute the reverse of an input string."""

import functools
from typing import Mapping

import jax
import jax.numpy as jnp

from tasks import task


class ReverseString(duplicate_string.DuplicateString):
  """A task which goal is to reverse a given string.

  The input is a string s_1 ... s_n composed of symbols from a finite set S. The
  output is the string, reversed, ie s_n ... s_1.

  Examples:
    011010 -> 010110
    123021 -> 120321

  In the paper, we use only binary strings (ie S = {0, 1}).
  Note that the sampling is jittable so this task is fast.
  """

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> Mapping[str, jnp.ndarray]:
    """Returns a batch of strings and their reversed version."""
    batch = super().sample_batch(rng, batch_size, length)
    batch['output'] = jnp.flip(batch['input'], axis=1)
    return batch

  def output_length(self, input_length: int) -> int:
    """Returns the output length for a given input length."""
    return input_length

