# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang


from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional, Tuple

import torch

from fla.layers import GatedLinearAttention as FLAGatedLinearAttention  # type: ignore

if TYPE_CHECKING:
    from transformers.processing_utils import Unpack

    from fla.models.utils import Cache


class GatedLinearAttention(FLAGatedLinearAttention):
    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        **kwargs: Unpack[Dict],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
        o, _, _ = super().forward(
            hidden_states,
            attention_mask,
            past_key_values,
            use_cache,
            output_attentions,
            **kwargs,
        )

        return o
