"""Metrics for the sampling-based approach to conformal prediction.

All metrics expect in each row:
y_true: a boolean array indicating admissibility
y_pred: a boolean array indicating which samples to include in the prediction set.

If y_pred sets a row to all inf, then the row is considered
to be an 'abstention' row."""

import functools
import numpy as np
import typing as tp

P = tp.ParamSpec("P")
T = tp.TypeVar("T")


class metric:
  """Decorator class for metrics that require additional context.

  Usage:
      @metric(context=['effective_set_sizes'])
      def my_metric(y_true, y_pred, effective_set_sizes):
          ...

      result = my_metric.with_context(y_true, y_pred, {'effective_set_sizes': sizes})
  """

  def __init__(self, context: tp.Optional[tp.List[str]] = None):
    self.required_context = context or []

  def __call__(self, fn: tp.Callable[P, T]) -> "MetricFunction[P, T]":
    return MetricFunction(fn, self.required_context)


class MetricFunction(tp.Generic[P, T]):
  """Wrapper class for metric functions that handles context injection."""

  def __init__(self, fn: tp.Callable[P, T], required_context: tp.List[str]):
    self.fn = fn
    self.required_context = required_context
    functools.update_wrapper(self, fn)

  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
    """Preserve the original function signature."""
    return self.fn(*args, **kwargs)

  def with_context(
    self, y_true: np.ndarray, y_pred: np.ndarray, context: tp.Dict[str, tp.Any]
  ) -> tp.Any:
    """Call the metric function with the required context."""
    missing = [key for key in self.required_context if key not in context]
    if missing:
      raise ValueError(f"Missing required context keys: {missing}")

    kwargs = {key: context[key] for key in self.required_context}
    return tp.cast(tp.Callable[..., tp.Any], self.fn)(y_true, y_pred, **kwargs)


def set_coverage(y_true, y_pred):
  """Compute the coverage for each prediction.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions

  Returns:
      Array of booleans indicating whether each prediction covers the true set
  """
  total_coverage = np.where(y_pred == 1, y_true, 0).sum(axis=1) > 0
  # find all inf rows, and for those mark coverage as 1
  inf_rows = get_abstention_rows(y_pred)
  total_coverage[inf_rows] = 1
  return total_coverage


def get_abstention_rows(array):
  """Identify rows where all predictions are inf (abstentions).

  Args:
      array: Array of predictions

  Returns:
      Boolean array indicating which rows are abstentions
  """
  return np.all(np.isinf(array), axis=1)


@metric()
def abstention_rate(y_true, y_pred):
  """Compute whether each prediction is an abstention.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions

  Returns:
      Boolean array indicating which rows are abstentions
  """
  del y_true  # Unused
  return get_abstention_rows(y_pred)


@metric()
def base_model_abstention_rate(y_true, y_pred):
  """Compute whether the base model abstains for each example.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions

  Returns:
      Boolean array indicating which rows have no admissible options
  """
  del y_pred  # Unused
  return np.sum(y_true, axis=1) == 0


@metric()
def false_positive_abstention_rate(y_true, y_pred):
  """Compute the fraction of predictions that are false positive abstentions.

  A false positive abstention is when the predictor abstains but the base model
  did not abstain (i.e., there was at least one admissible option).

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions

  Returns:
      Float indicating the fraction of false positive abstentions
  """
  abstentions = get_abstention_rows(y_pred)
  base_model_abstentions = np.sum(y_true, axis=1) == 0
  return abstentions & ~base_model_abstentions


@metric()
def non_abstention_coverage(y_true, y_pred):
  """Compute coverage for non-abstaining predictions only.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions

  Returns:
      Float indicating the coverage rate for non-abstaining predictions
  """
  base_model_abstentions = np.sum(y_true, axis=1) == 0
  outputs = np.zeros(y_true.shape[0])
  outputs[~base_model_abstentions] = set_coverage(
    y_true[~base_model_abstentions], y_pred[~base_model_abstentions]
  )
  return outputs


@metric()
def set_sizes(y_true, y_pred):
  """Compute the set sizes for non-abstaining predictions.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions

  Returns:
      Array of set sizes for non-abstaining predictions
  """
  del y_true  # Unused
  noninf_rows = ~get_abstention_rows(y_pred)
  set_sizes = np.full(y_pred.shape[0], np.nan)
  set_sizes[noninf_rows] = y_pred[noninf_rows].sum(axis=1)

  return set_sizes


@metric(context=["first_occurrence_indices"])
def effective_set_sizes(y_true, y_pred, first_occurrence_indices):
  """Compute the effective set sizes accounting for duplicates.

  For each prediction set, this computes the number of unique elements
  by using the first_occurrence_indices array. For example, if y_pred is [1, 0, 1]
  and first_occurrence_indices is [0, 1, 0], then the effective set size is 1
  because the third element is a duplicate of the first.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions
      first_occurrence_indices: Array where each element i contains the index j
                              where that element first occurred (j <= i).
                              If j == i, the element is unique up to that point.

  Returns:
      Array of effective set sizes for non-abstaining predictions
  """
  del y_true  # Unused
  noninf_rows = ~get_abstention_rows(y_pred)
  y_pred_filtered = y_pred[noninf_rows]
  output_arr = np.full(y_pred.shape[0], np.nan)
  effective_sizes = []
  for row_idx, pred in enumerate(y_pred_filtered):
    # Get indices where pred is 1
    selected_indices = np.where(pred == 1)[0]

    if len(selected_indices) == 0:
      effective_sizes.append(0)
      continue

    # Get first occurrence indices for selected elements
    first_occurrences = first_occurrence_indices[row_idx, selected_indices]
    # Count unique first occurrences
    unique_count = len(np.unique(first_occurrences))
    effective_sizes.append(unique_count)

  effective_sizes = np.array(effective_sizes)
  output_arr[noninf_rows] = effective_sizes
  return output_arr


def compute_metrics(y_true, y_pred, metrics_list, context=None):
  """Compute multiple metrics at once.

  Args:
      y_true: Boolean array of ground truth admissibility
      y_pred: Boolean array of predictions
      metrics_list: List of metric functions to compute
      context: Optional dictionary of context values needed by some metrics

  Returns:
      Dictionary mapping metric names to their computed values
  """
  results = {}
  context = context or {}

  for metric_fn in metrics_list:
    if isinstance(metric_fn, MetricFunction):
      if metric_fn.required_context:
        value = metric_fn.with_context(y_true, y_pred, context)
      else:
        value = metric_fn(y_true, y_pred)
    else:
      value = metric_fn(y_true, y_pred)

    # Get the name - for MetricFunction it's wrapped, for regular functions use __name__
    name = (
      metric_fn.__name__
      if not isinstance(metric_fn, MetricFunction)
      else metric_fn.fn.__name__
    )

    results[name] = value

  return results
