# Copyright 2024 The ALTA Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Library for running inference."""

from collections.abc import Callable
import jax
import jax.numpy as jnp


def predict(
    params: list[tuple[jnp.ndarray, jnp.ndarray]],
    activations: jnp.ndarray,
    activation_function: Callable[[jnp.ndarray], jnp.ndarray],
):
  """Makes prediction for given `activations` using `params`."""
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = activation_function(outputs)
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits


batched_predict = jax.vmap(predict, in_axes=(None, 0, None))
