from itertools import product

import pytest
import torch
from src.utils import pairwise_distances, solution_cost
from torch import Tensor

torch.manual_seed(0)


@pytest.mark.parametrize("x", [torch.randn((100, 2))])
def test_pairwise_distances(x: Tensor):
    n_points, _ = x.shape
    d = torch.zeros((n_points, n_points), dtype=torch.float32)
    for i, j in product(range(n_points), repeat=2):
        d[i, j] = torch.norm(x[i] - x[j], p=2)

    assert torch.allclose(d, pairwise_distances(x))


@pytest.mark.parametrize("x", [torch.randn((32, 100, 2))])
def test_batch_pairwise_distances(x: Tensor):
    batch_size, n_points, _ = x.shape
    d = torch.zeros((batch_size, n_points, n_points), dtype=torch.float32)
    for b, (i, j) in product(range(batch_size), product(range(n_points), repeat=2)):
        d[b, i, j] = torch.norm(x[b, i] - x[b, j], p=2)

    assert torch.allclose(d, torch.vmap(pairwise_distances)(x))


@pytest.mark.parametrize("x, s", [(torch.randn((100, 2)), torch.randperm(100))])
def test_solution_cost(x: Tensor, s: Tensor):
    n_cities, _ = x.shape
    d = pairwise_distances(x)
    c = sum((d[s[i], s[(i + 1) % n_cities]] for i in range(n_cities)), start=torch.tensor(0.0))
    assert torch.allclose(c, solution_cost(x, s))


@pytest.mark.parametrize(
    "x, s",
    [
        (torch.randn((32, 100, 2)), torch.stack([torch.randperm(100) for _ in range(32)])),
    ],
)
def test_batch_solution_cost(x: Tensor, s: Tensor):
    costs = torch.vmap(solution_cost)(x, s)
    for x, s, c in zip(x, s, costs):
        assert torch.allclose(c, solution_cost(x, s))
