# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

_USING_PARROTS = True
try:
    from parrots.autograd import gradcheck
except ImportError:
    from torch.autograd import gradcheck
    _USING_PARROTS = False

# yapf:disable
inputs = [([[[[1., 2.], [3., 4.]]]],
           [[0., 0.5, 0.5, 1., 1., 0]]),
          ([[[[1., 2.], [3., 4.]]]],
           [[0., 0.5, 0.5, 1., 1., np.pi / 2]]),
          ([[[[1., 2.], [3., 4.]],
             [[4., 3.], [2., 1.]]]],
           [[0., 0.5, 0.5, 1., 1., 0]]),
          ([[[[1., 2., 5., 6.], [3., 4., 7., 8.],
              [9., 10., 13., 14.], [11., 12., 15., 16.]]]],
           [[0., 1.5, 1.5, 3., 3., 0]]),
          ([[[[1., 2., 5., 6.], [3., 4., 7., 8.],
              [9., 10., 13., 14.], [11., 12., 15., 16.]]]],
           [[0., 1.5, 1.5, 3., 3., np.pi / 2]])]
outputs = [([[[[1.0, 1.25], [1.5, 1.75]]]],
            [[[[3.0625, 0.4375], [0.4375, 0.0625]]]]),
           ([[[[1.5, 1], [1.75, 1.25]]]],
            [[[[3.0625, 0.4375], [0.4375, 0.0625]]]]),
           ([[[[1.0, 1.25], [1.5, 1.75]],
              [[4.0, 3.75], [3.5, 3.25]]]],
            [[[[3.0625, 0.4375], [0.4375, 0.0625]],
              [[3.0625, 0.4375], [0.4375, 0.0625]]]]),
           ([[[[1.9375, 4.75], [7.5625, 10.375]]]],
            [[[[0.47265625, 0.42968750, 0.42968750, 0.04296875],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.04296875, 0.03906250, 0.03906250, 0.00390625]]]]),
           ([[[[7.5625, 1.9375], [10.375, 4.75]]]],
            [[[[0.47265625, 0.42968750, 0.42968750, 0.04296875],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.42968750, 0.39062500, 0.39062500, 0.03906250],
               [0.04296875, 0.03906250, 0.03906250, 0.00390625]]]])]
# yapf:enable

pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2


def _test_roialign_rotated_gradcheck(device, dtype):
    try:
        from mmcv.ops import RoIAlignRotated
    except ModuleNotFoundError:
        pytest.skip('RoIAlignRotated op is not successfully compiled')
    if dtype is torch.half:
        pytest.skip('grad check does not support fp16')
    for case in inputs:
        np_input = np.array(case[0])
        np_rois = np.array(case[1])

        x = torch.tensor(
            np_input, dtype=dtype, device=device, requires_grad=True)
        rois = torch.tensor(np_rois, dtype=dtype, device=device)

        froipool = RoIAlignRotated((pool_h, pool_w), spatial_scale,
                                   sampling_ratio)
        if torch.__version__ == 'parrots':
            gradcheck(
                froipool, (x, rois), no_grads=[rois], delta=1e-5, pt_atol=1e-5)
        else:
            gradcheck(froipool, (x, rois), eps=1e-5, atol=1e-5)


def _test_roialign_rotated_allclose(device, dtype):
    try:
        from mmcv.ops import RoIAlignRotated, roi_align_rotated
    except ModuleNotFoundError:
        pytest.skip('test requires compilation')
    pool_h = 2
    pool_w = 2
    spatial_scale = 1.0
    sampling_ratio = 2

    for case, output in zip(inputs, outputs):
        np_input = np.array(case[0])
        np_rois = np.array(case[1])
        np_output = np.array(output[0])
        np_grad = np.array(output[1])

        x = torch.tensor(
            np_input, dtype=dtype, device=device, requires_grad=True)
        rois = torch.tensor(np_rois, dtype=dtype, device=device)

        output = roi_align_rotated(x, rois, (pool_h, pool_w), spatial_scale,
                                   sampling_ratio, True)
        output.backward(torch.ones_like(output))
        assert np.allclose(
            output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3)
        assert np.allclose(
            x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)

    # Test deprecated parameters
    roi_align_rotated_module_deprecated = RoIAlignRotated(
        out_size=(pool_h, pool_w),
        spatial_scale=spatial_scale,
        sample_num=sampling_ratio)

    output_1 = roi_align_rotated_module_deprecated(x, rois)

    roi_align_rotated_module_new = RoIAlignRotated(
        output_size=(pool_h, pool_w),
        spatial_scale=spatial_scale,
        sampling_ratio=sampling_ratio)

    output_2 = roi_align_rotated_module_new(x, rois)

    assert np.allclose(
        output_1.data.type(torch.float).cpu().numpy(),
        output_2.data.type(torch.float).cpu().numpy())


@pytest.mark.parametrize('device', [
    'cpu',
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'mlu',
        marks=pytest.mark.skipif(
            not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
    torch.float,
    pytest.param(
        torch.double,
        marks=pytest.mark.skipif(
            IS_MLU_AVAILABLE,
            reason='MLU does not support for 64-bit floating point')),
    torch.half
])
def test_roialign_rotated(device, dtype):
    # check double only
    if dtype is torch.double:
        _test_roialign_rotated_gradcheck(device=device, dtype=dtype)
    _test_roialign_rotated_allclose(device=device, dtype=dtype)
