#!/usr/bin/env python3
"""
Created on 19:57, Jul. 28th, 2023

@author: Anonymous
"""
import copy as cp
import tensorflow as tf
import tensorflow.keras as K
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir, os.pardir))
    from TransformerBlock import TransformerBlock
else:
    from models.domain_adaptation.DomainContrastiveNet.DomainContrastiveTransformer.layers.TransformerBlock import TransformerBlock

__all__ = [
    "TransformerStack",
]

class TransformerStack(K.layers.Layer):
    """
    `TransformerStack` acts as an encoder of or a decoder.
    """

    def __init__(self, params, **kwargs):
        """
        Initialize `TransformerStack` object.

        Args:
            params: DotDict - Model parameters.
            kwargs: The arguments related to initialize `tf.keras.Model`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.Model`
        # style model and inherit it's functionality.
        super(TransformerStack, self).__init__(**kwargs)

        # Copy hyperparameters (e.g. network sizes) from parameter dotdict.
        self.params = cp.deepcopy(params)

        # Create trainable vars.
        self._init_trainable()

    """
    init funcs
    """
    # def _init_trainable func
    def _init_trainable(self):
        """
        Initialize trainable variables.

        Args:
            None

        Returns:
            None
        """
        ## Construct attention blocks.
        # Initialize attention blocks.
        # encoder_attns - (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        self.attn_blocks = [TransformerBlock(d_model=self.params.d_model,
            n_heads=self.params.n_heads, d_head=self.params.d_head, attn_dropout=self.params.attn_dropout,
            proj_dropout=self.params.proj_dropout, d_ff=self.params.d_ff, ff_dropout=self.params.ff_dropout,
            norm_type=self.params.norm_type, norm_first=self.params.norm_first, name="attn-block-{:d}".format(block_idx)
        ) for block_idx in range(self.params.n_blocks)]
        ## Construct pooling block.
        # Initialize pooling block.
        self.pool_block = K.Sequential(name="pool-block")
        # Add `MaxPool1D` layer.
        # maxpool - (batch_size, seq_len, d_model) -> (batch_size, seq_len / d_pooling_kernel, d_model)
        if self.params.d_pooling_kernel > 1:
            self.pool_block.add(K.layers.MaxPool1D(
                # Modified `MaxPool1D` layer parameters.
                pool_size=self.params.d_pooling_kernel,
                # Default `MaxPool1D` layer parameters.
                strides=None, padding="valid", data_format="channels_last"
            ))
        # Add `Dropout` after `MaxPool1D` layer.
        if self.params.pool_dropout > 0.:
            self.pool_block.add(K.layers.Dropout(rate=self.params.pool_dropout, noise_shape=None, seed=None))
        # Add `BatchNormalization` at the last layer of encoder block.
        if self.params.use_bn:
            self.pool_block.add(K.layers.BatchNormalization(
                # Modified `BatchNormalization` parameters.
                # Default `BatchNormalization` parameters.
                axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer="zeros",
                gamma_initializer="ones", moving_mean_initializer="zeros", moving_variance_initializer="ones",
                beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None
            ))

    """
    network funcs
    """
    # def call func
    def call(self, emb, attn_score=None, attn_mask=None, key_padding_mask=None):
        """
        Forward layers in `TransformerStack` to get the mha-ffn transformed result.

        Args:
            emb: (batch_size, seq_len, d_model) - The input embedding.
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score from the previous layer.
            attn_mask: (seq_len, seq_len) - The pre-defined attention mask within sequence.
            key_padding_mask: (batch_size, seq_len) - The pre-defined key mask within sequence.

        Returns:
            emb: (batch_size, seq_len // pool_size, d_model) - The mha-ffn transformed embedding.
            attn_weight: (batch_size, n_heads, seq_len, seq_len) - The attention weight.
            attn_score: (batch_size, n_heads, seq_len, seq_len) - The attention score.
        """
        # Forward `attn_blocks` to get the transformed embedding.
        # emb - (batch_size, seq_len, d_model)
        for block_idx in range(len(self.attn_blocks)):
            emb, attn_weight, attn_score = self.attn_blocks[block_idx](
                emb, attn_score=attn_score if self.params.res_attn else None,
                attn_mask=attn_mask, key_padding_mask=key_padding_mask
            )
        # Forward `pool_block` to get the pooled embedding.
        # emb - (batch_size, seq_len // pool_size, d_model)
        emb = self.pool_block(emb)
        # Return the final `emb` & `attn_weight` & `attn_score`.
        return emb, attn_weight, attn_score

if __name__ == "__main__":
    import numpy as np
    # local dep
    from params.patch_transformer_params import patch_transformer_params

    # macro
    dataset = "eeg_anonymous"; batch_size = 32; seq_len = 80

    # Instantiate params.
    patch_transformer_params_inst = patch_transformer_params(dataset=dataset)
    d_model = patch_transformer_params_inst.model.encoder.d_model
    # Instantiate TransformerStack.
    ts_inst = TransformerStack(patch_transformer_params_inst.model.encoder)
    # Initialize input emb.
    emb = tf.random.normal((batch_size, seq_len, d_model), dtype=tf.float32)
    # Forward layers in `ts_inst`.
    emb, attn_weight, attn_score = ts_inst(emb)

