from typing import Any, List
import tensorflow as tf
from .QDense import QBlockDense
from .QSoftmax import QSoftmax
from .I_Bert.I_Sqrt import compute_i_sqrt
import math


@tf.keras.utils.register_keras_serializable()
class QMHSADeiT(tf.keras.layers.Layer):
    def __init__(
        self,
        name: str,
        num_heads: int,
        power_exponent: float = 1.0,
        activation_bits: int = 8,
        weight_bits: int = 8,
        accumulator_bits: int = 32,
        per_channel: bool = False,
        deterministic: bool = True,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        weight_initializers: List[tf.keras.initializers.Initializer] = [
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
        ],
        bias_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
            "zeros",
            "zeros",
        ],
        *args: Any,
        **kwds: Any,
    ) -> None:
        super().__init__(*args, **kwds)
        self.num_heads = num_heads
        self.core_name = name
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.power_exponent = power_exponent
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers
        self.attention_head_size = int(
            self.weight_initializers[0].weight_tensor.shape[-2] / self.num_heads
        )
        self.num_attention_heads = self.num_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
        self.dropout = tf.keras.layers.Dropout(rate=0.0)

    def build(self, input_shape: tuple):
        hidden_size = input_shape[-1]
        num_heads = self.num_heads
        if hidden_size % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {hidden_size} should be divisible by number of heads = {num_heads}"
            )
        self.hidden_size = hidden_size
        self.projection_dim = hidden_size // num_heads
        self.query_dense = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/query",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[0],
            bias_initializer=self.bias_initializers[0],
        )
        self.key_dense = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/key",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[1],
            bias_initializer=self.bias_initializers[1],
        )
        self.value_dense = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/value",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[2],
            bias_initializer=self.bias_initializers[2],
        )
        self.combine_heads = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/out",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[3],
            bias_initializer=self.bias_initializers[3],
        )
        # self.softmax = QSoftmax(num_bits=self.activation_bits)

    def update_parameter(self, param_name: str, new_value: Any):
        """
        changes the value of a parameter given its name with the new
        provided value
        """
        if param_name == "ema":
            self.query_dense.update_parameter(param_name="ema", new_value=new_value)
            self.key_dense.update_parameter(param_name="ema", new_value=new_value)
            self.value_dense.update_parameter(param_name="ema", new_value=new_value)
            self.combine_heads.update_parameter(param_name="ema", new_value=new_value)
        elif param_name == "alpha":
            self.query_dense.update_parameter(param_name="alpha", new_value=new_value)
            self.key_dense.update_parameter(param_name="alpha", new_value=new_value)
            self.value_dense.update_parameter(param_name="alpha", new_value=new_value)
            self.combine_heads.update_parameter(param_name="alpha", new_value=new_value)
        else:
            raise NotImplementedError(
                f"requested to update {param_name}, but only 'ema' and 'alpha' are parameters to update"
            )

    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
        tensor = tf.reshape(
            tensor=tensor,
            shape=(
                batch_size,
                -1,
                self.num_attention_heads,
                self.attention_head_size,
            ),
        )
        return tf.transpose(tensor, perm=[0, 2, 1, 3])

    def call(self, inputs: tf.Variable, training=None, *args: Any, **kwds: Any) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)
        batch_size = tf.shape(inputs)[0]
        mixed_query_layer = self.query_dense(inputs, training)
        mixed_key_layer = self.key_dense(inputs, training)
        mixed_value_layer = self.value_dense(inputs, training)
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
        dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
        attention_scores = tf.divide(attention_scores, dk)

        attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
        attention_probs = self.dropout(inputs=attention_probs, training=training)
        attention_output = tf.matmul(attention_probs, value_layer)
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])

        attention_output = tf.reshape(
            tensor=attention_output, shape=(batch_size, -1, self.all_head_size)
        )
        attention_output = self.combine_heads(attention_output)
        return attention_output, attention_probs

    def get_config(self):
        config = super().get_config()
        config.update({"num_heads": self.num_heads})
        config["name"] = self.core_name
        config["activation_bits"] = self.activation_bits
        config["weight_bits"] = self.weight_bits
        config["accumulator_bits"] = self.accumulator_bits
        config["per_channel"] = self.per_channel
        config["power_exponent"] = self.power_exponent
        config["deterministic"] = self.deterministic
        config["ste_overflow"] = self.ste_overflow
        config["cyclical_alpha"] = self.cyclical_alpha
        config["weight_initializers"] = self.weight_initializers
        config["bias_initializers"] = self.bias_initializers
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)
