from typing import Dict, Any, Tuple, Type
import tensorflow as tf
from .QGeLU import QGeLU
from .QSoftmax import QSoftmax
from .QAdd import QAdd
from .QCreateBasicExprLayer import *
from .QCreateTransformerLayers import *
from vit_keras.layers import TransformerBlock
import copy


def get_specific_layers(layers_list: list, layer_type: Type) -> list:
    """
    find layers that share a specific type.
    If multiple instances exist... well we don't handle this case
    """
    layer_of_required_type = []
    for layer in layers_list:
        if isinstance(layer, layer_type):
            layer_of_required_type.append(layer)
    return layer_of_required_type


def create_custom_block(
    core_layer: tf.keras.layers.Layer,
    activation_layer: Any,
    pooling_layer: Any,
    config: Dict[str, Any],
    batchnormalization_layers: Any,
    quantize_transformer_blocks: bool,
) -> tf.keras.layers.Layer:
    """
    Creates a new computational block for quantized model
    """
    if isinstance(core_layer, tf.keras.layers.DepthwiseConv2D):
        return create_custom_depthwiseconv2d_block(
            core_layer=core_layer,
            activation_layer=activation_layer,
            pooling_layer=pooling_layer,
            batchnormalization_layers=batchnormalization_layers,
            config=config,
        )
    elif isinstance(core_layer, tf.keras.layers.Conv2D):
        return create_custom_conv2d_block(
            core_layer=core_layer,
            activation_layer=activation_layer,
            pooling_layer=pooling_layer,
            batchnormalization_layers=batchnormalization_layers,
            config=config,
        )
    elif isinstance(core_layer, tf.keras.layers.Dense):
        return create_custom_dense_block(
            core_layer=core_layer,
            activation_layer=activation_layer,
            pooling_layer=pooling_layer,
            batchnormalization_layers=batchnormalization_layers,
            config=config,
        )
    elif isinstance(core_layer, tf.keras.layers.LayerNormalization):
        return create_custom_layer_normalization(
            core_layer=core_layer,
            config=config,
        )
    elif (
        isinstance(core_layer, TransformerBlock)
        or "transformer_block" in core_layer.name
        or "sa_ffn_block" in core_layer.name
        or "ca_ffn_block" in core_layer.name
    ) and quantize_transformer_blocks:
        return create_custom_transformer_block(
            core_layer=core_layer,
            config=config,
        )
    elif (
        isinstance(core_layer, TransformerBlock)
        or "transformer_block" in core_layer.name
        or "sa_ffn_block" in core_layer.name
        or "ca_ffn_block" in core_layer.name
    ) and not quantize_transformer_blocks:
        return core_layer
    else:
        raise NotImplementedError(
            f"the requested layer to edit ({type(core_layer)}) is not supported yet"
        )


def create_QLayers(
    model: tf.keras.Model,
    layers_dict: Dict[str, list],
    config: Dict[str, Any],
    layers_to_remove: list,
    quantize_transformer_blocks: bool,
) -> Tuple[Dict[str, tf.keras.layers.Layer], list]:
    """
    from the cluster of layers and config we want to create the layers that will serve
    as building blocks for the quantized network.
    """
    layers_to_replace = {}
    for cpt, (core_layer_name, associated_layers) in enumerate(layers_dict.items()):
        config_to_use = copy.deepcopy(config)
        if config["first and last layer W8/A8"] and (
            cpt == 0 or cpt == (len(layers_dict.items()) - 1)
        ):
            config_to_use["weights bits"] = 8
            config_to_use["activation bits"] = 8
        core_layer = model.get_layer(core_layer_name)
        pooling_layers = get_specific_layers(
            layers_list=associated_layers, layer_type=tf.keras.layers.GlobalMaxPool2D
        )
        pooling_layers += get_specific_layers(
            layers_list=associated_layers, layer_type=tf.keras.layers.MaxPool2D
        )
        activation_layers = get_specific_layers(
            layers_list=associated_layers, layer_type=tf.keras.layers.Activation
        )
        if len(pooling_layers) > 1:
            for pooling_layer in pooling_layers:
                layers_to_remove.remove(pooling_layer)
            pooling_layers = None
        elif len(pooling_layers) == 1:
            pooling_layers = pooling_layers[0]
        else:
            pooling_layers = None
        if len(activation_layers) > 1:
            for activation_layer in activation_layers:
                layers_to_remove.remove(activation_layer.name)
            activation_layers = None
        elif len(activation_layers) == 1:
            activation_layers = activation_layers[0]
        else:
            activation_layers = None
        layers_to_replace[core_layer_name] = create_custom_block(
            core_layer=core_layer,
            activation_layer=activation_layers,
            pooling_layer=pooling_layers,
            batchnormalization_layers=config["bn in qblock"],
            config=config_to_use,
            quantize_transformer_blocks=quantize_transformer_blocks,
        )
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Add):
            layers_to_replace[layer.name] = layer
            # QAdd(
            #     accumulator_bits=config["accumulator bits"],
            #     ste_overflow=config["ste overflow"],
            #     cyclical_alpha=config["cyclical alpha"],
            # )
    return layers_to_replace, layers_to_remove
