from ..block.ResidualBlock import *

def get_resnet_config(depth="18", dataset='cifar10', oper_order='cba'):
    activation_generator_for_mix = []

    default_block_depth = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
        50: [3, 4, 6, 3],

        20: [],
        32: [],
        56: [],
        110: [],
    }

    default_block_type = {
        18: PreActBlock,
        32: PreActBlock,
        50: PreActBottleneck,
    
        20: PreActBlock,
        34: PreActBlock,
        56: PreActBottleneck,
    }
    

    default_plane_list = {
        'cifar10': [64, 128, 256, 512],
        'cifar100': [64, 128, 256, 512],

        'tinyImageNet': [64, 128, 256, 512],
        'ImageNet': [64, 128, 256, 512],
        'cub200': [64, 128, 256, 512]
    }

    default_stride = {
        'cifar10': [1, 2, 2, 2],
        'cifar100': [1, 2, 2, 2],

        'tinyImageNet': [1, 2, 2, 2],
        'ImageNet': [1, 2, 2, 2],
        'cub200': [1, 2, 2, 2]
    }

    oper_info = None
    oper_order_list = list(oper_order)
    oper_order_dict = {'full': oper_order_list[:],
                        'front1': oper_order_list[:1], 'front2': oper_order_list[:2],
                        'end1': oper_order_list[-1:], 'end2': oper_order_list[-2:],
                        'fe': oper_order_list[0] + oper_order_list[2]}
    oper_info = oper_order_dict

    block = default_block_type[depth]

    if len(activation_generator_for_mix) != 0:
        return default_block_depth[depth], default_plane_list[dataset], \
           default_stride[dataset], oper_info, block, activation_generator_for_mix
    else:
        return default_block_depth[depth], default_plane_list[dataset], \
               default_stride[dataset], oper_info, block