from typing import Any, Mapping, NamedTuple, Optional

import chex
import haiku as hk
import jax
import jax.nn as jnn
import jax.numpy as jnp

_EPSILON = 0.001


class NDStack(NamedTuple):
  """The non-deterministic stack.

  Note that alpha and top_stack depend on gamma.
  """
  gamma: chex.Array  # Shape (B, T, T, Q, S, Q, S)
  alpha: chex.Array  # Shape (B, T, Q, S)
  top_stack: chex.Array  # Shape (B, S)


def _update_stack(ndstack: NDStack,
                  push_actions: chex.Array,
                  pop_actions: chex.Array,
                  replace_actions: chex.Array,
                  timestep: int,
                  read_states: bool = True) -> NDStack:
  """Returns an updated NDStack.

  Args:
    ndstack: See above. Contains the internals needed to simulate a
      non-deterministic stack.
    push_actions: A tensor of shape (B, Q, S, Q, S).
    pop_actions: A tensor of shape (B, Q, S, Q).
    replace_actions: A tensor of shape (B, Q, S, Q, S).
    timestep: The current timestep while processing the sequence.
    read_states: Whether to read the state of the NPDA as well.
  """
  stack_size = ndstack.gamma.shape[2]
  mask = jnp.zeros((stack_size, stack_size))
  mask = mask.at[timestep - 1, timestep].set(1)
  new_push_gamma_t = jnp.einsum('bqxry,tT->btTqxry', push_actions,
                                mask)[:, :, timestep]

  index_k = jnp.stack([jnp.arange(start=0, stop=stack_size)] * stack_size)
  index_i = jnp.transpose(index_k)
  timestep_arr = jnp.full((stack_size, stack_size), timestep)
  index_mask = jnp.logical_and(index_k > index_i, index_k < timestep_arr - 1)
  index_mask = jnp.einsum('tT,bqxry->btTqxry', index_mask,
                          jnp.ones(push_actions.shape))
  new_pop_gamma_t = jnp.einsum(
      'bikqxuy,bkuysz,bszr->biqxry',
      index_mask * ndstack.gamma,
      ndstack.gamma[:, :, timestep - 1],
      pop_actions,
  )

  new_replace_gamma_t = jnp.einsum('biqxsz,bszry->biqxry',
                                   ndstack.gamma[:, :,
                                                 timestep - 1], replace_actions)

  new_gamma = jax.vmap(jax.vmap(lambda x, y: x.at[timestep].set(y)))(
      ndstack.gamma, new_replace_gamma_t + new_pop_gamma_t + new_push_gamma_t)

  alpha_t = jnp.einsum('biqx,biqxry->bry', ndstack.alpha, new_gamma[:, :,
                                                                    timestep])
  new_alpha = jax.vmap(lambda x, y: x.at[timestep].set(y))(ndstack.alpha,
                                                           alpha_t)

  if read_states:
    batch_size, states, symbols = alpha_t.shape
    obs = jnp.reshape(alpha_t, (batch_size, states * symbols))
  else:
    obs = jnp.sum(alpha_t, axis=1)

  obs = obs / (jnp.sum(obs, axis=-1, keepdims=True) + _EPSILON)
  return NDStack(new_gamma, new_alpha, top_stack=obs)


# First element is the NDStack, second is the current timestep, third is the
# hidden internal state.
_NDStackRnnState = tuple[NDStack, chex.Array, chex.Array]


class NDStackRNNCore(hk.RNNCore):
  """Core for the non-deterministic stack RNN."""

  def __init__(
      self,
      stack_symbols: int,
      stack_states: int,
      stack_size: int = 30,
      inner_core: type[hk.RNNCore] = hk.VanillaRNN,
      read_states: bool = False,
      name: Optional[str] = None,
      **inner_core_kwargs: Mapping[str, Any]
  ):
    """Initializes.

    Args:
      stack_symbols: The number of symbols which can be used in the stack.
      stack_states: The number of states of the non-deterministic stack.
        Corresponds to the number of branching in the graph, ie roughly n_stacks
        = stack_states ^ t.
      stack_size: The total size of the stacks. Be careful when increasing this
        value since the computation is in O(stack_size ^ 3).
      inner_core: The inner RNN core builder.
      read_states: Whether to read the states on the NPDA or only the top of the
        stack.
      name: See base class.
      **inner_core_kwargs: The arguments to be passed to the inner RNN core
        builder.
    """
    super().__init__(name=name)
    self._rnn_core = inner_core(**inner_core_kwargs)
    self._stack_symbols = stack_symbols
    self._stack_states = stack_states
    self._stack_size = stack_size
    self._read_states = read_states

  def __call__(
      self, inputs: chex.Array, prev_state: _NDStackRnnState
  ) -> tuple[chex.Array, _NDStackRnnState]:
    """Steps the non-deterministic stack RNN core.

    See base class docstring.

    Args:
      inputs: An input array of shape (batch_size, input_size). The time
        dimension is not included since it is an RNNCore, which is unrolled over
        the time dimension.
      prev_state: A _NDStackRnnState tuple, consisting of the previous nd-stack,
        the previous timestep and the previous state of the inner core.

    Returns:
      - output: An output array of shape (batch_size, output_size).
      - next_state: Same format as prev_state.
    """
    ndstack, timestep, old_core_state = prev_state

    # The network can always read the top of the stack.
    batch_size = ndstack.gamma.shape[0]
    inputs = jnp.concatenate([inputs, ndstack.top_stack], axis=-1)
    new_core_output, new_core_state = self._rnn_core(inputs, old_core_state)

    n_push_actions = (self._stack_states * self._stack_symbols)**2
    n_pop_actions = self._stack_states**2 * self._stack_symbols
    n_replace_actions = (self._stack_states * self._stack_symbols)**2
    actions = hk.Linear(n_push_actions + n_pop_actions + n_replace_actions)(
        new_core_output)
    actions = jnn.softmax(actions, axis=-1)

    push_actions = jnp.reshape(
        actions[:, :n_push_actions],
        (batch_size, self._stack_states, self._stack_symbols,
         self._stack_states, self._stack_symbols))

    pop_actions = jnp.reshape(
        actions[:, n_push_actions:n_push_actions + n_pop_actions],
        (batch_size, self._stack_states, self._stack_symbols,
         self._stack_states))

    replace_actions = jnp.reshape(
        actions[:, -n_replace_actions:],
        (batch_size, self._stack_states, self._stack_symbols,
         self._stack_states, self._stack_symbols))

    new_ndstack = _update_stack(
        ndstack,
        push_actions,
        pop_actions,
        replace_actions, (timestep + 1)[0],
        read_states=self._read_states)
    return new_core_output, (new_ndstack, timestep + 1, new_core_state)

  def initial_state(self, batch_size: Optional[int]) -> _NDStackRnnState:
    """Returns the initial state of the core, a hidden state and an empty stack."""
    core_state = self._rnn_core.initial_state(batch_size)

    # Gamma, the transition matrix, is initialized to full zeros: there is no
    # connection in the graph at the beginning.
    gamma = jnp.zeros(
        (batch_size, self._stack_size, self._stack_size, self._stack_states,
         self._stack_symbols, self._stack_states, self._stack_symbols))

    # Alpha is zero everywhere except for the first node, which is (0, q0, 0).
    alpha = jnp.zeros(
        (batch_size, self._stack_size, self._stack_states, self._stack_symbols))
    alpha = jax.vmap(lambda x: x.at[0, 0, 0].set(1))(alpha)

    if self._read_states:
      top_stack = jnp.zeros(
          (batch_size, self._stack_states * self._stack_symbols))
    else:
      # The top of the stack is 0 as the first node contains the symbol 0.
      top_stack = jnp.zeros((batch_size, self._stack_symbols))

    ndstack = NDStack(gamma, alpha, top_stack)
    return ndstack, jnp.zeros((batch_size,), dtype=jnp.int32), core_state
