import functools

import jax
from jax import nn as jnn
from jax import numpy as jnp
from jax import random as jrandom
import chex

from neural_networks_chomsky_hierarchy.tasks import task

class LeftDeterministicBLM(task.GeneralizationTask):
  """{b,d}*a{a,c,d}*

  Examples:
    A -> class 0
    BCD, ABCD -> class 1

  Note the sampling is jittable so this task is fast.
  """

  def __init__(self, num_states: int, *args, **kwargs):
    """Initializes the remember_string task.

    Args:
      vocab_size: The size of the alphabet.
      *args: Args for the base task class.
      **kwargs: Kwargs for the base task class.
    """
    super().__init__(*args, **kwargs)
    self._num_states = num_states
    
    self._forbidden_letters = {i: jnp.array([i + self._num_states - 1]) for i in range(self._num_states)}
    if self._num_states == 7:
        self._vocab_size = 20
        self._trans = {(0, 0): 1, (1, 1): 2, (2, 2): 3, (0, 3): 4, (4, 4): 5, (0, 5): 6}
        self._final_states = jnp.array([3, 5, 6])
    elif self._num_states == 3:
        self._vocab_size = 6
        self._trans = {(0, 0): 1, (1, 1): 2}
        self._final_states = jnp.array([2])
    elif self._num_states == 2:
        self._vocab_size = 4
        self._trans = {(0, 0): 1}
        self._final_states = jnp.array([1])
    else:
        raise NameError("Undefined.")

  @functools.partial(jax.jit, static_argnums=(0, 2, 3))
  def sample_batch(self, rng: jnp.ndarray, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of strings and the expected class."""

    strings = jrandom.randint(
        rng,
        shape=(batch_size, length),
        minval=0,
        maxval=self._vocab_size,
    )

    if self._num_states == 2:
        reserve1 = batch_size // 3
        reserve2 = reserve1 + (batch_size - reserve1) // 2

        min_as_mask = jnp.cumsum(strings == 0, axis=1).astype(bool)

        masks_b = (strings == 1) & ~ min_as_mask
        toggle_one_b = jnn.one_hot(masks_b.argmax(axis=1), length).astype(bool)
        masks_b = masks_b.at[reserve1:reserve2].set((masks_b & ~ toggle_one_b)[reserve1:reserve2])

        masks_c = (strings == 2) & min_as_mask
        toggle_one_c = jnn.one_hot(masks_c.argmax(axis=1), length).astype(bool)
        masks_c = masks_c.at[reserve2:batch_size].set((masks_c & ~ toggle_one_c)[reserve2:batch_size])

        jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
        strings = jnp.where(masks_b, jrandom.choice(rng, jit_delete(jnp.arange(self._vocab_size), jnp.array([0, 1]), assume_unique_indices=True), shape=(batch_size, length)), strings)
        strings = jnp.where(masks_c, jrandom.choice(rng, jit_delete(jnp.arange(self._vocab_size), jnp.array([2]), assume_unique_indices=True), shape=(batch_size, length)), strings)
    elif self._num_states == 3:
        reserve1 = batch_size // 4
        reserve2 = reserve1 + (batch_size - reserve1) // 3
        reserve3 = reserve2 + (batch_size - reserve1) // 3

        min_as_mask = jnp.cumsum(strings == 0, axis=1).astype(bool)
        min_bs_mask = jnp.cumsum((strings == 1) & min_as_mask, axis=1).astype(bool)

        masks_c = (strings == 2) & ~ min_as_mask
        toggle_one_c = jnn.one_hot(masks_c.argmax(axis=1), length).astype(bool)
        masks_c = masks_c.at[reserve1:reserve2].set((masks_c & ~ toggle_one_c)[reserve1:reserve2])

        masks_d = (strings == 3) & min_as_mask & ~ min_bs_mask 
        toggle_one_d = jnn.one_hot(masks_d.argmax(axis=1), length).astype(bool)
        masks_d = masks_d.at[reserve2:reserve3].set((masks_d & ~ toggle_one_d)[reserve2:reserve3])

        masks_e = (strings == 4) & min_bs_mask
        toggle_one_e = jnn.one_hot(masks_e.argmax(axis=1), length).astype(bool)
        masks_e = masks_e.at[reserve3:batch_size].set((masks_e & ~ toggle_one_e)[reserve3:batch_size])

        jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
        strings = jnp.where(masks_c, jrandom.choice(rng, jit_delete(jnp.arange(self._vocab_size), jnp.array([0, 2]), assume_unique_indices=True), shape=(batch_size, length)), strings)
        strings = jnp.where(masks_d, jrandom.choice(rng, jit_delete(jnp.arange(self._vocab_size), jnp.array([1, 3]), assume_unique_indices=True), shape=(batch_size, length)), strings)
        strings = jnp.where(masks_e, jrandom.choice(rng, jit_delete(jnp.arange(self._vocab_size), jnp.array([4]), assume_unique_indices=True), shape=(batch_size, length)), strings)
    else:
        raise NameError("Undefined.")

    def update_state(state, string):
        for tran in self._trans:
            state = jnp.where((state == tran[0]) & (string == tran[1]), self._trans[tran], state)
        return state
    
    def traverse(state, string):
        # -1 for reject state
        label = jnp.zeros_like(state, dtype=bool)
        for i in range(self._num_states):
            if i in self._forbidden_letters:
                label |= (state == i) & ~jnp.isin(string, self._forbidden_letters[i])
            else:
                label |= (state == i)

        state = update_state(state, string)
        state = jnp.where(label, state.astype(int), -1)
        return state, label

    one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)

    state = jnp.zeros((batch_size), dtype=int)
    state, labels = jax.lax.scan(traverse, state, strings.transpose())
    labels = labels.transpose()
    final_label = labels[:, -1] & jnp.isin(state, self._final_states)
    labels = jnp.concat((labels, jnp.expand_dims(final_label, axis=1)), axis=1)
    labels = labels.astype(jnp.float32)
    labels = jnn.one_hot(labels, num_classes=2)
    return {
        'input': one_hot_strings,
        'output': labels,
    }

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return self._vocab_size

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2
  
  def output_length(self, input_length: int) -> int:
    """Returns the length of the output, given an input length."""
    return input_length + 1

  def accuracy_fn(self, output: chex.Array, target: chex.Array) -> chex.Array:
    """Returns the accuracy between an output and a target."""
    y_pred = output.argmax(-1).reshape(-1)
    y_true = target.argmax(-1).reshape(-1)

    tp = ((y_pred == y_true) & (y_pred == 1)).sum()
    fp = ((y_pred != y_true) & (y_pred == 1)).sum()
    fn = ((y_pred != y_true) & (y_pred == 0)).sum()
    f1 = 2 * tp / (2 * tp + fp + fn) 
    return jnp.where((y_pred == y_true).all(), 1, f1)