from torch import nn
from layers import vanilla, mixbin, mixbin_identity, mixbin_relu, mixbin_hardtanh

registered_model_compression_batchnorm_strategy = {
    'none': vanilla.VanillaBatchNorm2d,
    'mixbin': mixbin.BinaryBatchNorm2d,
    'mixbin_identity': mixbin_identity.BinaryBatchNorm2d,
    'mixbin_relu': mixbin_relu.BinaryBatchNorm2d,
    'mixbin_hardtanh': mixbin_hardtanh.BinaryBatchNorm2d
}

registered_model_compression_conv_strategy = {
    'none': vanilla.VanillaConv2d,
    'mixbin': mixbin.HardBinaryConv2d,
    'mixbin_identity': mixbin_identity.HardBinaryConv2d,
    'mixbin_relu': mixbin_relu.HardBinaryConv2d,
    'mixbin_hardtanh': mixbin_hardtanh.HardBinaryConv2d
}

registered_model_compression_activation_strategy = {
    'none': vanilla.VanillaActivation,
    'mixbin': mixbin.BinaryActivation,
    'mixbin_identity': mixbin_identity.BinaryActivation,
    'mixbin_relu': mixbin_relu.BinaryActivation,
    'mixbin_hardtanh': mixbin_hardtanh.BinaryActivation
}

BINARY_BATCHNORMS = (mixbin.BinaryBatchNorm2d,  mixbin_identity.BinaryBatchNorm2d,  \
    mixbin_relu.BinaryBatchNorm2d,  mixbin_hardtanh.BinaryBatchNorm2d)