from .QInitializers import get_qlayer_initializers
from .QConv2D import QBlockConv2D
from .QDense import QBlockDense
from .QUtils import get_activation
from typing import Dict, Any, Tuple, Type
import tensorflow as tf

def create_custom_conv2d_block(
    core_layer: tf.keras.layers.Conv2D,
    activation_layer: Any,
    pooling_layer: Any,
    batchnormalization_layers: Any,
    config: Dict[str, Any],
) -> tf.keras.layers.Layer:
    """
    Creates a new conv2d block for quantized model
    """
    core_config = core_layer.get_config()
    weight_initializer, bias_initializer = get_qlayer_initializers(
        layer=core_layer, config=config
    )
    return QBlockConv2D(
        name=core_layer.name,
        filters=core_config["filters"],
        kernel_size=core_config["kernel_size"],
        stride=core_config["strides"],
        padding=core_config["padding"],
        use_bias=core_config["use_bias"],
        pooling=pooling_layer,
        batchnormalization=batchnormalization_layers,
        activation=get_activation(
            core_activation=core_config["activation"], extra_activation=activation_layer
        ),
        power_exponent=config["power exponent"],
        activation_bits=config["activation bits"],
        weight_bits=config["weights bits"],
        accumulator_bits=config["accumulator bits"],
        per_channel=config["per channel"],
        deterministic=config["deterministic"],
        ste_overflow=config["ste overflow"],
        cyclical_alpha=config["cyclical alpha"],
        is_depthwise=False,
        depth_multiplier=1,
        dilation_rate=(1, 1),
        weight_initializer=weight_initializer,
        bias_initializer=bias_initializer,
    )


def create_custom_depthwiseconv2d_block(
    core_layer: tf.keras.layers.Conv2D,
    activation_layer: Any,
    pooling_layer: Any,
    batchnormalization_layers: Any,
    config: Dict[str, Any],
) -> tf.keras.layers.Layer:
    """
    Creates a new depthwise conv2d block for quantized model
    """
    core_config = core_layer.get_config()
    weight_initializer, bias_initializer = get_qlayer_initializers(
        layer=core_layer, config=config
    )
    return QBlockConv2D(
        name=core_layer.name,
        kernel_size=core_config["kernel_size"],
        stride=core_config["strides"],
        padding=core_config["padding"],
        use_bias=core_config["use_bias"],
        is_depthwise=True,
        depth_multiplier=core_config["depth_multiplier"],
        dilation_rate=core_config["dilation_rate"],
        pooling=pooling_layer,
        batchnormalization=batchnormalization_layers,
        activation=get_activation(
            core_activation=core_config["activation"], extra_activation=activation_layer
        ),
        power_exponent=config["power exponent"],
        activation_bits=config["activation bits"],
        weight_bits=config["weights bits"],
        accumulator_bits=config["accumulator bits"],
        per_channel=config["per channel"],
        deterministic=config["deterministic"],
        ste_overflow=config["ste overflow"],
        cyclical_alpha=config["cyclical alpha"],
        filters=-1,
        weight_initializer=weight_initializer,
        bias_initializer=bias_initializer,
    )


def create_custom_dense_block(
    core_layer: tf.keras.layers.Dense,
    activation_layer: Any,
    pooling_layer: Any,
    batchnormalization_layers: Any,
    config: Dict[str, Any],
) -> tf.keras.layers.Layer:
    """
    Creates a new fully-connected block for quantized model
    """
    core_config = core_layer.get_config()
    weight_initializer, bias_initializer = get_qlayer_initializers(
        layer=core_layer, config=config
    )
    return QBlockDense(
        name=core_layer.name,
        units=core_config["units"],
        use_bias=core_config["use_bias"],
        activation=get_activation(
            core_activation=core_config["activation"], extra_activation=activation_layer
        ),
        pooling=pooling_layer,
        batchnormalization=batchnormalization_layers,
        power_exponent=config["power exponent"],
        activation_bits=config["activation bits"],
        weight_bits=config["weights bits"],
        accumulator_bits=config["accumulator bits"],
        per_channel=config["per channel"],
        deterministic=config["deterministic"],
        ste_overflow=config["ste overflow"],
        cyclical_alpha=config["cyclical alpha"],
        weight_initializer=weight_initializer,
        bias_initializer=bias_initializer,
    )


def create_custom_block(
    core_layer: tf.keras.layers.Layer,
    activation_layer: Any,
    pooling_layer: Any,
    config: Dict[str, Any],
    batchnormalization_layers: Any,
) -> tf.keras.layers.Layer:
    """
    Creates a new computational block for quantized model
    """
    if isinstance(core_layer, tf.keras.layers.DepthwiseConv2D):
        return create_custom_depthwiseconv2d_block(
            core_layer=core_layer,
            activation_layer=activation_layer,
            pooling_layer=pooling_layer,
            batchnormalization_layers=batchnormalization_layers,
            config=config,
        )
    elif isinstance(core_layer, tf.keras.layers.Conv2D):
        return create_custom_conv2d_block(
            core_layer=core_layer,
            activation_layer=activation_layer,
            pooling_layer=pooling_layer,
            batchnormalization_layers=batchnormalization_layers,
            config=config,
        )
    else:
        return create_custom_dense_block(
            core_layer=core_layer,
            activation_layer=activation_layer,
            pooling_layer=pooling_layer,
            batchnormalization_layers=batchnormalization_layers,
            config=config,
        )
