import functools
import math
from typing import Callable

import einops
import jax
import jax.ad_checkpoint as adc
import jax.numpy as jnp
import ninjax as nj
import numpy as np

COMPUTE_DTYPE = jnp.bfloat16
LAYER_CALLBACK = lambda tensor, name: tensor

f32 = jnp.float32


def cast(xs, force=False):
  if force:
    should = lambda x: True
  else:
    should = lambda x: jnp.issubdtype(x.dtype, jnp.floating)
  return jax.tree.map(lambda x: COMPUTE_DTYPE(x) if should(x) else x, xs)


def act(name):
  if name == 'none':
    return lambda x: x
  elif name == 'mish':
    return lambda x: x * jnp.tanh(jax.nn.softplus(x))
  elif name == 'relu2':
    return lambda x: jnp.square(jax.nn.relu(x))
  elif name == 'swiglu':
    def fn(x):
      x, y = jnp.split(x, 2, -1)
      return jax.nn.silu(x) * y
    return fn
  else:
    return getattr(jax.nn, name)


def init(name):
  if callable(name):
    return name
  elif name.endswith(('_in', '_out', '_avg')):
    dist, fan = name.rsplit('_', 1)
  else:
    dist, fan = name, 'in'
  return Initializer(dist, fan, 1.0)


def dropout(x, prob, training):
  if not prob or not training:
    return x
  keep = jax.random.bernoulli(nj.seed(), 1.0 - prob, x.shape)
  return x * keep / (1.0 - prob)


def symlog(x):
  return jnp.sign(x) * jnp.log1p(jnp.abs(x))


def symexp(x):
  return jnp.sign(x) * jnp.expm1(jnp.abs(x))


def where(condition, xs, ys):
  assert condition.dtype == bool, condition.dtype
  def fn(x, y):
    assert x.shape == y.shape, (x.shape, y.shape)
    expanded = jnp.expand_dims(condition, list(range(condition.ndim, x.ndim)))
    return jnp.where(expanded, x, y)
  return jax.tree.map(fn, xs, ys)


def mask(xs, mask):
  return where(mask, xs, jax.tree.map(jnp.zeros_like, xs))


def available(*trees, bdims=None):
  def fn(*xs):
    masks = []
    for x in xs:
      if jnp.issubdtype(x.dtype, jnp.floating):
        mask = (x != -jnp.inf)
      elif jnp.issubdtype(x.dtype, jnp.signedinteger):
        mask = (x != -1)
      elif (
          jnp.issubdtype(x.dtype, jnp.unsignedinteger) or
          jnp.issubdtype(x.dtype, bool)):
        shape = x.shape if bdims is None else x.shape[:bdims]
        mask = jnp.full(shape, True, bool)
      else:
        raise NotImplementedError(x.dtype)
      if bdims is not None:
        mask = mask.all(tuple(range(bdims, mask.ndim)))
      masks.append(mask)
    return jnp.stack(masks, 0).all(0)
  return jax.tree.map(fn, *trees)


@functools.partial(jax.custom_vjp, nondiff_argnums=[1, 2])
def ensure_dtypes(x, fwd=None, bwd=None):
  fwd = fwd or COMPUTE_DTYPE
  bwd = bwd or COMPUTE_DTYPE
  assert x.dtype == fwd, (x.dtype, fwd)
  return x
def ensure_dtypes_fwd(x, fwd=None, bwd=None):
  fwd = fwd or COMPUTE_DTYPE
  bwd = bwd or COMPUTE_DTYPE
  return ensure_dtypes(x, fwd, bwd), ()
def ensure_dtypes_bwd(fwd, bwd, cache, dx):
  fwd = fwd or COMPUTE_DTYPE
  bwd = bwd or COMPUTE_DTYPE
  assert dx.dtype == bwd, (dx.dtype, bwd)
  return (dx,)
ensure_dtypes.defvjp(ensure_dtypes_fwd, ensure_dtypes_bwd)


def rms(xs):
  xs = jax.tree.leaves(xs)
  count = sum(x.size for x in xs)
  sumsq = jnp.stack([f32(jnp.square(x).sum()) for x in xs]).sum()
  return jnp.sqrt(sumsq / f32(count))


def rope(x, ts=None, inverse=False, maxlen=4096):
  B, T, _, D = x.shape
  if ts is None:
    ts = jnp.ones(B, jnp.int32)[:, None] * jnp.arange(T)[None, :]  # [B, T]
  assert ts.shape == (B, T), (ts.shape, (B, T))
  if inverse:
    ts = -ts
  freq_exponents = (2.0 / D) * jnp.arange(D // 2)  # [D/2]
  timescale = maxlen ** freq_exponents
  radians = ts[:, :, None] / timescale[None, None, :]  # [B, T, D/2]
  radians = radians[..., None, :].astype(x.dtype)  # [B, T, 1, D/2]
  sin, cos = jnp.sin(radians), jnp.cos(radians)
  x1, x2 = jnp.split(x, 2, axis=-1)  # [B, T, H, D/2]
  res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
  return res


class Harmonizer(nj.Module):
  def __init__(self, k=1.0):
    self.k = k
    
  def __call__(self, loss, regularize=True):
    harmony_s = self.value('harmony_s', jnp.zeros, 1, f32)
    if regularize:
      return loss / (jnp.exp(harmony_s)) + self.k * jnp.log(jnp.exp(harmony_s) + 1)
    else:
      return loss / (jnp.exp(harmony_s)) + self.k * harmony_s

  def get_harmony(self):
    harmony_s = self.value('harmony_s', jnp.zeros, 1, f32)
    return harmony_s


class Initializer:

  def __init__(self, dist='trunc_normal', fan='in', scale=1.0):
    self.dist = dist
    self.fan = fan
    self.scale = scale

  def __call__(self, shape, dtype=jnp.float32, fshape=None):
    shape = (shape,) if isinstance(shape, int) else tuple(shape)
    assert all(isinstance(x, int) for x in shape), (
        shape, [type(x) for x in shape])
    assert all(x > 0 for x in shape), shape
    fanin, fanout = self.compute_fans(shape if fshape is None else fshape)
    fan = {
        'avg': (fanin + fanout) / 2, 'in': fanin, 'out': fanout, 'none': 1,
    }[self.fan]
    if self.dist == 'zeros':
      x = jnp.zeros(shape, dtype)
    elif self.dist == 'uniform':
      limit = np.sqrt(1 / fan)
      x = jax.random.uniform(nj.seed(), shape, dtype, -limit, limit)
    elif self.dist == 'normal':
      x = jax.random.normal(nj.seed(), shape)
      x *= np.sqrt(1 / fan)
    elif self.dist == 'trunc_normal':
      x = jax.random.truncated_normal(nj.seed(), -2, 2, shape)
      x *= 1.1368 * np.sqrt(1 / fan)
    elif self.dist == 'normed':
      x = jax.random.uniform(nj.seed(), shape, dtype, -1, 1)
      x *= (1 / jnp.linalg.norm(x.reshape((-1, shape[-1])), 2, 0))
    else:
      raise NotImplementedError(self.dist)
    x *= self.scale
    x = x.astype(dtype)
    return x

  def __repr__(self):
    return f'Initializer({self.dist}, {self.fan}, {self.scale})'

  def __eq__(self, other):
    attributes = ('dist', 'fan', 'scale')
    return all(getattr(self, k) == getattr(other, k) for k in attributes)

  @staticmethod
  def compute_fans(shape):
    if len(shape) == 0:
      return (1, 1)
    elif len(shape) == 1:
      return (1, shape[0])
    elif len(shape) == 2:
      return shape
    else:
      space = math.prod(shape[:-2])
      return (shape[-2] * space, shape[-1] * space)


class Embed(nj.Module):

  einit: str | Callable = Initializer('trunc_normal', 'out')
  combine: bool = False

  def __init__(self, classes, units, shape=()):
    self.classes = classes
    self.units = units
    self.shape = shape

  def __call__(self, x):
    batch_shape = x.shape[:x.ndim - len(self.shape)]
    event_shape = x.shape[x.ndim - len(self.shape):]
    assert event_shape == self.shape, (self.shape, x.shape)
    N = math.prod(self.shape)
    K = self.classes
    D = self.units
    shape = (*self.shape, self.classes, self.units)
    table = self.value('table', init(self.einit), shape)
    table = table.reshape(N, K, D)
    table = table.astype(COMPUTE_DTYPE)
    index = x.reshape(-1, N)
    embed = table[jnp.arange(N), index]
    if self.combine:
      embed = embed.sum(-2).reshape(*batch_shape, self.units)
    else:
      embed = embed.reshape(*batch_shape, *self.shape, self.units)
    return embed


class Linear(nj.Module):

  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  outscale: float = 1.0

  def __init__(self, units):
    self.units = (units,) if isinstance(units, int) else tuple(units)

  def __call__(self, x):
    ensure_dtypes(x)
    size = math.prod(self.units)
    shape = (x.shape[-1], size)
    x = x @ self.value('kernel', self._scaled_winit, shape).astype(x.dtype)
    if self.bias:
      x += self.value('bias', init(self.binit), size).astype(x.dtype)
    x = x.reshape((*x.shape[:-1], *self.units))
    return x

  def _scaled_winit(self, *args, **kwargs):
    return init(self.winit)(*args, **kwargs) * self.outscale


class BlockLinear(nj.Module):

  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  outscale: float = 1.0

  def __init__(self, units, blocks):
    assert isinstance(units, int), (units, type(units))
    assert blocks <= units and units % blocks == 0, (blocks, units)
    self.units = units
    self.blocks = blocks

  def __call__(self, x):
    ensure_dtypes(x)
    assert x.shape[-1] % self.blocks == 0, (x.shape, self.blocks)
    insize = x.shape[-1]
    shape = (self.blocks, insize // self.blocks, self.units // self.blocks)
    kernel = self.value('kernel', self._scaled_winit, shape).astype(x.dtype)
    x = x.reshape((*x.shape[:-1], self.blocks, insize // self.blocks))
    x = jnp.einsum('...ki,kio->...ko', x, kernel)
    x = x.reshape((*x.shape[:-2], self.units))
    if self.bias:
      x += self.value('bias', init(self.binit), self.units).astype(x.dtype)
    return x

  def _scaled_winit(self, *args, **kwargs):
    return init(self.winit)(*args, **kwargs) * self.outscale


class Conv2D(nj.Module):

  transp: bool = False
  groups: int = 1
  pad: str = 'same'
  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  outscale: float = 1.0

  def __init__(self, depth, kernel, stride=1):
    self.depth = depth
    self.kernel = (kernel,) * 2 if isinstance(kernel, int) else kernel
    self.stride = stride

  def __call__(self, x):
    ensure_dtypes(x)
    shape = (*self.kernel, x.shape[-1] // self.groups, self.depth)
    kernel = self.value('kernel', self._scaled_winit, shape).astype(x.dtype)
    if self.transp:
      assert self.pad == 'same', self.pad
      # Manual implementation of fractionally strided convolution because the
      # cuDNN implementation used by XLA has bugs and performance issues.
      x = x.repeat(self.stride, -2).repeat(self.stride, -3)
      maskh = ((jnp.arange(x.shape[-3]) - 1) % self.stride == 0)[:, None]
      maskw = ((jnp.arange(x.shape[-2]) - 1) % self.stride == 0)[None, :]
      x *= (maskh * maskw)[:, :, None]
      stride = (1, 1)
    else:
      stride = (self.stride, self.stride)
    x = jax.lax.conv_general_dilated(
        x, kernel, stride, self.pad.upper(),
        feature_group_count=self.groups,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
    if self.bias:
      x += self.value('bias', init(self.binit), self.depth).astype(x.dtype)
    return x

  def _scaled_winit(self, *args, **kwargs):
    return init(self.winit)(*args, **kwargs) * self.outscale


class Conv3D(nj.Module):

  transp: bool = False
  groups: int = 1
  pad: str = 'same'
  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')

  def __init__(self, depth, kernel, stride=1):
    self.depth = depth
    self.kernel = (kernel,) * 3 if isinstance(kernel, int) else kernel
    self.stride = (stride,) * 3 if isinstance(stride, int) else stride

  def __call__(self, x):
    ensure_dtypes(x)
    if self.transp:
      assert self.groups == 1, self.groups
      shape = (*self.kernel, x.shape[-1], self.depth)
      kernel = self.value('kernel', init(self.winit), shape).astype(x.dtype)
      x = jax.lax.conv_transpose(
          x, kernel, self.stride, self.pad.upper(),
          dimension_numbers=('NTHWC', 'THWIO', 'NTHWC'))
    else:
      shape = (*self.kernel, x.shape[-1] // self.groups, self.depth)
      kernel = self.value('kernel', init(self.winit), shape).astype(x.dtype)
      x = jax.lax.conv_general_dilated(
          x, kernel, self.stride, self.pad.upper(),
          feature_group_count=self.groups,
          dimension_numbers=('NTHWC', 'THWIO', 'NTHWC'))
    if self.bias:
      x += self.value('bias', init(self.binit), self.depth).astype(x.dtype)
    return x


class Norm(nj.Module):

  axis: tuple = (-1,)
  eps: float = 1e-4
  scale: bool = True
  shift: bool = True

  def __init__(self, impl):
    if '1em' in impl:
      impl, exp = impl.split('1em')
      self._fields['eps'] = 10 ** -int(exp)
    self.impl = impl

  def __call__(self, x):
    ensure_dtypes(x)
    dtype = x.dtype
    x = f32(x)
    axis = [a % x.ndim for a in self.axis]
    shape = [x.shape[i] if i in axis else 1 for i in range(min(axis), x.ndim)]
    if self.impl == 'none':
      pass
    elif self.impl == 'rms':
      mean2 = jnp.square(x).mean(axis, keepdims=True)
      mean2 = adc.checkpoint_name(mean2, 'small')
      scale = self._scale(shape, x.dtype)
      x = x * (jax.lax.rsqrt(mean2 + self.eps) * scale)
    elif self.impl == 'layer':
      mean = x.mean(axis, keepdims=True)
      mean2 = jnp.square(x).mean(axis, keepdims=True)
      mean2 = adc.checkpoint_name(mean2, 'small')
      var = jnp.maximum(0, mean2 - jnp.square(mean))
      var = adc.checkpoint_name(var, 'small')
      scale = self._scale(shape, x.dtype)
      shift = self._shift(shape, x.dtype)
      x = (x - mean) * (jax.lax.rsqrt(var + self.eps) * scale) + shift
    else:
      raise NotImplementedError(self.impl)
    x = x.astype(dtype)
    return x

  def _scale(self, shape, dtype):
    if not self.scale:
      return jnp.ones(shape, dtype)
    return self.value('scale', jnp.ones, shape, f32).astype(dtype)

  def _shift(self, shape, dtype):
    if not self.shift:
      return jnp.zeros(shape, dtype)
    return self.value('shift', jnp.zeros, shape, f32).astype(dtype)


class Attention(nj.Module):

  heads: int = 8
  kv_heads: int = 0
  dropout: float = 0.0
  rope: bool = True
  qknorm: str = 'none'
  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  outscale: float = 1.0

  def __call__(self, x, mask=None, ts=None, training=True):
    kw = dict(bias=self.bias, winit=self.winit, binit=self.binit)
    B, T, D = x.shape
    kv_heads = self.kv_heads or self.heads
    assert self.heads % kv_heads == 0
    head_ratio = self.heads // kv_heads
    if head_ratio == 1:
      qkv = self.sub('qkv', Linear, 3 * D, **kw)(x)
      q, k, v = jnp.split(qkv, 3, -1)
    else:
      q = self.sub('q', Linear, D, **kw)(x)
      k = self.sub('k', Linear, D // head_ratio, **kw)(x)
      v = self.sub('v', Linear, D // head_ratio, **kw)(x)
    q = einops.rearrange(q, 'b t (h d) -> b t h d', h=self.heads)
    k = einops.rearrange(k, 'b t (h d) -> b t h d', h=kv_heads)
    v = einops.rearrange(v, 'b t (h d) -> b t h d', h=kv_heads)

    if self.qknorm != 'none':
      q = self.sub('normq', Norm, self.qknorm)(q)
      k = self.sub('normk', Norm, self.qknorm)(k)

    if self.rope:
      q = rope(q, ts)
      k = rope(k, ts)

    q = einops.rearrange(q, 'b t (h g) d -> b t h g d', h=kv_heads)
    logits = einops.einsum(q, k, 'b tq h g d, b tk h d -> b h g tq tk')
    logits = logits * (1.0 / np.sqrt(k.shape[-1]))
    logits = f32(logits)
    if mask is not None:
      Tq, Tk = q.shape[1], k.shape[1]
      assert mask.shape == (B, Tq, Tk), (mask.shape, (B, Tq, Tk))
      mask = einops.rearrange(mask, 'b tq tk -> b 1 1 tq tk')
      logits = jnp.where(mask, logits, -1e30)
    weights = jax.nn.softmax(logits)
    weights = weights.astype(x.dtype)
    weights = dropout(weights, self.dropout, training)
    x = einops.einsum(weights, v, 'b h g tq tk, b tk h d -> b tq h g d')
    x = einops.rearrange(x, 'b t h g d -> b t (h g d)')
    x = self.sub('proj', Linear, D, **kw, outscale=self.outscale)(x)
    return x


class DictConcat:

  def __init__(self, spaces, fdims, squish=lambda x: x):
    assert 1 <= fdims, fdims
    self.keys = sorted(spaces.keys())
    self.spaces = spaces
    self.fdims = fdims
    self.squish = squish

  def __call__(self, xs):
    assert all(k in xs for k in self.spaces), (self.spaces, xs.keys())
    bdims = xs[self.keys[0]].ndim - len(self.spaces[self.keys[0]].shape)
    ys = []
    for key in self.keys:
      space = self.spaces[key]
      x = xs[key]
      m = available(x, bdims=bdims)
      x = mask(x, m)
      assert x.shape[bdims:] == space.shape, (key, bdims, space.shape, x.shape)
      if space.dtype == jnp.uint8 and len(space.shape) in (2, 3):
        raise NotImplementedError('Images are not supported.')
      elif space.discrete:
        classes = np.asarray(space.classes).flatten()
        assert (classes == classes[0]).all(), classes
        classes = classes[0].item()
        x = x.astype(jnp.int32)
        x = jax.nn.one_hot(x, classes, dtype=COMPUTE_DTYPE)
      else:
        x = self.squish(x)
        x = x.astype(COMPUTE_DTYPE)
      x = mask(x, m)
      x = x.reshape((*x.shape[:bdims + self.fdims - 1], -1))
      ys.append(x)
    return jnp.concatenate(ys, -1)


class DictEmbed(nj.Module):

  squish: Callable = lambda x: x
  padone: bool = True
  bias: bool = True
  einit: str | Callable = Initializer('trunc_normal', 'out')
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  impl: str = 'onehot'

  def __init__(self, spaces, units):
    self.keys = sorted(spaces.keys())
    self.spaces = spaces
    self.units = units
    self.ekw = dict(einit=self.einit)
    self.lkw = dict(bias=self.bias, winit=self.winit, binit=self.binit)

  def __call__(self, xs, bshape):
    assert isinstance(bshape, tuple), bshape
    assert all(k in xs for k in self.spaces), (self.spaces, xs.keys())
    ys = []
    init = self.value('init', self.einit, (self.units,))
    init = jnp.broadcast_to(init, (*bshape, self.units))
    init = COMPUTE_DTYPE(init)
    ys.append(init)
    for key in self.keys:
      try:
        space = self.spaces[key]
        x = xs[key]
        assert x.dtype == space.dtype, (key, space.dtype, x.dtype, x.shape)
        m = available(x, bdims=len(bshape))
        x = mask(x, m)
        if space.discrete:
          if space.dtype == jnp.uint8 and len(space.shape) in (2, 3):
            raise NotImplementedError('Images are not supported.')
          classes = int(np.asarray(space.classes).max())
          assert classes <= 256, (key, space, classes)
          if self.impl == 'lookup':
            x = self.sub(
                key, Embed, classes, self.units, space.shape,
                combine=True, **self.ekw)(x)
            # x = x.reshape((*x.shape[:len(bshape)], -1))
          elif self.impl == 'onehot':
            x = jax.nn.one_hot(x, classes, dtype=COMPUTE_DTYPE)
            x = x.reshape((*x.shape[:len(bshape)], -1))
            x = self.sub(key, Linear, self.units, **self.lkw)(x)
          else:
            raise NotImplementedError(self.impl)
        else:
          x = self.squish(x)
          x = x.astype(COMPUTE_DTYPE)
          x = x.reshape((*x.shape[:len(bshape)], -1))
          x = self.sub(key, Linear, self.units, **self.lkw)(x)
        x = mask(x, m)
        ys.append(x)
      except Exception:
        print(f"Error encoding key '{key}' with space {space}.")
        raise
    x = sum(ys)
    return x


class MLP(nj.Module):

  act: str = 'silu'
  norm: str = 'rms'
  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')

  def __init__(self, layers=5, units=1024):
    self.layers = layers
    self.units = units
    self.kw = dict(bias=self.bias, winit=self.winit, binit=self.binit)

  def __call__(self, x):
    shape = x.shape[:-1]
    x = x.astype(COMPUTE_DTYPE)
    x = x.reshape([-1, x.shape[-1]])
    for i in range(self.layers):
      x = self.sub(f'linear{i}', Linear, self.units, **self.kw)(x)
      x = self.sub(f'norm{i}', Norm, self.norm)(x)
      x = act(self.act)(x)
    x = x.reshape((*shape, x.shape[-1]))
    return x


class Transformer(nj.Module):

  units: int = 1024
  layers: int = 12
  heads: int = 8
  ffup: int = 4
  act: str = 'silu'
  norm: str = 'rms'
  glu: bool = False
  rope: bool = True
  qknorm: str = 'none'
  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  outscale: float = 1.0

  def __call__(self, x, mask=None, ts=None, training=True):
    kw = {k: getattr(self, k) for k in ('bias', 'winit', 'binit')}
    ak = {k: getattr(self, k) for k in ('heads', 'rope', 'qknorm', 'outscale')}
    D = x.shape[-1]
    assert D == self.units, (D, self.units)
    for i in range(self.layers):
      with nj.scope(f'layer{i}'):
        skip = x
        x = self.sub('norm1', Norm, self.norm)(x)
        x  = self.sub('mha', Attention, **kw, **ak)(x, mask, ts, training)
        x += skip
        skip = x
        x = self.sub('norm2', Norm, self.norm)(x)
        if self.glu:
          U = max(D, int((D * self.ffup * 2 / 3) // 32 * 32))
          ff1 = self.sub('ff1', Linear, U, **kw)
          ff2 = self.sub('ff2', Linear, U, **kw)
          ff3 = self.sub('ff3', Linear, D, **kw, outscale=self.outscale)
          x = ff3(act(self.act)(ff1(x)) * ff2(x))
        else:
          ff1 = self.sub('ff1', Linear, D * self.ffup, **kw)
          ff2 = self.sub('ff2', Linear, D, **kw, outscale=self.outscale)
          x = ff2(act(self.act)(ff1(x)))
        x += skip
    x = self.sub('outnorm', Norm, self.norm)(x)
    return x


class GRU(nj.Module):

  units: int = 1024
  bias: bool = True
  winit: str | Callable = Initializer('trunc_normal')
  binit: str | Callable = Initializer('zeros')
  norm: str = 'rms'
  update_bias: float = -1.0

  def initial(self, batch_size):
    return jnp.zeros((batch_size, self.units), COMPUTE_DTYPE)

  def __call__(self, carry, inputs, resets, single=False):
    assert carry.dtype == COMPUTE_DTYPE, carry.dtype
    assert inputs.dtype == COMPUTE_DTYPE, inputs.dtype
    assert resets.dtype == bool, resets.dtype
    if single:
      return self.step(carry, inputs, resets)
    carry, outputs = nj.scan(
        lambda carry, args: self.step(carry, *args),
        carry, (inputs, resets), axis=1)
    return carry, outputs

  def step(self, carry, inp, reset):
    # NOTE: When passing previous actions as input, ensure to zero out past
    # actions on is_first and clip actions to bounds if needed.
    kw = dict(bias=self.bias, winit=self.winit, binit=self.binit)
    carry = mask(carry, ~reset)
    x = jnp.concatenate([carry, inp], -1)
    x = self.sub('norm', Norm, self.norm)(x)
    x = self.sub('linear', Linear, 3 * self.units, **kw)(x)
    res, cand, update = jnp.split(x, 3, -1)
    cand = jnp.tanh(jax.nn.sigmoid(res) * cand)
    update = jax.nn.sigmoid(update + self.update_bias)
    carry = output = update * cand + (1 - update) * carry
    return carry, output

