from collections import namedtuple

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from torch import nn


class MambaWithDropout(MambaLMHeadModel):
    def __init__(self, config: MambaConfig, dropout_rate: float = 0.1,
                 initializer_cfg=None, device=None, dtype=None):
        # build the standard LMHead model first
        super().__init__(config, initializer_cfg=initializer_cfg, device=device, dtype=dtype)
        # add one dropout layer that will be applied to all token states
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input_ids, position_ids=None, inference_params=None,
                num_last_tokens=0, **mixer_kwargs):
        # identical to the reference forward, except we insert dropout on hidden_states
        hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
        hidden_states = self.dropout(hidden_states)
        if num_last_tokens > 0:
            hidden_states = hidden_states[:, -num_last_tokens:]
        lm_logits = self.lm_head(hidden_states)
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
        return CausalLMOutput(logits=lm_logits)
