from typing import Any, List, Tuple, Dict
import tensorflow as tf
from .QMHSADeiT import QMHSADeiT
from .QDense import QBlockDense
from .QGeLU import QGeLU
from .QInitializers import (
    QWeightInitializer,
    QBiasInitializer,
    QBetaInitializer,
    QGammaInitializer,
)
from .QStochasticDepth import StochasticDepth


@tf.keras.utils.register_keras_serializable()
class QDeiTBlock(tf.keras.layers.Layer):
    """
    Implements a DeiT block.
    see https://github.com/sayakpaul/deit-tf/tree/79cc91d3cb497f7abe5111fb536968fb9d9a754d
    """

    def __init__(
        self,
        num_heads: int,
        mlp_dim: int,
        dropout: float,
        name: str,
        power_exponent: float = 1.0,
        activation_bits: int = 8,
        weight_bits: int = 8,
        accumulator_bits: int = 32,
        per_channel: bool = False,
        deterministic: bool = True,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        drop_prop: float = 0.0,
        weight_initializers: List[tf.keras.initializers.Initializer] = [
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
            "glorot_uniform",
        ],
        bias_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
            "zeros",
            "zeros",
            "zeros",
            "zeros",
        ],
        beta_initializers: List[tf.keras.initializers.Initializer] = [
            "zeros",
            "zeros",
        ],
        gamma_initializers: List[tf.keras.initializers.Initializer] = [
            "ones",
            "ones",
        ],
        *args: Any,
        **kwds: Any,
    ):
        super().__init__(*args, **kwds)
        self.drop_prop = drop_prop
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.core_name = name
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.per_channel = per_channel
        self.power_exponent = power_exponent
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializers = weight_initializers
        self.bias_initializers = bias_initializers
        self.gamma_initializers = gamma_initializers
        self.beta_initializers = beta_initializers

    def build(self, input_shape):
        self.att = QMHSADeiT(
            num_heads=self.num_heads,
            name="MultiHeadDotProductAttention_1",
            power_exponent=self.power_exponent,
            activation_bits=self.activation_bits,
            weight_bits=self.weight_bits,
            accumulator_bits=self.accumulator_bits,
            per_channel=self.per_channel,
            deterministic=self.deterministic,
            ste_overflow=self.ste_overflow,
            cyclical_alpha=self.cyclical_alpha,
            weight_initializers=self.weight_initializers,
            bias_initializers=self.bias_initializers,
        )
        self.mlpblock = tf.keras.Sequential(
            [
                QBlockDense(
                    units=self.mlp_dim,
                    name=f"{self.core_name}/Dense_0",
                    power_exponent=self.power_exponent,
                    activation_bits=self.activation_bits,
                    weight_bits=self.weight_bits,
                    accumulator_bits=self.accumulator_bits,
                    use_bias=True,
                    activation="linear",
                    per_channel=self.per_channel,
                    deterministic=self.deterministic,
                    ste_overflow=self.ste_overflow,
                    cyclical_alpha=self.cyclical_alpha,
                    weight_initializer=self.weight_initializers[4],
                    bias_initializer=self.bias_initializers[4],
                ),
                QGeLU(num_bits=self.activation_bits),
                tf.keras.layers.Dropout(self.dropout),
                QBlockDense(
                    units=input_shape[-1],
                    name=f"{self.core_name}/Dense_1",
                    power_exponent=self.power_exponent,
                    activation_bits=self.activation_bits,
                    weight_bits=self.weight_bits,
                    accumulator_bits=self.accumulator_bits,
                    use_bias=True,
                    activation="linear",
                    per_channel=self.per_channel,
                    deterministic=self.deterministic,
                    ste_overflow=self.ste_overflow,
                    cyclical_alpha=self.cyclical_alpha,
                    weight_initializer=self.weight_initializers[5],
                    bias_initializer=self.bias_initializers[5],
                ),
                tf.keras.layers.Dropout(self.dropout),
            ],
            name="MlpBlock_3",
        )
        self.layernorm1 = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name="LayerNorm_0",
            beta_initializer=self.beta_initializers[0],
            gamma_initializer=self.gamma_initializers[0],
        )
        self.layernorm2 = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name="LayerNorm_2",
            beta_initializer=self.beta_initializers[1],
            gamma_initializer=self.gamma_initializers[1],
        )
        self.dropout_layer = tf.keras.layers.Dropout(self.dropout)
        self.StochasticDepth1 = StochasticDepth(drop_prop=self.drop_prop)
        self.StochasticDepth2 = StochasticDepth(drop_prop=self.drop_prop)

    def update_parameter(self, param_name: str, new_value: Any):
        """
        changes the value of a parameter given its name with the new
        provided value
        """
        if param_name == "ema":
            self.att.update_parameter(param_name="ema", new_value=new_value)
            self.mlpblock.layers[0].update_parameter(
                param_name="ema", new_value=new_value
            )
            self.mlpblock.layers[3].update_parameter(
                param_name="ema", new_value=new_value
            )
        elif param_name == "alpha":
            self.att.update_parameter(param_name="alpha", new_value=new_value)
            self.mlpblock.layers[0].update_parameter(
                param_name="alpha", new_value=new_value
            )
            self.mlpblock.layers[3].update_parameter(
                param_name="alpha", new_value=new_value
            )
        else:
            raise NotImplementedError(
                f"requested to update {param_name}, but only 'ema' and 'alpha' are parameters to update"
            )

    def call(self, inputs, training):
        x = self.layernorm1(inputs)
        x, weights = self.att(x)
        x = self.dropout_layer(x, training=training)
        x = x + inputs
        y = self.layernorm2(x)
        y = self.mlpblock(y)
        return x + y, weights

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "num_heads": self.num_heads,
                "mlp_dim": self.mlp_dim,
                "dropout": self.dropout,
            }
        )
        config["name"] = self.core_name
        config["activation_bits"] = self.activation_bits
        config["weight_bits"] = self.weight_bits
        config["accumulator_bits"] = self.accumulator_bits
        config["per_channel"] = self.per_channel
        config["power_exponent"] = self.power_exponent
        config["deterministic"] = self.deterministic
        config["ste_overflow"] = self.ste_overflow
        config["cyclical_alpha"] = self.cyclical_alpha
        config["weight_initializers"] = self.weight_initializers
        config["bias_initializers"] = self.bias_initializers
        config["gamma_initializers"] = self.gamma_initializers
        config["beta_initializers"] = self.beta_initializers
        config["drop_prop"] = self.drop_prop
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


def get_deit_weight_initializers(
    core_layer: tf.keras.layers.Layer,
    config: Dict[str, Any],
) -> Tuple[
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
    List[tf.keras.initializers.Initializer],
]:
    """
    get floating point values from original network to use as init for quantized network
    see https://github.com/sayakpaul/deit-tf/blob/79cc91d3cb497f7abe5111fb536968fb9d9a754d/vit/layers/mha.py#L99
    for special case on attention weights
    """
    weight_initializers = []
    bias_initializers = []
    gamma_initializers = []
    beta_initializers = []

    for l in core_layer.layers:
        if "weights" in dir(l):
            if "tf_vi_t_attention" in l.name:
                weight_initializers.append(
                    QWeightInitializer(
                        layer_name=l.name + "query",
                        weight_tensor=l.self_attention.query.weights[0].numpy(),
                        operator=config["operator"],
                        layer_type="Dense",
                        bits=config["weights bits"],
                    )
                )
                bias_initializers.append(
                    QBiasInitializer(
                        layer_name=l.name + "query",
                        bias_tensor=l.self_attention.query.weights[1],
                    )
                )
                weight_initializers.append(
                    QWeightInitializer(
                        layer_name=l.name + "key",
                        weight_tensor=l.self_attention.key.weights[0].numpy(),
                        operator=config["operator"],
                        layer_type="Dense",
                        bits=config["weights bits"],
                    )
                )
                bias_initializers.append(
                    QBiasInitializer(
                        layer_name=l.name + "key",
                        bias_tensor=l.self_attention.key.weights[1],
                    )
                )
                weight_initializers.append(
                    QWeightInitializer(
                        layer_name=l.name + "value",
                        weight_tensor=l.self_attention.value.weights[0].numpy(),
                        operator=config["operator"],
                        layer_type="Dense",
                        bits=config["weights bits"],
                    )
                )
                bias_initializers.append(
                    QBiasInitializer(
                        layer_name=l.name + "value",
                        bias_tensor=l.self_attention.value.weights[1],
                    )
                )
                weight_initializers.append(
                    QWeightInitializer(
                        layer_name=l.name + "output",
                        weight_tensor=l.dense_output.dense.weights[0].numpy(),
                        operator=config["operator"],
                        layer_type="Dense",
                        bits=config["weights bits"],
                    )
                )
                bias_initializers.append(
                    QBiasInitializer(
                        layer_name=l.name + "output",
                        bias_tensor=l.dense_output.dense.weights[1],
                    )
                )
            else:
                for w in l.weights:
                    if "kernel" in w.name:
                        weight_initializers.append(
                            QWeightInitializer(
                                layer_name=core_layer.name,
                                weight_tensor=w.numpy(),
                                operator=config["operator"],
                                layer_type="Dense",
                                bits=config["weights bits"],
                            )
                        )
                    elif "bias" in w.name:
                        bias_initializers.append(
                            QBiasInitializer(
                                layer_name=core_layer.name, bias_tensor=w.numpy()
                            )
                        )
                    elif "beta" in w.name:
                        beta_initializers.append(
                            QBetaInitializer(
                                layer_name=core_layer.name, beta_tensor=w.numpy()
                            )
                        )
                    elif "gamma" in w.name:
                        gamma_initializers.append(
                            QGammaInitializer(
                                layer_name=core_layer.name, gamma_tensor=w.numpy()
                            )
                        )
    return weight_initializers, bias_initializers, gamma_initializers, beta_initializers
