from typing import Any, List, Tuple, Dict
import tensorflow as tf
from .QMultiHeadSelfAttention import QMultiHeadSelfAttention
from .QDense import QBlockDense
from .QGeLU import QGeLU
from .QInitializers import (
    QWeightInitializer,
    QBiasInitializer,
    QBetaInitializer,
    QGammaInitializer,
)


@tf.keras.utils.register_keras_serializable()
class QTransformerBlock(tf.keras.layers.Layer):
    """Implements a Transformer block."""

    def __init__(
        self,
        num_heads: int,
        mlp_dim: int,
        dropout: float,
        name: str,
        power_exponent: float = 1.0,
        activation_bits: int = 8,
        weight_bits: int = 8,
        accumulator_bits: int = 32,
        per_channel: bool = False,
        deterministic: bool = True,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        weight_initializers: List[tf.keras.initializers.Initializer] = [
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
        ],
        bias_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
            "zeros",
            "zeros",
            "zeros",
            "zeros",
        ],
        beta_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
        ],
        gamma_initializers: List[tf.keras.initializers.Initializer] = [
            "ones",
            "ones",
        ],
        *args: Any,
        **kwds: Any,
    ):
        super().__init__(*args, **kwds)
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.core_name = name
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.power_exponent = power_exponent
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers
        self.gamma_initializers = gamma_initializers
        self.beta_initializers = beta_initializers

    def build(self, input_shape):
        self.att = QMultiHeadSelfAttention(
            num_heads=self.num_heads,
            name="MultiHeadDotProductAttention_1",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializers=self.weight_initializers,
            bias_initializers=self.bias_initializers,
        )
        self.mlpblock = tf.keras.Sequential(
            [
                QBlockDense(
                    units=self.mlp_dim,
                    name=f"{self.core_name}/Dense_0",
                    power_exponent=self.power_exponent,
                    activation_bits=self.activation_bits,
                    weight_bits=self.weight_bits,
                    accumulator_bits=self.accumulator_bits,
                    use_bias=True,
                    activation="linear",
                    per_channel=self.per_channel,
                    deterministic=self.deterministic,
                    ste_overflow=self.ste_overflow,
                    cyclical_alpha=self.cyclical_alpha,
                    weight_initializer=self.weight_initializers[4],
                    bias_initializer=self.bias_initializers[4],
                ),
                QGeLU(num_bits=self.activation_bits),
                tf.keras.layers.Dropout(self.dropout),
                QBlockDense(
                    units=input_shape[-1],
                    name=f"{self.core_name}/Dense_1",
                    power_exponent=self.power_exponent,
                    activation_bits=self.activation_bits,
                    weight_bits=self.weight_bits,
                    accumulator_bits=self.accumulator_bits,
                    use_bias=True,
                    activation="linear",
                    per_channel=self.per_channel,
                    deterministic=self.deterministic,
                    ste_overflow=self.ste_overflow,
                    cyclical_alpha=self.cyclical_alpha,
                    weight_initializer=self.weight_initializers[5],
                    bias_initializer=self.bias_initializers[5],
                ),
                tf.keras.layers.Dropout(self.dropout),
            ],
            name="MlpBlock_3",
        )
        self.layernorm1 = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name="LayerNorm_0",
            beta_initializer=self.beta_initializers[0],
            gamma_initializer=self.gamma_initializers[0],
        )
        self.layernorm2 = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name="LayerNorm_2",
            beta_initializer=self.beta_initializers[1],
            gamma_initializer=self.gamma_initializers[1],
        )
        self.dropout_layer = tf.keras.layers.Dropout(self.dropout)

    def update_parameter(self, param_name: str, new_value: Any):
        """
        changes the value of a parameter given its name with the new
        provided value
        """
        if param_name == "ema":
            self.att.update_parameter(param_name="ema", new_value=new_value)
            self.mlpblock.layers[0].update_parameter(
                param_name="ema", new_value=new_value
            )
            self.mlpblock.layers[3].update_parameter(
                param_name="ema", new_value=new_value
            )
        elif param_name == "alpha":
            self.att.update_parameter(param_name="alpha", new_value=new_value)
            self.mlpblock.layers[0].update_parameter(
                param_name="alpha", new_value=new_value
            )
            self.mlpblock.layers[3].update_parameter(
                param_name="alpha", new_value=new_value
            )
        else:
            raise NotImplementedError(
                f"requested to update {param_name}, but only 'ema' and 'alpha' are parameters to update"
            )

    def call(self, inputs, training):
        x = self.layernorm1(inputs)
        x, weights = self.att(x)
        x = self.dropout_layer(x, training=training)
        x = x + inputs
        y = self.layernorm2(x)
        y = self.mlpblock(y)
        return x + y, weights

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "num_heads": self.num_heads,
                "mlp_dim": self.mlp_dim,
                "dropout": self.dropout,
            }
        )
        config["name"] = self.core_name
        config["activation_bits"] = self.activation_bits
        config["weight_bits"] = self.weight_bits
        config["accumulator_bits"] = self.accumulator_bits
        config["per_channel"] = self.per_channel
        config["power_exponent"] = self.power_exponent
        config["deterministic"] = self.deterministic
        config["ste_overflow"] = self.ste_overflow
        config["cyclical_alpha"] = self.cyclical_alpha
        config["weight_initializers"] = self.weight_initializers
        config["bias_initializers"] = self.bias_initializers
        config["gamma_initializers"] = self.gamma_initializers
        config["beta_initializers"] = self.beta_initializers
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


def get_weight_initializers(
    core_layer: tf.keras.layers.Layer, config: Dict[str, Any]
) -> Tuple[
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
]:
    """
    get floating point values from original network to use as init for quantized network
    """
    weight_initializers = []
    bias_initializers = []
    gamma_initializers = []
    beta_initializers = []

    for w in core_layer.weights:
        if "kernel" in w.name:
            weight_initializers.append(
                QWeightInitializer(
                    layer_name=core_layer.name,
                    weight_tensor=w.numpy(),
                    operator=config["operator"],
                    layer_type="Dense",
                    bits=config["weights bits"],
                )
            )
        elif "bias" in w.name:
            bias_initializers.append(
                QBiasInitializer(layer_name=core_layer.name, bias_tensor=w.numpy())
            )
        elif "beta" in w.name:
            beta_initializers.append(
                QBetaInitializer(layer_name=core_layer.name, beta_tensor=w.numpy())
            )
        elif "gamma" in w.name:
            gamma_initializers.append(
                QGammaInitializer(layer_name=core_layer.name, gamma_tensor=w.numpy())
            )
    return weight_initializers, bias_initializers, gamma_initializers, beta_initializers
