import functools

import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task


class OddsFirst(task.GeneralizationTask):
  """A task with the goal of outputting a string's tokens at odd indices first.

  The input is a string s_1 ... s_n composed of symbols from a finite set S. The
  output is the same string, but where the values at odd indexes have been put
  first: s_1 s_3 s_5 ... s_2 s_4 s_6 ...

  Examples:
    00110101 -> 0100 0111
    110 -> 10 1

  In the paper, we use only binary strings (ie S = {0, 1}).
  Note that the sampling is jittable so this task is fast.
  """

  def __init__(self, vocab_size: int, *args, **kwargs):
    """Initializes the odds_first task.

    Args:
      vocab_size: The size of the alphabet.
      *args: Args for the base task class.
      **kwargs: Kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)

    self._vocab_size = vocab_size

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and their outputs."""
    strings = jrandom.randint(
        rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size)
    one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)
    output = jnp.concatenate(
        [one_hot_strings[:, 1::2], one_hot_strings[:, ::2]], axis=1)
    return {"input": one_hot_strings, "output": output}

  @property
  def input_size(self) -> int:
    """Returns the input size for the model."""
    return self._vocab_size

  @property
  def output_size(self) -> int:
    """Returns the output size for the model."""
    return self._vocab_size

  def output_length(self, input_length: int) -> int:
    """Returns the output length for the model."""
    return input_length
