# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Union

import torch
from torch import Tensor

from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig


@dataclass
class MambaLayerSubmodules:
    norm: Union[ModuleSpec, type] = IdentityOp
    mixer: Union[ModuleSpec, type] = IdentityOp
    mamba_bda: Union[ModuleSpec, type] = IdentityOp


class MambaLayer(MegatronModule):
    def __init__(
        self,
        config: TransformerConfig,
        submodules: MambaLayerSubmodules,
        mamba_ssm_ngroups=8,
        layer_number: int = 1,
        residual_in_fp32=False,
    ):
        """
        Top level Mamba Layer
        """
        super().__init__(config)
        self.config = config
        self.layer_number = layer_number
        self.residual_in_fp32 = residual_in_fp32
        self.hidden_dropout = config.hidden_dropout
        self.mixer = build_module(
            submodules.mixer,
            self.config,
            d_model=self.config.hidden_size,
            ngroups=mamba_ssm_ngroups,
            layer_number=layer_number,
        )
        self.norm = build_module(submodules.norm, self.config, self.config.hidden_size)
        self.mamba_bda = build_module(submodules.mamba_bda)
        self.bias_dropout_add_exec_handler = torch.enable_grad

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask: Tensor,  # Not used in MambaLayer
        inference_params=None,
        rotary_pos_emb: Tensor = None,  # Not used in MambaLayer
    ):

        residual = hidden_states
        if self.residual_in_fp32:
            residual = residual.to(torch.float32)

        hidden_states = hidden_states.to(dtype=self.config.params_dtype)
        hidden_states = self.norm(hidden_states)

        mixer_out_with_bias = self.mixer(hidden_states, inference_params=inference_params)

        with self.bias_dropout_add_exec_handler():
            hidden_states = self.mamba_bda(self.training, self.config.bias_dropout_fusion)(
                mixer_out_with_bias, residual, self.hidden_dropout
            )

        return hidden_states

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
