import pytest
import torch
from torch import Tensor

from torchjd.criterion._utils import _move_dim_back, _move_dim_to_front


@pytest.mark.parametrize(
    ["t", "dim", "expected_output"],
    [
        (torch.ones(2, 3), 1, torch.ones(3, 2)),
        (torch.ones(2, 3, 4, 5, 6), 0, torch.ones(2, 3, 4, 5, 6)),
        (torch.ones(2, 3, 4, 5, 6), 1, torch.ones(3, 2, 4, 5, 6)),
        (torch.ones(2, 3, 4, 5, 6), 2, torch.ones(4, 2, 3, 5, 6)),
        (torch.ones(2, 3, 4, 5, 6), 3, torch.ones(5, 2, 3, 4, 6)),
        (torch.ones(2, 3, 4, 5, 6), 4, torch.ones(6, 2, 3, 4, 5)),
        (torch.ones(2), 0, torch.ones(2)),
    ],
)
def test__move_dim_to_front(t: Tensor, dim: int, expected_output: Tensor):
    output = _move_dim_to_front(t, dim)
    assert torch.equal(output, expected_output)


@pytest.mark.parametrize(
    ["t", "dim", "expected_output"],
    [
        (torch.ones(3, 2), 1, torch.ones(2, 3)),
        (torch.ones(2, 3, 4, 5, 6), 0, torch.ones(2, 3, 4, 5, 6)),
        (torch.ones(3, 2, 4, 5, 6), 1, torch.ones(2, 3, 4, 5, 6)),
        (torch.ones(4, 2, 3, 5, 6), 2, torch.ones(2, 3, 4, 5, 6)),
        (torch.ones(5, 2, 3, 4, 6), 3, torch.ones(2, 3, 4, 5, 6)),
        (torch.ones(6, 2, 3, 4, 5), 4, torch.ones(2, 3, 4, 5, 6)),
        (torch.ones(2), 0, torch.ones(2)),
    ],
)
def test__move_dim_back(t: Tensor, dim: int, expected_output: Tensor):
    output = _move_dim_back(t, dim)
    assert torch.equal(output, expected_output)
