from generative_prediction_sets.core import ConformityScore, NonConformityScore
import typing as tp
import dataclasses
import numpy as np


@dataclasses.dataclass
class HPSScore(ConformityScore):
  """
  HPS score for classification.

  This is positively oriented, meaning that a higher score is better.
  Typically, we want this to by f(y|x) for a classification model.
  """

  """Whether or not a higher score is better. """

  def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    # y_pred needs to be [n_samples, n_classes]
    # check this
    if y_pred.ndim == 1:
      raise ValueError(
        "y_pred must be a 2D array class probabilities/scores for each sample"
      )
    # y_true needs to be [n_samples, n_classes]
    # where each row is one-hot encoded
    # Each row might have multiple-non-zero entries
    # we need to find the index of the non-zero entry
    # with the highest score
    if y_true.ndim == 1:
      raise ValueError(
        "y_true must be a 2D array one-hot encoded class labels for each sample"
      )

    # get indices of all non-zero entries in y_true
    # inf_val = np.inf if self.positively_oriented else -np.inf
    nonzero_scores = np.where(y_true, y_pred, np.inf)
    return nonzero_scores.min(axis=1)

  def invert(self, y_pred: np.ndarray, quantile: float) -> np.ndarray:
    if quantile == np.inf or quantile == -np.inf:
      return np.full_like(y_pred, np.inf)
    else:
      return y_pred >= quantile


def sequential_sort_fn(y_pred: np.ndarray) -> np.ndarray:
  # For each row, return indices [0, 1, 2, ...]
  return np.tile(np.arange(y_pred.shape[1]), (y_pred.shape[0], 1))


def _aps_default_sort_fn(y_pred: np.ndarray) -> np.ndarray:
  """Default sorting function for APS that sorts classes by descending probability.

  Args:
      y_pred: Array of shape [n_samples, n_classes] containing predicted probabilities

  Returns:
      Array of shape [n_samples, n_classes] containing indices sorted by descending probability
  """
  return np.fliplr(np.argsort(y_pred, axis=1))


@dataclasses.dataclass
class APSScore(NonConformityScore):
  """Adaptive Prediction Set (APS) score for classification.

  The APS score is a nonconformity score that produces nested prediction sets based on a
  specified ordering of classes. By default, it orders classes by their predicted probabilities
  in descending order, but this can be customized via the sort_fn parameter.

  For a given sample:
  1. Classes are ordered according to sort_fn (default: descending probability)
  2. The score is the sum of probabilities up to and including the true class in this ordering
  3. The prediction set at level alpha includes all classes up to the point where cumulative
     probability exceeds the calibrated quantile

  Args:
      sort_fn: Function that takes predicted probabilities [n_samples, n_classes] and returns
          indices [n_samples, n_classes] specifying the order to accumulate classes.
          Default orders by descending probability. Can be customized for different orderings,
          e.g. sequential order [0,1,2,...] regardless of probabilities.
      randomize: Whether to add randomization for smoother calibration. Default True.

  Example:
      ```python
      # Default behavior - order by descending probability
      score = APSScore()

      # Custom sequential ordering regardless of probability
      def sequential_sort(y_pred):
          return np.tile(np.arange(y_pred.shape[1]), (y_pred.shape[0], 1))
      score = APSScore(sort_fn=sequential_sort)
      ```
  """

  sort_fn: tp.Callable[[np.ndarray], np.ndarray] = _aps_default_sort_fn
  randomize: bool = True

  def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    """Compute the APS score for given true labels and predictions.

    The score for each sample is the sum of probabilities up to and including the true class
    in the ordering specified by sort_fn. If randomization is enabled, adds randomization
    to the true class probability for smoother calibration.

    Args:
        y_true: One-hot encoded true labels of shape [n_samples, n_classes]
        y_pred: Predicted probabilities of shape [n_samples, n_classes]

    Returns:
        Array of shape [n_samples] containing the APS scores
    """
    # Get the true class indices from one-hot encoded y_true
    y_true_indices = np.argmax(y_true, axis=1)

    # 1. Compute ranks of each class (in descending order of probabilities)
    sorted_indices = self.sort_fn(y_pred)
    ranks = np.argsort(sorted_indices, axis=1)

    # 2. Get ranks of true classes
    y_true_ranks = np.take_along_axis(ranks, y_true_indices[:, None], axis=1).flatten()

    # 3. Create masks for classes with rank <= true class rank
    set_masks = ranks <= y_true_ranks[:, None]

    # 4. Sum probabilities for classes in the set
    scores = np.where(set_masks, y_pred, 0).sum(axis=1)

    if self.randomize:
      # Apply randomization to the true class probabilities
      us = np.random.uniform(size=len(y_true))
      y_true_probs = np.take_along_axis(
        y_pred, y_true_indices[:, None], axis=1
      ).flatten()
      scores += y_true_probs * (us - 1)

    return scores

  def invert(self, y_pred: np.ndarray, quantile: float) -> np.ndarray:
    """Compute prediction sets by inverting the APS score.

    For each sample, includes classes in order specified by sort_fn until their cumulative
    probability exceeds the given quantile. If randomization is enabled, adds randomization
    to all probabilities.

    Args:
        y_pred: Predicted probabilities of shape [n_samples, n_classes]
        quantile: Quantile threshold for including classes in prediction set

    Returns:
        Boolean array of shape [n_samples, n_classes] indicating which classes are in
        each sample's prediction set
    """
    # 1. Sort probabilities according to sort_fn
    sorted_prob_idx = self.sort_fn(y_pred)
    sorted_probs = np.take_along_axis(y_pred, sorted_prob_idx, axis=1)

    # 2. Compute cumulative probabilities
    sorted_cumulative_probs = np.cumsum(sorted_probs, axis=1)

    # 3. Map back to original indices
    inverse_idx = np.argsort(sorted_prob_idx, axis=1)
    cumulative_probs = np.take_along_axis(sorted_cumulative_probs, inverse_idx, axis=1)

    if self.randomize:
      # Apply randomization to all probabilities
      us = np.random.uniform(size=len(y_pred))
      cumulative_probs = cumulative_probs - y_pred + us[:, None] * y_pred

    return cumulative_probs <= quantile
