import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task

class BoundedDyck(task.GeneralizationTask):
  """{b,d}*a{a,c,d}*

  Examples:
    A -> class 0
    BCD, ABCD -> class 1

  Note the sampling is jittable so this task is fast.
  """

  def __init__(self, *args, **kwargs):
    """Initializes the remember_string task.

    Args:
      vocab_size: The size of the alphabet.
      *args: Args for the base task class.
      **kwargs: Kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""
    length *= 2
    strings = jnp.full((batch_size, length), -1)
    strings = strings.at[:,length // 2:].set(1)
    strings = jrandom.permutation(rng, strings, axis=1, independent=True)
    
    def revise_string(carry, string):
        counter, buffer = carry
        condition1 = ((counter == 0) | (counter < 2) & (buffer < 0)) & (string == -1)
        condition2 = ((counter > 0) & (buffer > 0) | (counter > 1)) & (string == 1)

        buffer = jnp.where(condition1, buffer + 1, buffer)
        string = jnp.where(condition1, 1, string)

        buffer = jnp.where(condition2, buffer - 1, buffer)
        string = jnp.where(condition2, -1, string)

        counter += string
        return [counter, buffer], string
    
    counter = jnp.zeros((batch_size))
    buffer = jnp.zeros((batch_size))
    strings = jnp.transpose(strings)
    (counter, buffer), strings = jax.lax.scan(revise_string, [counter, buffer], strings)
    strings = jnp.transpose(strings)

    flips = jrandom.randint(
        rng,
        shape=(batch_size,),
        minval=max(0, length - 10),
        maxval=length,
    )
    flips = jnn.one_hot(flips, num_classes=length)
    flips = flips.at[batch_size // 2:, :].set(0)
    strings = jnp.where(flips == 1, -strings, strings)

    counter = jnp.cumsum(strings, axis=1)
    labels = jnp.all(counter >= 0, axis=1) & jnp.all(counter <= 2, axis=1) & (counter[:,-1] == 0)

    labels = labels.astype(jnp.float32)
    labels = jnn.one_hot(labels, num_classes=2)

    strings = jnp.where(strings == 1, 0, strings)
    strings = jnp.where(strings == -1, 1, strings)
    one_hot_strings = jnn.one_hot(strings, num_classes=2)

    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return 2

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2