"""Shared layers for the forward-only model."""
from functools import partial
from typing import Callable, Union, Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp

def torch_he_uniform(
    in_axis: Union[int, Sequence[int]] = -2,
    out_axis: Union[int, Sequence[int]] = -1,
    batch_axis: Sequence[int] = (),
    dtype=jnp.float_,
):
    return jax.nn.initializers.variance_scaling(
        0.3333,
        "fan_in",
        "uniform",
        in_axis=in_axis,
        out_axis=out_axis,
        batch_axis=batch_axis,
        dtype=dtype
    )


@partial(jax.jit, static_argnames=('k', 'pad'))
def k_folding(x: jax.Array, p0: jax.Array, k: int, pad: bool = True) \
        -> jax.Array:
    r"""Perform a k-folding operation on the input array.

    Applies :math:`k` equally spaced folds to the input array along the
    last dimension to produce an array of shape :math:`(N, *, k)`.

    The output is a :math:`k`-dimensional vector where the :math:`i`-th component
    is the dot product between the activations and the :math:`i`-th weight vector.

    Args:
        x: The input array. An array of shape :math:`(N, *, H_{in})` where
            :math:`*` means any number of dimensions including none and
            :math:`H_{in}` is the number of input features. Note that the last
            dimension must be divisible by :math:`k`.
        p0: The activations. An array of shape where all but the last dimension
            are the same shape as ``x``. The last dimension should be the size
            of a single fold, i.e. :math:`H_{in} / k`.
        k: The number of folds to apply to the input array. Must be positive.
        pad: If ``True``, the input array is padded with zeros along the last
            dimension to ensure that it is divisible by ``k``. The activations
            ``p0`` are also padded with zeros to match the shape of each folded
            component of the input array. If ``False``, an error is raised if
            the last dimension of ``x`` is not divisible by ``k``.

    Returns:
        An array of shape :math:`(N, *, k)` where all but the last dimension are
        the same shape as ``x``.

    Remarks:
        The k-folding operation can be thought of a downsampling operation where
        the :math:`i`-th component of the output array is a linear combination
        of the activations with coefficients determined by the :math:`i`-th
        partition of the input array.

    Raises:
        ValueError: If ``k`` is not positive.
        ValueError: If the shape of ``x`` and ``p0`` are not the same except
            for the last dimension.
        ValueError: If the last dimension of ``x`` is not divisible by ``k``
            and ``pad`` is ``False``.

    Examples::

        >>> x = jax.random.normal(jax.random.PRNGKey(0), (128, 256))
        >>> p0 = jax.random.normal(jax.random.PRNGKey(0), (128, 64))
        >>> y = k_folding(x, p0, 4)
        >>> y.shape
        (128, 4)
        >>> x = jnp.asarray([[1, 2, 3, 4, 5, 6, 7, 8]])
        >>> p0 = jnp.asarray([[-1, 2, 0.5, 0]])
        >>> y = k_folding(x, p0, 2)
        >>> y.tolist()
        [[4.5, 10.5]]
    """
    # Ensure that k is positive
    if k <= 0:
        raise ValueError(f'k must be positive, but got {k}.')

    # Ensure that x and p0 have the same shape except for the last dimension
    if x.shape[:-1] != p0.shape[:-1]:
        raise ValueError(
            f'Input array and activations must have the same shape except '
            f'for the last dimension, but got {x.shape} and {p0.shape}.'
        )

    # Ensure that the last dimension of x is divisible by k, padding if necessary
    h_in = x.shape[-1]
    if h_in % k != 0:
        if pad:
            x = jnp.pad(x, ((0, 0),) * (x.ndim - 1) + ((0, k - h_in % k),))
            h_in = x.shape[-1]
        else:
            raise ValueError(
                f'Input array must have a last dimension divisible by '
                f'{k}, but has shape {x.shape}. Please set pad=True to '
                f'automatically pad it with zeros.'
            )

    # Ensure that p0 has the correct number of features
    chunk_size = int(h_in // k)
    if p0.shape[-1] != chunk_size:
        if pad:
            p0 = jnp.pad(p0, ((0, 0),) * (p0.ndim - 1) + ((0, chunk_size - p0.shape[-1]),))
        else:
            raise ValueError(
                f'Activations must have {chunk_size} features, but got '
                f'{p0.shape[-1]} instead. Please set pad=True to '
                f'automatically pad it with zeros.'
            )

    # Reshape x to be a (k x chunk_size) matrix of weights
    x = x.reshape(x.shape[:-1] + (k, chunk_size))
    return jnp.matmul(x, jnp.expand_dims(p0, axis=-1)).squeeze(-1)


class FwdLinear(nn.Module):
    r"""Forward-only linear layer with an identity goodness function.

    Every forward-only linear layer has a forward pass that computes the
    activations and a goodness. The goodness is a vector value that is used to
    compute the loss for the goodness objective.

    Subclasses must implement the :meth:`compute_output` which compute the
    activations and the goodness.

    Args:
        features: the number of output features.
        goodness_features: number of goodness features. Default: ``1``
        use_bias: whether to add a bias term to the output. Default: ``True``
        act_fn: the activation function to use for the layer.
            Default: :func:`nn.relu`

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          dimensions including none and :math:`H_{in} = \text{in\_features}`.
        - Output: :math:`(N, *, H_{out})` where all but the last dimension are
          the same shape as the input and :math:`H_{out} = \text{out\_features}`.
        - Goodness: :math:`(N, *, G_{dim})` where all but the last dimension are
          the same shape as the input and :math:`G_{dim} = \text{goodness\_features}`.
    """

    features: int
    goodness_features: int
    use_bias: bool = True
    act_fn: Callable[[jax.Array], jax.Array] = nn.relu

    def setup(self) -> None:
        """Setup the layer."""
        assert self.features == self.goodness_features, \
            f'Number of features ({self.features}) must be equal to the ' \
            f'number of goodness features ({self.goodness_features}) since ' \
            f'the goodness is an identity function.'
        self.dense = nn.Dense(self.features, use_bias=self.use_bias)

    @nn.compact
    def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array]:
        """Compute the forward pass.

        Args:
            x: The input tensor.

        Returns:
            The output tensor and the goodness tensor.
        """
        h = self.dense(x)
        return self.compute_output(x, h)

    def compute_output(self, x:jax.Array, h: jax.Array) -> tuple[jax.Array, jax.Array]:
        """Compute the output of the layer.

        Args:
            h: Raw output of the linear layer.

        Returns:
            The output and goodness arrays. The activation function is only
            applied to the output array.
        """
        return self.act_fn(h), h


class FoldingFwdLinear(FwdLinear):
    r"""Applies a linear transformation with K-folding to the incoming data.

    Three folding modes are supported, which differ in how the layer is
    structured and how the goodness is computed:

    * Compressive folding: The layer has :math:`\text{out\_features}` units and
        is divided into :math:`\text{goodness_features} + 1` equal parts, such that
        each part is :math:`\text{out_features} / (\text{goodness_features} + 1)`
        units wide. The output of the layer is a concatenation of all the parts,
        i.e. the activations and the goodness weights.

    * Expansive folding: The layer is actually :math:`\text{goodness_features} + 1`
        times wider (larger than :math:`\text{out_features}`), and each part is
        :math:`\text{out_features}` units wide. The output of the layer is the
        first part, i.e. the activations.

    * Attention folding: The layer has :math:`\text{out\_features}` units and
        consists of two dense layers: the first for projecting the input into
        activations :math:`a` and the second for projecting the activations into
        the goodness weights :math:`w_1, \dots, w_{\text{goodness_features}}`.

    For compressive and expansive folding, the first part denotes the activations
    and the remaining parts denote the goodness weights, which is a matrix of
    shape :math:`(\text{goodness_features}, \text{out_features})`. In general,
    the goodness :math:`g= (g_1, \dots, g_{\text{goodness_features}})` is a
    linear combination of the goodness weights and the activations, given by:

    .. math::
        g_i = w_i \cdot a = \sum_{j=1}^{|w_i|} w_{ij} a_j,

    where :math:`w_i` is the :math:`i`-th row of the weight matrix and :math:`a`
    are the activations.

    Args:
        out_features: number of output features.
        goodness_features: number of goodness features. Default: ``1``
        folding_mode: The folding mode to use. Must be one of ``'compressive'``,
            ``'expansive'``, or ``'attention'``. Default: ``'compressive'``.
        normalize: Whether to normalize the output so that it has unit norm.
            This helps with training stability as the dot product tends to
            get very large as the number of features increases. Default: ``True``
        use_bias: whether to add a bias term to the output. Default: ``True``
        act_fn: The activation function to use for the layer.
            Default: :func:`nn.relu`
        weight_fn: The transformation function to use for the goodness weights.
            Default: :func:`nn.tanh`
        goodness_type: The type of goodness computation to use for 'msq' folding mode.
            Must be one of 'msq', 'mean', 'std', 'rms', 'variance', or 'weighted'.
            Default: ``'variance'``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          dimensions including none and :math:`H_{in} = \text{in\_features}`.
        - Output: :math:`(N, *, H_{out})` where all but the last dimension are
          the same shape as the input and :math:`H_{out} = \text{out\_features}`.
        - Goodness: :math:`(N, *, G_{dim})` where all but the last dimension are
          the same shape as the input and :math:`G_{dim} = \text{goodness\_features}`.

    Examples::

        >>> m = FoldingFwdLinear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output, goodness = m(input)
        >>> print(output.size(), goodness.size())
        torch.Size([128, 30]) torch.Size([128, 1])

    Remarks:
        This layer has :math:`\text{out\_features} \times (\text{goodness\_features} + 1)`
        actual output features.
    """

    folding_mode: str = 'compressive'
    normalize: bool = True
    normalization_method: str = 'layer_norm'
    weight_fn: Callable[[jax.Array], jax.Array] = nn.tanh
    goodness_type: str = 'variance'

    def setup(self) -> None:
        """Setup the layer."""
        if self.folding_mode == 'expansive':
            # Expansive folding increases the number of features
            self.y_features = self.features
            out_features = (self.goodness_features + 1) * self.features
        elif self.folding_mode == 'compressive':
            # Compressive folding keeps the number of features the same
            self.y_features = self.features // (self.goodness_features + 1)
            out_features = self.features
        elif self.folding_mode == 'attention':
            # Attention folding keeps the number of features the same
            self.y_features = self.features
            out_features = self.features
            # attn_dense projects the activations into a flattened goodness matrix
            self.attn_dense = nn.Dense(self.goodness_features * out_features, use_bias=self.use_bias)
        elif self.folding_mode == 'msq':
            # Attention folding keeps the number of features the same
            self.y_features = self.features
            out_features = self.features
            # attn_dense projects the activations into a flattened goodness matrix
            self.msq_features = out_features//20
            self.n_actions = self.goodness_features
            self.attn_dense = nn.Dense(self.msq_features * out_features, use_bias=self.use_bias)
        elif self.folding_mode == 'residual':
            # Residual folding keeps the number of features the same
            self.y_features = self.features
            out_features = self.features
            # attn_dense projects the activations into a flattened goodness matrix
            self.attn_dense = nn.Dense(self.goodness_features * out_features, use_bias=self.use_bias)
        else:
            raise ValueError(
                f'Invalid folding mode {self.folding_mode}. Must be one of '
                f'"compressive", "expansive", or "attention".'
            )

        self.dense = nn.Dense(out_features, use_bias=self.use_bias, kernel_init=torch_he_uniform())

        if self.normalize:
            if self.normalization_method == 'layer_norm':
                self.layer_norm = nn.LayerNorm(epsilon=1e-6)
            elif self.normalization_method == 'l1': pass
            elif self.normalization_method == 'l2': pass
            elif self.normalization_method == 'scaled': pass
            elif self.normalization_method == 'scaled_no_sqrt': pass
            elif self.normalization_method == 'scaled_no_sqrt_times_10': pass
            else:
                raise NotImplementedError("Only layer_norm and l2 normalization are supported.")

    def compute_output(self, x: jax.Array, h: jax.Array) -> tuple[jax.Array, jax.Array]:
        """Computes the output."""

        if self.folding_mode == 'attention' or self.folding_mode == 'msq':
            x = x[None, :]  if x.ndim == 1 else x
            h = h[None, :]  if h.ndim == 1 else h
            sz_b = h.shape[0]

            # Attention folding
            y = self.act_fn(h)

            if self.folding_mode == 'msq':

                # INSERT ACTION AS INPUT (x)
                action_candidates = jnp.tile(jnp.arange(self.n_actions), (sz_b,))
                y_2 = jnp.expand_dims(y, axis=1)                                        # (B, 1, D)
                y_2 = jnp.tile(y_2, (1, self.n_actions, 1))                             # (B, n_actions, D)
                y_2 = y_2.reshape(-1, y_2.shape[-1])                                    # (B * n_actions, D)
                x = jnp.expand_dims(x, axis=1)                                          # (B, 1, D)
                x = jnp.tile(x, (1, self.n_actions, 1))                                 # (B, n_actions, D)
                x = x.reshape(-1, x.shape[-1])                                          # (B * n_actions, D)
                action_one_hot = jax.nn.one_hot(action_candidates, num_classes=self.n_actions).astype(jnp.float32)
                x = jnp.concatenate([x, action_one_hot], axis=-1)
                sz_b *= self.n_actions
                w = self.attn_dense(x)

            else:
                w = self.attn_dense(x)

            w = self.weight_fn(w)
        else:
            # Compressive and expansive folding
            chunks = jnp.split(h, [self.y_features], axis=-1)
            y = self.act_fn(chunks[0])
            w = self.weight_fn(chunks[1])

        if self.folding_mode == 'compressive':
            output = jnp.concatenate([y, w], axis=-1)
        else:
            output = y

        if self.normalization_method == 'layer_norm':
            output = self.layer_norm(output)
        elif self.normalization_method == 'l1':
            output = output / (jnp.linalg.norm(output, ord=1, axis=-1, keepdims=True) + 1e-4)
        elif self.normalization_method == 'l2':
            output = output / (jnp.linalg.norm(output, ord=2, axis=-1, keepdims=True) + 1e-4)
        elif self.normalization_method == 'scaled':
            output = output / output.shape[-1]**0.5
        elif self.normalization_method == 'scaled_no_sqrt':
            output = output / output.shape[-1]
        elif self.normalization_method == 'scaled_no_sqrt_times_10':
            output = output / (output.shape[-1] * 10)

        if self.folding_mode == 'msq':
            z = k_folding(w, y_2, self.msq_features)
            goodness = self._calc_logits(z).reshape(-1, self.n_actions)
        else:
            goodness = k_folding(w, y, self.goodness_features)
        return output, goodness

    def _calc_logits(self, z: jnp.ndarray, layer_idx: int = 0, update_baseline: bool = True) -> jnp.ndarray:
        if self.goodness_type == 'msq':
            logits = jnp.mean(jnp.square(z), axis=-1, keepdims=True)

        elif self.goodness_type == 'mean':
            logits = jnp.mean(z, axis=tuple(range(1, z.ndim)))
            logits = logits[:, None]

        elif self.goodness_type == 'std':
            logits = jnp.std(z, axis=tuple(range(1, z.ndim)))
            logits = logits[:, None]

        elif self.goodness_type == 'rms':
            logits = jnp.sqrt(jnp.mean(jnp.square(z), axis=tuple(range(1, z.ndim))))
            logits = logits[:, None]

        elif self.goodness_type == 'variance':
            logits = jnp.var(z, axis=tuple(range(1, z.ndim)))
            logits = logits[:, None]

        elif self.goodness_type == 'weighted':
            logits = self.goodness_weights[layer_idx](z)

        return logits


if __name__ == '__main__':
    # Benchmark a FoldingFwdLinear layer compared to a regular sequential
    # network with the same number of parameters.
    import time

    in_features = 128
    goodness_features = 10
    out_features = (goodness_features + 1) * 100

    folding_expansive = FoldingFwdLinear(
        out_features,
        goodness_features=goodness_features,
        folding_mode='expansive'
    )
    folding_compressive = FoldingFwdLinear(
        out_features,
        goodness_features=goodness_features,
        folding_mode='compressive'
    )

    linear_expansive = nn.Sequential([
        nn.Dense((goodness_features + 1) * out_features),
        nn.tanh,
        nn.Dense(out_features)
    ])
    linear_compressive = nn.Sequential([
        nn.Dense(out_features),
        nn.tanh
    ])

    # Generate some random data (num_batches, batch_size, in_features)
    key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
    x = jax.random.normal(key3, (100, 32, in_features))

    # Benchmark the forward pass per model
    models = [
        (linear_compressive, 'nn.Sequential (compressive)'),
        (folding_compressive, 'FoldingFwdLinear (compressive)'),
        (linear_expansive, 'nn.Sequential (expansive)'),
        (folding_expansive, 'FoldingFwdLinear (expansive)')
    ]

    with jax.log_compiles(True):
        metrics = []
        for model, name in models:
            # Apply jit to the model
            model.apply = jax.jit(model.apply)

            dummy_x = jax.random.normal(key1, (in_features,))
            params = model.init(key2, dummy_x)
            num_params = sum(p.size for p in jax.tree_util.tree_leaves(params))

            start = time.time()
            for batch in x:
                model.apply(params, batch)
            end = time.time()

            # Convert to milliseconds
            total_time = (end - start) * 1000

            print()
            print(f'{name}:')
            print(f'\t{num_params} parameters')
            print(f'\t{total_time:.4f}ms total')
            print()

            metrics.append({
                'num_params': num_params,
                'total_time': total_time
            })

    # Percent difference between compressive folding and nn.Sequential compressive equivalent
    p_diff_compressive = (metrics[1]['total_time'] - metrics[0]['total_time']) / metrics[0]['total_time']
    print(f'FoldingFwdLinear (compressive) vs nn.Sequential (compressive): {p_diff_compressive * 100:.2f}%')

    # Percent difference between expansive folding and nn.Sequential expansive equivalent
    p_diff_expansive = (metrics[3]['total_time'] - metrics[2]['total_time']) / metrics[2]['total_time']
    print(f'FoldingFwdLinear (expansive) vs nn.Sequential (expansive): {p_diff_expansive * 100:.2f}%')
