import os.path as osp

import mmcv
import pytest
import torch

from mmdet import digit_version
from mmdet.models.necks import FPN, YOLOV3Neck
from .utils import ort_validate

if digit_version(torch.__version__) <= digit_version('1.5.0'):
    pytest.skip(
        'ort backend does not support version below 1.5.0',
        allow_module_level=True)

# Control the returned model of fpn_neck_config()
fpn_test_step_names = {
    'fpn_normal': 0,
    'fpn_wo_extra_convs': 1,
    'fpn_lateral_bns': 2,
    'fpn_bilinear_upsample': 3,
    'fpn_scale_factor': 4,
    'fpn_extra_convs_inputs': 5,
    'fpn_extra_convs_laterals': 6,
    'fpn_extra_convs_outputs': 7,
}

# Control the returned model of yolo_neck_config()
yolo_test_step_names = {'yolo_normal': 0}

data_path = osp.join(osp.dirname(__file__), 'data')


def fpn_neck_config(test_step_name):
    """Return the class containing the corresponding attributes according to
    the fpn_test_step_names."""
    s = 64
    in_channels = [8, 16, 32, 64]
    feat_sizes = [s // 2**i for i in range(4)]  # [64, 32, 16, 8]
    out_channels = 8

    feats = [
        torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
        for i in range(len(in_channels))
    ]

    if (fpn_test_step_names[test_step_name] == 0):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs=True,
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 1):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs=False,
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 2):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs=True,
            no_norm_on_lateral=False,
            norm_cfg=dict(type='BN', requires_grad=True),
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 3):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs=True,
            upsample_cfg=dict(mode='bilinear', align_corners=True),
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 4):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs=True,
            upsample_cfg=dict(scale_factor=2),
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 5):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs='on_input',
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 6):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs='on_lateral',
            num_outs=5)
    elif (fpn_test_step_names[test_step_name] == 7):
        fpn_model = FPN(
            in_channels=in_channels,
            out_channels=out_channels,
            add_extra_convs='on_output',
            num_outs=5)
    return fpn_model, feats


def yolo_neck_config(test_step_name):
    """Config yolov3 Neck."""
    in_channels = [16, 8, 4]
    out_channels = [8, 4, 2]

    # The data of yolov3_neck.pkl contains a list of
    # torch.Tensor, where each torch.Tensor is generated by
    # torch.rand and each tensor size is:
    # (1, 4, 64, 64), (1, 8, 32, 32), (1, 16, 16, 16).
    yolov3_neck_data = 'yolov3_neck.pkl'
    feats = mmcv.load(osp.join(data_path, yolov3_neck_data))

    if (yolo_test_step_names[test_step_name] == 0):
        yolo_model = YOLOV3Neck(
            in_channels=in_channels, out_channels=out_channels, num_scales=3)
    return yolo_model, feats


def test_fpn_normal():
    outs = fpn_neck_config('fpn_normal')
    ort_validate(*outs)


def test_fpn_wo_extra_convs():
    outs = fpn_neck_config('fpn_wo_extra_convs')
    ort_validate(*outs)


def test_fpn_lateral_bns():
    outs = fpn_neck_config('fpn_lateral_bns')
    ort_validate(*outs)


def test_fpn_bilinear_upsample():
    outs = fpn_neck_config('fpn_bilinear_upsample')
    ort_validate(*outs)


def test_fpn_scale_factor():
    outs = fpn_neck_config('fpn_scale_factor')
    ort_validate(*outs)


def test_fpn_extra_convs_inputs():
    outs = fpn_neck_config('fpn_extra_convs_inputs')
    ort_validate(*outs)


def test_fpn_extra_convs_laterals():
    outs = fpn_neck_config('fpn_extra_convs_laterals')
    ort_validate(*outs)


def test_fpn_extra_convs_outputs():
    outs = fpn_neck_config('fpn_extra_convs_outputs')
    ort_validate(*outs)


def test_yolo_normal():
    outs = yolo_neck_config('yolo_normal')
    ort_validate(*outs)
