from layers.registry import registered_model_compression_batchnorm_strategy, registered_model_compression_conv_strategy
from layers.registry import registered_model_compression_activation_strategy 

class ModuleInjection:
    model_compression_strategy = None
    prunable_modules = []

    @staticmethod
    def get_conv_bn_act(conv_layer, batchnorm_layer, is_skip_connection=False, is_first_layer=False, keep_full_precision=False):
        conv_layer = registered_model_compression_conv_strategy[ModuleInjection.model_compression_strategy].get_from_conv(conv_layer, \
                                                                                                                    is_skip_connection=is_skip_connection, \
                                                                                                                    is_first_layer=is_first_layer, \
                                                                                                                    keep_full_precision=keep_full_precision)
        batchnorm_layer = registered_model_compression_batchnorm_strategy[ModuleInjection.model_compression_strategy].get_from_batchnorm(batchnorm_layer, \
                                                                                                                    conv_layer, \
                                                                                                                    is_skip_connection=is_skip_connection, \
                                                                                                                    is_first_layer=is_first_layer)
        activation_layer = registered_model_compression_activation_strategy[ModuleInjection.model_compression_strategy].get_from_bn_conv(batchnorm_layer, \
                                                                                                                    conv_layer, \
                                                                                                                    is_skip_connection=is_skip_connection, \
                                                                                                                    is_first_layer=is_first_layer)
        ModuleInjection.prunable_modules.append(batchnorm_layer)
        return conv_layer, batchnorm_layer, activation_layer

    @staticmethod
    def is_valid_model_compression_strategy():
        strategy = ModuleInjection.model_compression_strategy

        if strategy not in registered_model_compression_batchnorm_strategy:
            raise ValueError('Invalid batchnorm strategy: {}'.format(strategy))

        if strategy not in registered_model_compression_conv_strategy:
            raise ValueError('Invalid conv strategy: {}'.format(strategy))

        return True

    @staticmethod
    def update_model_compression_strategy(model_compression_strategy):
        ModuleInjection.model_compression_strategy = model_compression_strategy
        ModuleInjection.prunable_modules = []
        ModuleInjection.is_valid_model_compression_strategy()

