from .QDense import QBlockDense
import tensorflow as tf
from typing import Any, Dict, List


@tf.keras.utils.register_keras_serializable()
class QTalkingHeadAtt(tf.keras.layers.Layer):
    def __init__(
        self,
        num_heads: int,
        projection_dim: int,
        name: str,
        power_exponent: float,
        activation_bits: int,
        weight_bits: int,
        accumulator_bits: int,
        per_channel: bool,
        deterministic: bool,
        ste_overflow: bool,
        cyclical_alpha: float,
        weight_initializers: List[tf.keras.initializers.Initializer],
        bias_initializers: List[tf.keras.initializers.Initializer],
        *args,
        **kwargs,
    ) -> None:

        super().__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.core_name = name
        self.power_exponent = power_exponent
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers
        head_dim = projection_dim // self.num_heads
        self.projection_dim = projection_dim

        self.scale = head_dim**-0.5

    def build(self, input_shape: tuple):

        self.qkv = QBlockDense(
            units=self.projection_dim * 3,
            name=f"{self.core_name}/qkv",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[0],
            bias_initializer=self.bias_initializers[0],
        )

        self.proj = QBlockDense(
            units=self.projection_dim,
            name=f"{self.core_name}/proj",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[1],
            bias_initializer=self.bias_initializers[1],
        )

        self.proj_l = QBlockDense(
            units=self.num_heads,
            name=f"{self.core_name}/proj_l",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[2],
            bias_initializer=self.bias_initializers[2],
        )

        self.proj_w = QBlockDense(
            units=self.num_heads,
            name=f"{self.core_name}/proj_w",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[3],
            bias_initializer=self.bias_initializers[3],
        )

    def call(self, inputs: tf.Variable, training=None, *args: Any, **kwargs: Any) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)
        B, N, C = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]

        qkv = self.qkv(inputs)
        qkv = tf.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads))
        qkv = tf.transpose(qkv, perm=[2, 0, 3, 1, 4])
        scale = tf.cast(self.scale, dtype=qkv.dtype)
        q, k, v = qkv[0] * scale, qkv[1], qkv[2]

        attn = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2]))
        attn = self.proj_l(tf.transpose(attn, perm=[0, 2, 3, 1]))
        attn = tf.transpose(attn, perm=[0, 3, 1, 2])
        attn = tf.nn.softmax(attn, axis=-1)

        attn = self.proj_w(tf.transpose(attn, perm=[0, 2, 3, 1]))
        attn = tf.transpose(attn, perm=[0, 3, 1, 2])

        x = tf.matmul(attn, v)
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        x = tf.reshape(x, (B, N, C))

        x = self.proj(x)

        return (x, attn)

    def get_config(
        self,
    ) -> Dict[str, Any]:
        config = super().get_config()
        config.update(
            {
                "num_heads": self.num_heads,
                "projection_dim": self.projection_dim,
                "name": self.core_name,
                "power_exponent": self.power_exponent,
                "activation_bits": self.activation_bits,
                "weight_bits": self.weight_bits,
                "accumulator_bits": self.accumulator_bits,
                "per_channel": self.per_channel,
                "deterministic": self.deterministic,
                "ste_overflow": self.ste_overflow,
                "cyclical_alpha": self.cyclical_alpha,
                "weight_initializers": self.weight_initializers,
                "bias_initializers": self.bias_initializers,
            }
        )

        return config
