import pytest
import pandas as pd
import numpy as np
import pytest
import pandas as pd
import numpy as np
from pandas.testing import assert_frame_equal

# Import the function to test
from llm_inference.utils import pd_unique  # Adjust this import as needed


@pytest.mark.parametrize(
  "input_data, subset, expected_unique_count",
  [
    (
      # Basic functionality test
      {
        "A": [1, 1, 2, 2, 3, 3],
        "B": ["a", "a", "b", "b", "c", "c"],
        "C": [10, 10, 20, 20, 30, 30],
        "D": [100, 200, 300, 400, 500, 600],
      },
      ["A", "B", "C"],
      3,
    ),
    (
      # All unique test
      {"A": [1, 2, 3, 4, 5], "B": ["a", "b", "c", "d", "e"]},
      ["A", "B"],
      5,
    ),
    (
      # All duplicate test
      {"A": [1, 1, 1, 1, 1], "B": ["a", "a", "a", "a", "a"]},
      ["A", "B"],
      1,
    ),
    (
      # Mixed case with some duplicates
      {"A": [1, 2, 1, 3, 2, 4], "B": ["a", "b", "a", "c", "b", "d"]},
      ["A", "B"],
      4,
    ),
    (
      # Test with non-consecutive duplicates
      {"A": [1, 2, 3, 1, 2, 3], "B": ["a", "b", "c", "a", "b", "c"]},
      ["A", "B"],
      3,
    ),
  ],
)
def test_pd_unique(input_data, subset, expected_unique_count):
  df = pd.DataFrame(input_data)
  unique_idx, inverse_idx = pd_unique(df, subset)

  # Check if the number of unique rows is correct
  assert len(unique_idx) == expected_unique_count

  # Check if inverse_idx has the same length as the original dataframe
  assert len(inverse_idx) == len(df)

  # Check if all values in inverse_idx are within the range of unique_idx
  assert np.all(inverse_idx < len(unique_idx))

  # Verify broadcasting functionality
  df_unique = df.iloc[unique_idx].copy()
  test_values = range(len(df_unique))

  # Broadcast the test values back to the original dataframe shape
  broadcasted_values = np.array(test_values)[inverse_idx]

  # Check if the broadcasted values match the expected pattern
  for i, group in df.groupby(subset):
    group_indices = group.index
    assert (
      len(set(broadcasted_values[group_indices])) == 1
    )  # All values in the group should be the same

  # Additional check: reconstruct the original dataframe using unique_idx and inverse_idx
  df_reconstructed = df.iloc[unique_idx].iloc[inverse_idx].reset_index(drop=True)
  assert_frame_equal(df[subset], df_reconstructed[subset])


@pytest.mark.parametrize(
  "input_data, subset",
  [
    (
      # Test with empty DataFrame
      pd.DataFrame(),
      ["A", "B"],
    ),
    (
      # Test with non-existent column
      pd.DataFrame({"A": [1, 2, 3]}),
      ["A", "B"],
    ),
  ],
)
def test_pd_unique_edge_cases(input_data, subset):
  with pytest.raises(Exception):
    pd_unique(input_data, subset)
