# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn

from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
                             Linear, MaxPool2d, MaxPool3d)

if torch.__version__ != 'parrots':
    torch_version = '1.1'
else:
    torch_version = 'parrots'


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
    'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
    [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
                padding, dilation):
    """
    CommandLine:
        xdoctest -m tests/test_wrappers.py test_conv2d
    """
    # train mode
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_channel, in_h, in_w)
    torch.manual_seed(0)
    wrapper = Conv2d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation)
    wrapper_out = wrapper(x_empty)

    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True)
    torch.manual_seed(0)
    ref = nn.Conv2d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation)
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    wrapper_out.sum().backward()
    assert wrapper.weight.grad is not None
    assert wrapper.weight.grad.shape == wrapper.weight.shape

    assert torch.equal(wrapper(x_normal), ref_out)

    # eval mode
    x_empty = torch.randn(0, in_channel, in_h, in_w)
    wrapper = Conv2d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation)
    wrapper.eval()
    wrapper(x_empty)


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
    'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation',  # noqa: E501
    [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
                padding, dilation):
    """
    CommandLine:
        xdoctest -m tests/test_wrappers.py test_conv3d
    """
    # train mode
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
    torch.manual_seed(0)
    wrapper = Conv3d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation)
    wrapper_out = wrapper(x_empty)

    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_channel, in_t, in_h,
                           in_w).requires_grad_(True)
    torch.manual_seed(0)
    ref = nn.Conv3d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation)
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    wrapper_out.sum().backward()
    assert wrapper.weight.grad is not None
    assert wrapper.weight.grad.shape == wrapper.weight.shape

    assert torch.equal(wrapper(x_normal), ref_out)

    # eval mode
    x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
    wrapper = Conv3d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation)
    wrapper.eval()
    wrapper(x_empty)


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
    'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
    [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
                            stride, padding, dilation):
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
    # out padding must be smaller than either stride or dilation
    op = min(stride, dilation) - 1
    if torch.__version__ == 'parrots':
        op = 0
    torch.manual_seed(0)
    wrapper = ConvTranspose2d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=op)
    wrapper_out = wrapper(x_empty)

    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_channel, in_h, in_w)
    torch.manual_seed(0)
    ref = nn.ConvTranspose2d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=op)
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    wrapper_out.sum().backward()
    assert wrapper.weight.grad is not None
    assert wrapper.weight.grad.shape == wrapper.weight.shape

    assert torch.equal(wrapper(x_normal), ref_out)

    # eval mode
    x_empty = torch.randn(0, in_channel, in_h, in_w)
    wrapper = ConvTranspose2d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=op)
    wrapper.eval()
    wrapper(x_empty)


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
    'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation',  # noqa: E501
    [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
                            kernel_size, stride, padding, dilation):
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
    # out padding must be smaller than either stride or dilation
    op = min(stride, dilation) - 1
    torch.manual_seed(0)
    wrapper = ConvTranspose3d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=op)
    wrapper_out = wrapper(x_empty)

    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
    torch.manual_seed(0)
    ref = nn.ConvTranspose3d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=op)
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    wrapper_out.sum().backward()
    assert wrapper.weight.grad is not None
    assert wrapper.weight.grad.shape == wrapper.weight.shape

    assert torch.equal(wrapper(x_normal), ref_out)

    # eval mode
    x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
    wrapper = ConvTranspose3d(
        in_channel,
        out_channel,
        kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=op)
    wrapper.eval()
    wrapper(x_empty)


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
    'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
    [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
                     padding, dilation):
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
    wrapper = MaxPool2d(
        kernel_size, stride=stride, padding=padding, dilation=dilation)
    wrapper_out = wrapper(x_empty)

    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_channel, in_h, in_w)
    ref = nn.MaxPool2d(
        kernel_size, stride=stride, padding=padding, dilation=dilation)
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    assert torch.equal(wrapper(x_normal), ref_out)


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
    'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation',  # noqa: E501
    [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
@pytest.mark.skipif(
    torch.__version__ == 'parrots' and not torch.cuda.is_available(),
    reason='parrots requires CUDA support')
def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
                     stride, padding, dilation):
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
    wrapper = MaxPool3d(
        kernel_size, stride=stride, padding=padding, dilation=dilation)
    if torch.__version__ == 'parrots':
        x_empty = x_empty.cuda()
    wrapper_out = wrapper(x_empty)
    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
    ref = nn.MaxPool3d(
        kernel_size, stride=stride, padding=padding, dilation=dilation)
    if torch.__version__ == 'parrots':
        x_normal = x_normal.cuda()
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    assert torch.equal(wrapper(x_normal), ref_out)


@patch('torch.__version__', torch_version)
@pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1),
                                                              (20, 20, 3, 3)])
def test_linear(in_w, in_h, in_feature, out_feature):
    # wrapper op with 0-dim input
    x_empty = torch.randn(0, in_feature, requires_grad=True)
    torch.manual_seed(0)
    wrapper = Linear(in_feature, out_feature)
    wrapper_out = wrapper(x_empty)

    # torch op with 3-dim input as shape reference
    x_normal = torch.randn(3, in_feature)
    torch.manual_seed(0)
    ref = nn.Linear(in_feature, out_feature)
    ref_out = ref(x_normal)

    assert wrapper_out.shape[0] == 0
    assert wrapper_out.shape[1:] == ref_out.shape[1:]

    wrapper_out.sum().backward()
    assert wrapper.weight.grad is not None
    assert wrapper.weight.grad.shape == wrapper.weight.shape

    assert torch.equal(wrapper(x_normal), ref_out)

    # eval mode
    x_empty = torch.randn(0, in_feature)
    wrapper = Linear(in_feature, out_feature)
    wrapper.eval()
    wrapper(x_empty)


@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10))
def test_nn_op_forward_called():

    for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
        with patch(f'torch.nn.{m}.forward') as nn_module_forward:
            # randn input
            x_empty = torch.randn(0, 3, 10, 10)
            wrapper = eval(m)(3, 2, 1)
            wrapper(x_empty)
            nn_module_forward.assert_called_with(x_empty)

            # non-randn input
            x_normal = torch.randn(1, 3, 10, 10)
            wrapper = eval(m)(3, 2, 1)
            wrapper(x_normal)
            nn_module_forward.assert_called_with(x_normal)

    for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']:
        with patch(f'torch.nn.{m}.forward') as nn_module_forward:
            # randn input
            x_empty = torch.randn(0, 3, 10, 10, 10)
            wrapper = eval(m)(3, 2, 1)
            wrapper(x_empty)
            nn_module_forward.assert_called_with(x_empty)

            # non-randn input
            x_normal = torch.randn(1, 3, 10, 10, 10)
            wrapper = eval(m)(3, 2, 1)
            wrapper(x_normal)
            nn_module_forward.assert_called_with(x_normal)

    with patch('torch.nn.Linear.forward') as nn_module_forward:
        # randn input
        x_empty = torch.randn(0, 3)
        wrapper = Linear(3, 3)
        wrapper(x_empty)
        nn_module_forward.assert_called_with(x_empty)

        # non-randn input
        x_normal = torch.randn(1, 3)
        wrapper = Linear(3, 3)
        wrapper(x_normal)
        nn_module_forward.assert_called_with(x_normal)
