import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task


class StarFree(task.GeneralizationTask):
  """No B after C

  Examples:
    CB -> class 0
    AB, AC, ABC -> class 1

  Note the sampling is jittable so this task is fast.
  """

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""
    reserve = batch_size // 2
    strings = jrandom.randint(
        rng,
        shape=(batch_size, length),
        minval=0,
        maxval=3,
    )
    if reserve > 0:
        masks_b = (strings == 1).at[reserve:].set(False)
        masks_c = (strings == 2).at[reserve:].set(False)
        for i in range(reserve):
            threshold = int(length / reserve * i)
            masks_b = masks_b.at[i, :threshold].set(False)
            masks_c = masks_c.at[i, threshold:].set(False)
        strings = jnp.where(masks_b, 2, strings)
        strings = jnp.where(masks_c, 1, strings)
    one_hot_strings = jnn.one_hot(strings, num_classes=3)

    max_bs = (strings == 1)[:,::-1].argmax(axis=1)
    max_bs = length - max_bs - 1

    min_cs = (strings == 2).argmax(axis=1)

    labels = ~ jnp.any(strings == 1, axis=1) | ~ jnp.any(strings == 2, axis=1) | (max_bs < min_cs) 
    labels = labels.astype(jnp.float32)
    labels = jnn.one_hot(labels, num_classes=2)
    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return 3

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2