import numpy as np
from generative_prediction_sets.sampling_metrics import (
  effective_set_sizes,
  set_coverage,
  compute_metrics,
  abstention_rate,
  base_model_abstention_rate,
  set_sizes,
)


def test_effective_set_sizes_basic():
  """Test basic functionality without duplicates."""
  y_true = np.array(
    [
      [1, 0, 0],  # one option
      [1, 1, 0],  # two options
      [1, 1, 1],  # three options
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 0],  # size should be 1
      [1, 1, 0],  # size should be 2
      [1, 1, 1],  # size should be 3
    ]
  )
  # Each row has its own mapping where each element points to itself (all unique)
  first_occurrence_indices = np.array(
    [
      [0, 1, 2],  # Each element is unique in first row
      [0, 1, 2],  # Each element is unique in second row
      [0, 1, 2],  # Each element is unique in third row
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(sizes, [1, 2, 3])


def test_effective_set_sizes_with_duplicates():
  """Test when there are duplicate elements."""
  y_true = np.array(
    [
      [1, 0, 1],  # first and last are same
      [1, 1, 1],  # all selected, but first two are NOT same
      [1, 1, 0],  # first two are NOT same
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 1],  # should count as 1 (elements 0 and 2 are same)
      [1, 1, 1],  # should count as 2 (element 2 is same as 0)
      [1, 1, 0],  # should count as 2 (elements are not duplicates)
    ]
  )
  # Each row has its own mapping of duplicates
  first_occurrence_indices = np.array(
    [
      [0, 1, 0],  # In first row, element 2 is a duplicate of element 0
      [0, 1, 0],  # In second row, element 2 is a duplicate of element 0
      [0, 1, 2],  # In third row, all elements are unique
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(sizes, [1, 2, 2])


def test_effective_set_sizes_with_abstentions():
  """Test behavior with abstention rows."""
  y_true = np.array(
    [
      [1, 0, 1],
      [1, 1, 0],
    ]
  )
  y_pred = np.array(
    [
      [np.inf, np.inf, np.inf],  # abstention
      [1, 1, 0],  # normal prediction
    ]
  )
  # Each row has its own mapping
  first_occurrence_indices = np.array(
    [
      [0, 1, 2],  # Doesn't matter since it's an abstention
      [0, 1, 2],  # All elements unique in second row
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(sizes, [np.inf, 2])  # abstention row should be inf


def test_effective_set_sizes_complex_duplicates():
  """Test with more complex duplicate relationships."""
  y_true = np.array(
    [
      [1, 1, 1, 1],  # all options available
    ]
  )
  y_pred = np.array(
    [
      [1, 1, 1, 1],  # all selected
    ]
  )
  # One row with two pairs of duplicates
  first_occurrence_indices = np.array(
    [
      [0, 0, 2, 2],  # Elements 0-1 are duplicates and 2-3 are duplicates
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(sizes, [2])  # should count as 2 unique elements


def test_effective_set_sizes_empty_predictions():
  """Test with empty prediction sets."""
  y_true = np.array(
    [
      [1, 1, 1],
      [1, 1, 1],
    ]
  )
  y_pred = np.array(
    [
      [0, 0, 0],  # empty set
      [1, 1, 1],  # full set
    ]
  )
  # Each row has its own mapping
  first_occurrence_indices = np.array(
    [
      [0, 1, 2],  # Doesn't matter since prediction is empty
      [0, 1, 2],  # All elements unique
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(sizes, [0, 3])


def test_effective_set_sizes_transitive_duplicates():
  """Test with transitive duplicate relationships."""
  y_true = np.array(
    [
      [1, 1, 1, 1],  # all options available
    ]
  )
  y_pred = np.array(
    [
      [1, 1, 1, 1],  # all selected
    ]
  )
  # One row where first three elements are duplicates
  first_occurrence_indices = np.array(
    [
      [
        0,
        0,
        0,
        3,
      ],  # Elements 0,1,2 all point to 0 (they're duplicates), element 3 is unique
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(
    sizes, [2]
  )  # should count as 2 (one group of three + one singleton)


def test_effective_set_sizes_mixed_abstentions():
  """Test with a mix of abstentions, empty sets, and normal predictions."""
  y_true = np.array(
    [
      [1, 0, 1],  # two options
      [1, 1, 0],  # two options
      [1, 0, 1],  # two options
      [1, 1, 0],  # two options
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 1],  # normal prediction
      [np.inf, np.inf, np.inf],  # abstention
      [0, 0, 0],  # empty set
      [1, 1, 0],  # normal prediction
    ]
  )
  first_occurrence_indices = np.array(
    [
      [0, 1, 0],  # third element is duplicate of first
      [0, 1, 2],  # doesn't matter (abstention)
      [0, 1, 2],  # doesn't matter (empty)
      [0, 1, 2],  # all unique
    ]
  )

  sizes = effective_set_sizes(y_true, y_pred, first_occurrence_indices)
  np.testing.assert_array_equal(sizes, [1, np.inf, 0, 2])


def test_set_sizes_with_abstentions():
  """Test set_sizes with a mix of abstentions, empty sets, and normal predictions."""
  y_true = np.array(
    [
      [1, 0, 1],  # two options
      [1, 1, 0],  # two options
      [1, 0, 1],  # two options
      [1, 1, 0],  # two options
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 1],  # size should be 2
      [np.inf, np.inf, np.inf],  # abstention
      [0, 0, 0],  # empty set
      [1, 1, 0],  # size should be 2
    ]
  )

  sizes = set_sizes(y_true, y_pred)
  np.testing.assert_array_equal(sizes, [2, np.inf, 0, 2])


def test_set_coverage_basic():
  """Test basic coverage functionality."""
  y_true = np.array(
    [
      [1, 0, 0],  # true option is first
      [0, 1, 0],  # true option is second
      [0, 0, 1],  # true option is third
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 0],  # should cover (1)
      [0, 0, 1],  # should not cover (0)
      [1, 1, 1],  # should cover (1)
    ]
  )

  coverage = set_coverage(y_true, y_pred)
  np.testing.assert_array_equal(coverage, [1, 0, 1])


def test_set_coverage_multiple_true():
  """Test coverage with multiple true options."""
  y_true = np.array(
    [
      [1, 1, 0],  # two true options
      [1, 1, 1],  # all true
    ]
  )
  y_pred = np.array(
    [
      [0, 1, 0],  # should cover (hits second true)
      [0, 0, 1],  # should cover (hits third true)
    ]
  )

  coverage = set_coverage(y_true, y_pred)
  np.testing.assert_array_equal(coverage, [1, 1])


def test_set_coverage_with_abstentions():
  """Test coverage behavior with abstention rows."""
  y_true = np.array(
    [
      [1, 0, 0],
      [0, 1, 0],
      [0, 0, 1],
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 0],  # normal prediction
      [np.inf, np.inf, np.inf],  # abstention
      [0, 0, 0],  # empty set
    ]
  )

  coverage = set_coverage(y_true, y_pred)
  np.testing.assert_array_equal(coverage, [1, 1, 0])


def test_set_coverage_empty_predictions():
  """Test coverage with empty prediction sets."""
  y_true = np.array(
    [
      [1, 0, 0],
      [0, 1, 0],
    ]
  )
  y_pred = np.array(
    [
      [0, 0, 0],  # empty set
      [0, 0, 0],  # empty set
    ]
  )

  coverage = set_coverage(y_true, y_pred)
  np.testing.assert_array_equal(coverage, [0, 0])


def test_compute_metrics_basic():
  """Test compute_metrics with basic metrics that don't require context."""
  y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
  y_pred = np.array(
    [
      [1, 0, 0],  # correct prediction
      [np.inf, np.inf, np.inf],  # abstention
      [0, 0, 0],  # empty set
    ]
  )

  metrics_list = [abstention_rate, base_model_abstention_rate]
  results = compute_metrics(y_true, y_pred, metrics_list)

  assert "abstention_rate" in results
  assert "base_model_abstention_rate" in results
  np.testing.assert_array_equal(
    results["abstention_rate"],
    [False, True, False],  # only second row is abstention
  )
  np.testing.assert_array_equal(
    results["base_model_abstention_rate"],
    [False, False, False],  # all rows have one true option
  )


def test_compute_metrics_with_context():
  """Test compute_metrics with metrics that require context."""
  y_true = np.array(
    [
      [1, 0, 1],  # two options
      [1, 1, 0],  # two options
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 1],  # selects both
      [1, 1, 0],  # selects both
    ]
  )

  # Mock first_occurrence_indices where in first row, third element is duplicate of first
  first_occurrence_indices = np.array(
    [
      [0, 1, 0],  # third element is duplicate of first
      [0, 1, 2],  # all unique
    ]
  )

  context = {"first_occurrence_indices": first_occurrence_indices}
  metrics_list = [effective_set_sizes]
  results = compute_metrics(y_true, y_pred, metrics_list, context=context)

  assert "effective_set_sizes" in results
  np.testing.assert_array_equal(results["effective_set_sizes"], [1, 2])


def test_compute_metrics_missing_context():
  """Test compute_metrics raises error when required context is missing."""
  y_true = np.array([[1, 0, 0]])
  y_pred = np.array([[1, 0, 0]])

  metrics_list = [effective_set_sizes]  # requires first_occurrence_indices

  try:
    compute_metrics(y_true, y_pred, metrics_list)
    assert False, "Expected ValueError for missing context"
  except ValueError as e:
    assert "Missing required context keys" in str(e)


def test_compute_metrics_mixed():
  """Test compute_metrics with mix of context and non-context metrics."""
  y_true = np.array(
    [
      [1, 0, 1],
      [1, 1, 0],
    ]
  )
  y_pred = np.array(
    [
      [1, 0, 1],
      [np.inf, np.inf, np.inf],
    ]
  )

  first_occurrence_indices = np.array(
    [
      [0, 1, 0],
      [0, 1, 2],
    ]
  )

  metrics_list = [abstention_rate, effective_set_sizes, set_coverage]

  context = {"first_occurrence_indices": first_occurrence_indices}
  results = compute_metrics(y_true, y_pred, metrics_list, context=context)

  assert "abstention_rate" in results
  assert "effective_set_sizes" in results
  assert "set_coverage" in results
  np.testing.assert_array_equal(
    results["abstention_rate"],
    [False, True],  # second row is abstention
  )
  np.testing.assert_array_equal(
    results["effective_set_sizes"],
    [
      1,
      np.inf,
    ],  # first row has 1 unique element (due to duplicate), second row is abstention
  )
  np.testing.assert_array_equal(
    results["set_coverage"],
    [True, True],  # both rows covered (abstention counts as covered)
  )
