from inspect import getmembers
import as_models

models_has_no_pretrained_files = [
    'EfficientNet', 'HRNet_W60_C', 'MobileNetV1_x1_0', 'MobileNetV2_x1_0',
    'MobileNetV3_large_x0_25', 'MobileNetV3_large_x0_5', 'MobileNetV3_large_x0_75', 'MobileNetV3_large_x1_0', 'MobileNetV3_large_x1_25',
    'MobileNetV3_small_x0_25', 'MobileNetV3_small_x0_5', 'MobileNetV3_small_x0_75', 'MobileNetV3_small_x1_25',
    'Res2Net101_26w_4s', 'Res2Net152_26w_4s', 'Res2Net152_vd_26w_4s', 'Res2Net200_vd_26w_4s',
    'Res2Net50_26w_6s', 'Res2Net50_26w_8s', 'Res2Net50_48w_2s',
    'Res2Net50_vd_14w_8s', 'Res2Net50_vd_26w_6s', 'Res2Net50_vd_26w_8s', 'Res2Net50_vd_48w_2s',
    'ResNeXt101_64x4d', 'ResNet101_ACNet', 'ResNet101_vc', 'ResNet152_ACNet', 'ResNet152_vc',
    'ResNet18_ACNet', 'ResNet34_ACNet',
    'SE_HRNet_W18_C', 'SE_HRNet_W30_C', 'SE_HRNet_W32_C', 'SE_HRNet_W40_C', 'SE_HRNet_W44_C', 'SE_HRNet_W48_C',
    'SE_HRNet_W60_C', 'SE_HRNet_W64_C', 'SE_ResNeXt101_vd_32x4d', 'SE_ResNeXt152_32x4d', 'SE_ResNet101_vd', 'SE_ResNet152_vd', 'SE_ResNet200_vd',
    'ShuffleNetV2_x0_5_swish', 'ShuffleNetV2_x1_0_swish', 'ShuffleNetV2_x1_5_swish', 'ShuffleNetV2_x2_0_swish',
    'Xception71_deeplab']


def model_choice_all():
    model_dict = {}
    for f, model in getmembers(as_models):
        if f == 'DARTS_4M':  # name in as_models
            f = 'AutoDL_4M'  # name in downloaded trained model file.
        if f == 'DARTS_6M':
            f = 'AutoDL_6M'
        if f == 'DarkNet53':
            f = 'DarkNet53_ImageNet1k'
        if f == 'GoogLeNet':
            f = 'GoogleNet'
        if f == 'ShuffleNetV2':
            f = 'ShuffleNetV2_x1_0'
        if f == 'ShuffleNetV2_swish':
            f = 'ShuffleNetV2'

        if f.endswith('__') or f.lower() == f:
            continue

        if f in models_has_no_pretrained_files:
            continue

        if f == 'MobileNetV3_small_x1_0':
            # some problem, very low accuracy.
            continue

        model_dict[f] = model

    return model_dict


def model_choice_best_in_each_structure():
    """
    the best one in each structure
    Returns:

    """
    model_dict = model_choice_all()

    choices = [
        'AlexNet', 'AutoDL_6M', 'DPN131', 'DarkNet53_ImageNet1k', 'DenseNet264', 'EfficientNetB7',
        'Fix_ResNeXt101_32x48d_wsl', 'GoogleNet', 'HRNet_W64_C', 'InceptionV4',
        'MobileNetV1', 'MobileNetV2',
        'Res2Net101_vd_26w_4s',
        'ResNeXt152_64x4d', 'ResNeXt152_vd_64x4d',
        'ResNet152', 'ResNet200_vd',
        'SENet154_vd', 'SE_ResNeXt101_32x4d',
        'ShuffleNetV2_x1_0',
        'SqueezeNet1_1',
        'VGG19',
        'Xception71'
    ]

    # print(len(choices))  # 23
    model_dict = {k: model_dict[k] for k in choices}

    return model_dict


def model_choice_res_family():
    """
    the Res- family.
    Returns:

    """
    model_dict = model_choice_all()

    # model_names = model_dict.keys()
    model_dict = {k: model_dict[k] for k in model_dict.keys() if 'res' in k.lower()}

    print(len(model_dict))  # 39

    return model_dict


def model_choice_efficientnets():
    """
    the Efficient-Net family.
    Returns:

    """
    model_dict = model_choice_all()

    # model_names = model_dict.keys()
    model_dict = {k: model_dict[k] for k in model_dict.keys() if 'efficientnet' in k.lower()}

    print(len(model_dict))  # 9

    return model_dict


def model_choice_input224():
    model_dict = model_choice_all()

    models_use_different_input_size = [
        "InceptionV4", "Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab", "Xception71",
        "DarkNet53_ImageNet1k", "Fix_ResNeXt101_32x48d_wsl",
        "EfficientNetB1", "EfficientNetB2", "EfficientNetB3",
        "EfficientNetB4", "EfficientNetB5", "EfficientNetB6", "EfficientNetB7"
    ]

    model_dict = {k: model_dict[k] for k in model_dict.keys() if k not in models_use_different_input_size}

    return model_dict


def model_choice_3():
    """
    models that can receive different size as input.
    Returns:

    """
    model_dict = model_choice_all()

    # model_names = model_dict.keys()
    # model_dict = {k: model_dict[k] for k in model_dict.keys() if 'res' in k.lower()}

    # print(len(model_dict))  # 74

    raise NotImplementedError("")

    return model_dict

# all_models = all_models()
# choice_1 = choice_1()


# if __name__ == '__main__':
#     model_dict = model_choice_1()
#     for k in model_dict.keys():
#         print(k, model_dict[k])
