import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task


class LocalThreshold(task.GeneralizationTask):
  """Eactly one B

  Examples:
    A, BB -> class 0
    B, AB -> class 1

  Note the sampling is jittable so this task is fast.
  """

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""
    reserve1 = batch_size // 6
    reserve2 = reserve1 + batch_size // 6
    reserve3 = reserve2 + batch_size // 2

    ones = jrandom.randint(
                                    rng,
                                    shape=(reserve3,),
                                    minval=0,
                                    maxval=length,
                                )
    strings = jnn.one_hot(ones, num_classes=length)

    reserve_ones1 = jnp.where(ones[:reserve1] + 1 < length, ones[:reserve1] + 1, length)
    reserve_strings1 = jnn.one_hot(reserve_ones1, num_classes=length) + strings[:reserve1]
    strings = strings.at[:reserve1].set(reserve_strings1)

    reserve_ones2 = jrandom.randint(
                                    rng,
                                    shape=(reserve2 - reserve1,),
                                    minval=0,
                                    maxval=length,
                                )
    reserve_strings2 = jnn.one_hot(reserve_ones2, num_classes=length) + strings[reserve1:reserve2]
    reserve_strings2 = jnp.where(reserve_strings2 > 1, 1, reserve_strings2)
    strings = strings.at[reserve1:reserve2].set(reserve_strings2)

    strings = jnp.concat([strings, jnp.zeros((batch_size-reserve3, length))], axis=0)
    
    one_hot_strings = jnn.one_hot(strings, num_classes=2)
    labels = (jnp.sum(strings, axis=1) == 1).astype(jnp.float32)
    labels = jnn.one_hot(labels, num_classes=2)
    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return 2

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2
