import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom

from neural_networks_chomsky_hierarchy.tasks import task


class LocallyTestable(task.GeneralizationTask):
  """At least one B

  Examples:
    A -> class 0
    B, ABB -> class 1

  Note the sampling is jittable so this task is fast.
  """

  def __init__(self, sub_length: int, *args, **kwargs):
    """Initializes the remember_string task.

    Args:
      vocab_size: The size of the alphabet.
      *args: Args for the base task class.
      **kwargs: Kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)

    self._sub_length = sub_length
    self._vocab_size = 3

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""
    reserve1 = batch_size // 2
    # reserve2 = reserve1 + batch_size // 4
    reserve2 = batch_size
    strings = jrandom.randint(
        rng,
        shape=(batch_size, length),
        minval=0,
        maxval=self._vocab_size,
    )

    if self._sub_length == 1:
        reserve_strings = strings[:reserve2]
        position = (reserve_strings == 0)

    elif self._sub_length == 2:
        reserve_strings = strings[:reserve2]
        position = (reserve_strings == 0)[:, :-1] & (reserve_strings == 1)[:, 1:]
        position = jnp.concatenate((jnp.zeros((reserve2, 1)).astype(bool), position), axis=-1)

    else:
        raise NotImplementedError()

    toggle_one = jnn.one_hot(length - position[:,::-1].argmax(axis=1) - 1, length).astype(bool)
    position = position.at[reserve1:].set((position & ~ toggle_one)[reserve1:])
    replace_strings = jrandom.randint(
        rng,
        shape=(reserve2, length),
        minval=2,
        maxval=self._vocab_size,
    )
    reserve_strings = jnp.where(position, replace_strings, reserve_strings)

    strings = jnp.concatenate((reserve_strings, strings[reserve2:]), axis=0)

    one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)

    if self._sub_length == 1:
        labels = jnp.any(strings == 0, axis=1).astype(jnp.float32)
    elif self._sub_length == 2:
        labels = jnp.any((strings == 0)[:, :-1] & (strings == 1)[:, 1:], axis=-1)

    labels = labels.astype(jnp.float32)
    labels = jnn.one_hot(labels, num_classes=2)
    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return self._vocab_size

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2