"""Bucket sort task for generalization."""

import functools
from typing import Mapping

import chex
import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from tasks import task


class BucketSort(task.GeneralizationTask):
  """A task which goal is to sort tokens from a fixed alphabet.

  The input string is composed of tokens from a fixed-size alphabet, i.e.,
  `{0, 1, ..., vocab_size - 1}`, and the goal is to return the sorted string (in
  lexicographically increasing order).

  Examples:
    10204112  ->  00111224  (with `vocab_size = 5`)
    1110001   ->  0001111   (with `vocab_size = 2`)
  """

  def __init__(self, *args, vocab_size: int = 5, **kwargs) -> None:
    """Initializes the task.

    Args:
      *args: The args for the base task class.
      vocab_size: The size of the alphabet.
      **kwargs: The kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)
    self._vocab_size = vocab_size

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(
      self,
      rng: chex.PRNGKey,
      batch_size: int,
      length: int,
  ) -> Mapping[str, chex.Array]:
    """Returns a batch of strings and tokens sorted by (inc.) occurrence."""
    strings = jrandom.randint(
        rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size)
    sorted_strings = jnp.sort(strings, axis=-1)

    return {
        'input': jnn.one_hot(strings, num_classes=self.input_size),
        'output': jnn.one_hot(sorted_strings, num_classes=self.output_size),
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return self._vocab_size

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return self._vocab_size

  def output_length(self, input_length: int) -> int:
    """Returns the output length for a given input length."""
    return input_length

