# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

# Copied from
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/

from __future__ import absolute_import, division, print_function
import math  # noqa
import time

import torch
import torch.nn as nn  # noqa
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from torch.autograd import gradcheck  # noqa

H_in, W_in = 8, 8
N, M, D = 2, 4, 16
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

torch.manual_seed(3)


@torch.no_grad()
def check_forward_equal_with_pytorch_double():
    input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
    offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
    mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask /= mask.sum(-1, keepdim=True)
    mask = mask.reshape(N, H_out, W_out, M * P)

    output_pytorch = dcnv3_core_pytorch(input.double(), offset.double(),
                                        mask.double(), Kh, Kw, stride, stride,
                                        Kh // 2, Kw // 2, dilation, dilation,
                                        M, D, offset_scale,
                                        remove_center).detach().cpu()

    im2col_step = 2
    output_cuda = DCNv3Function.apply(input.double(), offset.double(),
                                      mask.double(), Kh, Kw, stride, stride,
                                      Kh // 2, Kw // 2, dilation, dilation, M,
                                      D, offset_scale, im2col_step,
                                      remove_center).detach().cpu()

    fwdok = torch.allclose(output_cuda, output_pytorch)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    print('>>> forward double')
    print(f'* {fwdok} check_forward_equal_with_pytorch_double:'
          f' max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


@torch.no_grad()
def check_forward_equal_with_pytorch_float():
    input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
    offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
    mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask /= mask.sum(-1, keepdim=True)
    mask = mask.reshape(N, H_out, W_out, M * P)

    output_pytorch = dcnv3_core_pytorch(input, offset, mask, Kh, Kw, stride,
                                        stride, Kh // 2, Kw // 2, dilation,
                                        dilation, M, D, offset_scale,
                                        remove_center).detach().cpu()

    im2col_step = 2
    output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride,
                                      stride, Kh // 2, Kw // 2, dilation,
                                      dilation, M, D, offset_scale,
                                      im2col_step,
                                      remove_center).detach().cpu()

    fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    print('>>> forward float')
    print(f'* {fwdok} check_forward_equal_with_pytorch_float:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


def check_backward_equal_with_pytorch_double(channels=4,
                                             grad_input=True,
                                             grad_offset=True,
                                             grad_mask=True):
    # H_in, W_in = 4, 4
    N = 2
    M = 2
    H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
    W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

    D = channels
    input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
    offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
    mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask0 /= mask0.sum(-1, keepdim=True)
    mask0 = mask0.reshape(N, H_out, W_out, M * P)
    input0.requires_grad = grad_input
    offset0.requires_grad = grad_offset
    mask0.requires_grad = grad_mask

    output_pytorch = dcnv3_core_pytorch(input0.double(), offset0.double(),
                                        mask0.double(), Kh, Kw, stride, stride,
                                        Kh // 2, Kw // 2, dilation, dilation,
                                        M, D, offset_scale, remove_center)
    output_pytorch.sum().backward()

    input1 = input0.detach()
    offset1 = offset0.detach()
    mask1 = mask0.detach()
    input1.requires_grad = grad_input
    offset1.requires_grad = grad_offset
    mask1.requires_grad = grad_mask

    im2col_step = 2
    output_cuda = DCNv3Function.apply(input1.double(), offset1.double(),
                                      mask1.double(), Kh, Kw, stride, stride,
                                      Kh // 2, Kw // 2, dilation, dilation, M,
                                      D, offset_scale, im2col_step,
                                      remove_center)
    output_cuda.sum().backward()

    print(f'>>> backward double: channels {D}')
    bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (input0.grad - input1.grad).abs().max()
    max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max()
    print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_double:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (offset0.grad - offset1.grad).abs().max()
    max_rel_err = ((offset0.grad - offset1.grad).abs() /
                   offset0.grad.abs()).max()
    print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (mask0.grad - mask1.grad).abs().max()
    max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max()
    print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


def check_backward_equal_with_pytorch_float(channels=4,
                                            grad_input=True,
                                            grad_offset=True,
                                            grad_mask=True):
    # H_in, W_in = 4, 4
    N = 2
    M = 2
    H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
    W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

    D = channels
    input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
    offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
    mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask0 /= mask0.sum(-1, keepdim=True)
    mask0 = mask0.reshape(N, H_out, W_out, M * P)
    input0.requires_grad = grad_input
    offset0.requires_grad = grad_offset
    mask0.requires_grad = grad_mask

    output_pytorch = dcnv3_core_pytorch(input0, offset0, mask0, Kh, Kw, stride,
                                        stride, Kh // 2, Kw // 2, dilation,
                                        dilation, M, D, offset_scale,
                                        remove_center)
    output_pytorch.sum().backward()

    input1 = input0.detach()
    offset1 = offset0.detach()
    mask1 = mask0.detach()
    input1.requires_grad = grad_input
    offset1.requires_grad = grad_offset
    mask1.requires_grad = grad_mask

    im2col_step = 2
    output_cuda = DCNv3Function.apply(input1, offset1, mask1, Kh, Kw, stride,
                                      stride, Kh // 2, Kw // 2, dilation,
                                      dilation, M, D, offset_scale,
                                      im2col_step, remove_center)
    output_cuda.sum().backward()

    print(f'>>> backward float: channels {D}')
    bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (input0.grad - input1.grad).abs().max()
    max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max()
    print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_float:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (offset0.grad - offset1.grad).abs().max()
    max_rel_err = ((offset0.grad - offset1.grad).abs() /
                   offset0.grad.abs()).max()
    print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (mask0.grad - mask1.grad).abs().max()
    max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max()
    print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float:'
          f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


@torch.no_grad()
def check_time_cost(im2col_step=128):
    N = 512
    H_in, W_in = 64, 64
    H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
    W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

    input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
    offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
    mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask /= mask.sum(-1, keepdim=True)
    mask = mask.reshape(N, H_out, W_out, M * P)
    print(f'>>> time cost: im2col_step {im2col_step};'
          f'input {input.shape}; points {P} ')
    repeat = 100
    for i in range(repeat):
        output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride,
                                          stride, Kh // 2, Kw // 2, dilation,
                                          dilation, M, D, 1.0, im2col_step,
                                          remove_center)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(repeat):
        output_cuda = DCNv3Function.apply(  # noqa
            input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2,
            dilation, dilation, M, D, 1.0, im2col_step, remove_center)
    torch.cuda.synchronize()
    print(f'foward time cost: {(time.time() - start) / repeat}')


if __name__ == '__main__':
    check_forward_equal_with_pytorch_double()
    check_forward_equal_with_pytorch_float()
    for channels in [1, 16, 30, 32, 64, 71, 1025]:
        check_backward_equal_with_pytorch_double(channels, True, True, True)
    for channels in [1, 16, 30, 32, 64, 71, 1025]:
        check_backward_equal_with_pytorch_float(channels, True, True, True)
    for i in range(3):
        im2col_step = 128 * (2**i)
        check_time_cost(im2col_step)
