"""Pax ML model for patched time-series decoder.

The file implements Residual MLPs, Patched Decoder layers and PAX ML models.
"""

import dataclasses
from typing import Callable, Optional, Tuple

import einshape as es
import jax
from jax import lax
import jax.numpy as jnp
from praxis import base_hyperparams
from praxis import base_layer
from praxis import base_model
from praxis import layers
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import activations
from praxis.layers import embedding_softmax
from praxis.layers import linears
from praxis.layers import normalizations
from praxis.layers import stochastics
from praxis.layers import transformer_models
from praxis.layers import transformers
from tensorflow_probability.substrates import jax as tfp  # pylint:disable=g-importing-member

import definitions


# PAX shortcuts
NestedMap = py_utils.NestedMap
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
InstantiableParams = py_utils.InstantiableParams
JTensor = pytypes.JTensor
NpTensor = pytypes.NpTensor
WeightedScalars = pytypes.WeightedScalars
instantiate = base_hyperparams.instantiate
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
AuxLossStruct = base_layer.AuxLossStruct

AUX_LOSS = base_layer.AUX_LOSS
template_field = base_layer.template_field

# Standard prng key names
PARAMS = base_layer.PARAMS
RANDOM = base_layer.RANDOM

tfd = tfp.distributions
PAD_VAL = definitions.PAD_VAL
DATE_FEATURE_DICT = definitions.DATE_FEATURE_DICT

# NestedMap keys
_INPUT_TS = 'input_ts'
_INPUT_PADDING = definitions.NAME_INPUT_PADDING
_DATE_FEATURES = definitions.NAME_DATE_FEATURES
_STATS = 'stats'
_FREQ = definitions.NAME_FREQ
_OUTPUT_TS = 'output_ts'
_TRAIN_LOSS = 'train_loss'
_OUTPUT_TOKENS = 'output_tokens'
_AUGMENTED_BATCH = 'augmented_batch'

# Numerical Constants
MIX_UB = 1e4
MIX_LB = 1e-4
CHUNK_SEPARATOR_VAL_PADS = 0
CHUNK_SEPARATOR_VAL_FREQ = 0
CHUNK_SEPARATOR_VAL_DATE = 0
INPUT_CHUNK_EMBEDDING_ID = 0
DATE_CHUNK_EMBEDDING_ID = 1

NORMAL_SCALE_UB = 1e3
NORMAL_SCALE_LB = 1e-3

# Small numerical value.
_TOLERANCE = 1e-7


def _softerplus(x: JTensor) -> JTensor:
  """Positive mapping."""
  return jnp.divide(2.0 + x + jnp.absolute(x), 2.0 - x + jnp.absolute(x))


def _softermax(x: JTensor) -> JTensor:
  """A different version of softmax."""
  mixin = _softerplus(x)
  mixin = jnp.divide(mixin, 1.0 + mixin)
  mixin = mixin * (MIX_UB - MIX_LB) + MIX_LB
  return mixin / jnp.sum(mixin, axis=-1, keepdims=True)


def _normal_link(x: JTensor) -> JTensor:
  """Link function for mapping logits to Normal dist. parameters."""
  loc = x[..., [0]]
  scale = x[..., [1]]
  scale = _softerplus(scale)
  scale = scale / (1.0 + scale)
  scale = scale * (NORMAL_SCALE_UB - NORMAL_SCALE_LB) + NORMAL_SCALE_LB
  return jnp.concatenate([loc, scale], axis=-1)


def get_normal_mixture(x: JTensor, num_components: int) -> tfd.Mixture:
  """Define a mixture of gaussians."""
  if x.shape[-1] < 2 * num_components + num_components:
    raise ValueError(
        'The last dimension of `x` is insufficient to represent the mixture.'
    )
  if num_components == 1:
    normal_out = _normal_link(x[..., 0:2])
    return tfd.Normal(loc=normal_out[..., 0], scale=normal_out[..., 1])
  mixture_weights = _softermax(x[..., :num_components])
  comps = []
  for i in range(num_components):
    normal_in = x[..., num_components + 2 * i : num_components + 2 * i + 2]
    normal_out = _normal_link(normal_in)
    comps.append(tfd.Normal(loc=normal_out[..., 0], scale=normal_out[..., 1]))
  return tfd.Mixture(
      cat=tfd.Categorical(probs=mixture_weights), components=comps
  )


def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
  """Shifts rows of seq based on the first 0 in each row of the mask."""
  num = seq.shape[-2]

  # Find the index of the first 0 in each row of the mask
  first_zero_idx = jnp.argmin(mask, axis=-1)

  # Create a range array for indexing
  idx_range = jnp.arange(num)

  def shift_row(carry, x):
    seq_row, shift = x
    shifted_idx = (idx_range - shift) % num
    shifted_row = seq_row[shifted_idx]
    return carry, shifted_row

  # Use lax.scan to shift each row of seq based on the corresponding
  # first_zero_idx.
  _, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))

  return shifted_seq


def _map_freq_tensor(tensor: JTensor) -> JTensor:
  """Maps frequency tensor to zeros, ones, twos."""
  conditions = [
      tensor < 6,  # Values 0-5 map to 0
      (tensor >= 6) & (tensor < 8),  # Values 6-7 map to 1
      tensor >= 8,  # Values 8-9 map to 2
  ]

  choices = [
      jnp.zeros_like(tensor),  # Replace with zeros
      jnp.ones_like(tensor),  # Replace with ones
      2 * jnp.ones_like(tensor),  # Replace with twos
  ]

  return jnp.select(conditions, choices)


def _mix_series(key: jax.Array, arr: JTensor, pad: JTensor):
  alpha = jax.random.uniform(key, shape=())
  perm_arr = jax.random.permutation(key, arr, axis=0)
  new_pad = jax.random.permutation(key, pad, axis=0)
  new_pad = jnp.where(jnp.abs(pad - 1) < _TOLERANCE, 1, new_pad)
  new_arr = perm_arr * alpha + (1 - alpha) * arr
  new_arr = jnp.where(jnp.abs(arr - PAD_VAL) < _TOLERANCE, PAD_VAL, new_arr)
  new_arr = jnp.where(
      jnp.abs(perm_arr - PAD_VAL) < _TOLERANCE, PAD_VAL, new_arr
  )
  return new_arr, new_pad


class ResidualBlock(base_layer.BaseLayer):
  """Simple feedforward block with residual connection.

  Attributes:
    input_dims: input dimension.
    hidden_dims: hidden dimension.
    output_dims: output dimension.
    layer_norm: whether to use layer norm or not.
    dropout_tpl: config for dropout.
    ln_tpl: config for layer norm.
    act_tpl: config for activation in hidden layer.
  """

  input_dims: int = 0
  hidden_dims: int = 0
  output_dims: int = 0
  dropout_prob: float = 0.0
  layer_norm: bool = False
  dropout_tpl: LayerTpl = template_field(stochastics.Dropout)
  ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)
  act_tpl: LayerTpl = template_field(activations.Swish)

  def setup(self):
    lnorm_tpl = self.ln_tpl.clone()
    lnorm_tpl.dim = self.output_dims
    self.create_child('ln_layer', lnorm_tpl)

    dropout_tpl = self.dropout_tpl.clone()
    dropout_tpl.keep_prob = 1.0 - self.dropout_prob
    self.create_child('dropout', dropout_tpl)

    self.create_child(
        'hidden_layer',
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.input_dims,
            output_dims=self.hidden_dims,
            activation_tpl=self.act_tpl.clone(),
        ),
    )

    self.create_child(
        'output_layer',
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.hidden_dims,
            output_dims=self.output_dims,
            activation_tpl=pax_fiddle.Config(activations.Identity),
        ),
    )

    self.create_child(
        'residual_layer',
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.input_dims,
            output_dims=self.output_dims,
            activation_tpl=pax_fiddle.Config(activations.Identity),
        ),
    )

  def __call__(self, inputs: JTensor) -> JTensor:
    hidden = self.hidden_layer(inputs)
    output = self.output_layer(hidden)
    output = self.dropout(output)
    residual = self.residual_layer(inputs)

    if self.layer_norm:
      return self.ln_layer(output + residual)
    else:
      return output + residual


class MoEResidualBlock(base_layer.BaseLayer):
  """MoE feedforward block with residual connection.

  Attributes:
    input_dims: input dimension.
    hidden_dims: hidden_dimension
    output_dims: output dimension.
    layer_norm: whether to use layer norm or not.
    dropout_tpl: config for dropout.
    ln_tpl: config for layer norm.
    moe_tpl: moe block tpl.
  """

  input_dims: int = 0
  output_dims: int = 0
  hidden_dims: int = 0
  dropout_prob: float = 0.0
  layer_norm: bool = False
  dropout_tpl: LayerTpl = template_field(stochastics.Dropout)
  ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)
  moe_tpl: LayerTpl = template_field(transformers.TransformerFeedForwardMoe)

  def setup(self):
    lnorm_tpl = self.ln_tpl.clone()
    lnorm_tpl.dim = self.output_dims
    self.create_child('ln_layer', lnorm_tpl)

    dropout_tpl = self.dropout_tpl.clone()
    dropout_tpl.keep_prob = 1.0 - self.dropout_prob
    self.create_child('dropout', dropout_tpl)

    moe_tpl_cl = self.moe_tpl.clone()
    moe_tpl_cl.hidden_dims = self.hidden_dims
    moe_tpl_cl.input_dims = self.input_dims
    moe_tpl_cl.norm_policy = 'post'
    self.create_child('moe_layer', moe_tpl_cl)

    self.create_child(
        'output_layer',
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.input_dims,
            output_dims=self.output_dims,
            activation_tpl=pax_fiddle.Config(activations.Identity),
        ),
    )

    self.create_child(
        'residual_layer',
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.input_dims,
            output_dims=self.output_dims,
            activation_tpl=pax_fiddle.Config(activations.Identity),
        ),
    )

  def __call__(self, inputs: JTensor) -> JTensor:
    hidden = self.moe_layer(inputs)
    output = self.output_layer(hidden)
    output = self.dropout(output)
    residual = self.residual_layer(inputs)
    if self.layer_norm:
      return self.ln_layer(output + residual)
    else:
      return output + residual


def _masked_mean_std(
    inputs: JTensor, padding: JTensor
) -> Tuple[JTensor, JTensor]:
  """Calculates mean and standard deviation of arr across axis 1.

  It should exclude values where pad is 1.

  Args:
    inputs: A JAX array of shape [..., n, p].
    padding: A JAX array of shape [..., n, p] with values 0 or 1.

  Returns:
    A tuple containing the mean and standard deviation of arr.
    Mean and std are of shape [...] (i.e., we collapse the n and p dimensions).
    We return the statistics of the first patch with more than three non-padded 
    values.
  """
  # Selecting the first pad with more than 3 unpadded values.
  pad_sum = jnp.sum(1 - padding, axis=-1)

  def _get_patch_index(arr: JTensor):
    indices = jnp.argmax(arr >= 3, axis=-1)
    row_sum = (arr >= 3).sum(axis=-1)
    return jnp.where(row_sum == 0, arr.shape[-1] - 1, indices)

  patch_indices = _get_patch_index(pad_sum)
  arr = jnp.take_along_axis(inputs,
                            patch_indices[..., None, None],
                            axis=-2)[..., 0, :]
  pad = jnp.take_along_axis(padding,
                            patch_indices[..., None, None],
                            axis=-2)[..., 0, :]

  # Create a mask where P is 0
  mask = 1 - pad

  # Calculate the number of valid elements
  num_valid_elements = jnp.sum(mask, axis=-1)

  num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)

  # Calculate the masked sum and squared sum of M
  masked_sum = jnp.sum(arr * mask, axis=-1)
  masked_squared_sum = jnp.sum((arr * mask) ** 2, axis=-1)

  # Calculate the masked mean and standard deviation
  masked_mean = masked_sum / num_valid_elements
  masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
  masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)
  masked_std = jnp.sqrt(masked_var)

  return masked_mean, masked_std


def _create_quantiles() -> list[float]:
  return [0.1, 0.25, 0.5, 0.75, 0.9]


class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
  """Patch decoder layer for time-series foundation model.

  Attributes:
    patch_len: length of input patches.
    horizon_len: length of output patches.
    model_dims: model dimension of stacked transformer layer.
    hidden_dims: hidden dimensions in fully connected layers.
    revin: type of reversible normalization 
      -- see: https://openreview.net/forum?id=cGDAkQo1C0p for details.
    probabilistic: prob. model or not.
    use_pos_emb: use (absolute) positional embeddings or not.
    num_components: number of normal components in mixture loss function.
    quantiles: list of quantiles for non prob model.
    use_date_features: use date features or not.
    use_freq: use freq. as a feature in the model.
    use_prev_chunk_vals: whether prev chunk values should be fed to transformer
      or not (in the case of chunked data, otherwise ignored)
    use_prev_chunk_separators: whether prev chunk separators should be fed to 
      transformer or not (in the case of chunked data, otherwise ignored)
    attend_to_back_padded: whether to attend to back padded patches or not.
    residual_block_tpl: config for residual block.
    moe_residual_block_tpl: config for moe residula block
    use_moe: if true then MOE layer is used for input and output.
    stacked_transformer_params_tpl: config for stacked transformer.
    shift_pos_emb_with_pad: shift positional embedding to start at the first
      patch that is not padded.
    repad_front_zeroes: whether to repad the front zeroes after the forward
      transform.

  In all of what followed, except specified otherwise, B is batch size, 
  C is the number of chunks in each example, T is
  sequence length of time-series. N is the number of input patches that can be
  obtained from T. P is the input patch length and H is the horizon length. Q is
  number of output logits. D is model dimension.
  """

  patch_len: int = 0
  horizon_len: int = 0
  model_dims: int = 0
  hidden_dims: int = 0
  revin: definitions.RevinType = definitions.RevinType.NORMED
  probabilistic: bool = True
  use_pos_emb: bool = True
  num_components: int = 2
  quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)
  use_date_features: bool = True
  use_freq: bool = False
  use_prev_chunk_vals: bool = False
  use_prev_chunk_separators: bool = False
  attend_to_back_padded: bool = True
  residual_block_tpl: Optional[LayerTpl] = template_field(MoEResidualBlock)
  moe_residual_block_tpl: Optional[LayerTpl] = template_field(MoEResidualBlock)
  use_moe: bool = True
  stacked_transformer_params_tpl: LayerTpl = template_field(
      transformers.StackedTransformer
  )
  shift_pos_emb_with_pad: bool = True
  repad_front_zeroes: bool = False

  def setup(self) -> None:
    """Construct the model."""
    if self.revin == definitions.RevinType.RENORMED and self.probabilistic:
      raise ValueError(
          'Probabilistic models are not supported under "RENORMED" revin'
      )
    if self.probabilistic:
      num_outputs = 3 * self.num_components
    else:
      num_outputs = len(self.quantiles) + 1

    stl = self.stacked_transformer_params_tpl.clone()
    stl.model_dims = self.model_dims
    stl.hidden_dims = self.hidden_dims
    stl.mask_self_attention = True

    if (
        self.mesh_axis_names is not None
        and len(self.mesh_axis_names) == 3
        and self.ici_mesh_shape is not None
        and len(self.ici_mesh_shape) == 3
    ):
      replica_axis, data_axis, model_axis = self.mesh_axis_names

      stl.ici_mesh_shape = self.ici_mesh_shape
      stl.dcn_mesh_shape = self.dcn_mesh_shape
      stl.mesh_axis_names = self.mesh_axis_names

    self.create_child('stacked_transformer_layer', stl)

    if self.use_moe:
      assert self.moe_residual_block_tpl is not None
      input_resl = self.moe_residual_block_tpl.clone()
    else:
      assert self.residual_block_tpl is not None
      input_resl = self.residual_block_tpl.clone()

    if self.use_date_features:
      ff_in_dims = self.patch_len * (len(DATE_FEATURE_DICT) + 2)
    else:
      ff_in_dims = 2 * self.patch_len
    input_resl.input_dims = ff_in_dims
    input_resl.hidden_dims = self.hidden_dims
    input_resl.output_dims = self.model_dims
    self.create_child(
        'input_ff_layer',
        input_resl,
    )
    if self.use_moe and self.moe_residual_block_tpl is not None:
      horizon_resl = self.moe_residual_block_tpl.clone()
    elif self.residual_block_tpl is not None:
      horizon_resl = self.residual_block_tpl.clone()
    else:
      raise ValueError(
          'Atleast one of the input layer configs must be supplied.'
      )
    if self.use_date_features:
      horizon_resl.input_dims = self.model_dims + len(DATE_FEATURE_DICT)
    else:
      horizon_resl.input_dims = self.model_dims
    horizon_resl.hidden_dims = self.hidden_dims
    horizon_resl.output_dims = self.horizon_len * num_outputs
    self.create_child(
        'horizon_ff_layer',
        horizon_resl,
    )

    if self.use_pos_emb:
      self.create_child(
          'position_emb',
          pax_fiddle.Config(
              layers.PositionalEmbedding, embedding_dims=self.model_dims
          ),
      )
    if self.use_freq:
      self.create_child(
          'freq_emb',
          pax_fiddle.Config(
              embedding_softmax.Embedding,
              num_classes=3,
              input_dims=self.model_dims,
          ),
      )
    if self.use_prev_chunk_separators:
      self.create_child(
          'chunk_separator_emb',
          pax_fiddle.Config(
              embedding_softmax.Embedding,
              num_classes=1,
              input_dims=self.model_dims,
          )
      )

  def _add_separator(self,
                     inputs: JTensor,
                     id: float | None=None,
                     val: float | None=None) -> JTensor:
    """Returns the inputs with a separator at the end of each time-series.

    Args:
      inputs: input tensor of shape [..., N, D].
      id: the id of the separator.
    Returns:
      tensor of shape [..., N+1, D].
    """
    sep = jnp.broadcast_to(self.chunk_separator_emb(id) if id is not None else val,
                           (*inputs.shape[:-2], 1, inputs.shape[-1]))
    return jnp.concat(
        [
            inputs,
            sep,
        ],
        axis=-2,
    )

  def transform_decode_state(
      self, transform_fn: base_layer.DecodeStateTransformFn
  ) -> None:
    """Transforms all decode state variables based on transform_fn."""
    self.stacked_transformer_layer.transform_decode_state(transform_fn)

  def _forward_transform(
      self, inputs: JTensor, patched_pads: JTensor
  ) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:
    """Input is of shape [..., N, P]."""
    mu, sigma = _masked_mean_std(inputs, patched_pads)
    sigma = jnp.where(sigma < _TOLERANCE, 1.0, sigma)
    # Normalize each patch.
    outputs = (inputs - mu[..., None, None]) / sigma[..., None, None]
    outputs = jnp.where(
        jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs
    )
    return outputs, (mu, sigma)

  def _reverse_transform(
      self, outputs: JTensor, stats: Tuple[JTensor, JTensor]
  ) -> JTensor:
    """Output is of shape [B, C, N, P, Q]."""
    mu, sigma = stats
    return outputs * sigma[..., None, None, None] + mu[..., None, None, None]

  def _preprocess_input(
      self,
      input_ts: JTensor,
      input_padding: JTensor,
      date_features: JTensor,
      freq: JTensor | None = None,
      pos_emb: Optional[JTensor] = None,
      n_chunks: int = 1,
      use_prev_chunk_vals: bool = False,
      use_prev_chunk_separators: bool = False,
      attend_to_back_padded: bool = True,
  ) -> Tuple[
      JTensor,
      JTensor,
      JTensor,
      Optional[JTensor],
      Optional[Tuple[JTensor, JTensor]],
      JTensor,
  ]:
    """Preprocess input for stacked transformer."""
    # Reshape into patches.
    stats = None
    patched_inputs = es.jax_einshape('...(np)->...np',
                                     input_ts,
                                     p=self.patch_len)
    patched_pads = es.jax_einshape(
        '...(np)->...np',
        input_padding,
        p=self.patch_len
    )
    patched_inputs = jnp.where(
        jnp.abs(patched_pads - 1.0) < _TOLERANCE, 0.0, patched_inputs
    )
    patched_pads = jnp.where(
        jnp.abs(patched_inputs - PAD_VAL) < _TOLERANCE, 1, patched_pads
    )

    diff_features = jnp.diff(date_features, axis=-2)
    diff_features = jnp.concatenate(
        [
            jnp.zeros(
                (*date_features.shape[:-2], 1, date_features.shape[-1])
            ),
            diff_features,
        ],
        axis=-2,
    )
    patched_diff_features = es.jax_einshape(
        '...(np)f->...n(pf)',
        jnp.where(input_padding[..., jnp.newaxis] == 1.0, -1.0, diff_features),
        p=self.patch_len,
    )[..., -date_features.shape[-1]:]

    patched_features = es.jax_einshape(
        '...(np)f->...n(pf)',
        jnp.where(input_padding[..., jnp.newaxis] == 1.0, -1.0, date_features),
        p=self.patch_len,
    )
    if self.revin in [
        definitions.RevinType.NORMED,
        definitions.RevinType.RENORMED,
    ]:
      patched_inputs, stats = self._forward_transform(
          patched_inputs, patched_pads
      )

    # Initial open model trained without the following line.
    if self.repad_front_zeroes:
      patched_inputs = patched_inputs * (1.0 - patched_pads)

    if self.use_date_features:
      concat_inputs = jnp.concatenate(
          [patched_inputs, patched_pads, patched_features], axis=-1
      )
    else:
      concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)

    segment_mask = None
    if n_chunks > 1 and (use_prev_chunk_separators or use_prev_chunk_vals):
      batch_size, _, n_patches, _ = patched_inputs.shape
      if use_prev_chunk_separators and not use_prev_chunk_vals:
        # Attend to the separator patches, but nothing else
        segment_mask = jnp.ones((
            batch_size,
            n_chunks,
            n_chunks,
            n_patches + 1,
        ))
        segment_mask = segment_mask.at[..., -1].set(0)
      elif not use_prev_chunk_separators and use_prev_chunk_vals:
        # Attend to everything (no separators are created)
        segment_mask = jnp.zeros((
            batch_size,
            n_chunks,
            n_chunks,
            n_patches,
        ))
      else:
        # Attend to everything, since use_prev_chunk_separators
        # and use_prev_chunk_vals are true
        segment_mask = jnp.zeros((
            batch_size,
            n_chunks,
            n_chunks,
            n_patches+1,
        ))
      if not attend_to_back_padded:
        # Ignore the patches from previous chunks which were back-padded
        # Recall patched_pads of shape [B, C, N, P]
        # So if patched_pads[..., -1] == 1 then the patch is back-padded
        # Here back_padded is of shape [B, C, N(+1)]
        back_padded = jnp.concatenate(
            [
                patched_pads[..., -1],
                jnp.zeros(
                    (batch_size, n_chunks, int(use_prev_chunk_separators))
                ),
            ],
            axis=-1,
        )
        # Attend to patches that are not back-padded
        # So if a patch is masked already, OR if it is back-padded, then
        # we set the mask to 1.
        # Note that the causal mask is already applied in the transformer,
        # and these masks will be merged
        segment_mask = jnp.maximum(
            segment_mask, back_padded[:, jnp.newaxis, :, :]
        )
      # Ensure we always attend to everything in the current patch
      segment_mask = segment_mask.at[
          :, jnp.arange(n_chunks), jnp.arange(n_chunks), :
      ].set(0)
      segment_mask = jnp.repeat(
          segment_mask[:, :, jnp.newaxis, ...],
          repeats=n_patches + use_prev_chunk_separators,
          axis=2,
      )
      segment_mask = es.jax_einshape('bcncn->b(cn)(cn)', segment_mask)
      segment_mask = segment_mask[:, jnp.newaxis, ...]
      # Since the mask needs to be ready to pass to logits
      segment_mask *= py_utils.get_large_negative_number(jnp.float32)

    model_input = self.input_ff_layer(concat_inputs)
    if use_prev_chunk_vals or use_prev_chunk_separators:
      if n_chunks > 1 and use_prev_chunk_separators:
        model_input = self._add_separator(model_input,
                                          id=INPUT_CHUNK_EMBEDDING_ID)
        # NOTE: We set the pad separator to 0 so that the transformer
        # will attend to the separator patches.
        patched_pads = self._add_separator(patched_pads,
                                           val=CHUNK_SEPARATOR_VAL_PADS)
        if self.use_date_features:
          # NOTE: This is just a dummy separator for the date features
          # of the separator token. Ultimately this is just appended to the
          # transformer's output, which is then ignored since it's the
          # separator. So this will have no effect on the output.
          patched_diff_features = self._add_separator(
              patched_diff_features,
              val=CHUNK_SEPARATOR_VAL_DATE)
      model_input = es.jax_einshape('bcn...->b(cn)...',
                                    model_input)
      patched_pads = es.jax_einshape('bcn...->b(cn)...', patched_pads)
      # patched_features = es.jax_einshape('bcn...->b(cn)...',
      #                                             patched_features)
      patched_diff_features = es.jax_einshape(
          'bcn...->b(cn)...', patched_diff_features
      )
    elif n_chunks > 1:
      # In this case, we don't want to feed previous chunks to the transformer.
      # Thus, we reshape things so that each chunk is fed separately as an
      # example.
      model_input = es.jax_einshape('bcn...->(bc)n...',
                                    model_input)
      patched_pads = es.jax_einshape('bcn...->(bc)n...', patched_pads)
      # patched_features = es.jax_einshape('bcn...->(bc)n...',
      #                                             patched_features)
      patched_diff_features = es.jax_einshape(
          'bcn...->(bc)n...', patched_diff_features
      )

    # A patch should not be padded even if there is at least one zero.
    patched_padding = jnp.min(patched_pads, axis=-1)
    if self.use_pos_emb:
      if pos_emb is None:
        position_emb = self.position_emb(seq_length=model_input.shape[-2])
      else:
        position_emb = pos_emb
      if self.shift_pos_emb_with_pad and self.do_eval:
        if position_emb.shape[0] != model_input.shape[0]:
          position_emb = jnp.broadcast_to(position_emb, model_input.shape)
        position_emb = _shift_padded_seq(patched_padding, position_emb)
      model_input += position_emb
    if self.use_freq:
      freq = _map_freq_tensor(freq)
      f_emb = self.freq_emb(freq)  # B (x C) x 1 x D
      n_per_chunk = patched_inputs.shape[-2]
      f_emb = jnp.repeat(f_emb, n_per_chunk, axis=-2) # B (x C) x N x D
      if use_prev_chunk_vals or use_prev_chunk_separators:
        if n_chunks > 1 and use_prev_chunk_separators:
          f_emb = self._add_separator(f_emb,
                                      val=CHUNK_SEPARATOR_VAL_FREQ) # BxCxN+1xD
        f_emb = es.jax_einshape('bcn...->b(cn)...', f_emb) # BxC.(N or N+1)xD
      elif n_chunks > 1:
        f_emb = es.jax_einshape('bcn...->(bc)n...', f_emb) # B.CxNxD
      model_input += f_emb
    return (
        model_input,
        patched_padding,
        patched_diff_features,
        segment_mask,
        stats,
        patched_inputs,
    )

  def _postprocess_output(
      self,
      model_output: JTensor,
      num_outputs: int,
      stats: Tuple[JTensor, JTensor],
      n_chunks: int,
      use_prev_chunk_vals: bool,
      use_prev_chunk_separators: bool,
  ) -> JTensor:
    """Postprocess output of stacked transformer."""
    if use_prev_chunk_vals or use_prev_chunk_separators:
      # B x C.(N or N+1) x H -> B x C x N or N+1 x H
      model_output = es.jax_einshape('b(cn)...->bcn...',
                                     model_output,
                                     c=n_chunks)
    elif n_chunks > 1:
      # B.C x N x H -> B x C x N x H
      model_output = es.jax_einshape('(bc)n...->bcn...',
                                     model_output,
                                     c=n_chunks)
    # B (x C) x (N or N+1) x H -> B (x C) x N or N+1 x H.Q
    output_ts = self.horizon_ff_layer(model_output)
    # B (x C) x (N or N+1) x H.Q -> B (x C) x N or N+1 x H x Q
    output_ts = es.jax_einshape(
        '...n(hq)->...nhq', output_ts, q=num_outputs, h=self.horizon_len
    )
    if self.revin == definitions.RevinType.RENORMED:
      output_ts = self._reverse_transform(output_ts, stats)
    if use_prev_chunk_separators and n_chunks > 1:
      # Drop the predictions from the separator patches.
      # Note if C=1 then do not add a chunk separator.
      return output_ts[..., :-1, :, :]
    else:
      # There are no separator patches.
      return output_ts

  def _quantile_loss(
      self, pred: JTensor, actual: JTensor, quantile: float
  ) -> JTensor:
    """Calculates quantile loss.

    Args:
      pred: B x C x T
      actual: B x C x T
      quantile: quantile at which loss is computed.

    Returns:
      scalar loss.
    """
    dev = actual - pred
    loss_first = dev * quantile
    loss_second = -dev * (1.0 - quantile)
    return 2 * jnp.where(loss_first >= 0, loss_first, loss_second)

  def _mse_loss(self, pred: JTensor, actual: JTensor) -> JTensor:
    """Calculates mse loss."""
    return jnp.square(pred - actual)

  def _mixture_loss(self, logits: JTensor, actual: JTensor) -> JTensor:
    """Calculates mixture loss."""
    mixt_dist = get_normal_mixture(logits, self.num_components)
    return -mixt_dist.log_prob(actual)

  def compute_loss(
      self,
      output_ts: JTensor,
      actual_ts: JTensor,
      stats: Optional[Tuple[JTensor, JTensor]] = None,
  ) -> JTensor:
    """Computes loss.

    Args:
      output_ts: B (x C) x N x H x Q
      actual_ts: B (x C) x N x H
      stats: optional statistics tensor tuple.

    Returns:
      Mixture distribution loss or quantile loss.
    """
    if self.probabilistic:
      loss = self._mixture_loss(output_ts, actual_ts)
    else:
      loss = self._mse_loss(output_ts[..., 0], actual_ts)
      for i, quantile in enumerate(self.quantiles):
        ipred = output_ts[..., i + 1]
        loss += self._quantile_loss(ipred, actual_ts, quantile)
    weights = jnp.where(jnp.abs(actual_ts - PAD_VAL) < _TOLERANCE, 0, 1.0)
    if stats is not None:
      var = stats[1] ** 2
      weights = weights * var[..., None, None]
    return jnp.sum(loss * weights) / (jnp.sum(weights) + _TOLERANCE)

  def __call__(self, inputs: NestedMap) -> NestedMap:
    """PatchTST call.

    Args:
      inputs: A NestedMap containing, input_ts -- Input sequence 
        of shape [B(, C), T] where T must be multiple of patch_length,
        input_padding that contains
        padding map. date_features -- date features of shape B (x C) x T x F
        where F is the number of keys on DATE_FEATURE_DICT.

    Returns:
      A nested map with two keys:
      (1) 'output_tokens' of shape [B(, C), N, D].
      (2) 'output_ts' of shape [B(, C), N, H, Q]
      (3) 'stats' a Tuple of statistics for renormalization.
    """
    input_ts, input_padding, date_features, freq = (
        inputs[_INPUT_TS],
        inputs[_INPUT_PADDING],
        inputs[_DATE_FEATURES],
        inputs[_FREQ].astype(jnp.int32) if self.use_freq else None,
    )
    if len(input_ts.shape) == 3:
      n_chunks = input_ts.shape[1]
      use_prev_chunk_vals = self.use_prev_chunk_vals
      use_prev_chunk_separators = self.use_prev_chunk_separators
    elif len(input_ts.shape) == 2:
      n_chunks = 1
      use_prev_chunk_vals = False
      use_prev_chunk_separators = False
    else:
      raise ValueError(
          'Input_ts must be of shape [B, T] or [B, C, T].'
      )
    if self.probabilistic:
      num_outputs = 3 * self.num_components
    else:
      num_outputs = len(self.quantiles) + 1
    (
        model_input,
        patched_padding,
        patched_diff_features,
        segment_mask,
        stats,
        _,
    ) = self._preprocess_input(
        input_ts=input_ts,
        input_padding=input_padding,
        date_features=date_features,
        freq=freq,
        n_chunks=n_chunks,
        use_prev_chunk_vals=use_prev_chunk_vals,
        use_prev_chunk_separators=use_prev_chunk_separators,
        attend_to_back_padded=self.attend_to_back_padded,
    )
    model_output = self.stacked_transformer_layer(model_input,
                                                  patched_padding,
                                                  segment_mask=segment_mask)
    if self.use_date_features:
      model_output = jnp.concatenate([model_output, patched_diff_features],
                                     axis=-1)
    output_ts = self._postprocess_output(model_output,
                                         num_outputs,
                                         stats,
                                         n_chunks,
                                         use_prev_chunk_vals,
                                         use_prev_chunk_separators)

    return NestedMap(
        {_OUTPUT_TOKENS: model_output, _OUTPUT_TS: output_ts, _STATS: stats}
    )

  def decode(
      self,
      inputs: NestedMap,
      horizon_len: int,
      output_patch_len: Optional[int] = None,
      max_len: int = 512,
      return_forecast_on_context: bool = False,
  ) -> tuple[JTensor, JTensor]:
    """Auto-regressive decoding without caching.

    Args:
      inputs: input time-series and paddings. Time-series shape B (x Ch) x C,
        padding shape shape B (x Ch) x (C + H) where H is the prediction length.
        freq shape B (x Ch) x 1.
        date_features shape B (x Ch) x (C + H) x F where F is the number of keys
        on DATE_FEATURE_DICT.
      horizon_len: prediction length.
      output_patch_len: output length to be fetched from one step of
        auto-regressive decoding.
      max_len: maximum training context length.
      return_forecast_on_context: whether to return the model forecast on the
        context except the first input patch.

    Returns:
      Tuple of two forecasting results:
      - Point (mean) output predictions as a tensor with shape B (x Ch) x H.
      - Full predictions (mean and quantiles) as a tensor with shape
        B (x Ch) x H x (1 + # quantiles).
    """
    if self.probabilistic:
      raise ValueError('Probabilistic decoding is not supported.')
    final_out = inputs[_INPUT_TS]
    paddings = inputs[_INPUT_PADDING]
    if self.use_freq:
      freq = inputs[_FREQ].astype(jnp.int32)
    else:
      freq = jnp.zeros([*final_out.shape[:-1], 1],
                       dtype=jnp.int32)
    full_outputs = []
    date_features = inputs[_DATE_FEATURES]
    context_len = final_out.shape[-1]
    if paddings.shape[-1] != context_len + horizon_len:
      raise ValueError(
          'Length of paddings must match length of input + horizon_len:'
          f' {paddings.shape[-1]} != {context_len} + {horizon_len}'
      )
    if self.use_date_features:
      if date_features.shape[-2] != context_len + horizon_len:
        raise ValueError(
            'Length of date_features must match length of input + horizon_len'
        )
    else:
      date_features = jnp.zeros((
          *final_out.shape[:-1],
          context_len + horizon_len,
          len(DATE_FEATURE_DICT),
      ))
    if output_patch_len is None:
      output_patch_len = self.horizon_len
    num_decode_patches = (
        horizon_len + output_patch_len - 1
    ) // output_patch_len
    for step_index in range(num_decode_patches):
      # Recompute inputs based on outputs from last step,
      # since we are decoding one patch at a time.
      current_padding = paddings[..., 0 : final_out.shape[-1]]
      input_ts = final_out[..., -max_len:]
      current_df = date_features[..., 0 : final_out.shape[-1], :]
      input_padding = current_padding[..., -max_len:]
      input_df = current_df[..., -max_len:, :]
      model_input = NestedMap(
          input_ts=input_ts,
          input_padding=input_padding,
          date_features=input_df,
          freq=freq,
      )
      model_output = self(model_input)
      fprop_outputs = model_output[_OUTPUT_TS]
      stats = model_output[_STATS]

      if return_forecast_on_context and step_index == 0:
        # For the first decodings step, collect the model forecast on the
        # context except the unavailable first input batch forecast.
        new_full_ts = fprop_outputs[..., :-1, : self.patch_len, :]
        new_full_ts = es.jax_einshape('...nph->...(np)h', new_full_ts)
        if self.revin == definitions.RevinType.NORMED:
          mean, stddev = stats
          new_full_ts = (
              new_full_ts * stddev[..., None, None] + mean[..., None, None]
          )
        full_outputs.append(new_full_ts)

      # (full batch, last patch, output_patch_len, index of mean forecast = 0)
      new_ts = fprop_outputs[..., -1, :output_patch_len, 0]
      new_full_ts = fprop_outputs[..., -1, :output_patch_len, :]
      if self.revin == definitions.RevinType.NORMED:
        # In this case the outputs need to be renormalized.
        mean, stddev = stats
        new_ts = new_ts * stddev[..., None] + mean[..., None]
        new_full_ts = (
            new_full_ts * stddev[..., None, None] + mean[..., None, None]
        )
      # (full batch, last patch, output_patch_len, all output indices)
      full_outputs.append(new_full_ts)
      final_out = jnp.concatenate([final_out, new_ts], axis=-1)

    if return_forecast_on_context:
      # `full_outputs` indexing starts at after the first input patch.
      full_outputs = jnp.concatenate(full_outputs, axis=-2)[
          ..., : (context_len - self.patch_len + horizon_len), :
      ]
    else:
      # `full_outputs` indexing starts at the forecast horizon.
      full_outputs = jnp.concatenate(full_outputs,
                                     axis=-2)[..., 0:horizon_len, :]

    return (full_outputs[..., 0], full_outputs)


class PatchedTimeSeriesModel(base_model.BaseModel):
  """Pax ML model for training Patched Time-series decoder.

  Attributes:
    core_layer_tpl: config for PatchedTimeSeriesDecoder.
    last_patch_only: whether to compute horizon loss over only the last patch or
      not.
    mix_up_fraction: fraction of augmented mixed time-series to add per batch.
    metrics: custom metrics logged during model training.
  """

  core_layer_tpl: pax_fiddle.Config[base_model.BaseModel] = (
      pax_fiddle.template_field(PatchedTimeSeriesDecoder)
  )
  last_patch_only: bool = False
  mix_up_fraction: float = 0.0
  metrics: dict[str, Callable[[JTensor, JTensor], JTensor]] | None = None

  def setup(self) -> None:
    """Model constructor."""
    core_layer_tpl = self.core_layer_tpl.clone()
    core_layer_tpl.ici_mesh_shape = self.ici_mesh_shape
    core_layer_tpl.dcn_mesh_shape = self.dcn_mesh_shape
    core_layer_tpl.mesh_axis_names = self.mesh_axis_names
    core_layer_tpl.weight_split_dims_mapping = self.weight_split_dims_mapping
    core_layer_tpl.activation_split_dims_mapping = (
        self.activation_split_dims_mapping
    )
    self.create_child('core_layer', core_layer_tpl)

  def compute_predictions(self, input_batch: NestedMap) -> NestedMap:
    """Compute predictions given input_batch.

    Args:
      input_batch: Nestedmap with key `input_ts` mapping to input tensor of
        shape B (x C) x T. Also has key `date_features` of shape B(x C) x T x F.

    Returns:
      Output from ptsd model.
    """
    input_ts, input_padding, date_features = (
        input_batch[_INPUT_TS],
        input_batch[_INPUT_PADDING],
        input_batch[_DATE_FEATURES],
    )
    if _FREQ not in input_batch:
      raise ValueError(f'Model input does not have freq: {input_batch}')
    if self.mix_up_fraction > 0.0 and not self.do_eval:
      num_mixed = int(self.mix_up_fraction * input_ts.shape[0] / 8) * 8
      mixed_series, new_pad = _mix_series(
          self.next_prng_key(), input_ts, input_padding
      )
      mixed_series, new_pad = (
          mixed_series[0:num_mixed, ...],
          new_pad[0:num_mixed, ...],
      )
      new_date_features = date_features[0:num_mixed, ...]
      input_ts = jnp.concatenate([input_ts, mixed_series], axis=0)
      input_padding = jnp.concatenate([input_padding, new_pad], axis=0)
      date_features = jnp.concatenate(
          [date_features, new_date_features], axis=0
      )
      freq = input_batch[_FREQ]
      new_freq = jnp.concatenate([freq, freq[0:num_mixed, ...]], axis=0)
    else:
      new_freq = input_batch[_FREQ]

    new_batch = NestedMap(
        input_ts=input_ts,
        input_padding=input_padding,
        date_features=date_features,
        freq=new_freq,
    )
    out = self.core_layer(new_batch)
    out[_AUGMENTED_BATCH] = new_batch
    return out

  def _convert_inputs(
      self, inputs: JTensor, stats: Optional[Tuple[JTensor, JTensor]] = None
  ) -> Tuple[JTensor, int]:
    """Takes the input of shape B (x C) x T and converts it into the actuals.

    Args:
      inputs: Tensor of shape B (x C) x T is broken down into blocks inputs[...,
        P:P+H], inputs[..., 2 * P: 2 * P + H] ... inputs[..., 
        num * P : (num + 1) * P], where `num = (T - H) // P`.
        The blocks are then arranged into shape B (x C) x num x H.
      stats: context statistics in case revin is active

    Returns:
      Tuple of (output (B (x C) x num x H), num).
    """
    horizon_len = self.core_layer.horizon_len
    patch_len = self.core_layer.patch_len
    revin = self.core_layer.revin
    time_len = inputs.shape[-1]
    num = (time_len - horizon_len) // patch_len
    if revin == definitions.RevinType.NORMED:
      mod_inputs = (inputs - stats[0][..., None]) / stats[1][..., None]
      mod_inputs = jnp.where(
          jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, mod_inputs
      )
    else:
      mod_inputs = inputs
    # Function to select a subarray
    start_indices = jnp.arange(1, num + 1) * patch_len

    def slice_func(start_index):
      return lax.dynamic_slice(
          mod_inputs,
          start_indices=(0,)*len(mod_inputs.shape[:-1]) + (start_index,),
          slice_sizes=(*mod_inputs.shape[:-1], horizon_len)
      )

    result = jax.vmap(slice_func)(start_indices)

    return es.jax_einshape('n...h->...nh', result), num

  def compute_loss(
      self, prediction_output: NestedMap, input_batch: NestedMap
  ) -> Tuple[NestedMap, NestedMap]:
    """Computes loss given predictions and input batch."""
    del input_batch
    inputs = prediction_output[_AUGMENTED_BATCH][_INPUT_TS]
    output_ts = prediction_output[_OUTPUT_TS]
    stats = prediction_output[_STATS]
    actual_ts, num = self._convert_inputs(inputs, stats)
    stats_for_loss = None
    if self.core_layer.revin == definitions.RevinType.NORMED:
      # In this revin mode, the loss is weighted by the variance. This makes
      # revin NORMED and RENORMED equivalent for non probabilistic model.
      stats_for_loss = stats
    if self.last_patch_only:
      loss = self.core_layer.compute_loss(
          output_ts[..., num - 1 : num, :, :],
          actual_ts[..., -1:, :],
          stats_for_loss,
      )
    else:
      loss = self.core_layer.compute_loss(
          output_ts[..., 0:num, :, :], actual_ts, stats_for_loss
      )
    loss_weight = jnp.array(1.0, dtype=jnp.float32)
    losses = {_TRAIN_LOSS: (loss, loss_weight)}

    if self.metrics is not None:
      decoder_ts = self.core_layer.decode_step_with_call_output(
          prediction_output,
          ts_idx=0,
          output_patch_len=self.core_layer.horizon_len,
          patch_idx=self.core_layer.horizon_len // self.core_layer.patch_len,
      )
      for metric_name, metric_fn in self.metrics.items():
        losses[metric_name] = (
            metric_fn(
                decoder_ts,
                inputs[..., (-self.core_layer.horizon_len) :],
            ),
            loss_weight,
        )
    per_example_out = NestedMap()
    return NestedMap(losses), per_example_out