"""Model evaluation."""

import functools
from typing import (
  Callable,
  Dict,
  Iterable,
  Mapping,
  Optional,
  Sequence,
  Tuple,
  Type,
  Union,
)

from absl import logging
from clu import metrics
import flax
from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from ebm_obj.lib import losses
from ebm_obj.lib import utils
import tensorflow as tf
from tqdm import tqdm


Array = jnp.ndarray
ArrayTree = Union[
  Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]
]  # pytype: disable=not-supported-yet
PRNGKey = Array


def get_eval_metrics(
  preds: Dict[str, ArrayTree],
  batch: Dict[str, Array],
  loss_fn: losses.LossFn,
  eval_metrics_cls: Type[metrics.Collection],
  predicted_max_num_instances: int,
  ground_truth_max_num_instances: int,
) -> Union[None, metrics.Collection]:
  """Compute the metrics for the model predictions in inference mode.

  The metrics are averaged across *all* devices (of all hosts).

  Args:
    preds: Model predictions.
    batch: Inputs that should be evaluated.
    loss_fn: Loss function that takes model predictions and a batch of data.
    eval_metrics_cls: Evaluation metrics collection.
    predicted_max_num_instances: Maximum number of instances in prediction.
    ground_truth_max_num_instances: Maximum number of instances in ground truth,
    including background (which counts as a separate instance).

  Returns:
    The evaluation metrics.
  """
  loss, loss_aux = loss_fn(preds, batch)
  metrics_update = eval_metrics_cls.gather_from_model_output(
    loss=loss,
    **loss_aux,
    predicted_segmentations=utils.remove_singleton_dim(
      preds["outputs"].get("segmentations")
    ),  # pytype: disable=attribute-error
    ground_truth_segmentations=batch.get("segmentations"),
    predicted_max_num_instances=predicted_max_num_instances,
    ground_truth_max_num_instances=ground_truth_max_num_instances,
    padding_mask=batch.get("padding_mask"),
    mask=batch.get("mask")
  )
  return metrics_update


def eval_step(
  model: nn.Module,
  state: utils.TrainState,
  batch: Dict[str, Array],
  rng: PRNGKey,
) -> Dict[str, ArrayTree]:
  """Get the model predictions with a freshly initialized recurrent state.

  The model is applied to the inputs using all devices on the host.

  Args:
    model: Model used in eval step.
    state_variables: State variables for the model.
    params: Params for the model.
    batch: Inputs that should be evaluated.
    rng: PRNGKey for model forward pass.
    conditioning_key: Optional string. If provided, defines the batch key to be
    used as conditioning signal for the model. Otherwise this is inferred from
    the available keys in the batch.
  Returns:
    The model's predictions.
  """
  logging.info("eval_step(batch=%s)", batch)
  state_variables = state.state_vars
  params = state.optimizer.target
  z_rng, eps_rng = jax.random.split(rng, num=2)

  conditioning = None
  preds, mutable_vars = model.apply(
    {"params": params, **state_variables},
    batch["image"],
    mutable="intermediates",
    rngs={"z_state_init": z_rng, "eps": eps_rng},
    train=False,
  )

  if "intermediates" in mutable_vars:
    preds["intermediates"] = flax.core.unfreeze(
      mutable_vars["intermediates"])

  return preds


def evaluate(
  model: nn.Module,
  state: utils.TrainState,
  eval_ds: tf.data.Dataset,
  loss_fn: Callable[[Dict[str, ArrayTree], Dict[str, ArrayTree]], Array],
  eval_metrics_cls: Type[metrics.Collection],
  predicted_max_num_instances: int,
  ground_truth_max_num_instances: int,
  metrics_on_cpu: bool = False,
  num_eval_steps: int = 1,
) -> Tuple[metrics.Collection, Dict[str, ArrayTree], Dict[str, ArrayTree]]:
  """Evaluate the model on the given dataset."""
  eval_metrics = None
  batch = None
  preds = None
  rng = state.rng[0]  # Get training state PRNGKey from first replica.

  if metrics_on_cpu and jax.process_count() > 1:
    raise NotImplementedError(
      "metrics_on_cpu feature cannot be used in a multi-host setup."
      " This experiment is using {} hosts.".format(jax.process_count())
    )
  metric_devices = jax.devices(
    "cpu") if metrics_on_cpu else jax.local_devices()

  p_eval_step = jax.pmap(eval_step, axis_name="batch",
               static_broadcasted_argnums=(0))

  p_get_eval_metrics = jax.pmap(
    get_eval_metrics,
    axis_name="batch",
    static_broadcasted_argnums=(2, 3, 4, 5),
    devices=metric_devices,
    backend="cpu" if metrics_on_cpu else None,
  )

  def reshape_fn(x):
    """Function to reshape preds and batch before calling p_get_eval_metrics."""
    return np.reshape(x, [len(metric_devices), -1] + list(x.shape[2:]))

  step = 0
  for batch in tqdm(eval_ds, desc="Evaluating"):
    rng, eval_rng = jax.random.split(rng)
    eval_rng = jax.random.fold_in(eval_rng, jax.process_index())  # Bind to host.
    eval_rngs = jax.random.split(eval_rng, jax.local_device_count())
    batch = jax.tree_map(np.asarray, batch)

    preds = p_eval_step(model, state, batch, eval_rngs)

    if metrics_on_cpu:
      # Reshape replica dim and batch-dims to work with metric_devices.
      preds = jax.tree_map(reshape_fn, preds)
      batch = jax.tree_map(reshape_fn, batch)
    # Get metric updates.
    update = p_get_eval_metrics(
      preds,
      batch,
      loss_fn,
      eval_metrics_cls,
      predicted_max_num_instances,
      ground_truth_max_num_instances,
    )
    update = flax.jax_utils.unreplicate(update)
    eval_metrics = update if eval_metrics is None else eval_metrics.merge(
      update)
    
    step += 1
    if num_eval_steps >= 0 and step >= num_eval_steps:
      break

  assert eval_metrics is not None
  return eval_metrics, batch, preds
