import pytest
import torch
from datasets import Dataset

from hallucinations.utils import sort_dataset_by_input_length
from hallucinations.utils.misc import print_shape


def test_dataset_sort() -> None:
    ds = Dataset.from_dict({"text": ["a", "bb", "ccc", "ddd", "eeee"], "label": [1, 2, 3, 4, 5]})

    ds, sort_idx = sort_dataset_by_input_length(ds, "text")
    assert ds["text"] == ["eeee", "ccc", "ddd", "bb", "a"]
    assert ds["label"] == [5, 3, 4, 2, 1]
    assert sort_idx == [4, 3, 1, 2, 0]


def test_dataset_inverse_sort() -> None:
    ds = Dataset.from_dict({"text": ["a", "bb", "ccc", "ddd", "eeee"], "label": [1, 2, 3, 4, 5]})

    ds, sort_idx = sort_dataset_by_input_length(ds, "text")
    ds = ds.select(sort_idx)
    assert ds["text"] == ["a", "bb", "ccc", "ddd", "eeee"]
    assert ds["label"] == [1, 2, 3, 4, 5]


def test_print_shape(capsys: pytest.CaptureFixture) -> None:
    data = [torch.zeros(2, 3), (torch.zeros(1),)]
    print_shape(data)
    captured = capsys.readouterr()
    assert "List(len=2)" in captured.out
    assert "Tensor(shape=(2, 3)" in captured.out
    assert "Tuple(len=1)" in captured.out
    assert "Tensor(shape=(1,)" in captured.out
