from torch import nn

from torchjd.criterion import SplitTensorCriterion


class _FakeLoss(nn.Module):
    def __str__(self):
        return "FakeLoss"


def test_str_split_batch():
    criterion = SplitTensorCriterion(loss_function=_FakeLoss(), dim=0, chunk_size=8)
    assert str(criterion) == "SplitBatch-8 FakeLoss"


def test_str_other():
    criterion = SplitTensorCriterion(loss_function=_FakeLoss(), dim=1, chunk_size=4)
    assert str(criterion) == "SplitTensor-1-4 FakeLoss"
