import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task


class StrictlyLocal(task.GeneralizationTask):
  """(AB)*

  Examples:
    A, ABA -> class 0
    AB, ABAB -> class 1

  Note the sampling is jittable so this task is fast.
  """

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""
    length *= 2
    strings = jnp.zeros((batch_size, length)).at[:,1::2].set(1)
    flips = jrandom.randint(
        rng,
        shape=(batch_size,),
        minval=max(0, length - 10),
        maxval=length,
    )
    flips = jnn.one_hot(flips, num_classes=length)
    flips = flips.at[batch_size // 2:, :].set(0)
    strings = jnp.where(flips == 1, ~strings.astype(bool), strings)

    one_hot_strings = jnn.one_hot(strings, num_classes=2)

    labels = jnp.all(strings[:,::2] == 0, axis=1) & jnp.all(strings[:,1::2] == 1, axis=1)
    labels = labels.astype(jnp.float32)

    labels = jnn.one_hot(labels, num_classes=2)
    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return 2

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2

