"""Odds first task for generalization."""

import functools
from typing import Mapping

import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom

from tasks import task


class OddsFirst(task.GeneralizationTask):
  """A task which goal is to output the tokens at odd indices of a string first.

  The input is a string s_1 ... s_n composed of symbols from a finite set S. The
  output is the same string, but where the values at odd indexes have been put
  first: s_1 s_3 s_5 ... s_2 s_4 s_6 ...

  Examples:
    00110101 -> 0100 0111
    110 -> 10 1

  In the paper, we use only binary strings (ie S = {0, 1}).
  Note that the sampling is jittable so this task is fast.
  """

  def __init__(self, vocab_size: int, *args, **kwargs):
    """Initializes the odds_first task.

    Args:
      vocab_size: The size of the alphabet.
      *args: Args for the base task class.
      **kwargs: Kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)

    self._vocab_size = vocab_size

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> Mapping[str, jnp.ndarray]:
    """Returns a batch of strings and their outputs."""
    strings = jrandom.randint(
        rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size)
    one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)
    output = jnp.concatenate(
        [one_hot_strings[:, 1::2], one_hot_strings[:, ::2]], axis=1)
    return {"input": one_hot_strings, "output": output}

  @property
  def input_size(self) -> int:
    """Returns the input size for the model."""
    return self._vocab_size

  @property
  def output_size(self) -> int:
    """Returns the output size for the model."""
    return self._vocab_size

  def output_length(self, input_length: int) -> int:
    """Returns the output length for the model."""
    return input_length

