import numpy as np
import pytest

from generative_prediction_sets.utils import (
  pairwise_comparisons as pairwise_comparisons_no_dedup,
  compute_first_occurrence_indices,
)


# Simple binary function for pairwise comparisons
def simple_comparison_fn(x, y):
  return f"{x}_{y}"


def idfn(val):
  if isinstance(val, np.ndarray):
    return f"array_{val.shape}"
  elif isinstance(val, bool):
    return f"cache_{val}"
  return str(val)


test_cases = [
  # Test case 1: Basic functionality with integers
  (
    np.array([1, 2, 3]),
    np.array([["1_1", "1_2", "1_3"], ["1_2", "2_2", "2_3"], ["1_3", "2_3", "3_3"]]),
    "basic_functionality_with_integers",
  ),
  # Test case 2: All identical elements with strings
  (
    np.array(["a", "a", "a"]),
    np.array([["a_a", "a_a", "a_a"], ["a_a", "a_a", "a_a"], ["a_a", "a_a", "a_a"]]),
    "all_identical_elements_with_strings",
  ),
  # Test case 3: Array with mixed types
  (
    np.array([1, "a", 2]),
    np.array([["1_1", "1_a", "1_2"], ["1_a", "a_a", "a_2"], ["1_2", "a_2", "2_2"]]),
    "array_with_mixed_types",
  ),
  # Test case 4: Array with repeating integers
  (
    np.array([1, 2, 2, 3, 1]),
    np.array(
      [
        ["1_1", "1_2", "1_2", "1_3", "1_1"],
        ["1_2", "2_2", "2_2", "2_3", "2_1"],
        ["1_2", "2_2", "2_2", "2_3", "2_1"],
        ["1_3", "2_3", "2_3", "3_3", "3_1"],
        ["1_1", "2_1", "2_1", "3_1", "1_1"],
      ]
    ),
    "array_with_repeating_integers",
  ),
  # Test case 5: Empty array
  (np.array([]), np.array([]).reshape(0, 0), "empty_array"),
]


@pytest.mark.parametrize(
  "arr, expected_result, use_cache",
  [
    pytest.param(arr, expected_result, use_cache, id=f"{id}_cache_{use_cache}")
    for arr, expected_result, id in test_cases
    for use_cache in [True, False]
  ],
  ids=idfn,
)
def test_pairwise_comparisons_no_dedup(arr, expected_result, use_cache):
  result = pairwise_comparisons_no_dedup(arr, simple_comparison_fn, cache=use_cache)
  np.testing.assert_array_equal(
    result,
    expected_result,
    err_msg=f"Test case with array {arr} and cache={use_cache} failed",
  )


# Test cases for compute_first_occurrence_indices
first_occurrence_test_cases = [
  # Basic list with duplicates
  (["a", "b", "a", "c"], np.array([0, 1, 0, 3]), "basic_list_with_duplicates"),
  # All unique elements
  (["x", "y", "z"], np.array([0, 1, 2]), "all_unique_elements"),
  # All identical elements
  (["a", "a", "a", "a"], np.array([0, 0, 0, 0]), "all_identical_elements"),
  # Empty list
  ([], np.array([], dtype=np.int32), "empty_list"),
  # Mixed types
  ([1, "a", 1, 2.0, "a"], np.array([0, 1, 0, 3, 1]), "mixed_types"),
  # Single element
  (["x"], np.array([0]), "single_element"),
]


@pytest.mark.parametrize(
  "items, expected_indices, test_id",
  first_occurrence_test_cases,
  ids=lambda x: x[2] if isinstance(x, tuple) else str(x),
)
def test_compute_first_occurrence_indices(items, expected_indices, test_id):
  result = compute_first_occurrence_indices(items)
  np.testing.assert_array_equal(
    result, expected_indices, err_msg=f"Test case '{test_id}' failed"
  )
