from __future__ import annotations
from typing import Protocol


class BackendHooks(Protocol):
    def build_ffn_gate_spec(self, model): ...
    def build_attn_gate_spec(self, model): ...
    def patch_with_maskable_ffn(self, model): ...
    def patch_with_maskable_attn_out(self, model): ...


def get_backend(name: str) -> BackendHooks:
    name = (name or "sasrec").lower()
    if name == "sasrec":
        from .sasrec import SASRecBackend

        return SASRecBackend()
    if name == "bert4rec":
        from .bert4rec import BERT4RecBackend

        return BERT4RecBackend()
    raise ValueError(f"Unknown backend '{name}' (expected 'sasrec' or 'bert4rec')")
