import pytest
import torch
from src.model import Model
from src.solve import solve
from torch import Tensor

torch.manual_seed(0)


def trivial_solve(x: Tensor, model: Model) -> Tensor:
    batch_size, n_cities, _ = x.shape
    device = x.device

    m = torch.ones((batch_size, n_cities), dtype=torch.bool, device=device)
    o = e = torch.zeros((batch_size,), dtype=torch.int, device=device)
    b = torch.arange(batch_size, device=device)

    solutions = [o]

    for _ in range(n_cities - 1):
        l = model(x, m, o, e)
        m[b, o] = False
        m[b, e] = True
        o = l.argmax(dim=1)
        solutions.append(o)

    return torch.stack(solutions, dim=1)


@pytest.mark.parametrize(
    "model, x",
    [
        (Model(16, 32, 1, 1, use_coords=True), torch.randn((32, 16, 2))),
    ],
)
def test_trivial_solve(model: Model, x: Tensor):
    """Make sure the generated solution is valid."""
    _, n_cities, _ = x.shape
    device = "cuda" if torch.cuda.is_available() else "cpu"

    x = x.to(device)
    model.to(device)

    s = trivial_solve(x, model)

    for s_ in s:
        assert torch.all(torch.sort(s_)[0] == torch.arange(n_cities, device=device))


@pytest.mark.parametrize(
    "model, x",
    [
        (Model(16, 32, 1, 1, use_coords=True), torch.randn((32, 16, 2))),
        (Model(16, 32, 1, 1, use_coords=True), torch.randn((1, 16, 2))),
    ],
)
def test_solve(model: Model, x: Tensor):
    """Make sure the generated solution is valid."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x = x.to(device)
    model.to(device)

    s = solve(x, model)
    s_ = trivial_solve(x, model)

    assert torch.all(s == s_)
