from .QDense import QBlockDense
from .QGeLU import QGeLU
from .QTalkingHeadAtt import QTalkingHeadAtt
from .QLayerScale import QLayerScale
from .QInitializers import (
    QWeightInitializer,
    QBiasInitializer,
    QBetaInitializer,
    QGammaInitializer,
)
import tensorflow as tf
from typing import Any, Dict, List, Tuple
import numpy as np


@tf.keras.utils.register_keras_serializable()
class QSA_FFN_Block(tf.keras.layers.Layer):
    def __init__(
        self,
        name: str,
        num_heads: int,
        projection_dim: int,
        mlp_ratio: int,
        power_exponent: float,
        activation_bits: int,
        weight_bits: int,
        accumulator_bits: int,
        per_channel: bool,
        deterministic: bool,
        ste_overflow: bool,
        cyclical_alpha: float,
        layerscale_initializers: List[np.ndarray],
        weight_initializers: List[tf.keras.initializers.Initializer],
        bias_initializers: List[tf.keras.initializers.Initializer],
        beta_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
        ],
        gamma_initializers: List[tf.keras.initializers.Initializer] = [
            "ones",
            "ones",
        ],
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.core_name = name
        self.projection_dim = projection_dim
        self.mlp_ratio = mlp_ratio
        self.power_exponent = power_exponent
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.layerscale_initializers = layerscale_initializers
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers
        self.beta_initializers = beta_initializers
        self.gamma_initializers = gamma_initializers

    def build(self, input_shape: tuple):
        self.mlpblock = tf.keras.Sequential(
            [
                QBlockDense(
                    units=self.projection_dim * self.mlp_ratio,
                    name=f"{self.core_name}/Dense_0",
                    power_exponent=self.power_exponent,
                    activation_bits=self.activation_bits,
                    weight_bits=self.weight_bits,
                    accumulator_bits=self.accumulator_bits,
                    use_bias=True,
                    activation="linear",
                    per_channel=self.per_channel,
                    deterministic=self.deterministic,
                    ste_overflow=self.ste_overflow,
                    cyclical_alpha=self.cyclical_alpha,
                    weight_initializer=self.weight_initializers[4],
                    bias_initializer=self.bias_initializers[4],
                ),
                QGeLU(num_bits=self.activation_bits),
                QBlockDense(
                    units=self.projection_dim,
                    name=f"{self.core_name}/Dense_1",
                    power_exponent=self.power_exponent,
                    activation_bits=self.activation_bits,
                    weight_bits=self.weight_bits,
                    accumulator_bits=self.accumulator_bits,
                    use_bias=True,
                    activation="linear",
                    per_channel=self.per_channel,
                    deterministic=self.deterministic,
                    ste_overflow=self.ste_overflow,
                    cyclical_alpha=self.cyclical_alpha,
                    weight_initializer=self.weight_initializers[5],
                    bias_initializer=self.bias_initializers[5],
                ),
            ],
            name="MlpBlock_3",
        )
        self.layernorm1 = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name="LayerNorm_0",
            beta_initializer=self.beta_initializers[0],
            gamma_initializer=self.gamma_initializers[0],
        )
        self.layernorm2 = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name="LayerNorm_2",
            beta_initializer=self.beta_initializers[1],
            gamma_initializer=self.gamma_initializers[1],
        )
        self.TalkingHeadAttn = QTalkingHeadAtt(
            num_heads=self.num_heads,
            projection_dim=self.projection_dim,
            name=f"{self.core_name}_TalkingHeadAttn",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializers=self.weight_initializers,
            bias_initializers=self.bias_initializers,
        )
        self.layerscale1 = QLayerScale(
            projection_dim=self.projection_dim,
            init_values=self.layerscale_initializers[0],
        )
        self.layerscale2 = QLayerScale(
            projection_dim=self.projection_dim,
            init_values=self.layerscale_initializers[1],
        )

    def call(self, inputs: tf.Variable, training=None, *args: Any, **kwargs: Any) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)
        x1 = self.layernorm1(inputs)
        attn_output, attn_scores = self.TalkingHeadAttn(x1)
        attn_output = self.layerscale1(attn_output)
        x2 = tf.keras.layers.Add()([inputs, attn_output])

        # FFN.
        x3 = self.layernorm2(x2)
        x4 = self.mlpblock(x3)
        x4 = self.layerscale2(x4)
        outputs = tf.keras.layers.Add()([x2, x4])
        return outputs, attn_scores

    def get_config(
        self,
    ) -> Dict[str, Any]:
        config = super().get_config()
        config.update(
            {
                "name": self.core_name,
                "num_heads": self.num_heads,
                "projection_dim": self.projection_dim,
                "mlp_ratio": self.mlp_ratio,
                "power_exponent": self.power_exponent,
                "activation_bits": self.activation_bits,
                "weight_bits": self.weight_bits,
                "accumulator_bits": self.accumulator_bits,
                "per_channel": self.per_channel,
                "deterministic": self.deterministic,
                "ste_overflow": self.ste_overflow,
                "cyclical_alpha": self.cyclical_alpha,
                "layerscale_initializers": self.layerscale_initializers,
                "weight_initializers": self.weight_initializers,
                "bias_initializers": self.bias_initializers,
                "gamma_initializers": self.gamma_initializers,
                "beta_initializers": self.beta_initializers,
            }
        )
        return config


def get_initializers_for_sa_att_block(
    sa_att_block: tf.keras.Model,
    config: Dict[str, Any],
) -> Tuple[
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[np.ndarray],
    int,
    int,
    int,
]:

    weight_initializers = []
    bias_initializers = []
    gamma_initializers = []
    beta_initializers = []
    layerscale_initializers = []
    for layer in [
        sa_att_block.layers[2].qkv,
        sa_att_block.layers[2].proj,
        sa_att_block.layers[2].proj_l,
        sa_att_block.layers[2].proj_w,
    ]:
        weight_initializers.append(
            QWeightInitializer(
                layer_name=layer.name,
                weight_tensor=layer.weights[0].numpy(),
                operator=config["operator"],
                layer_type="Dense",
                bits=config["weights bits"],
            )
        )
        bias_initializers.append(
            QBiasInitializer(
                layer_name=layer.name, bias_tensor=layer.weights[1].numpy()
            )
        )
    layerscale_initializers.append(sa_att_block.layers[3].gamma.numpy())
    layerscale_initializers.append(sa_att_block.layers[10].gamma.numpy())
    gamma_initializers.append(
        QGammaInitializer(
            layer_name=sa_att_block.layers[1].name,
            gamma_tensor=sa_att_block.layers[1].weights[0].numpy(),
        )
    )
    beta_initializers.append(
        QBetaInitializer(
            layer_name=sa_att_block.layers[1].name,
            beta_tensor=sa_att_block.layers[1].weights[1].numpy(),
        )
    )
    gamma_initializers.append(
        QGammaInitializer(
            layer_name=sa_att_block.layers[5].name,
            gamma_tensor=sa_att_block.layers[5].weights[0].numpy(),
        )
    )
    beta_initializers.append(
        QBetaInitializer(
            layer_name=sa_att_block.layers[5].name,
            beta_tensor=sa_att_block.layers[5].weights[1].numpy(),
        )
    )
    weight_initializers.append(
        QWeightInitializer(
            layer_name=sa_att_block.layers[6].name,
            weight_tensor=sa_att_block.layers[6].weights[0].numpy(),
            operator=config["operator"],
            layer_type="Dense",
            bits=config["weights bits"],
        )
    )
    bias_initializers.append(
        QBiasInitializer(
            layer_name=sa_att_block.layers[6].name,
            bias_tensor=sa_att_block.layers[6].weights[1].numpy(),
        )
    )
    weight_initializers.append(
        QWeightInitializer(
            layer_name=sa_att_block.layers[8].name,
            weight_tensor=sa_att_block.layers[8].weights[0].numpy(),
            operator=config["operator"],
            layer_type="Dense",
            bits=config["weights bits"],
        )
    )
    bias_initializers.append(
        QBiasInitializer(
            layer_name=sa_att_block.layers[8].name,
            bias_tensor=sa_att_block.layers[8].weights[1].numpy(),
        )
    )

    projection_dim = sa_att_block.layers[8].weights[1].numpy().shape[-1]
    mlp_ratio = int(
        sa_att_block.layers[6].weights[1].numpy().shape[-1] / projection_dim
    )
    num_heads = sa_att_block.layers[2].proj_l.weights[1].numpy().shape[-1]
    return (
        weight_initializers,
        bias_initializers,
        gamma_initializers,
        beta_initializers,
        layerscale_initializers,
        projection_dim,
        mlp_ratio,
        num_heads,
    )
