from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional, Tuple

import torch

from fla.layers import DeltaNet as FLADeltaNet

if TYPE_CHECKING:
    from transformers.processing_utils import Unpack

    from fla.models.utils import Cache


class DeltaNet(FLADeltaNet):
    def __init__(
        self,
        dim: int = 1024,
        layernorm_eps: float = 1e-5,
        **kwargs,
    ) -> None:
        kwargs.pop("max_length")
        super().__init__(hidden_size=dim, norm_eps=layernorm_eps, **kwargs)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        **kwargs: Unpack[Dict],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
        o, _, _ = super().forward(
            hidden_states,
            attention_mask,
            past_key_values,
            use_cache,
            output_attentions,
            **kwargs,
        )

        return o
