"""
 Copyright 2023 [Anonymized]
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      https://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """

from typing import Any, Dict, Union, Callable

import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras import layers
from tensorflow.keras.layers import Layer

from tensorflow_similarity.layers import GeneralizedMeanPooling1D


@tf.keras.utils.register_keras_serializable(package="retsim")
class SequencePooling(Layer):
    """Learnable sequence pool
    Adapted from:
    https://github.com/SHI-Labs/Compact-Transformers
    """

    def __init__(self, activation: str = None, **kwargs) -> None:
        super().__init__()
        self.Dense = layers.Dense(1, activation=activation)
        self.activation = activation

    def call(self, inputs: Tensor, training: bool) -> Tensor:
        attention = self.Dense(inputs)
        attention = tf.nn.softmax(attention, axis=1)

        weighted_representation = tf.matmul(attention, inputs, transpose_a=True)
        weighted_representation = tf.squeeze(weighted_representation, -2)

        return weighted_representation

    def get_config(self) -> Dict[str, Any]:
        return {"activation": self.activation}


def pooling(
    inputs: Tensor,
    pooling_type: str,
    activation: str = None,
    gem_pooling_p: float = 3.0,
    output_dim: int = 0,
) -> Tensor:
    x = inputs

    if pooling_type == "GEM":
        x = GeneralizedMeanPooling1D(p=gem_pooling_p)(x)

    elif pooling_type == "max":
        x = layers.GlobalMaxPool1D()(x)

    elif pooling_type == "avg":
        x = layers.GlobalAveragePooling1D()(x)

    elif pooling_type == "flatten":
        x = layers.Flatten()(x)

    elif pooling_type == "dense" and output_dim > 0:
        x = layers.Flatten()(x)
        x = layers.Dense(output_dim, activation=activation)(x)

    return x


@tf.keras.utils.register_keras_serializable(package="retsim")
class ConvNextBlock(Layer):
    """ConvNeXt block.

    Adapted from A ConvNet for the 2020s (https://arxiv.org/pdf/2201.03545.pdf)

    This layer is compatitible with existing TensorFlow.js supported ops,
    which means that models built using this layer be converted to javascript
    using the TensorFlow.js converter. For more info, visit
    https://www.tensorflow.org/js/guide/conversion.
    """

    def __init__(
        self,
        kernel_size: int,
        depth: int,
        filters: int,
        hidden_dim: int,
        dropout_rate: float = 0,
        epsilon: float = 1e-10,
        activation: Union[str, Callable] = "gelu",
        strides: int = 1,
        residual: bool = True,
        **kwargs,
    ) -> None:
        """Initialize a ConvNextBlock.

        Args:
            kernel_size: Kernel size for convolution.

            depth: Depth multiplier for depthwise 1D convolution.

            filters: Number of convolution filters.

            hidden_dim: Hidden dim of block.

            dropout_rate: Feature dropout rate. Defaults to 0.

            epsilon: Layer norm epsilon. Defaults to 1e-10.

            activation: Layer activation. Defaults to 'gelu'.

            strides: Strides to apply convolution. Defaults to 1.

            residual: Whether to add residual connection. Defaults to True.
        """
        super().__init__(**kwargs)
        self.kernel_size = kernel_size
        self.depth = depth
        self.filters = filters
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        self.epsilon = epsilon
        self.activation = activation
        self.strides = strides
        self.residual = residual
        self.depthconv = layers.DepthwiseConv1D(
            kernel_size=kernel_size,
            strides=strides,
            depth_multiplier=depth,
            padding="same",
        )
        self.norm = layers.LayerNormalization(epsilon=epsilon)
        self.hidden = layers.Dense(hidden_dim)
        self.activation = layers.Activation(activation)
        self.drop = layers.Dropout(dropout_rate)
        self.out = layers.Dense(filters)

    def call(self, inputs: Tensor, training: bool) -> Tensor:
        residual = inputs
        x = self.depthconv(inputs)
        x = self.norm(x, training=training)
        x = self.hidden(x)
        x = self.drop(x, training=training)
        x = self.out(x)
        if self.residual:
            x = x + residual
        return x

    def get_config(self) -> Dict[str, Any]:
        return {
            "kernel_size": self.kernel_size,
            "depth": self.depth,
            "filters": self.filters,
            "hidden_dim": self.hidden_dim,
            "dropout_rate": self.dropout_rate,
            "epsilon": self.epsilon,
            "activation": self.activation,
            "strides": self.strides,
            "residual": self.residual,
        }

@tf.keras.utils.register_keras_serializable(package="retsim")
class FFN(Layer):
    def __init__(
        self,
        hidden_size: int,
        out_size: int,
        activation: str,
        dropout_rate: float = 0,
        **kwargs,
    ) -> None:
        """
        Construct a standard FFN layer
        """
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.out_size = out_size
        self.activation = activation
        self.dropout_rate = dropout_rate
        self.hidden = layers.Dense(hidden_size, use_bias=False)
        self.activation = layers.Activation(activation)
        self.out = layers.Dense(out_size, use_bias=False)
        self.dropout = layers.Dropout(dropout_rate)

    def call(self, inputs: Tensor, training: bool = False) -> Tensor:
        inputs = self.hidden(inputs)
        inputs = self.activation(inputs)
        inputs = self.dropout(inputs, training=training)
        inputs = self.out(inputs)
        return inputs

    def get_config(self) -> Dict[str, Any]:
        return {
            "hidden_size": self.hidden_size,
            "out_size": self.out_size,
            "activation": self.activation,
            "dropout_rate": self.dropout_rate,
        }


@tf.keras.utils.register_keras_serializable(package="retsim")
class GatedFFN(Layer):
    def __init__(
        self,
        hidden_size: int,
        out_size: int,
        activation: str,
        dropout_rate: float = 0,
        **kwargs,
    ) -> None:
        """Implements Gated FFN based off https://arxiv.org/pdf/2002.05202.pdf
        Note:
        - to be size equivalent, the hidden_dim should be about 2/3 of the
        standard FeedForward network
        - Swish activated gate and GELU activated gate seems to perform
        the best.
        """
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.out_size = out_size
        self.activation = activation
        self.dropout_rate = dropout_rate
        self.Hidden = layers.Dense(hidden_size, use_bias=False)
        self.Gate = layers.Dense(hidden_size, use_bias=False)
        self.Activation = layers.Activation(activation)
        self.Out = layers.Dense(out_size, use_bias=False)
        self.Dropout = layers.Dropout(dropout_rate)

    def call(self, inputs: Tensor, training: bool = False) -> Tensor:
        # compute gate
        gate = self.Gate(inputs)
        gate = self.Activation(gate)

        # expand & gate
        hidden = self.Hidden(inputs)
        hidden = hidden * gate  # apply gate

        # drop & compress
        hidden = self.Dropout(hidden, training=training)
        hidden = self.Out(hidden)
        return hidden

    def get_config(self) -> Dict[str, Any]:
        return {
            "hidden_size": self.hidden_size,
            "out_size": self.out_size,
            "activation": self.activation,
            "dropout_rate": self.dropout_rate,
        }
