import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task

class PiecewiseTestable(task.GeneralizationTask):
  """A*aA*bA*cA*

  Examples:
    A -> class 0
    BCD, ABCD -> class 1

  Note the sampling is jittable so this task is fast.
  """

  def __init__(self, vocab_size: int, sub_length: int, *args, **kwargs):
    """Initializes the remember_string task.

    Args:
      vocab_size: The size of the alphabet.
      *args: Args for the base task class.
      **kwargs: Kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)

    self._vocab_size = vocab_size
    self._sub_length = sub_length

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""
    reserve1 = batch_size // 4
    reserve2 = batch_size // 2
    reserve3 = reserve1 + reserve2
    strings = jrandom.randint(
        rng,
        shape=(batch_size, length),
        minval=0,
        maxval=self._vocab_size,
    )
    
    min_as_mask = jnp.cumsum(strings == 0, axis=1).astype(bool)
    masks_b = (strings == 1) & min_as_mask
    toggle_one_b = jnn.one_hot(length - masks_b[:,::-1].argmax(axis=1) - 1, length).astype(bool)
    masks_b = masks_b.at[:reserve1].set((masks_b & ~ toggle_one_b)[:reserve1])
    masks_b = masks_b.at[reserve1:reserve2].set(False)
    masks_b = masks_b.at[reserve3:].set(False)

    max_bs_mask = jnp.cumsum((strings == 1)[:, ::-1], axis=1)[:, ::-1].astype(bool)
    masks_a = (strings == 0) & max_bs_mask
    toggle_one_a = jnn.one_hot(masks_a.argmax(axis=1), length).astype(bool)
    masks_a = masks_a.at[reserve1:reserve2].set((masks_a & ~ toggle_one_a)[reserve1:reserve2])
    masks_a = masks_a.at[:reserve1].set(False)
    masks_a = masks_a.at[reserve2:reserve3].set(False)

    strings = jnp.where(masks_a, 2, strings)
    strings = jnp.where(masks_b, 2, strings)
    one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)
    subseq = jnp.arange(0, self._sub_length)
    def is_subseq(string):
        def update_mark(i, mark):
            new_mark = jnp.argwhere(string == subseq[i], size=length, fill_value=length)
            new_mark = jnp.where(new_mark > mark, new_mark, length)
            new_mark = jnp.min(new_mark)
            mark = jnp.where(new_mark > -1, new_mark, length)
            return mark
        return jax.lax.fori_loop(0, len(subseq), update_mark, -1) < length

    labels = jax.vmap(is_subseq)(strings)
    labels = labels.astype(jnp.float32)
    labels = jnn.one_hot(labels, num_classes=2)
    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return self._vocab_size

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2