from .QDense import QBlockDense
import tensorflow as tf
from typing import Any, Dict, List


@tf.keras.utils.register_keras_serializable()
class QClassAttn(tf.keras.layers.Layer):
    def __init__(
        self,
        num_heads: int,
        projection_dim: int,
        name: str,
        power_exponent: float,
        activation_bits: int,
        weight_bits: int,
        accumulator_bits: int,
        per_channel: bool,
        deterministic: bool,
        ste_overflow: bool,
        cyclical_alpha: float,
        weight_initializers: List[tf.keras.initializers.Initializer],
        bias_initializers: List[tf.keras.initializers.Initializer],
        *args,
        **kwargs,
    ) -> None:

        super().__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.core_name = name
        self.power_exponent = power_exponent
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers
        head_dim = projection_dim // self.num_heads
        self.scale = head_dim**-0.5
        self.projection_dim = projection_dim

    def build(self, input_shape: tuple):

        self.q = QBlockDense(
            units=self.projection_dim,
            name=f"{self.core_name}/q",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[0],
            bias_initializer=self.bias_initializers[0],
        )

        self.k = QBlockDense(
            units=self.projection_dim,
            name=f"{self.core_name}/k",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[1],
            bias_initializer=self.bias_initializers[1],
        )

        self.v = QBlockDense(
            units=self.projection_dim,
            name=f"{self.core_name}/v",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[2],
            bias_initializer=self.bias_initializers[2],
        )

        self.proj = QBlockDense(
            units=self.projection_dim,
            name=f"{self.core_name}/proj",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            use_bias=True,
            activation="linear",
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializer=self.weight_initializers[3],
            bias_initializer=self.bias_initializers[3],
        )

    def call(
        self, inputs: tf.Variable, training=None, *args: Any, **kwargs: Any
    ) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)
        B, N, C = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]

        # Query projection. `cls_token` embeddings are queries.
        q = tf.expand_dims(self.q(inputs[:, 0]), axis=1)
        q = tf.reshape(q, (B, 1, self.num_heads, C // self.num_heads))
        q = tf.transpose(q, perm=[0, 2, 1, 3])
        scale = tf.cast(self.scale, dtype=q.dtype)
        q = q * scale

        # Key projection. Patch embeddings as well the cls embedding are used as keys.
        k = self.k(inputs)
        k = tf.reshape(k, (B, N, self.num_heads, C // self.num_heads))
        k = tf.transpose(k, perm=[0, 2, 1, 3])

        # Value projection. Patch embeddings as well the cls embedding are used as keys.
        v = self.v(inputs)
        v = tf.reshape(v, (B, N, self.num_heads, C // self.num_heads))
        v = tf.transpose(v, perm=[0, 2, 1, 3])

        # Calculate attention between cls_token embedding and patch embeddings.
        attn = tf.matmul(q, k, transpose_b=True)
        attn = tf.nn.softmax(attn, axis=-1)

        x_cls = tf.matmul(attn, v)
        x_cls = tf.transpose(x_cls, perm=[0, 2, 1, 3])
        x_cls = tf.reshape(x_cls, (B, 1, C))
        x_cls = self.proj(x_cls)

        return x_cls, attn

    def get_config(
        self,
    ) -> Dict[str, Any]:
        config = super().get_config()
        config.update(
            {
                "num_heads": self.num_heads,
                "projection_dim": self.projection_dim,
                "name": self.core_name,
                "power_exponent": self.power_exponent,
                "activation_bits": self.activation_bits,
                "weight_bits": self.weight_bits,
                "accumulator_bits": self.accumulator_bits,
                "per_channel": self.per_channel,
                "deterministic": self.deterministic,
                "ste_overflow": self.ste_overflow,
                "cyclical_alpha": self.cyclical_alpha,
                "weight_initializers": self.weight_initializers,
                "bias_initializers": self.bias_initializers,
            }
        )

        return config
