import torch

from smlmsim.dynamics.discretize import batch_discretize_periods, discretize_periods


def test_discretize():
    periods = torch.tensor(
        [
            [
                [1.9, 2.3],
                [3.1, 3.3],
                [3.8, 4.1],
                [6.0, 7.2],
            ],
            [
                [0.2, 0.6],
                [0.9, 3.2],
                [5.7, 6.9],
                [torch.inf, torch.inf],
            ],
        ]
    )
    expected_fractions = torch.tensor(
        [
            [0.0, 0.1, 0.3, 0.4, 0.1, 0.0, 1.0, 0.2],
            [0.5, 1.0, 1.0, 0.2, 0.0, 0.3, 0.9, 0.0],
        ]
    )

    for p, f_tg in zip(periods, expected_fractions):
        f = discretize_periods(p, n_frames=8)
        torch.testing.assert_close(f, f_tg)


def test_batch_discretize():
    periods = torch.tensor(
        [
            [
                [1.9, 2.3],
                [3.1, 3.3],
                [3.8, 4.1],
                [6.0, 7.2],
            ],
            [
                [0.2, 0.6],
                [0.9, 3.2],
                [5.7, 6.9],
                [torch.inf, torch.inf],
            ],
        ]
    )
    expected_fractions = torch.tensor(
        [
            [0.0, 0.1, 0.3, 0.4, 0.1, 0.0, 1.0, 0.2],
            [0.5, 1.0, 1.0, 0.2, 0.0, 0.3, 0.9, 0.0],
        ]
    )
    fractions = batch_discretize_periods(periods, n_frames=8)
    torch.testing.assert_close(fractions, expected_fractions)
