### Imports ###########################################################################################################

import pytest
import numpy as np
import torch
from gcontrol.utils.grad_im_preprocessors import grad_normalize, grad_scale, grad_resize, grad_centercrop
from fixtures.grad_preprocessors_fixtures import *
from torchvision.transforms import CenterCrop

#######################################################################################################################

ATOL = 1e-8
RTOL = 1e-8

### Tests #############################################################################################################


# grad_normalize tests
@pytest.mark.parametrize(
    "mat, mean, std, gtruth",
    [
        ("mat_22", "mean_0", "std_0", "nres_22_0"),
        ("mat_22", "mean_1", "std_1", "nres_22_0"),
        ("mat_122", "mean_0", "std_0", "nres_122_0"),
        ("mat_122", "mean_1", "std_1", "nres_122_0"),
        ("mat_234", "mean_0", "std_0", "nres_234_0"),
        ("mat_234", "mean_1", "std_1", "nres_234_0"),
        ("mat_23112", "mean_0", "std_0", "nres_23112_0"),
        ("mat_23112", "mean_1", "std_1", "nres_23112_0"),
    ],
)
def test_grad_im_normalize_0(mat, mean, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    mean = request.getfixturevalue(mean)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    result = torch.round(grad_normalize(mat, mean, std), decimals=4)
    assert torch.allclose(result, gtruth, rtol=RTOL, atol=ATOL)
    assert result.dtype == gtruth.dtype


@pytest.mark.parametrize(
    "mat, mean, std, gtruth",
    [
        ("mat_234", "mean_2", "std_2", "nres_234_2"),
        ("mat_23112", "mean_231", "std_231", "nres_23112_231"),
    ],
)
def test_grad_im_normalize_2(mat, mean, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    mean = request.getfixturevalue(mean)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    result = torch.round(grad_normalize(mat, mean, std), decimals=4)
    assert torch.allclose(result, gtruth, rtol=RTOL, atol=ATOL)
    assert result.dtype == gtruth.dtype


@pytest.mark.parametrize(
    "mat, mean, std, gtruth, dtype",
    [
        ("mat_22", "mean_0", "std_0", "nres_22_0", torch.float16),
        ("mat_23112", "mean_231", "std_231", "nres_23112_231", torch.float64),
    ],
)
def test_grad_im_normalize_typing(mat, mean, std, gtruth, dtype, request):

    mat = request.getfixturevalue(mat)
    mean = request.getfixturevalue(mean)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    mat = mat.to(dtype)
    gtruth = gtruth.to(dtype)
    result = torch.round(grad_normalize(mat, mean, std), decimals=4)
    assert torch.allclose(result, gtruth, rtol=1e-2, atol=1e-2)
    assert result.dtype == dtype


@pytest.mark.parametrize(
    "mat, mean, std, err",
    [
        (np.array([1, 2]), "mean_0", "std_0", TypeError),
        (torch.tensor([1, 2], dtype=torch.int32), "mean_0", "std_0", ValueError),
        ("mat_122", "mean_231", "std_0", ValueError),
        ("mat_122", "mean_0", "std_231", ValueError),
    ],
)
def test_grad_im_normalize_errorhandle(mat, mean, std, err, request):

    if isinstance(mat, str):
        mat = request.getfixturevalue(mat)
    mean = request.getfixturevalue(mean)
    std = request.getfixturevalue(std)

    try:
        grad_normalize(mat, mean, std)
        assert False
    except err:
        assert True


@pytest.mark.parametrize(
    "mat, mean, std, gtruth",
    [("mat_22", "mean_0", "std_0", "nres_22_0"), ("mat_23112", "mean_231", "std_231", "nres_23112_231")],
)
def test_grad_im_normalize_autograd(mat, mean, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    mean = request.getfixturevalue(mean)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    vmat = torch.autograd.Variable(mat, requires_grad=True)
    result = grad_normalize(vmat, mean, std)
    torch.sum(result).backward()
    result = torch.round(result, decimals=4)
    gmat = vmat.grad

    assert torch.allclose(gmat[..., 0, 0], 1 / std, rtol=RTOL, atol=ATOL)
    assert gmat.dtype == mat.dtype

    assert torch.allclose(result, gtruth, rtol=RTOL, atol=ATOL)
    assert result.dtype == gtruth.dtype


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.parametrize(
    "mat, mean, std, gtruth",
    [("mat_22", "mean_0", "std_0", "nres_22_0"), ("mat_23112", "mean_231", "std_231", "nres_23112_231")],
)
def test_grad_im_normalize_cuda(mat, mean, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    mean = request.getfixturevalue(mean)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    c_mat = mat.cuda()
    c_mean = mean.to(device=c_mat.device)
    c_std = std.to(device=c_mat.device)
    c_gtruth = gtruth.to(device=c_mat.device)

    c_result = torch.round(grad_normalize(c_mat, c_mean, c_std), decimals=4)
    assert torch.allclose(c_result, c_gtruth, rtol=RTOL, atol=ATOL)
    assert c_result.dtype == c_gtruth.dtype


# grad_scale tests
@pytest.mark.parametrize(
    "mat, std, gtruth",
    [
        ("mat_22", "std_0", "sres_22_0"),
        ("mat_22", "std_1", "sres_22_0"),
        ("mat_122", "std_0", "sres_122_0"),
        ("mat_122", "std_1", "sres_122_0"),
        ("mat_234", "std_0", "sres_234_0"),
        ("mat_234", "std_1", "sres_234_0"),
        ("mat_23112", "std_0", "sres_23112_0"),
        ("mat_23112", "std_1", "sres_23112_0"),
    ],
)
def test_grad_im_scale_0(mat, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    tmp = torch.round(grad_scale(mat, std), decimals=4)
    assert torch.allclose(tmp, gtruth, rtol=RTOL, atol=ATOL)
    assert tmp.dtype == gtruth.dtype


@pytest.mark.parametrize(
    "mat, std, gtruth",
    [
        ("mat_234", "std_2", "sres_234_2"),
        ("mat_23112", "std_231", "sres_23112_231"),
    ],
)
def test_grad_im_scale_2(mat, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    tmp = torch.round(grad_scale(mat, std), decimals=4)
    assert torch.allclose(tmp, gtruth, rtol=RTOL, atol=ATOL)
    assert tmp.dtype == gtruth.dtype


@pytest.mark.parametrize(
    "mat, std, gtruth, dtype",
    [
        ("mat_22", "std_0", "sres_22_0", torch.float16),
        ("mat_23112", "std_231", "sres_23112_231", torch.float64),
    ],
)
def test_grad_im_scale_typing(mat, std, gtruth, dtype, request):

    mat = request.getfixturevalue(mat)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    tmp_data = mat.to(dtype)
    tmp_sres = gtruth.to(dtype)
    tmp = torch.round(grad_scale(tmp_data, std), decimals=4)
    assert torch.allclose(tmp, tmp_sres, rtol=1e-2, atol=1e-2)
    assert tmp.dtype == dtype


@pytest.mark.parametrize(
    "mat, std, err",
    [
        (np.array([1, 2]), "std_0", TypeError),
        (torch.tensor([1, 2], dtype=torch.int32), "std_0", ValueError),
        ("mat_122", "std_231", ValueError),
    ],
)
def test_grad_im_scale_errorhandle(mat, std, err, request):

    if isinstance(mat, str):
        mat = request.getfixturevalue(mat)
    std = request.getfixturevalue(std)

    try:
        grad_scale(mat, std)
        assert False
    except err:
        assert True


@pytest.mark.parametrize(
    "mat, std, gtruth",
    [("mat_22", "std_0", "sres_22_0"), ("mat_23112", "std_231", "sres_23112_231")],
)
def test_grad_im_scale_autograd(mat, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    vmat = torch.autograd.Variable(mat, requires_grad=True)
    tmp = grad_scale(vmat, std)
    torch.sum(tmp).backward()
    tmp = torch.round(tmp, decimals=4)
    gmat = vmat.grad

    assert torch.allclose(gmat[..., 0, 0], std, rtol=RTOL, atol=ATOL)
    assert gtruth.dtype == torch.float32

    assert torch.allclose(tmp, gtruth, rtol=RTOL, atol=ATOL)
    assert tmp.dtype == gtruth.dtype


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.parametrize(
    "mat, std, gtruth",
    [("mat_22", "std_0", "sres_22_0"), ("mat_23112", "std_231", "sres_23112_231")],
)
def test_grad_im_scale_cuda(mat, std, gtruth, request):

    mat = request.getfixturevalue(mat)
    std = request.getfixturevalue(std)
    gtruth = request.getfixturevalue(gtruth)

    c_mat = mat.cuda()
    c_std = std.to(device=c_mat.device)
    c_gtruth = gtruth.to(device=c_mat.device)

    c_result = torch.round(grad_scale(c_mat, c_std), decimals=4)
    assert torch.allclose(c_result, c_gtruth, rtol=RTOL, atol=ATOL)
    assert c_result.dtype == c_gtruth.dtype


# grad_resize tests
@pytest.mark.parametrize(
    "mat, in_shape, reshape, out_shape",
    [
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), (15, 15), (10, 3, 15, 15)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), (35, 50), (10, 3, 35, 50)),
        (torch.linspace(0, 100, 1 * 3 * 20 * 25, dtype=torch.float32), (1, 3, 20, 25), (15, 15), (1, 3, 15, 15)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float32), (10, 3, 20, 25), (35, 50), (10, 3, 35, 50)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float32), (10, 3, 20, 25), (15, 15), (10, 3, 15, 15)),
        (torch.linspace(0, 100, 1 * 3 * 20 * 25, dtype=torch.float64), (1, 3, 20, 25), (35, 50), (1, 3, 35, 50)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), 40, (10, 3, 40, 50)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), 50, (10, 3, 50, 62)),
        (torch.linspace(0, 100, 1 * 3 * 20 * 20, dtype=torch.float32), (1, 3, 20, 20), 15, (1, 3, 15, 15)),
        (torch.linspace(0, 100, 10 * 3 * 25 * 25, dtype=torch.float32), (10, 3, 25, 25), 50, (10, 3, 50, 50)),
        (torch.linspace(0, 100, 1 * 3 * 20 * 25, dtype=torch.float64), (1, 3, 20, 25), 15, (1, 3, 15, 19)),
    ],
)
def test_grad_im_resize_4dims(mat, in_shape, reshape, out_shape):

    mat = torch.reshape(mat, in_shape)
    n_res = grad_resize(mat, reshape, True, "nearest")
    b_res = grad_resize(mat, reshape, True, "bilinear")
    assert n_res.shape == torch.Size(out_shape)
    assert n_res.dtype == mat.dtype
    assert b_res.shape == torch.Size(out_shape)
    assert b_res.dtype == mat.dtype


@pytest.mark.parametrize(
    "mat, in_shape, reshape, out_shape",
    [
        (torch.linspace(0, 100, 1 * 20 * 25, dtype=torch.float16), (1, 20, 25), (15, 15), (1, 15, 15)),
        (torch.linspace(0, 100, 8 * 20 * 25, dtype=torch.float16), (8, 20, 25), (35, 50), (8, 35, 50)),
        (torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float32), (5, 20, 25), (15, 15), (5, 15, 15)),
        (torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float32), (5, 20, 25), (35, 50), (5, 35, 50)),
        (torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float32), (5, 20, 25), (15, 15), (5, 15, 15)),
        (torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float64), (5, 20, 25), (35, 50), (5, 35, 50)),
        (torch.linspace(0, 100, 1 * 20 * 25, dtype=torch.float16), (1, 20, 25), 40, (1, 40, 50)),
        (torch.linspace(0, 100, 8 * 20 * 25, dtype=torch.float16), (8, 20, 25), 50, (8, 50, 62)),
        (torch.linspace(0, 100, 5 * 20 * 20, dtype=torch.float32), (5, 20, 20), 15, (5, 15, 15)),
        (torch.linspace(0, 100, 5 * 25 * 25, dtype=torch.float32), (5, 25, 25), 50, (5, 50, 50)),
        (torch.linspace(0, 100, 7 * 20 * 25, dtype=torch.float64), (7, 20, 25), 15, (7, 15, 19)),
    ],
)
def test_grad_im_resize_3dims(mat, in_shape, reshape, out_shape):

    mat = torch.reshape(mat, in_shape)
    n_res = grad_resize(mat, reshape, True, "nearest")
    b_res = grad_resize(mat, reshape, True, "bilinear")
    assert n_res.shape == torch.Size(out_shape)
    assert n_res.dtype == mat.dtype
    assert b_res.shape == torch.Size(out_shape)
    assert b_res.dtype == mat.dtype


@pytest.mark.parametrize(
    "mat, in_shape, reshape, out_shape",
    [
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float16), (20, 25), (15, 15), (15, 15)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float16), (20, 25), (35, 50), (35, 50)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float32), (20, 25), (15, 15), (15, 15)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float32), (20, 25), (35, 50), (35, 50)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float32), (20, 25), (15, 15), (15, 15)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float64), (20, 25), (35, 50), (35, 50)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float16), (20, 25), 40, (40, 50)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float16), (20, 25), 50, (50, 62)),
        (torch.linspace(0, 100, 20 * 20, dtype=torch.float32), (20, 20), 15, (15, 15)),
        (torch.linspace(0, 100, 25 * 25, dtype=torch.float32), (25, 25), 50, (50, 50)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float64), (20, 25), 15, (15, 19)),
    ],
)
def test_grad_im_resize_2dims(mat, in_shape, reshape, out_shape):

    mat = torch.reshape(mat, in_shape)
    n_res = grad_resize(mat, reshape, True, "nearest")
    b_res = grad_resize(mat, reshape, True, "bilinear")
    assert n_res.shape == torch.Size(out_shape)
    assert n_res.dtype == mat.dtype
    assert b_res.shape == torch.Size(out_shape)
    assert b_res.dtype == mat.dtype


@pytest.mark.parametrize(
    "mat, in_shape, reshape, err",
    [
        (np.linspace(0, 100, 6 * 5 * 4, dtype=np.float32), (6, 5, 4), (5, 5), TypeError),
        (torch.linspace(0, 100, 6 * 5 * 4 * 3 * 2, dtype=torch.float32), (6, 5, 4, 3, 2), (5, 5), ValueError),
        (torch.linspace(0, 100, 5, dtype=torch.float32), (5,), (5, 5), ValueError),
        (torch.linspace(0, 100, 6 * 5 * 4 * 3, dtype=torch.float32), (6, 5, 4, 3), "a", TypeError),
        (torch.linspace(0, 100, 6 * 5 * 4 * 3, dtype=torch.float32), (6, 5, 4, 3), (5, 5.3), TypeError),
        (torch.linspace(0, 100, 6 * 5 * 4 * 3, dtype=torch.float64), (6, 5, 4, 3), (5, 5, 4), ValueError),
    ],
)
def test_grad_im_resize_errorhandle(mat, in_shape, reshape, err):

    if isinstance(mat, torch.Tensor):
        mat = torch.reshape(mat, in_shape)
    elif isinstance(mat, np.ndarray):
        mat = np.reshape(mat, in_shape)

    try:
        grad_resize(mat, reshape, False, "nearest")
        assert False
    except err:
        assert True


@pytest.mark.parametrize(
    "mat, in_shape, reshape",
    [
        (torch.linspace(0, 100, 1 * 3 * 20 * 25, dtype=torch.float16), (1, 3, 20, 25), (15, 15)),
        (torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float32), (5, 20, 25), (15, 15)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float64), (20, 25), (15, 15)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), 40),
        (torch.linspace(0, 100, 5 * 20 * 20, dtype=torch.float32), (5, 20, 20), 15),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float64), (20, 25), 15),
    ],
)
def test_grad_im_resize_autograd(mat, in_shape, reshape):

    mat = torch.reshape(mat, in_shape)
    n_mat = torch.autograd.Variable(mat, requires_grad=True)
    b_mat = torch.autograd.Variable(mat, requires_grad=True)
    n_res = torch.sum(2 * grad_resize(n_mat, reshape, True, "nearest"))
    b_res = torch.sum(2 * grad_resize(b_mat, reshape, True, "bilinear"))

    n_res.backward()
    b_res.backward()

    assert torch.any(n_mat.grad > 0)
    assert n_mat.grad.shape == torch.Size(in_shape)
    assert n_mat.grad.dtype == mat.dtype

    assert torch.any(b_mat.grad > 0)
    assert b_mat.grad.shape == torch.Size(in_shape)
    assert b_mat.grad.dtype == mat.dtype


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.parametrize(
    "mat, in_shape, reshape, out_shape",
    [
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), (15, 15), (10, 3, 15, 15)),
        (torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float32), (5, 20, 25), (15, 15), (5, 15, 15)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float32), (20, 25), (15, 15), (15, 15)),
        (torch.linspace(0, 100, 10 * 3 * 20 * 25, dtype=torch.float16), (10, 3, 20, 25), 40, (10, 3, 40, 50)),
        (torch.linspace(0, 100, 5 * 20 * 20, dtype=torch.float32), (5, 20, 20), 15, (5, 15, 15)),
        (torch.linspace(0, 100, 20 * 25, dtype=torch.float64), (20, 25), 15, (15, 19)),
    ],
)
def test_grad_im_resize_cuda(mat, in_shape, reshape, out_shape):

    mat = torch.reshape(mat, in_shape).cuda()
    n_res = grad_resize(mat, reshape, True, "nearest")
    b_res = grad_resize(mat, reshape, True, "bilinear")
    assert n_res.shape == torch.Size(out_shape)
    assert n_res.dtype == mat.dtype
    assert n_res.device == mat.device
    assert b_res.shape == torch.Size(out_shape)
    assert b_res.dtype == mat.dtype
    assert b_res.device == mat.device


# grad_centercrop tests
@pytest.mark.parametrize(
    "mat, size",
    [
        ("mat_22", (3, 1)),
        ("mat_22", (3, 2)),
        ("mat_22", (1, 3)),
        ("mat_22", (1, 2)),
        ("mat_22", 3),
        ("mat_22", (4, 4)),
        ("mat_22", (1, 1)),
        ("mat_234", (4, 2)),
        ("mat_234", (4, 4)),
        ("mat_234", (1, 5)),
        ("mat_234", (3, 5)),
        ("mat_234", (5, 5)),
        ("mat_234", 6),
        ("mat_234", (2, 3)),
        ("mat_23112", (6, 6)),
        ("mat_23112", (2, 3)),
    ],
)
def test_grad_centercrop_crop(mat, size, request):

    mat = request.getfixturevalue(mat)
    gtruth = CenterCrop(size)(mat)

    tmp = grad_centercrop(mat, size)
    assert torch.allclose(tmp, gtruth, rtol=RTOL, atol=ATOL)
    assert tmp.dtype == gtruth.dtype


@pytest.mark.parametrize(
    "mat, size, dtype",
    [
        ("mat_22", (4, 4), torch.float16),
        ("mat_22", (1, 1), torch.int),
        ("mat_234", (4, 2), torch.int),
        ("mat_234", (4, 4), torch.float64),
        ("mat_23112", (6, 6), torch.bfloat16),
        ("mat_23112", (2, 3), torch.int),
    ],
)
def test_grad_centercrop_typing(mat, size, dtype, request):

    mat = request.getfixturevalue(mat)
    mat = mat.to(dtype=dtype)
    gtruth = CenterCrop(size)(mat)

    tmp = grad_centercrop(mat, size)
    assert torch.allclose(tmp, gtruth, rtol=RTOL, atol=ATOL)
    assert tmp.dtype == gtruth.dtype
    assert tmp.dtype == dtype


@pytest.mark.parametrize(
    "mat, size, err",
    [
        ("mat_22", (1, 1, 3), ValueError),
        ("mat_234", (4, "a", 2), TypeError),
        ("mat_234", (4, -1), ValueError),
        ("mat_23112", -1, ValueError),
        ("mat_23112", "b", TypeError),
    ],
)
def test_grad_centercrop_errorhandle(mat, size, err, request):

    mat = request.getfixturevalue(mat)

    try:
        grad_centercrop(mat, size)
        assert False
    except err:
        assert True


@pytest.mark.parametrize(
    "mat, size, num_used",
    [
        ("mat_22", (3, 1), 2),
        ("mat_22", (3, 2), 4),
        ("mat_22", (1, 3), 2),
        ("mat_22", (1, 2), 2),
        ("mat_22", 3, 4),
        ("mat_22", (4, 4), 4),
        ("mat_22", (1, 1), 1),
        ("mat_234", (4, 2), 12),
        ("mat_234", (4, 4), 24),
        ("mat_234", (1, 5), 8),
        ("mat_234", (3, 5), 24),
        ("mat_234", (5, 5), 24),
        ("mat_234", 6, 24),
        ("mat_234", (2, 3), 12),
        ("mat_23112", (6, 6), 12),
        ("mat_23112", (2, 3), 12),
    ],
)
def test_grad_im_normalize_autograd(mat, size, num_used, request):

    mat = request.getfixturevalue(mat)

    vmat = torch.autograd.Variable(mat, requires_grad=True)

    torch.sum(2 * grad_centercrop(vmat, size)).backward()

    assert torch.sum(vmat.grad == 2) == num_used
    assert vmat.grad.dtype == mat.dtype


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.parametrize(
    "mat, size",
    [
        ("mat_22", (3, 1)),
        ("mat_22", (3, 2)),
        ("mat_22", (1, 3)),
        ("mat_22", (1, 2)),
        ("mat_22", 3),
        ("mat_22", (4, 4)),
        ("mat_22", (1, 1)),
        ("mat_234", (4, 2)),
        ("mat_234", (4, 4)),
        ("mat_234", (1, 5)),
        ("mat_234", (3, 5)),
        ("mat_234", (5, 5)),
        ("mat_234", 6),
        ("mat_234", (2, 3)),
        ("mat_23112", (6, 6)),
        ("mat_23112", (2, 3)),
    ],
)
def test_grad_centercrop_cuda(mat, size, request):

    mat = request.getfixturevalue(mat).cuda()
    gtruth = CenterCrop(size)(mat).cuda()

    tmp = grad_centercrop(mat, size)
    assert torch.allclose(tmp, gtruth, rtol=RTOL, atol=ATOL)
    assert tmp.dtype == gtruth.dtype


#######################################################################################################################
