from typing import Any, List
import tensorflow as tf
from .QDense import QBlockDense
from .QSoftmax import QSoftmax
from .I_Bert.I_Sqrt import compute_i_sqrt


@tf.keras.utils.register_keras_serializable()
class QMultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(
        self,
        name: str,
        num_heads: int,
        power_exponent: float = 1.0,
        activation_bits: int = 8,
        weight_bits: int = 8,
        accumulator_bits: int = 32,
        per_channel: bool = False,
        deterministic: bool = True,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        weight_initializers: List[tf.keras.initializers.Initializer] = [
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
        ],
        bias_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
            "zeros",
            "zeros",
        ],
        *args: Any,
        **kwds: Any,
    ) -> None:
        super().__init__(*args, **kwds)
        self.num_heads = num_heads
        self.core_name = name
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.power_exponent = power_exponent
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers

    def build(self, input_shape: tuple):
        hidden_size = input_shape[-1]
        num_heads = self.num_heads
        if hidden_size % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {hidden_size} should be divisible by number of heads = {num_heads}"
            )
        self.hidden_size = hidden_size
        self.projection_dim = hidden_size // num_heads
        self.query_dense = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/query",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[0],
            bias_initializer=self.bias_initializers[0],
        )
        self.key_dense = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/key",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[1],
            bias_initializer=self.bias_initializers[1],
        )
        self.value_dense = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/value",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[2],
            bias_initializer=self.bias_initializers[2],
        )
        self.combine_heads = QBlockDense(
            units=hidden_size,
            name=f"{self.core_name}/out",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[3],
            bias_initializer=self.bias_initializers[3],
        )
        # self.softmax = QSoftmax(num_bits=self.activation_bits)

    def update_parameter(self, param_name: str, new_value: Any):
        """
        changes the value of a parameter given its name with the new
        provided value
        """
        if param_name == "ema":
            self.query_dense.update_parameter(param_name="ema", new_value=new_value)
            self.key_dense.update_parameter(param_name="ema", new_value=new_value)
            self.value_dense.update_parameter(param_name="ema", new_value=new_value)
            self.combine_heads.update_parameter(param_name="ema", new_value=new_value)
        elif param_name == "alpha":
            self.query_dense.update_parameter(param_name="alpha", new_value=new_value)
            self.key_dense.update_parameter(param_name="alpha", new_value=new_value)
            self.value_dense.update_parameter(param_name="alpha", new_value=new_value)
            self.combine_heads.update_parameter(param_name="alpha", new_value=new_value)
        else:
            raise NotImplementedError(
                f"requested to update {param_name}, but only 'ema' and 'alpha' are parameters to update"
            )

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], score.dtype)
        scaled_score = score / compute_i_sqrt(dim_key)
        # weights = self.softmax(scaled_score, axis=-1)
        weights = tf.keras.layers.Softmax(axis=-1)(scaled_score)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs: tf.Variable, training=None, *args: Any, **kwds: Any) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs, training)
        key = self.key_dense(inputs, training)
        value = self.value_dense(inputs, training)
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention, (batch_size, -1, self.hidden_size))
        output = self.combine_heads(concat_attention, training)
        return output, weights

    def get_config(self):
        config = super().get_config()
        config.update({"num_heads": self.num_heads})
        config["name"] = self.core_name
        config["activation_bits"] = self.activation_bits
        config["weight_bits"] = self.weight_bits
        config["accumulator_bits"] = self.accumulator_bits
        config["per_channel"] = self.per_channel
        config["power_exponent"] = self.power_exponent
        config["deterministic"] = self.deterministic
        config["ste_overflow"] = self.ste_overflow
        config["cyclical_alpha"] = self.cyclical_alpha
        config["weight_initializers"] = self.weight_initializers
        config["bias_initializers"] = self.bias_initializers
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)
