# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from xformers.components import reversible as rv
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
from xformers.factory.block_configs import (
    xFormerBlockConfig,
    xFormerDecoderConfig,
    xFormerEncoderConfig,
)
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit

logger = logging.getLogger("xformers")


@dataclass(init=False)
class xFormerConfig:
    """
    The configuration structure to define a full Transformer.
    This can include a stack of encoder layers, and a stack of decoder layers.

    It is optionally possible to share the embedding weights in between
    the encoder and decoder positional encoding, as proposed for instance by
    `Using the Output Embedding to Improve Language Models`, Press et al.

    A full config example is for instance as follows:

    ::

        xformer_config = [
            {
                "reversible": False,  # Turn on to test the effect of using reversible layers
                "block_type": "encoder",
                "num_layers": LAYERS,
                "dim_model": EMB,
                "residual_norm_style": "pre",
                "position_encoding_config": {
                    "name": "vocab",
                    "seq_len": CONTEXT,
                    "vocab_size": VOCAB_SIZE,
                },
                "multi_head_config": {
                    "num_heads": NUM_HEADS,
                    "residual_dropout": RES_DROP,
                    "use_rotary_embeddings": True,
                    "attention": {
                        "name": ATTENTION_MECHANISM_STR,
                        "dropout": ATTN_DROP,
                        "causal": True,
                        "seq_len": CONTEXT,
                    },
                },
                "feedforward_config": {
                    "name": "FusedMLP",  # Use MLP if Triton is not available
                    "dropout": MLP_DROP,
                    "activation": "gelu",
                    "hidden_layer_multiplier": MLP_MULTIPLIER,
                },
            }
        ]


    .. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
    """

    stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
    tie_embedding_weights: bool = False
    weight_init: xFormerWeightInit = xFormerWeightInit.ViT

    def __init__(
        self,
        stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
        tie_embedding_weights: bool = False,
        weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
    ):
        # Type all the configurations. Possible typos are caught here
        if isinstance(stack_configs, dict):
            self.stack_configs = {}
            for k, config in stack_configs.items():
                if config["block_type"] == "encoder":
                    self.stack_configs[k] = xFormerEncoderConfig(**config)
                else:
                    self.stack_configs[k] = xFormerDecoderConfig(**config)
        else:
            self.stack_configs = []
            for config in stack_configs:
                if config["block_type"] == "encoder":
                    self.stack_configs.append(xFormerEncoderConfig(**config))
                else:
                    self.stack_configs.append(xFormerDecoderConfig(**config))

        self.tie_embedding_weights = tie_embedding_weights
        self.weight_init = weight_init


class xFormer(torch.nn.Module):
    def __init__(
        self,
        stack_configs: Union[
            xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
        ],
        tie_embedding_weights: bool = False,
        weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
    ):
        """
        Given a serialized configuration, generate the corresponding model.
        This is only a helper and can easily be bypassed
        """
        super().__init__()

        if isinstance(stack_configs, Dict):
            stack_configs = list(stack_configs.values())

        # Convenience, users can pass either a list of configs or a single one
        if not isinstance(stack_configs, List):
            stack_configs = [stack_configs]

        # Sanity checks, some config combinations do not make sense
        self._verify_reversible(stack_configs)
        self._verify_deepnorm(stack_configs)

        encoders: List[torch.nn.Module] = []
        decoders: List[torch.nn.Module] = []

        self.reversible_encoder = False
        self.rev_enc_pose_encoding = None

        # Unroll the configs and build the model
        for config in stack_configs:
            # Handle either Encoder or Decoder stacks
            builder = (
                xFormerEncoderBlock.from_config
                if isinstance(config, xFormerEncoderConfig)
                else xFormerDecoderBlock.from_config
            )
            recipient = (
                encoders if isinstance(config, xFormerEncoderConfig) else decoders
            )

            # Build up the stack
            for i in range(config.num_layers):
                # Label where this layer is in the stack
                # (for instance useful for the positional encoding, or late layer norm)
                if len(recipient) > 0:
                    config.layer_position.mark_not_first()

                if config != stack_configs[-1] or i < config.num_layers - 1:
                    config.layer_position.mark_not_last()

                block = builder(config)  # type: ignore

                # If reversible: extract the reversible sub-parts, else append the block as-is
                if config.reversible:
                    # WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
                    assert isinstance(config, xFormerEncoderConfig)
                    if block.pose_encoding is not None:
                        self.rev_enc_pose_encoding = block.pose_encoding
                    self.reversible_encoder = True

                    f, g = xFormerEncoderBlock.get_reversible_layer(config)
                    recipient.append(torch.nn.ModuleList([f, g]))
                else:
                    recipient.append(block)  # type: ignore

        # Tie embedding weights, if requested and possible
        assert (
            not tie_embedding_weights or not self.reversible_encoder
        ), "Reversible layers and  tied embeddings is not supported for now"

        if (
            tie_embedding_weights
            and encoders
            and encoders[0].pose_encoding
            and decoders
            and decoders[0].pose_encoding
            and not config.reversible
        ):
            logger.info("Tying encoder and decoder embeddings, as requested")
            encoders[0].pose_encoding = decoders[0].pose_encoding

        self.encoders: torch.nn.Module = (
            rv.ReversibleSequence(torch.nn.ModuleList(encoders))
            if self.reversible_encoder
            else torch.nn.ModuleList(encoders)
        )
        self.decoders = torch.nn.ModuleList(decoders)

        use_deepnorm = (
            stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
        )

        assert (
            not use_deepnorm or not self.reversible_encoder
        ), "Reversible layers and deepnorm is not supported for now"

        self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)

    @classmethod
    def from_config(cls, config: xFormerConfig):
        return cls(
            config.stack_configs, config.tie_embedding_weights, config.weight_init
        )

    def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
        reversible = [
            c.reversible
            for c in filter(lambda x: x.block_type == "encoder", stack_configs)
        ]

        assert all(reversible) or not any(reversible), (
            "All layers need to have the same reversibility setting. "
            + f"Currently {reversible}"
        )

    def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
        deepnorm = [
            c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
        ]

        assert all(deepnorm) or not any(deepnorm), (
            "All layers need to have the same deepnorm setting. "
            + f"Currently {deepnorm}"
        )

    def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
        # The deepnorm weight initialization method requires different gain factors for the encoder
        # and decoder, depending on the general model structure (number of respective layers)
        if use_deep_norm:
            encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
                encoder_layers=len(self.encoders), decoder_layers=len(self.decoders)  # type: ignore
            )
        else:
            encoder_coefficients, decoder_coefficients = None, None

        encoder_gain = (
            encoder_coefficients.beta if encoder_coefficients is not None else 1.0
        )
        decoder_gain = (
            decoder_coefficients.beta if decoder_coefficients is not None else 1.0
        )

        # Pick the desired init function
        init_fn = get_weight_init_fn(weight_init)

        # Initialize all the encoder weights
        for name, module in self.encoders.named_children():
            init_fn(module=module, name=name, gain=encoder_gain)

        for name, module in self.decoders.named_children():
            init_fn(module=module, name=name, gain=decoder_gain)

    def forward(
        self,
        src: torch.Tensor,
        tgt: Optional[torch.Tensor] = None,
        encoder_input_mask: Optional[torch.Tensor] = None,
        decoder_input_mask: Optional[torch.Tensor] = None,
    ) -> Optional[torch.Tensor]:

        # Encode to latent space if encoder is present
        if len(list(self.encoders.parameters())) > 0:
            encoders = self.encoders
            memory = src.clone()
            if isinstance(encoders, torch.nn.ModuleList):
                for encoder in encoders:
                    memory = encoder(memory, input_mask=encoder_input_mask)
            else:
                if self.rev_enc_pose_encoding:
                    memory = self.rev_enc_pose_encoding(src)

                # Reversible Encoder
                x = torch.cat([memory, memory], dim=-1)

                # Apply the optional input masking
                if encoder_input_mask is not None:
                    if x.dim() - encoder_input_mask.dim() > 1:
                        encoder_input_mask.unsqueeze(0)
                    x += encoder_input_mask.unsqueeze(-1)

                x = encoders(x)
                memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)

            if not self.decoders:
                return memory

        # If decoder: either use the encoder ouput, or just decode, both options are possible
        if len(self.decoders) > 0:
            tgt = src.clone() if tgt is None else tgt

            for decoder in self.decoders:
                tgt = decoder(
                    target=tgt,
                    # pyre-fixme[61]: `memory` is not always initialized here.
                    memory=memory,
                    input_mask=decoder_input_mask,
                )

            return tgt

        return None
