from generative_prediction_sets.classification_scores import HPSScore, APSScore
import numpy as np


def test_hps_score_sanity():
  y_true = np.array([[0, 1, 1]])
  y_pred = np.array([[0.1, 0.8, 0.9]])
  score = HPSScore()
  assert score(y_true, y_pred)[0] == 0.8
  assert (score.invert(y_pred, 0.8)[0] == [0, 1, 1]).all()


def test_aps_score_no_randomization():
  # Test with a simple 3-class example
  y_true = np.array([[0, 1, 0]])  # Second class is true
  y_pred = np.array([[0.1, 0.8, 0.1]])  # High confidence on second class

  score = APSScore(randomize=False)

  # Test compute
  # Should include only the true class (0.8) since it has highest probability
  assert score.compute(y_true, y_pred)[0] == 0.8

  # Test invert with different quantiles
  # At quantile 0.7, should include only the highest prob class
  pred_set = score.invert(y_pred, 0.8)[0]
  assert (pred_set == [0, 1, 0]).all()

  # At quantile 0.9, should include top two classes
  pred_set = score.invert(y_pred, 0.9)[0]
  assert (pred_set == [0, 1, 1]).all()


def test_aps_score_multiple_samples():
  y_true = np.array(
    [
      [0, 1, 0],  # Second class is true
      [1, 0, 0],  # First class is true
    ]
  )
  y_pred = np.array(
    [
      [0.1, 0.8, 0.1],  # High confidence, correct
      [0.2, 0.5, 0.3],  # Low confidence, correct
    ]
  )

  score = APSScore(randomize=False)

  # Test compute
  scores = score.compute(y_true, y_pred)
  assert scores[0] == 0.8  # First sample: only includes true class (high confidence)
  assert scores[1] == 1  # Second sample: includes top two classes (low confidence)

  # Test invert
  pred_sets = score.invert(y_pred, 0.8)
  assert (pred_sets[0] == [0, 1, 0]).all()  # First sample: first class
  assert (pred_sets[1] == [0, 1, 1]).all()  # Second sample: top two classes


def test_aps_score_with_randomization():
  # Test with fixed random seed for reproducibility
  np.random.seed(42)

  y_true = np.array([[0, 1, 0]])  # Second class is true
  y_pred = np.array([[0.1, 0.8, 0.1]])  # High confidence on second class

  score = APSScore(randomize=True)

  # Test compute with randomization
  randomized_score = score.compute(y_true, y_pred)[0]
  # The score should be 0.8 + 0.8 * (u - 1) where u is random uniform
  # With seed 42, we can compute the exact expected value
  np.random.seed(42)  # Reset seed to get same random number
  u = np.random.uniform()
  expected_score = 0.8 + 0.8 * (u - 1)
  assert np.allclose(randomized_score, expected_score)

  # Test invert with different quantiles
  # At quantile 0.8, should still include the highest prob class
  pred_set_08 = score.invert(y_pred, 0.8)[0]
  assert (pred_set_08 == [0, 1, 0]).all()

  # At quantile 0.9, should include top two classes like non-randomized case
  pred_set_09 = score.invert(y_pred, 0.9)[0]
  assert (pred_set_09 == [0, 1, 1]).all()

  # Test that different random seeds give different results
  np.random.seed(43)
  different_score = score.compute(y_true, y_pred)[0]
  assert randomized_score != different_score


def test_aps_score_sequential_sort():
  # Test with a simple 3-class example
  y_true = np.array([[0, 1, 0]])  # Second class is true
  y_pred = np.array([[0.1, 0.8, 0.1]])  # High confidence on second class

  # Define a sort_fn that returns sequential order regardless of probabilities
  def sequential_sort_fn(y_pred: np.ndarray) -> np.ndarray:
    # For each row, return indices [0, 1, 2, ...]
    return np.tile(np.arange(y_pred.shape[1]), (y_pred.shape[0], 1))

  score = APSScore(sort_fn=sequential_sort_fn, randomize=False)

  # Test compute
  # Since we're using sequential order and true class is index 1,
  # score should be sum of probabilities up to and including index 1
  assert score.compute(y_true, y_pred)[0] == 0.9  # 0.1 + 0.8

  # Test invert with different quantiles
  # At quantile 0.8, should include first class
  pred_set = score.invert(y_pred, 0.8)[0]
  assert (pred_set == [1, 0, 0]).all()

  # At quantile 0.9, should include first two classes
  pred_set = score.invert(y_pred, 0.9)[0]
  assert (pred_set == [1, 1, 0]).all()

  # At quantile 0.95, should include all three classes
  pred_set = score.invert(y_pred, 1)[0]
  assert (pred_set == [1, 1, 1]).all()
