# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Prediction functions for PaliGemma."""

import collections
import functools

from big_vision.pp import registry
import big_vision.utils as u
import einops
import jax
import jax.numpy as jnp
import numpy as np


P = jax.sharding.PartitionSpec

# pylint: disable=missing-function-docstring


def get_all(model):
  """Returns `predict_fns` for evaluators."""
  fns = {
      "logits": _logits,
      "image_avg_repr": _image_avg_repr,
      "decode": _decode,
      "decode_with_logp": _decode_with_logp,
      "beam_decode": _beam_decode,
  }
  return {name: functools.partial(fn, model=model) for name, fn in fns.items()}


def _logits(train_state, batch, *, model):
  images, text, mask = batch["image"], batch["text"], batch["mask_ar"]
  text_logits, out = model.apply(
      {"params": train_state["params"]},
      images, text[:, :-1], mask[:, :-1],
  )
  return text_logits, out


def _image_avg_repr(train_state, batch, *, model, key="img/pre_logits"):
  zimg, out = model.apply(
      {"params": train_state["params"]},
      image=batch["image"],
      method=model.embed_image,
  )
  if key:
    zimg = u.tree_get(out, key)
  # At this point, zimg is a (batch of) sequence of image tokens, because we
  # assume the model is a vit with "none" head. This predict-fn is for fewshot
  # evaluator, so we need to turn it into reasonably-sized vector -> avg.
  zimg = jnp.mean(zimg, axis=range(1, zimg.ndim - 1))
  return zimg, out


def _decode_with_logp(
    train_state, batch, *, model, devices, max_decode_len, eos_token,
    best_of_n=1, sampler="greedy", replicate_out=False, eos_look_behind=0):
  """Sample token continuations to the input sequences."""
  mesh = jax.sharding.Mesh(devices, ("devices",))
  replicate_sharding = jax.sharding.NamedSharding(mesh, P())
  out_sharding = jax.sharding.NamedSharding(
      mesh, P() if replicate_out else P("devices")
  )

  # Prefill the model cache and generate logits for first token.
  logits, cache = jax.jit(
      _prefill_cache,
      out_shardings=out_sharding,
      static_argnames=("model", "max_decode_len"),
  )(
      train_state["params"],
      {
          "image": batch["image"],
          "text": batch["text"],
          "mask_input": batch["mask_input"],
          "mask_ar": batch["mask_ar"],
      },
      model=model,
      max_decode_len=max_decode_len,
  )

  # Mask indicating real examples. False if example is used to pad the batch.
  mask = batch["_mask"]

  # Repeat example in case we are picking the best of n.
  logits, cache, mask = jax.jit(
      _bon_repeat,
      static_argnames=("n",)
  )((logits, cache, mask), n=best_of_n)

  decode_sample_output = jax.jit(
      _decode_sample_output,
      static_argnames=("max_decode_len", "sampler"),
  )
  decode_early_stop = jax.jit(
      _decode_early_stop,
      out_shardings=replicate_sharding,
      static_argnames=("eos_token",),
  )
  extend_cache = jax.jit(
      _extend_cache,
      donate_argnums=1,
      static_argnames=("model",),
  )

  # Keep sampling tokens from last logits until EOS or max_decode_len.
  state = None
  # Setting `eos_look_behind>0` removes blocking transfer with small batches.
  stops = collections.deque(maxlen=1 + eos_look_behind)
  for idx in range(max_decode_len):
    tokens, state = decode_sample_output(
        state, logits, max_decode_len=max_decode_len, sampler=sampler
    )

    if idx + 1 >= max_decode_len:
      break

    stops.append(decode_early_stop(state, mask, eos_token=eos_token))
    if len(stops) == stops.maxlen and jax.device_get(stops[0]):
      break

    # Compute logits for next token
    logits, cache = extend_cache(
        train_state["params"], cache, tokens, model=model
    )

  # Select the best of n sample for each example.
  _, tokens, logp = jax.jit(
      _bon_select,
      out_shardings=out_sharding,
      static_argnames=("n", "eos_token"),
  )(state, n=best_of_n, eos_token=eos_token)

  return tokens, logp


def _decode(train_state, batch, **kwargs):
  tokens, _ = _decode_with_logp(train_state, batch, **kwargs)
  return tokens


def _bon_repeat(tree, *, n):
  return jax.tree.map(lambda x: jnp.repeat(x, n, axis=0), tree)


def _compute_score(tokens, logp, eos_token):
  """Compute log-probability of each sequence up to first eos (including it)."""
  seqlen = jnp.sum(jnp.cumsum(tokens == eos_token, axis=-1) == 0, axis=-1) + 1
  token_mask = jnp.arange(tokens.shape[-1]) < seqlen[..., None]
  scores = jnp.sum(logp * token_mask, axis=-1)
  return scores


def _bon_select(state, *, n, eos_token):
  """Pick the sampled sequence with the highest likelihood for each example."""
  (_, tokens, logp) = state

  # Filter state to only keep the best of each example.
  scores = _compute_score(tokens, logp, eos_token)
  scores = einops.rearrange(scores, "(b n) -> b n", n=n)
  state = jax.tree.map(
      lambda x: einops.rearrange(x, "(b n) l -> b n l", n=n), state)
  best_indices = jnp.argmax(scores, -1)  # [b]
  state = jax.tree.map(
      lambda x: jnp.take_along_axis(x, best_indices[:, None, None], axis=1),
      state)
  state = jax.tree.map(lambda x: x[:, 0], state)

  return state


def _decode_sample_output(state, logits, *, max_decode_len, sampler):
  if state is None:
    # Decode state keeps track of sampled tokens and their logp.
    bs = logits.shape[0]
    seqlen = jnp.zeros((bs, 1), dtype=jnp.int32)
    tokens = jnp.zeros((bs, max_decode_len), dtype=jnp.int32)
    logp = jnp.zeros((bs, max_decode_len), dtype=logits.dtype)
  else:
    (seqlen, tokens, logp) = state

  # Sample tokens.
  sampled_tokens, sampled_logp = _sample_logits(logits, sampler=sampler)

  # Update state with sampled outputs.
  new_len = seqlen + 1
  new_tokens = _put_along_last_axis(tokens, seqlen, sampled_tokens)
  new_logp = _put_along_last_axis(logp, seqlen, sampled_logp)
  new_state = (new_len, new_tokens, new_logp)

  return sampled_tokens, new_state


def _decode_early_stop(state, mask, *, eos_token):
  (seqlen, tokens, unused_logp) = state
  token_mask = jnp.arange(tokens.shape[-1])[None, :] < seqlen
  has_eos = jnp.any(jnp.logical_and(tokens == eos_token, token_mask), axis=-1)
  done = jnp.logical_or(has_eos, jnp.logical_not(mask))
  return jnp.all(done)


def _put_along_last_axis(arr, indices, values):
  """Like np.put_along_axis(..., axis=-1), since jax is missing it."""
  assert arr.ndim == indices.ndim == values.ndim, (
      arr.ndim, indices.ndim, values.ndim)
  onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
  put_mask = jnp.einsum("...i,...in->...n",
                        jnp.ones(values.shape, jnp.int32), onehot)
  put_values = jnp.einsum("...i,...in->...n", values, onehot)
  return jnp.where(put_mask, put_values, arr)


def _prefill_cache(params, batch, *, model, max_decode_len):
  """Initialize the model cache for decoding with the prompts."""
  variables = {"params": params}
  (x, input_mask, mask_ar), _ = model.apply(
      variables, batch["image"], batch["text"],
      input_mask=batch["mask_input"],
      mask_ar=batch["mask_ar"],
      method=model.embed_image_and_text)
  last_logits, variables = model.apply(
      variables, x, input_mask, mask_ar,
      cache_size=x.shape[1] + max_decode_len,
      method=model.prefill_cache,
      mutable=("cache",))
  return last_logits, variables["cache"]


def _extend_cache(params, cache, tokens, *, model):
  """Extend the model cache for decoding with one token per sequence."""
  variables = {"params": params, "cache": cache}
  x, _ = model.apply(variables, tokens, method=model.embed_text)
  last_logits, variables = model.apply(
      variables, x, method=model.extend_cache, mutable=("cache",))
  return last_logits, variables["cache"]


def _sample_logits(logits, sampler):
  """Returns a sampled token and its logp from logits."""
  # Note: Consider making it possible for evaluators to pass rng seed to
  # decode functions. For now generate it from jax.lax and avoid evaluators
  # having to deal with it.
  rng = jax.random.PRNGKey(
      jax.lax.rng_uniform(0, np.iinfo(np.int32).max, tuple()))

  # Use Registry to support specifying things like:
  #  "greedy", "nucleus(0.2)", "temperature(t=1.0)"
  sampled_tokens = registry.Registry.lookup("paligemma_sampler." + sampler)(
      logits=logits, rng=rng)

  # Find the log probability (normalized logits) of selected tokens.
  sampled_logp = jnp.take_along_axis(
      jax.nn.log_softmax(logits, axis=-1),
      sampled_tokens[..., None], -1)[..., 0]

  return sampled_tokens, sampled_logp


@registry.Registry.register("paligemma_sampler.greedy")
def _greedy_sampling(*, logits, rng):
  del rng
  return jnp.argmax(logits, axis=-1)


@registry.Registry.register("paligemma_sampler.temperature")
def _temperature_sampling(t, *, logits, rng):
  return jax.random.categorical(rng, logits / t)


@registry.Registry.register("paligemma_sampler.nucleus")
def _nucleus_sampling(p: float, t: float = 1.0, *, logits, rng):
  logits = logits / t
  neg_inf = np.array(-1.0e7)  # Effective negative infinity.
  logits_sorted = jnp.sort(logits, axis=-1, descending=True)
  sorted_cum_probs = jnp.cumsum(
      jax.nn.softmax(logits_sorted, axis=-1), axis=-1)
  cutoff_index = jnp.sum(sorted_cum_probs < p, axis=-1, keepdims=True)
  cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1)
  logits = jnp.where(logits < cutoff_logit,
                     jnp.full_like(logits, neg_inf), logits)
  return jax.random.categorical(rng, logits)


def _beam_decode(train_state, batch, *,
                 model, devices, max_decode_len,
                 eos_token, beam_size, replicate_out=False):
  """Beam search (greedy/top-k exploration)."""
  mesh = jax.sharding.Mesh(devices, ("devices",))
  replicate_sharding = jax.sharding.NamedSharding(mesh, P())
  out_sharding = jax.sharding.NamedSharding(
      mesh, P() if replicate_out else P("devices")
  )

  # Prefill the model cache and generate logits for first token.
  logits, cache = jax.jit(
      _prefill_cache,
      out_shardings=out_sharding,
      static_argnames=("model", "max_decode_len"),
  )(
      train_state["params"],
      {
          "image": batch["image"],
          "text": batch["text"],
          "mask_input": batch["mask_input"],
          "mask_ar": batch["mask_ar"],
      },
      model=model,
      max_decode_len=max_decode_len,
  )

  # Mask indicating real examples. False if example is used to pad the batch.
  mask = batch["_mask"]

  beam_sample_output = jax.jit(
      _beam_sample_output,
      static_argnames=("max_decode_len", "beam_size", "eos_token"),
  )
  beam_early_stop = jax.jit(
      _beam_early_stop,
      out_shardings=replicate_sharding,
      static_argnames=("eos_token",),
  )
  extend_cache = jax.jit(
      _extend_cache,
      donate_argnums=1,
      static_argnames=("model",),
  )

  # Keep sampling tokens from last logits until EOS or max_decode_len.
  state = None
  for idx in range(max_decode_len):
    tokens, state, cache = beam_sample_output(
        state, logits, cache,
        max_decode_len=max_decode_len, beam_size=beam_size, eos_token=eos_token)

    early_stop = beam_early_stop(state, mask, eos_token=eos_token)
    if jax.device_get(early_stop) or (idx + 1 >= max_decode_len):
      break

    # Compute logits for next token
    logits, cache = extend_cache(
        train_state["params"], cache, tokens, model=model)

  return jax.jit(_beam_make_output, out_shardings=out_sharding)(state)


def _beam_early_stop(state, mask, eos_token):
  (best_tokens, best_logp, seqlen, unused_tokens, logp) = state

  # Scores of finalized sequences.
  best_scores = _compute_score(best_tokens, best_logp, eos_token)

  # Scores of live sequences.
  live_mask = jnp.arange(logp.shape[-1])[None, None] < seqlen
  live_scores = jnp.sum(logp * live_mask, axis=-1)
  live_scores = jnp.max(live_scores, axis=1)

  done = live_scores < best_scores
  return jnp.all(jnp.logical_or(done, jnp.logical_not(mask)))


def _beam_make_output(state):
  (best_tokens, *_) = state
  return best_tokens[:, 0, ...]


def _beam_sample_output(state, logits, cache, *,
                        beam_size, max_decode_len, eos_token):
  assert logits.shape[1] == 1
  logits = jax.nn.log_softmax(logits[:, 0, :])  # Normalize logits

  if state is None:
    bs = logits.shape[0]
    # Beam decode state keeps track of:
    # A) Best sampled output for each example. At initialization these have
    # shape[1]=0, but end up with shape[1]=1 after first call.
    best_tokens = jnp.zeros((bs, 0, max_decode_len), dtype=jnp.int32)
    best_logp = jnp.zeros((bs, 0, max_decode_len), dtype=logits.dtype)
    # B) N candidate sequences for each example. At initialization these have
    # beam_size=1, but end up with correct beam_size when expanded.
    seqlen = jnp.zeros((bs, 1, 1), dtype=jnp.int32)
    tokens = jnp.zeros((bs, 1, max_decode_len), dtype=jnp.int32)
    logp = jnp.zeros((bs, 1, max_decode_len), dtype=logits.dtype)
  else:
    (best_tokens, best_logp, seqlen, tokens, logp) = state
    bs = logits.shape[0] // beam_size
    assert best_tokens.shape[0] == bs

  # Reshape cache to [example, candidate, ...].
  # Note: on first call the number of candidates is 1. Later it is beam_size.
  cache, logits = jax.tree.map(
      lambda x: einops.rearrange(x, "(b n) ... -> b n ...", b=bs),
      (cache, logits))

  # Consider a live sequence could end now and update the best finished
  # sequences so far for each example. This strategy is found in some beam
  # implementations such as in praxis.
  # The code below also adjusts the best shape[1]=0 -> 1 during first call.
  eos_tokens = jnp.array(eos_token)[None, None, None]
  new_tokens = _put_along_last_axis(tokens, seqlen, eos_tokens)
  new_logp = _put_along_last_axis(logp, seqlen, logits[:, :, eos_token, None])

  best_tokens = jnp.concatenate([best_tokens, new_tokens], axis=1)
  best_logp = jnp.concatenate([best_logp, new_logp], axis=1)
  best_scores = _compute_score(best_tokens, best_logp, eos_token=eos_token)
  _, top_indices = jax.lax.top_k(best_scores, k=1)

  best_tokens = jnp.take_along_axis(best_tokens, top_indices[..., None], axis=1)
  best_logp = jnp.take_along_axis(best_logp, top_indices[..., None], axis=1)

  # To find the next best N live candidates we expand each candidate and keep
  # the best N (ignoring EOS tokens). In this case we expand into (N+1)
  # candidates and set their likelihood to "-inf" (if EOS) after the fact.
  live_mask = jnp.arange(logp.shape[-1])[None, None] < seqlen
  live_scores = jnp.sum(logp * live_mask, axis=-1)
  topk_logits, topk_tokens = jax.lax.top_k(logits, beam_size+1)
  scores = live_scores[..., None] + topk_logits
  scores = jnp.where(
      topk_tokens != eos_token, scores, jnp.finfo(scores.dtype).min)

  # From the N*(N+1) candidates find the top N for each example.
  topk_logits, topk_tokens, scores = jax.tree.map(
      lambda x: einops.rearrange(x, "b n np1 -> b (n np1)"),
      (topk_logits, topk_tokens, scores))
  _, topk_indices = jax.lax.top_k(scores, k=beam_size)
  sampled_indices = topk_indices // (beam_size+1)
  sampled_tokens = jnp.take_along_axis(
      topk_tokens, topk_indices, axis=-1)[..., None]
  sampled_logits = jnp.take_along_axis(
      topk_logits, topk_indices, axis=-1)[..., None]

  # Adjust cache and state so it matches the selected top N input candidates.
  # This also adjusts the beam_size=1->n during first call.
  def take_candidates(x):
    one_hot_matrix = jax.nn.one_hot(sampled_indices, x.shape[1], dtype=x.dtype)
    return jnp.einsum("bi...,boi->bo...", x, one_hot_matrix)
  cache, seqlen, tokens, logp = jax.tree.map(
      take_candidates, (cache, seqlen, tokens, logp))

  # Write the sampled tokens/logits on the reshuffled state.
  tokens = _put_along_last_axis(tokens, seqlen, sampled_tokens)
  logp = _put_along_last_axis(logp, seqlen, sampled_logits)
  seqlen = seqlen + 1

  state = (best_tokens, best_logp, seqlen, tokens, logp)

  # Reshape to [(example, candidate), ...].
  sampled_tokens, cache = jax.tree.map(
      lambda x: einops.rearrange(x, "b n ... -> (b n) ..."),
      (sampled_tokens, cache))

  return sampled_tokens, state, cache
