import inspect

import math

from dataclasses import dataclass, field

from functools import partial

from typing import Optional, Tuple, Union, Dict, Any, List, Any

from einops import repeat, rearrange

import torch

from mamba_ssm.modules.mamba_simple import Mamba

try:

    from mamba_ssm.modules.mamba_simple import Block

except ImportError:

    from mamba_ssm.modules.block import Block

from torch import nn

from torch.nn import functional as F

from transformers import PreTrainedModel, PretrainedConfig

from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput

from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn







from .chunking import (

    Downsampler,

    RoutingModule,

    ste_boundary

)



def reverse_padded(x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:

    """
    Vectorized reverse over the valid (non-pad) prefix per batch element.
    Pads remain on the right.

    x:        [B, L, ...]  (any trailing dims)
    pad_mask: [B, L] True where PAD
    returns:  [B, L, ...]
    """

    B, L = pad_mask.shape

    device = x.device



    valid = ~pad_mask

    lengths = valid.sum(dim=1)

    idx = torch.arange(L, device=device)









    src = (lengths[:, None] - 1 - idx[None, :]).clamp(min=0)





    take_rev = idx[None, :] < lengths[:, None]

    src = torch.where(take_rev, src, idx[None, :])







    expand_shape = (B, L) + (1,) * (x.dim() - 2)

    src_expanded = src.view(expand_shape).expand_as(x)



    out = x.gather(dim=1, index=src_expanded)





    if x.dim() == 2:

        out = out.masked_fill(pad_mask, 0)

    else:

        out = out.masked_fill(pad_mask.unsqueeze(-1), 0)



    return out



@dataclass

class DNAChunkerConfig(PretrainedConfig):

    vocab_size: int = 50257

    d_model: int = 768

    n_enc_layer: int = 4

    n_main_layer: int = 16

    n_dec_layer: int = 4

    transformer_n_head: int = 12

    transformer_mlp_mult: int = 4

    norm_epsilon: float = 1e-5

    rms_norm: bool = True

    residual_in_fp32: bool = True

    fused_add_norm: bool = True

    bidirectional: bool = True

    bidirectional_strategy: str = "add"

    bidirectional_weight_tie: bool = True

    target_ratio_stage1: float = 0.5

    target_ratio_stage2: float = 0.3

    gumbel_temperature: float = 1.0

    entropy_weight: float = 0.1

    ratio_loss_weight: float = 1.0

    dropout: float = 0.0

    encoder_conv_expansion: int = 2

    encoder_conv_kernel: int = 5

    pad_vocab_size_multiple: int = 8

    tokenizer_type: str = "default"

    ssm_cfg: Optional[Dict[str, Any]] = field(default_factory=dict)

    mamba_block_size: int = 256

    mamba_headdim: int = 32

    torch_dtype: Optional[str] = "float32"

    initializer_cfg: Optional[Dict[str, Any]] = field(default_factory=dict)

    pad_token_id: int = 8

    @property

    def n_layer(self):

        return self.n_enc_layer + self.n_main_layer + self.n_dec_layer



    use_return_dict: Optional[bool] = True

    output_attentions: Optional[bool] = False

    output_hidden_states: Optional[bool] = False

    torchscript: bool = False

    is_encoder_decoder: bool = False

    pruned_heads: Dict[str, Dict[int, List[int]]] = field(default_factory=dict)





    model_type: str = "DNAChunker"



    @property

    def _attn_implementation(self):



        return getattr(self, "_attn_implementation_internal", "eager")



    @_attn_implementation.setter

    def _attn_implementation(self, value):

        self._attn_implementation_internal = value



@dataclass

class DNAChunkerTransformerModelOutput(BaseModelOutputWithNoAttention):

    ratio_loss: Optional[torch.FloatTensor] = None

    ratio_loss_s1: Optional[torch.FloatTensor] = None

    ratio_loss_s2: Optional[torch.FloatTensor] = None

    motif_loss: Optional[torch.FloatTensor] = None

    compression_ratio_stage1: Optional[torch.FloatTensor] = None

    compression_ratio_stage2: Optional[torch.FloatTensor] = None

    boundary_prob_mean_stage1: Optional[torch.FloatTensor] = None

    boundary_prob_std_stage1: Optional[torch.FloatTensor] = None

    boundary_prob_mean_stage2: Optional[torch.FloatTensor] = None

    boundary_prob_std_stage2: Optional[torch.FloatTensor] = None

    confidence_mean_stage1: Optional[torch.FloatTensor] = None

    confidence_std_stage1: Optional[torch.FloatTensor] = None

    confidence_mean_stage2: Optional[torch.FloatTensor] = None

    confidence_std_stage2: Optional[torch.FloatTensor] = None

    segment_length_mean_stage1: Optional[torch.FloatTensor] = None

    segment_length_std_stage1: Optional[torch.FloatTensor] = None

    segment_length_mean_stage2: Optional[torch.FloatTensor] = None

    segment_length_std_stage2: Optional[torch.FloatTensor] = None





@dataclass

class DNAChunkerMaskedLMOutput(MaskedLMOutput):

    ratio_loss: Optional[torch.FloatTensor] = None

    ratio_loss_s1: Optional[torch.FloatTensor] = None

    ratio_loss_s2: Optional[torch.FloatTensor] = None

    mlm_loss: Optional[torch.FloatTensor] = None

    motif_loss: Optional[torch.FloatTensor] = None

    compression_ratio_stage1: Optional[torch.FloatTensor] = None

    compression_ratio_stage2: Optional[torch.FloatTensor] = None

    boundary_prob_mean_stage1: Optional[torch.FloatTensor] = None

    boundary_prob_std_stage1: Optional[torch.FloatTensor] = None

    boundary_prob_mean_stage2: Optional[torch.FloatTensor] = None

    boundary_prob_std_stage2: Optional[torch.FloatTensor] = None

    confidence_mean_stage1: Optional[torch.FloatTensor] = None

    confidence_std_stage1: Optional[torch.FloatTensor] = None

    confidence_mean_stage2: Optional[torch.FloatTensor] = None

    confidence_std_stage2: Optional[torch.FloatTensor] = None

    segment_length_mean_stage1: Optional[torch.FloatTensor] = None

    segment_length_std_stage1: Optional[torch.FloatTensor] = None

    segment_length_mean_stage2: Optional[torch.FloatTensor] = None

    segment_length_std_stage2: Optional[torch.FloatTensor] = None











class SwiGLU(nn.Module):

    """ The SwiGLU FFN module. """

    def __init__(self, d_model, d_ffn, bias=False):

        super().__init__()

        hidden_dim = int(2 * d_ffn / 3)

        self.w1 = nn.Linear(d_model, hidden_dim, bias=bias)

        self.w2 = nn.Linear(d_model, hidden_dim, bias=bias)

        self.w3 = nn.Linear(hidden_dim, d_model, bias=bias)



    def forward(self, x):

        gate = F.silu(self.w1(x)); data = self.w2(x)

        return self.w3(gate * data)



def create_attention_mask(padding_mask: torch.Tensor, query_tensor: torch.Tensor) -> Optional[torch.Tensor]:

    if padding_mask is None: return None

    bool_mask = padding_mask.to(torch.bool)

    mask_for_fill = bool_mask.unsqueeze(1).unsqueeze(2) if bool_mask.dim() == 2 else bool_mask

    min_val = torch.finfo(query_tensor.dtype).min

    final_mask = torch.zeros(mask_for_fill.shape, device=query_tensor.device, dtype=query_tensor.dtype)

    final_mask.masked_fill_(mask_for_fill, min_val)

    return final_mask



def rotate_half(x):

    """Rotates half the hidden dimensions."""

    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]

    return torch.cat((-x2, x1), dim=-1)







class RoutingModule(nn.Module):

    def __init__(self, hid_size):

        super().__init__()

        self.w_q = nn.Linear(hid_size, hid_size, bias=False)

        self.w_k = nn.Linear(hid_size, hid_size, bias=False)

    def _get_protection_boundaries(self, mask_locations):

        shifted = torch.cat([torch.zeros_like(mask_locations[:, :1]), mask_locations[:, :-1]], dim=1)

        return mask_locations | shifted

    def forward(self, x, mask_locations, pad_locations, temperature=1.0):

        B = x.shape[0]; q_aligned = self.w_q(x[:, 1:, :]); k_aligned = self.w_k(x[:, :-1, :])

        dot = torch.sum(q_aligned * k_aligned, dim=-1); q_norm = torch.linalg.vector_norm(q_aligned, dim=-1)

        k_norm = torch.linalg.vector_norm(k_aligned, dim=-1); eps = torch.finfo(q_norm.dtype).eps

        norm_prod = (q_norm * k_norm).clamp(min=eps); sim = (dot / norm_prod).clamp(min=-1.0, max=1.0); p_vals = 0.5 * (1 - sim)

        first_p = torch.ones(B, 1, device=x.device, dtype=p_vals.dtype); p = torch.cat([first_p, p_vals], dim=1)

        p = p * (~pad_locations); logits = torch.log(p.clamp(min=eps) / (1 - p).clamp(min=eps))

        b = (p > 0.5).float()

        special_b = self._get_protection_boundaries(mask_locations).float()

        b_final = (b.bool() | special_b.bool()).float(); p_final = torch.max(p, special_b)

        return p_final, b_final



class Downsampler(nn.Module):

    def forward(self, x, p_original, boundaries, mask_locations, pad_locations,

                input_original_indices_l0=None):

        """
        Downsamples by selecting vectors at hard boundary locations (b_t=1) and
        weighting them by their original probability (p_t).

        Critically, it calculates and propagates the *original L0 token indices*
        for use in Rotary Position Embeddings (RoPE).

        Args:
            x (torch.Tensor): Input tensor (B, L, D) from current stage.
            p_original (torch.Tensor): Original boundary probabilities (B, L) from router.
            boundaries (torch.Tensor): Hard boundary indicators (B, L).
            mask_locations (torch.Tensor): Special mask tokens (B, L).
            pad_locations (torch.Tensor): Padding mask (B, L), True if padded.
            input_original_indices_l0 (torch.Tensor, optional):
                (B, L) tensor of *original L0 indices* corresponding to the input `x`.
                If None (i.e., Stage 1), indices are generated from scratch.

        Returns:
            Tuple: (selected_chunks_weighted, chunk_lengths, comp_mask_loc, comp_pad_loc, output_original_indices_l0)
                   - selected_chunks_weighted: (B, max_chunks, D) Weighted selected vectors.
                   - chunk_lengths: (B,) Int tensor of non-padded boundaries per item.
                   - comp_mask_loc: (B, max_chunks) Boolean mask locations in the compressed sequence.
                   - comp_pad_loc: (B, max_chunks) Boolean padding mask for the compressed sequence.
                   - output_original_indices_l0: (B, max_chunks) Float tensor of the *original L0 token index* for each selected boundary.
        """

        B, L, D = x.shape

        device = x.device

        target_dtype = x.dtype





        x_masked = x * (~pad_locations).unsqueeze(-1).to(target_dtype)



        boundaries_masked = boundaries.bool() & (~pad_locations)





        chunk_lengths = torch.sum(boundaries_masked.int(), dim=1)





        max_chunks = 0

        if B > 0:

            if chunk_lengths.numel() > 0 and chunk_lengths.max() > 0:

                max_chunks = int(chunk_lengths.max().item())

            elif boundaries_masked.any():

                max_chunks = 1

            else:

                max_chunks = 1





        selected_chunks_weighted = torch.zeros(B, max_chunks, D, device=device, dtype=target_dtype)

        output_original_indices_l0 = torch.zeros(B, max_chunks, device=device, dtype=torch.float32)

        comp_mask_loc_float = torch.zeros(B, max_chunks, device=device, dtype=torch.float32)





        batch_indices, time_indices = torch.where(boundaries_masked)



        if batch_indices.numel() > 0:



            selected_x = x_masked[batch_indices, time_indices]

            selected_p = p_original[batch_indices, time_indices]

            weighted_x = selected_x * selected_p.unsqueeze(-1)





            cumsum_boundaries_per_item = torch.cumsum(boundaries_masked.int(), dim=1)

            dest_indices_within_item = cumsum_boundaries_per_item[batch_indices, time_indices].long() - 1

            dest_indices_within_item = dest_indices_within_item.clamp(min=0, max=max_chunks - 1)





            lin_idx = batch_indices * max_chunks + dest_indices_within_item

            weighted_x_casted = weighted_x.to(target_dtype)

            selected_chunks_weighted.view(-1, D).scatter_(0, lin_idx.unsqueeze(-1).expand(-1, D), weighted_x_casted)





            if input_original_indices_l0 is None:



                original_positions_tensor = torch.arange(L, device=device, dtype=torch.float32)

                gathered_positions = original_positions_tensor[time_indices]

            else:





                gathered_positions = input_original_indices_l0[batch_indices, time_indices]





            output_original_indices_l0.view(-1).scatter_(0, lin_idx, gathered_positions)









            gathered_mask_locs = mask_locations[batch_indices, time_indices]

            comp_mask_loc_float.view(-1).scatter_add_(0, lin_idx, gathered_mask_locs.float())





        comp_mask_loc = comp_mask_loc_float.bool()







        comp_pad_loc = torch.arange(max_chunks, device=device).unsqueeze(0) >= chunk_lengths.unsqueeze(1)







        selected_chunks_weighted = selected_chunks_weighted * (~comp_pad_loc).unsqueeze(-1).to(target_dtype)

        output_original_indices_l0 = output_original_indices_l0 * (~comp_pad_loc).to(output_original_indices_l0.dtype)

        comp_mask_loc = comp_mask_loc & (~comp_pad_loc)





        return (selected_chunks_weighted, chunk_lengths, comp_mask_loc, comp_pad_loc,

                output_original_indices_l0)



def create_block(

        d_model,

        ssm_cfg=None,

        norm_epsilon=1e-5,

        rms_norm=False,

        residual_in_fp32=False,

        fused_add_norm=False,

        layer_idx=None,

        bidirectional=True,

        bidirectional_strategy="add",

        bidirectional_weight_tie=True,

        rcps=False,

        device=None,

        dtype=None,

):

    """Create DNAChunker block with padding mask support.

    Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
    """

    if ssm_cfg is None:

        ssm_cfg = {}

    factory_kwargs = {"device": device, "dtype": dtype}

    bidirectional_kwargs = {

        "bidirectional": bidirectional,

        "bidirectional_strategy": bidirectional_strategy,

        "bidirectional_weight_tie": bidirectional_weight_tie,

    }

    mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)

    norm_cls = partial(

        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs

    )

    block_cls = Block



    if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:

        block = block_cls(

            d_model,

            mixer_cls,

            mlp_cls=nn.Identity,

            norm_cls=norm_cls,

            fused_add_norm=fused_add_norm,

            residual_in_fp32=residual_in_fp32,

        )

    else:

        block = block_cls(

            d_model,

            mixer_cls,

            norm_cls=norm_cls,

            fused_add_norm=fused_add_norm,

            residual_in_fp32=residual_in_fp32,

        )

    block.layer_idx = layer_idx





    original_forward = block.forward



    def forward_with_padding_mask(hidden_states, residual=None, inference_params=None, padding_mask=None):



        if hasattr(block, 'mixer') and hasattr(block.mixer, 'forward'):



            original_mixer_forward = block.mixer.forward





            def mixer_forward_wrapper(hidden_states, inference_params=None):

                return original_mixer_forward(hidden_states, inference_params, padding_mask)





            block.mixer.forward = mixer_forward_wrapper





            result = original_forward(hidden_states, residual, inference_params)





            block.mixer.forward = original_mixer_forward



            return result

        else:

            return original_forward(hidden_states, residual, inference_params)



    block.forward = forward_with_padding_mask

    return block





class ConvEncoderBlock(nn.Module):

    """Depthwise-separable bottleneck block for encoder-only local context."""



    def __init__(

        self,

        d_model: int,

        expansion: int = 2,

        kernel_size: int = 5,

        norm_epsilon: float = 1e-5,

        rms_norm: bool = False,

        device=None,

        dtype=None,

    ) -> None:

        super().__init__()

        factory_kwargs = {"device": device, "dtype": dtype}

        hidden_dim = d_model * expansion

        self.pw_expand = nn.Conv1d(d_model, hidden_dim, kernel_size=1, **factory_kwargs)

        self.act = nn.SiLU()

        self.dw_conv = nn.Conv1d(

            hidden_dim,

            hidden_dim,

            kernel_size=kernel_size,

            padding=kernel_size // 2,

            groups=hidden_dim,

            **factory_kwargs,

        )

        self.pw_project = nn.Conv1d(hidden_dim, d_model, kernel_size=1, **factory_kwargs)

        if RMSNorm is None:

            raise RuntimeError("RMSNorm is required for ConvEncoderBlock stability but is not available.")

        self.norm = RMSNorm(d_model, eps=norm_epsilon, **factory_kwargs)



    def forward(self, hidden_states, residual=None, inference_params=None, padding_mask=None):

        x = hidden_states

        if padding_mask is not None:

            x = x * (~padding_mask).unsqueeze(-1).to(x.dtype)

        residual_out = x

        x = x.transpose(1, 2)

        x = self.pw_expand(x)

        x = self.act(x)

        x = self.dw_conv(x)

        x = self.act(x)

        x = self.pw_project(x)

        x = x.transpose(1, 2)

        if padding_mask is not None:

            x = x * (~padding_mask).unsqueeze(-1).to(x.dtype)

        x = self.norm(x + residual_out)

        return x, None





def compute_relative_positions(boundaries: torch.Tensor, pad_locations: torch.Tensor) -> torch.Tensor:

    """Compute segment-relative positions from hard boundaries."""

    B, L = boundaries.shape

    device = boundaries.device

    positions = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)

    start_positions = torch.where(boundaries > 0.5, positions, torch.zeros_like(positions))

    last_start = torch.cummax(start_positions, dim=1).values

    rel_pos = positions - last_start

    rel_pos = rel_pos * (~pad_locations).to(rel_pos.dtype)

    return rel_pos





def splat_upsample(

    z_chunks: torch.Tensor,

    boundaries: torch.Tensor,

    pad_locations: torch.Tensor,

) -> torch.Tensor:

    """Repeat chunk representations across their segments using hard boundaries."""

    B, L_in = boundaries.shape

    _, L_chunks, D = z_chunks.shape

    valid = ~pad_locations

    boundary_masked = (boundaries > 0.5).float() * valid.float()

    segment_ids = torch.cumsum(boundary_masked, dim=1)

    segment_ids = segment_ids.clamp(min=1) - 1

    segment_ids = segment_ids.long().clamp(min=0, max=L_chunks - 1)

    gathered = z_chunks.gather(1, segment_ids.unsqueeze(-1).expand(B, L_in, D))

    gathered = gathered * valid.unsqueeze(-1).to(gathered.dtype)

    return gathered





def compute_residual_gate(

    token_to_chunk: torch.Tensor,

    chunk_mask_locations: torch.Tensor,

    pad_locations: torch.Tensor,

) -> torch.Tensor:

    """Gate residuals to zero for tokens whose assigned chunk contains a mask."""

    mask_score = torch.bmm(

        token_to_chunk,

        chunk_mask_locations.float().unsqueeze(-1),

    ).squeeze(-1)

    gate = (mask_score <= 0.5).to(mask_score.dtype)

    gate = gate * (~pad_locations).to(gate.dtype)

    return gate





class BiMambaWrapper(nn.Module):

    """
    Corrected wrapper for BiMamba that properly handles padding for the reverse pass.
    """



    def __init__(

            self,

            d_model: int,

            bidirectional: bool = True,

            bidirectional_strategy: Optional[str] = "add",

            bidirectional_weight_tie: bool = True,

            **mamba_kwargs,

    ):

        super().__init__()

        if bidirectional_strategy not in ["add", "ew_multiply"]:

            raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")



        self.bidirectional = True

        self.bidirectional_strategy = bidirectional_strategy



        self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs)

        self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs)



        if bidirectional_weight_tie:

            self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight

            self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias

            self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight

            self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias



    def forward(self, hidden_states, inference_params=None, padding_mask=None):

        """
        Bidirectional-enabled forward pass with proper padding handling.

        Args:
            hidden_states (torch.Tensor): (B, L, D) input tensor.
            inference_params: Optional inference parameters (not used in bidirectional training).
            padding_mask (torch.Tensor): (B, L) boolean mask where True indicates padding.
        Returns:
            torch.Tensor: Output tensor of the same shape as hidden_states.
        """

        if padding_mask is None:



            out_fwd = self.mamba_fwd(hidden_states, inference_params=inference_params)

            out_rev = self.mamba_rev(

                hidden_states.flip(dims=(1,)), inference_params=inference_params

            ).flip(dims=(1,))

        else:





            hidden_states_masked = hidden_states * (~padding_mask).unsqueeze(-1)

            out_fwd = self.mamba_fwd(hidden_states_masked, inference_params=inference_params)





            B, L, D = hidden_states.shape

            device = hidden_states.device





            seq_lengths = (~padding_mask).sum(dim=1)









            arange_l = torch.arange(L, device=device)

            rev_indices = seq_lengths.unsqueeze(1) - 1 - arange_l

            is_within_seq = arange_l < seq_lengths.unsqueeze(1)



            gather_indices = torch.where(is_within_seq, rev_indices, arange_l)





            hidden_states_rev_packed = torch.gather(

                hidden_states_masked, 1, gather_indices.unsqueeze(-1).expand_as(hidden_states)

            )





            processed_rev = self.mamba_rev(hidden_states_rev_packed, inference_params=inference_params)







            out_rev = torch.zeros_like(hidden_states, dtype=processed_rev.dtype, device=processed_rev.device)

            out_rev.scatter_(1, gather_indices.unsqueeze(-1).expand_as(hidden_states), processed_rev)







        if padding_mask is not None:

            out_fwd = out_fwd * (~padding_mask).unsqueeze(-1)

            out_rev = out_rev * (~padding_mask).unsqueeze(-1)





        if self.bidirectional_strategy == "add":

            out = out_fwd + out_rev

        elif self.bidirectional_strategy == "ew_multiply":

            out = out_fwd * out_rev

        else:



            raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")



        return out





class DNAChunkerEmbeddings(nn.Module):

    def __init__(

            self,

            config: DNAChunkerConfig,

            device=None,

            dtype=None,

    ):

        super().__init__()

        factory_kwargs = {"device": device, "dtype": dtype}

        self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)



    def forward(self, input_ids):

        """
            input_ids: (batch, seqlen)
        """

        return self.word_embeddings(input_ids)





class RotaryEmbedding(nn.Module):

    def __init__(self, dim):

        super().__init__()

        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))

        self.register_buffer('inv_freq', inv_freq)





    def forward(self, q, k, positions=None):

        """
        Args:
            q (torch.Tensor): Query tensor (batch, n_heads, seq_len, head_dim)
            k (torch.Tensor): Key tensor (batch, n_heads, seq_len, head_dim)
            positions (torch.Tensor, optional): Custom positions (batch, seq_len).
                                                 If None, standard sequential positions are used.
        """

        seq_len = q.shape[2]

        device = q.device





        if positions is None:



            t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)



            freqs = torch.einsum('i,j->ij', t, self.inv_freq)

        else:



            t = positions.to(dtype=self.inv_freq.dtype)



            freqs = torch.einsum('bi,j->bij', t, self.inv_freq)



        emb = torch.cat((freqs, freqs), dim=-1)





        if positions is None:



            cos = emb.cos().unsqueeze(0).unsqueeze(1)

            sin = emb.sin().unsqueeze(0).unsqueeze(1)

        else:



            cos = emb.cos().unsqueeze(1)

            sin = emb.sin().unsqueeze(1)



        q_rot = (q * cos) + (rotate_half(q) * sin)

        k_rot = (k * cos) + (rotate_half(k) * sin)

        return q_rot, k_rot



class RotarySelfAttention(nn.Module):

    def __init__(self, config: DNAChunkerConfig, rotary_emb: RotaryEmbedding):

        super().__init__()

        self.d_model = config.d_model

        self.n_head = config.transformer_n_head

        self.d_head = self.d_model // self.n_head

        assert self.d_model % self.n_head == 0, "d_model must be divisible by n_head"



        self.rotary_emb = rotary_emb

        self.qkv_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=False)

        self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)

        self.dropout_p = float(getattr(config, "dropout", 0.0))





        self.q_norm = RMSNorm(hidden_size=self.d_head)

        self.k_norm = RMSNorm(hidden_size=self.d_head)



    def forward(self, x, attention_mask=None, positions=None):

        B, L, D = x.shape

        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)



        q = q.view(B, L, self.n_head, self.d_head).transpose(1, 2)

        k = k.view(B, L, self.n_head, self.d_head).transpose(1, 2)

        v = v.view(B, L, self.n_head, self.d_head).transpose(1, 2)





        q, k = self.rotary_emb(q, k, positions=positions)





        q = self.q_norm(q)

        k = self.k_norm(k)





        if attention_mask is not None:

            if attention_mask.dim() == 2:

                attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            attention_mask = create_attention_mask(attention_mask, q)



        attn_output = F.scaled_dot_product_attention(

            q,

            k,

            v,

            attn_mask=attention_mask.to(q.dtype) if attention_mask is not None else None,

            dropout_p=self.dropout_p,

            is_causal=False,

            scale=None,

        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, D)

        return self.out_proj(attn_output)



class TransformerBlock(nn.Module):

    def __init__(self, config: DNAChunkerConfig, rotary_emb: RotaryEmbedding):

        super().__init__()

        self.norm1 = nn.LayerNorm(config.d_model)

        self.attn = RotarySelfAttention(config, rotary_emb)

        self.norm2 = nn.LayerNorm(config.d_model)







        d_ffn = config.d_model * config.transformer_mlp_mult

        self.ffn = SwiGLU(config.d_model, d_ffn, bias=False)



    def forward(self, x, attention_mask=None, positions=None):



        x = x + self.attn(self.norm1(x), attention_mask=attention_mask, positions=positions)





        x = x + self.ffn(self.norm2(x))

        return x



class BidirectionalGatedSmoothing(nn.Module):

    def __init__(self, d_model: int):

        super().__init__()

        if selective_scan_fn is None:

            raise RuntimeError("selective_scan_fn is required for BidirectionalGatedSmoothing but was not found.")

        self.d_model = d_model

        self.A = nn.Parameter(-torch.ones(d_model, 1))



    def _apply_scan(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:

        B, L, D = x.shape

        gate = p.to(x.dtype).clamp(min=0.0, max=1.0)

        dt = gate.unsqueeze(-1).expand(B, L, D)

        B_mat = torch.ones(D, 1, device=x.device, dtype=torch.float32)

        C_mat = torch.ones(D, 1, device=x.device, dtype=torch.float32)

        out = selective_scan_fn(

            x.transpose(1, 2).contiguous(),

            dt.transpose(1, 2).contiguous(),

            self.A.to(dtype=torch.float32),

            B_mat,

            C_mat,

            None,

            None,

            None,

            delta_softplus=True,

            return_last_state=False,

        ).transpose(1, 2)

        return out



    def forward(self, z_main_upsampled, c, pad_mask=None):

        z = z_main_upsampled

        gate = c.to(z.dtype).clamp(0.0, 1.0)



        if pad_mask is None:

            h_fwd = self._apply_scan(z, gate)

            h_rev = self._apply_scan(z.flip(1), gate.flip(1)).flip(1)

            return 0.5 * (h_fwd + h_rev)



        z = z.masked_fill(pad_mask.unsqueeze(-1), 0.0)

        gate = gate.masked_fill(pad_mask, 1.0)



        h_fwd = self._apply_scan(z, gate)

        z_rev = reverse_padded(z, pad_mask)

        g_rev = reverse_padded(gate.unsqueeze(-1), pad_mask).squeeze(-1)

        h_rev = self._apply_scan(z_rev, g_rev)

        h_rev = reverse_padded(h_rev, pad_mask)



        out = 0.5 * (h_fwd + h_rev)

        return out.masked_fill(pad_mask.unsqueeze(-1), 0.0)







class DNAChunkerMixerModel(nn.Module):

    def __init__(self, config: DNAChunkerConfig, device=None, dtype=None, **kwargs) -> None:

        super().__init__()

        factory_kwargs = {"device": device, "dtype": dtype}

        self.config = config; self.fused_add_norm = config.fused_add_norm

        self.residual_in_fp32 = config.residual_in_fp32



        self.tokenizer_type = config.tokenizer_type

        if config.tokenizer_type == "default": self.embeddings = DNAChunkerEmbeddings(config, **factory_kwargs)



        else: raise ValueError(f"Invalid tokenizer type: {config.tokenizer_type}")







        self.routing_module_stage1 = RoutingModule(config.d_model)

        self.routing_module_stage2 = RoutingModule(config.d_model)



        self.downsampler_s1 = Downsampler()

        self.downsampler_s2 = Downsampler()









        torch_dtype_key = str(config.torch_dtype).lower() if config.torch_dtype is not None else ""

        if torch_dtype_key in {"bf16", "bf16-mixed", "bfloat16"}:

            model_dtype = torch.bfloat16

        elif torch_dtype_key in {"fp16", "float16"}:

            model_dtype = torch.float16

        elif torch_dtype_key in {"32-true", "float32", "fp32"}:

            model_dtype = torch.float32

        else:



            model_dtype = dtype if dtype else torch.float32



        block_size = getattr(config, 'mamba_block_size', 256)

        headdim = getattr(config, 'mamba_headdim', 32)



        self.dechunker1 = BidirectionalGatedSmoothing(config.d_model)

        self.dechunker2 = BidirectionalGatedSmoothing(config.d_model)





        self.target_ratio_stage1 = getattr(config, "target_ratio_stage1", 0.5)

        self.target_ratio_stage2 = getattr(config, "target_ratio_stage2", 0.3)

        self.gumbel_temperature = 1.0



        n_enc1, n_enc2 = config.n_enc_layer // 2, config.n_enc_layer - config.n_enc_layer // 2

        n_dec1, n_dec2 = config.n_dec_layer // 2, config.n_dec_layer - config.n_dec_layer // 2



        conv_expansion = int(getattr(config, "encoder_conv_expansion", 2))

        conv_kernel = int(getattr(config, "encoder_conv_kernel", 5))

        self.encoder1_layers = nn.ModuleList([

            ConvEncoderBlock(

                config.d_model,

                expansion=conv_expansion,

                kernel_size=conv_kernel,

                norm_epsilon=config.norm_epsilon,

                rms_norm=config.rms_norm,

                **factory_kwargs,

            )

            for _ in range(n_enc1)

        ])

        self.encoder2_layers = nn.ModuleList([

            ConvEncoderBlock(

                config.d_model,

                expansion=conv_expansion,

                kernel_size=conv_kernel,

                norm_epsilon=config.norm_epsilon,

                rms_norm=config.rms_norm,

                **factory_kwargs,

            )

            for _ in range(n_enc2)

        ])

        d_head = config.d_model // config.transformer_n_head

        rotary_emb = RotaryEmbedding(dim=d_head)

        self.main_model = nn.ModuleList([TransformerBlock(config, rotary_emb) for _ in range(config.n_main_layer)])

        self.decoder1_layers = nn.ModuleList([create_block(config.d_model, ssm_cfg=config.ssm_cfg, norm_epsilon=config.norm_epsilon, rms_norm=config.rms_norm, residual_in_fp32=config.residual_in_fp32, fused_add_norm=config.fused_add_norm, layer_idx=config.n_enc_layer + config.n_main_layer + i, bidirectional=config.bidirectional, bidirectional_strategy=config.bidirectional_strategy, bidirectional_weight_tie=config.bidirectional_weight_tie, **factory_kwargs) for i in range(n_dec1)])

        self.decoder2_layers = nn.ModuleList([create_block(config.d_model, ssm_cfg=config.ssm_cfg, norm_epsilon=config.norm_epsilon, rms_norm=config.rms_norm, residual_in_fp32=config.residual_in_fp32, fused_add_norm=config.fused_add_norm, layer_idx=config.n_enc_layer + config.n_main_layer + n_dec1 + i, bidirectional=config.bidirectional, bidirectional_strategy=config.bidirectional_strategy, bidirectional_weight_tie=config.bidirectional_weight_tie, **factory_kwargs) for i in range(n_dec2)])

        norm_layer = RMSNorm if RMSNorm is not None and config.rms_norm else nn.LayerNorm

        self.norm_f = norm_layer(config.d_model, eps=config.norm_epsilon, **factory_kwargs)





    def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False, boundaries=None):

        """DNAChunker forward pass using hard downsampling and DNAChunker-style dechunking."""

        all_hidden_states = []

        if inputs_embeds is not None:

            hidden_states = inputs_embeds

        else:

            hidden_states = self.embeddings(input_ids)



        x0_embed = hidden_states







        mask_locations = (input_ids == 9); pad_locations = (input_ids == 8)



        residual_after_enc1 = None

        for layer in self.encoder1_layers:

             hidden_states, residual_after_enc1 = layer(hidden_states, residual_after_enc1, padding_mask=pad_locations)

        x_hat_enc1 = hidden_states





        if torch.isnan(x_hat_enc1).any():

            raise ValueError(f"Encoder1 produced NaN! Check model initialization or gradients.")

        if torch.isinf(x_hat_enc1).any():

            raise ValueError(f"Encoder1 produced Inf! Check for numerical overflow.")



        if output_hidden_states: all_hidden_states.append(x_hat_enc1.clone().detach())





        p_stage1, b_stage1, entropy_s1 = self.routing_module_stage1(

            x_hat_enc1, mask_locations, pad_locations,

            temperature=self.gumbel_temperature, return_entropy=False,

            enforce_mask_boundaries=False

        )





        if torch.isnan(p_stage1).any() or torch.isnan(b_stage1).any():

            raise ValueError(f"Stage 1 routing produced NaN! p_stage1 NaN: {torch.isnan(p_stage1).any()}, "

                           f"b_stage1 NaN: {torch.isnan(b_stage1).any()}")

        if torch.isinf(p_stage1).any() or torch.isinf(b_stage1).any():

            raise ValueError(f"Stage 1 routing produced Inf! p_stage1 Inf: {torch.isinf(p_stage1).any()}, "

                           f"b_stage1 Inf: {torch.isinf(b_stage1).any()}")



        (x_s1, chunk_lengths_s1, mask_locations_s1, pad_locations_s1, original_indices_l0_s1, token_to_chunk_s1, segment_confidence_s1) = self.downsampler_s1(

            x_hat_enc1, p_stage1, b_stage1, mask_locations, pad_locations

        )



        if output_hidden_states: all_hidden_states.append(x_s1.clone().detach())





        hidden_states_s1 = x_s1

        residual_after_enc2 = None

        for layer in self.encoder2_layers:

            hidden_states_s1, residual_after_enc2 = layer(hidden_states_s1, residual_after_enc2, padding_mask=pad_locations_s1)

        x_hat_enc2 = hidden_states_s1

        if output_hidden_states: all_hidden_states.append(x_hat_enc2.clone().detach())





        p_stage2, b_stage2, entropy_s2 = self.routing_module_stage2(

            x_hat_enc2, mask_locations_s1, pad_locations_s1,

            temperature=self.gumbel_temperature, return_entropy=False,

            enforce_mask_boundaries=False

        )





        if torch.isnan(p_stage2).any() or torch.isnan(b_stage2).any():

            raise ValueError(f"Stage 2 routing produced NaN! p_stage2 NaN: {torch.isnan(p_stage2).any()}, "

                           f"b_stage2 NaN: {torch.isnan(b_stage2).any()}")

        if torch.isinf(p_stage2).any() or torch.isinf(b_stage2).any():

            raise ValueError(f"Stage 2 routing produced Inf! p_stage2 Inf: {torch.isinf(p_stage2).any()}, "

                           f"b_stage2 Inf: {torch.isinf(b_stage2).any()}")





        (x_s2, chunk_lengths_s2, mask_locations_s2, pad_locations_s2, original_indices_l0_s2, token_to_chunk_s2, segment_confidence_s2) = self.downsampler_s2(

            x_hat_enc2, p_stage2, b_stage2, mask_locations_s1, pad_locations_s1,

            input_original_indices_l0=original_indices_l0_s1

        )

        if output_hidden_states: all_hidden_states.append(x_s2.clone().detach())





        main_hidden_states = x_s2

        for layer in self.main_model:



            main_hidden_states = layer(main_hidden_states,

                                     attention_mask=pad_locations_s2,

                                     positions=original_indices_l0_s2)

        z_hat_s2 = main_hidden_states

        if output_hidden_states: all_hidden_states.append(z_hat_s2.clone().detach())







        z_s2_upsampled = splat_upsample(z_hat_s2, b_stage2, pad_locations_s1)



        c_stage2 = torch.where(b_stage2.bool(), p_stage2, 1.0 - p_stage2)

        c_stage2 = c_stage2.masked_fill(pad_locations_s1, 1.0)



        coef2 = c_stage2 + (1.0 - c_stage2).detach()

        z_s2_upsampled = z_s2_upsampled * coef2.unsqueeze(-1)

        z_s2_upsampled = z_s2_upsampled.masked_fill(pad_locations_s1.unsqueeze(-1), 0.0)



        z_dechunked_s1 = self.dechunker2(z_s2_upsampled, c_stage2, pad_mask=pad_locations_s1)





        residual_gate_s1 = compute_residual_gate(token_to_chunk_s2, mask_locations_s2, pad_locations_s1).unsqueeze(-1)

        residual_gate_s1 = residual_gate_s1.to(z_dechunked_s1.dtype)

        residual = None

        for layer in self.decoder1_layers:

            if residual is None:

                residual = (z_dechunked_s1 + x_hat_enc2) * residual_gate_s1

            else:

                residual = residual * residual_gate_s1

            z_dechunked_s1, residual = layer(z_dechunked_s1, residual, padding_mask=pad_locations_s1)

        if output_hidden_states: all_hidden_states.append(z_dechunked_s1.clone().detach())



        z_s1_upsampled = splat_upsample(z_dechunked_s1, b_stage1, pad_locations)



        c_stage1 = torch.where(b_stage1.bool(), p_stage1, 1.0 - p_stage1)

        c_stage1 = c_stage1.masked_fill(pad_locations, 1.0)



        coef1 = c_stage1 + (1.0 - c_stage1).detach()

        z_s1_upsampled = z_s1_upsampled * coef1.unsqueeze(-1)

        z_s1_upsampled = z_s1_upsampled.masked_fill(pad_locations.unsqueeze(-1), 0.0)



        z_dechunked = self.dechunker1(z_s1_upsampled, c_stage1, pad_mask=pad_locations)



        if output_hidden_states: all_hidden_states.append(z_dechunked.clone().detach())



        residual_gate_s0 = compute_residual_gate(token_to_chunk_s1, mask_locations_s1, pad_locations).unsqueeze(-1)

        residual_gate_s0 = residual_gate_s0.to(z_dechunked.dtype)

        residual = None

        for layer in self.decoder2_layers:

            if residual is None:

                residual = (z_dechunked + x_hat_enc1) * residual_gate_s0

            else:

                residual = residual * residual_gate_s0

            z_dechunked, residual = layer(z_dechunked, residual, padding_mask=pad_locations)

        hidden_states = z_dechunked

        if output_hidden_states: all_hidden_states.append(hidden_states.clone().detach())







        eps = torch.finfo(hidden_states.dtype).eps





        valid_s1 = (~pad_locations).float()

        L0_unpadded_length = valid_s1.sum(dim=1).clamp(min=eps)

        F1 = (b_stage1 * valid_s1).sum(dim=1) / L0_unpadded_length

        G1 = (p_stage1 * valid_s1).sum(dim=1) / L0_unpadded_length

        alpha1 = torch.as_tensor(self.target_ratio_stage1, device=G1.device, dtype=G1.dtype).clamp(min=eps, max=1 - eps)

        ratio_loss_stage1 = (1.0 / (1.0 - alpha1)) * (

            ((1.0 - alpha1) / alpha1) * F1 * G1 + (1.0 - F1) * (1.0 - G1)

        )

        ratio_loss_stage1 = ratio_loss_stage1.mean()





        valid_s2 = (~pad_locations_s1).float()

        L1_unpadded_length = valid_s2.sum(dim=1).clamp(min=eps)

        F2 = (b_stage2 * valid_s2).sum(dim=1) / L1_unpadded_length

        G2 = (p_stage2 * valid_s2).sum(dim=1) / L1_unpadded_length

        alpha2 = torch.as_tensor(self.target_ratio_stage2, device=G2.device, dtype=G2.dtype).clamp(min=eps, max=1 - eps)

        ratio_loss_stage2 = (1.0 / (1.0 - alpha2)) * (

            ((1.0 - alpha2) / alpha2) * F2 * G2 + (1.0 - F2) * (1.0 - G2)

        )

        ratio_loss_stage2 = ratio_loss_stage2.mean()







        ratio_loss_weight = getattr(self.config, 'ratio_loss_weight', 1.0)





        ratio_loss_s1 = ratio_loss_weight * ratio_loss_stage1

        ratio_loss_s2 = ratio_loss_weight * ratio_loss_stage2

        ratio_loss = ratio_loss_s1 + ratio_loss_s2





        if not self.fused_add_norm:

             final_residual = (hidden_states + residual) if residual is not None else hidden_states

             hidden_states = self.norm_f(final_residual.to(dtype=self.norm_f.weight.dtype))

        else:

             fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn

             hidden_states = fused_add_norm_fn(hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32)



        if output_hidden_states: all_hidden_states[-1] = hidden_states.clone().detach()





        def _masked_mean_std(values: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

            masked = values[mask]

            if masked.numel() == 0:

                zero = values.new_zeros(())

                return zero, zero

            mean = masked.mean()

            var = masked.var(unbiased=False)

            return mean, torch.sqrt(var)



        with torch.no_grad():

            valid_mask_s1 = ~pad_locations

            valid_mask_s2 = ~pad_locations_s1

            p_stage1_f = p_stage1.float()

            p_stage2_f = p_stage2.float()



            boundary_prob_mean_stage1, boundary_prob_std_stage1 = _masked_mean_std(p_stage1_f, valid_mask_s1)

            boundary_prob_mean_stage2, boundary_prob_std_stage2 = _masked_mean_std(p_stage2_f, valid_mask_s2)



            confidence_stage1 = (2 * p_stage1_f - 1).pow(2)

            confidence_stage2 = (2 * p_stage2_f - 1).pow(2)

            confidence_mean_stage1, confidence_std_stage1 = _masked_mean_std(confidence_stage1, valid_mask_s1)

            confidence_mean_stage2, confidence_std_stage2 = _masked_mean_std(confidence_stage2, valid_mask_s2)





        compression_ratio_stage1 = chunk_lengths_s1.float().mean() / L0_unpadded_length.mean()

        compression_ratio_stage2 = chunk_lengths_s2.float().mean() / L1_unpadded_length.mean()



        final_all_hidden_states = tuple(all_hidden_states) if output_hidden_states else None





        padding_masks = {

            'l0': pad_locations,

            'l2': pad_locations_s2

        }



        chunk_stats = {

            "boundary_prob_mean_stage1": boundary_prob_mean_stage1,

            "boundary_prob_std_stage1": boundary_prob_std_stage1,

            "boundary_prob_mean_stage2": boundary_prob_mean_stage2,

            "boundary_prob_std_stage2": boundary_prob_std_stage2,

            "confidence_mean_stage1": confidence_mean_stage1,

            "confidence_std_stage1": confidence_std_stage1,

            "confidence_mean_stage2": confidence_mean_stage2,

            "confidence_std_stage2": confidence_std_stage2,

            "segment_confidence_mean_stage1": segment_confidence_s1.mean(),

            "segment_confidence_mean_stage2": segment_confidence_s2.mean(),

        }



        return (

            hidden_states,

            final_all_hidden_states,

            (ratio_loss, ratio_loss_s1, ratio_loss_s2, compression_ratio_stage1, compression_ratio_stage2),

            padding_masks,

            chunk_stats,

        )



def cross_entropy(logits, y, ignore_index=-100):

    """Cross entropy loss."""

    logits = logits.view(-1, logits.shape[-1])

    y = y.view(-1)

    if ignore_index is None:

        ignore_index = -100

    return F.cross_entropy(logits, y, ignore_index=ignore_index)





def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):

    """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""

    logits = logits.view(-1, logits.shape[-1])

    y = y.view(-1)

    ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")

    loss_weights = loss_weights.view(-1)

    loss_weights[y == ignore_index] = 0.0



    return (ce * (loss_weights / loss_weights.sum())).sum()





class DNAChunkerPreTrainedModel(PreTrainedModel):

    """PreTrainedModel wrapper for DNAChunker backbone."""

    config_class = DNAChunkerConfig

    base_model_prefix = "DNAChunker_transformer"

    supports_gradient_checkpointing = False

    _no_split_modules = ["BiMambaWrapper", "TransformerBlock"]



    def _init_weights(

            self,

            module,

            initializer_range=0.02,

            **kwargs,

    ):

        """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""



        n_layer = self.config.n_layer

        initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}

        rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)

        initializer_range = initialized_cfg.get("initializer_range", initializer_range)

        n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)



        if isinstance(module, nn.Linear):

            if module.bias is not None:

                if not getattr(module.bias, "_no_reinit", False):

                    nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):

            nn.init.normal_(module.weight, std=initializer_range)



        if rescale_prenorm_residual:

            for name, p in module.named_parameters():

                if name in ["out_proj.weight", "fc2.weight", "mlp.2.weight"]:

                    nn.init.kaiming_uniform_(p, a=math.sqrt(5))

                    with torch.no_grad():

                        p /= math.sqrt(n_residuals_per_layer * n_layer)





class DNAChunkerTransformer(DNAChunkerPreTrainedModel):

    """DNAChunker model that can be instantiated using HF patterns."""

    def __init__(self, config: DNAChunkerConfig, device=None, dtype=None, **kwargs):

        super().__init__(config)



        if config.vocab_size % config.pad_vocab_size_multiple != 0:

            config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)



        self.config = config

        factory_kwargs = {"device": device, "dtype": dtype}

        self.backbone = DNAChunkerMixerModel(config, **factory_kwargs, **kwargs)



    def forward(

            self,

            input_ids: torch.LongTensor = None,

            inputs_embeds: Optional[torch.FloatTensor] = None,

            output_hidden_states: Optional[bool] = None,

            return_dict: Optional[bool] = None,

            boundaries: Optional[torch.LongTensor] = None,

    ) -> Union[torch.Tensor, Tuple, DNAChunkerTransformerModelOutput]:

        """HF-compatible forward method."""

        output_hidden_states = (

            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

        )

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict



        hidden_states, all_hidden_states, (ratio_loss, ratio_loss_s1, ratio_loss_s2, compression_ratio_stage1, compression_ratio_stage2), padding_masks, chunk_stats = self.backbone(

            input_ids,

            inputs_embeds=inputs_embeds,

            output_hidden_states=output_hidden_states,

            boundaries=boundaries,

        )

        if return_dict:

            return DNAChunkerTransformerModelOutput(

                last_hidden_state=hidden_states,

                hidden_states=all_hidden_states if output_hidden_states else None,

                ratio_loss=ratio_loss,

                ratio_loss_s1=ratio_loss_s1,

                ratio_loss_s2=ratio_loss_s2,

                compression_ratio_stage1=compression_ratio_stage1,

                compression_ratio_stage2=compression_ratio_stage2,

                boundary_prob_mean_stage1=chunk_stats.get("boundary_prob_mean_stage1"),

                boundary_prob_std_stage1=chunk_stats.get("boundary_prob_std_stage1"),

                boundary_prob_mean_stage2=chunk_stats.get("boundary_prob_mean_stage2"),

                boundary_prob_std_stage2=chunk_stats.get("boundary_prob_std_stage2"),

                confidence_mean_stage1=chunk_stats.get("confidence_mean_stage1"),

                confidence_std_stage1=chunk_stats.get("confidence_std_stage1"),

                confidence_mean_stage2=chunk_stats.get("confidence_mean_stage2"),

                confidence_std_stage2=chunk_stats.get("confidence_std_stage2"),

                segment_length_mean_stage1=chunk_stats.get("segment_length_mean_stage1"),

                segment_length_std_stage1=chunk_stats.get("segment_length_std_stage1"),

                segment_length_mean_stage2=chunk_stats.get("segment_length_mean_stage2"),

                segment_length_std_stage2=chunk_stats.get("segment_length_std_stage2"),

            )



        output = (hidden_states,)

        if output_hidden_states:

            output += (all_hidden_states,)

        output += (ratio_loss, ratio_loss_s1, ratio_loss_s2, compression_ratio_stage1, compression_ratio_stage2, padding_masks)

        return output





class DNAChunkerTransformerForMaskedLM(DNAChunkerPreTrainedModel):

    """HF-compatible DNAChunker model for masked language modeling."""



    def __init__(self, config: DNAChunkerConfig, device=None, dtype=None, **kwargs):

        super().__init__(config, **kwargs)

        factory_kwargs = {"device": device, "dtype": dtype}

        self.net = DNAChunkerTransformer(config, **factory_kwargs, **kwargs)



        self.lm_head = nn.Linear(

            config.d_model,

            self.config.vocab_size,

            bias=False,

            **factory_kwargs

        )

        self.post_init()



    def get_input_embeddings(self):

        return self.net.backbone.embeddings.word_embeddings



    def set_input_embeddings(self, value):

        self.net.backbone.embeddings.word_embeddings = value



    def get_output_embeddings(self):

        return self.lm_head



    def set_output_embeddings(self, new_embeddings):

        self.lm_head = new_embeddings



    def tie_weights(self):

        """
        Override the default tie_weights behavior. Only tie weights if
        we are NOT using the composite STFT embedding.
        """



        if self.net.backbone.tokenizer_type == "stft":

            print("Weight tying is disabled for STFT embeddings.")

            return

        else:



            super().tie_weights()



    def get_decoder(self):

        return self.net



    def set_decoder(self, decoder):

        self.net = decoder



    def forward(

        self,

        input_ids: torch.LongTensor = None,

        inputs_embeds: Optional[torch.FloatTensor] = None,

        labels: Optional[torch.LongTensor] = None,

        loss_weights: Optional[torch.FloatTensor] = None,

        output_hidden_states: Optional[bool] = None,

        return_dict: Optional[bool] = None,

        boundaries: Optional[torch.LongTensor] = None,

    ) -> Union[Tuple, DNAChunkerMaskedLMOutput]:

        """HF-compatible forward method."""



        output_hidden_states = (

            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

        )

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict



        outputs = self.net(

            input_ids=input_ids,

            inputs_embeds=inputs_embeds,

            output_hidden_states=output_hidden_states,

            return_dict=return_dict,

            boundaries=boundaries,

        )



        hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state

        logits = self.lm_head(hidden_states)

        logits = logits.float()



        mlm_loss = None

        if labels is not None:

            if loss_weights is not None:

                mlm_loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=-100)

            else:

                mlm_loss = cross_entropy(logits, labels, ignore_index=-100)



        if not return_dict:

            output = (logits,) + outputs[1:]

            return (mlm_loss,) + output if mlm_loss is not None else output



        return DNAChunkerMaskedLMOutput(

            loss=mlm_loss,

            mlm_loss=mlm_loss,

            logits=logits,

            hidden_states=outputs.hidden_states,

            ratio_loss=outputs.ratio_loss,

            ratio_loss_s1=outputs.ratio_loss_s1,

            ratio_loss_s2=outputs.ratio_loss_s2,

            compression_ratio_stage1=outputs.compression_ratio_stage1,

            compression_ratio_stage2=outputs.compression_ratio_stage2,

            boundary_prob_mean_stage1=outputs.boundary_prob_mean_stage1,

            boundary_prob_std_stage1=outputs.boundary_prob_std_stage1,

            boundary_prob_mean_stage2=outputs.boundary_prob_mean_stage2,

            boundary_prob_std_stage2=outputs.boundary_prob_std_stage2,

            confidence_mean_stage1=outputs.confidence_mean_stage1,

            confidence_std_stage1=outputs.confidence_std_stage1,

            confidence_mean_stage2=outputs.confidence_mean_stage2,

            confidence_std_stage2=outputs.confidence_std_stage2,

            segment_length_mean_stage1=outputs.segment_length_mean_stage1,

            segment_length_std_stage1=outputs.segment_length_std_stage1,

            segment_length_mean_stage2=outputs.segment_length_mean_stage2,

            segment_length_std_stage2=outputs.segment_length_std_stage2,

        )

