from .QTransformerBlock import QTransformerBlock, get_weight_initializers
from .QDeiTBlock import QDeiTBlock, get_deit_weight_initializers
from .QSA_FFN_Block import (
    QSA_FFN_Block,
    get_initializers_for_sa_att_block,
)
from .QCA_FFN_Block import (
    QCA_FFN_Block,
    get_initializers_for_ca_att_block,
)
from .QConfig import quantization_local_config
import tensorflow as tf
from typing import Dict, Any
from vit_keras.layers import TransformerBlock


def create_custom_transformer_block(
    core_layer: tf.keras.layers.Layer, config: Dict[str, Any]
) -> QTransformerBlock:
    """
    this function creates a QTransformerBlock based on the config and the provided core layer
    to get num heads for DeiT check : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/deit.py
    """
    if config["fp weights init"] and isinstance(core_layer, TransformerBlock):
        (
            weight_initializers,
            bias_initializers,
            gamma_initializers,
            beta_initializers,
        ) = get_weight_initializers(core_layer=core_layer, config=config)
    elif config["fp weights init"] and "transformer_block" in core_layer.name:
        (
            weight_initializers,
            bias_initializers,
            gamma_initializers,
            beta_initializers,
        ) = get_deit_weight_initializers(core_layer=core_layer, config=config)
    else:
        weight_initializers = [
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
        ]
        bias_initializers = [
            "zeros",
            "zeros",
            "zeros",
            "zeros",
            "zeros",
            "zeros",
        ]
        gamma_initializers = [
            "ones",
            "ones",
        ]
        beta_initializers = [
            "zeros",
            "zeros",
        ]
    if (
        not isinstance(core_layer, TransformerBlock)
        and "transformer_block" in core_layer.name
    ):
        num_heads = 3
        if "deit_s" in config["model name"].lower():
            num_heads = 6
        elif "deit_b" in config["model name"].lower():
            num_heads = 12
        return QDeiTBlock(
            num_heads=num_heads,
            mlp_dim=core_layer.weights[-2].shape[0],
            dropout=core_layer.layers[-2].get_config()["rate"],
            name=core_layer.name,
            power_exponent=config["power exponent"],
            activation_bits=config["activation bits"],
            weight_bits=config["weights bits"],
            accumulator_bits=config["accumulator bits"],
            per_channel=config["per channel"],
            deterministic=config["deterministic"],
            ste_overflow=config["ste overflow"],
            cyclical_alpha=config["cyclical alpha"],
            weight_initializers=weight_initializers,
            bias_initializers=bias_initializers,
            gamma_initializers=gamma_initializers,
            beta_initializers=beta_initializers,
        )
    elif (
        not isinstance(core_layer, TransformerBlock)
        and "sa_ffn_block" in core_layer.name
    ):
        (
            weight_initializers,
            bias_initializers,
            gamma_initializers,
            beta_initializers,
            layerscale_initializers,
            projection_dim,
            mlp_ratio,
            num_heads,
        ) = get_initializers_for_sa_att_block(sa_att_block=core_layer, config=config)
        quantization_local_config["num heads"] = num_heads
        return QSA_FFN_Block(
            name=core_layer.name,
            num_heads=num_heads,
            projection_dim=projection_dim,
            mlp_ratio=mlp_ratio,
            power_exponent=config["power exponent"],
            activation_bits=config["activation bits"],
            weight_bits=config["weights bits"],
            accumulator_bits=config["accumulator bits"],
            per_channel=config["per channel"],
            deterministic=config["deterministic"],
            ste_overflow=config["ste overflow"],
            cyclical_alpha=config["cyclical alpha"],
            layerscale_initializers=layerscale_initializers,
            weight_initializers=weight_initializers,
            bias_initializers=bias_initializers,
            gamma_initializers=gamma_initializers,
            beta_initializers=beta_initializers,
        )
    elif (
        not isinstance(core_layer, TransformerBlock)
        and "ca_ffn_block" in core_layer.name
    ):
        (
            weight_initializers,
            bias_initializers,
            gamma_initializers,
            beta_initializers,
            layerscale_initializers,
            projection_dim,
            mlp_ratio,
        ) = get_initializers_for_ca_att_block(ca_att_block=core_layer, config=config)
        return QCA_FFN_Block(
            name=core_layer.name,
            num_heads=quantization_local_config["num heads"],
            projection_dim=projection_dim,
            mlp_ratio=mlp_ratio,
            power_exponent=config["power exponent"],
            activation_bits=config["activation bits"],
            weight_bits=config["weights bits"],
            accumulator_bits=config["accumulator bits"],
            per_channel=config["per channel"],
            deterministic=config["deterministic"],
            ste_overflow=config["ste overflow"],
            cyclical_alpha=config["cyclical alpha"],
            layerscale_initializers=layerscale_initializers,
            weight_initializers=weight_initializers,
            bias_initializers=bias_initializers,
            gamma_initializers=gamma_initializers,
            beta_initializers=beta_initializers,
        )
    return QTransformerBlock(
        num_heads=core_layer.get_config()["num_heads"],
        mlp_dim=core_layer.get_config()["mlp_dim"],
        dropout=core_layer.get_config()["dropout"],
        name=core_layer.get_config()["name"],
        power_exponent=config["power exponent"],
        activation_bits=config["activation bits"],
        weight_bits=config["weights bits"],
        accumulator_bits=config["accumulator bits"],
        per_channel=config["per channel"],
        deterministic=config["deterministic"],
        ste_overflow=config["ste overflow"],
        cyclical_alpha=config["cyclical alpha"],
        weight_initializers=weight_initializers,
        bias_initializers=bias_initializers,
        gamma_initializers=gamma_initializers,
        beta_initializers=beta_initializers,
    )


def create_custom_layer_normalization(
    core_layer: tf.keras.layers.Layer, config: Dict[str, Any]
) -> tf.keras.layers.LayerNormalization:
    return core_layer
