import pytest
import torch
from src.dataset import Batch, TSPDataset
from src.utils import solution_cost
from torch import Tensor


@pytest.mark.parametrize("x, s", [(torch.randn(100, 2), torch.randperm(100))])
def test_data_aug(x: Tensor, s: Tensor):
    c1 = solution_cost(x, s)
    c2 = solution_cost(x, TSPDataset.data_aug(s))
    assert torch.allclose(c1, c2)


@pytest.mark.parametrize(
    "x, s",
    [
        (
            [torch.randn(100, 2) for _ in range(256)],
            [torch.randperm(100) for _ in range(256)],
        )
    ],
)
def test_batch_generation(x: list[Tensor], s: list[Tensor]):
    (n_cities,) = s[0].shape
    batch = [(x_, s_) for x_, s_ in zip(x, s)]
    batch = Batch.from_samples(batch, torch.device("cpu"))

    assert torch.all(batch.o == 0)

    for i, (x_, s_) in enumerate(zip(x, s)):
        assert not torch.allclose(x_, batch.x[i, :n_cities])
        assert torch.allclose(
            solution_cost(x_, s_),
            solution_cost(batch.x[i, :n_cities], batch.s[i, :n_cities]),
        )
        for j in range(n_cities):
            if batch.e[i] != batch.o[i]:
                assert batch.m[i, j] == (j <= batch.e[i])
            else:
                assert batch.m[i, j]
