from typing import Any
from transformers import AutoModelForCausalLM
from peft import PeftModel

from .peft_utils import attach_redflag_lm_row


def load_model_with_redflag_row(
    base_model_name_or_path: str,
    adapter_name_or_path: str,
    token_id: int = 0,
    **model_kwargs: Any,
) -> PeftModel:
    """Load base model, attach the minimal LM-head row-delta, then load PEFT adapter.

    Args:
        base_model_name_or_path: HF model repo or local path
        adapter_name_or_path: HF PEFT adapter repo or local path
        token_id: Placeholder token id. Will be overwritten by adapter state if present.
        **model_kwargs: Passed to AutoModelForCausalLM.from_pretrained

    Returns:
        PeftModel with row-delta hook attached.
    """
    model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, **model_kwargs)
    # Attach with a placeholder id; adapter state will restore the actual token_index buffer
    try:
        attach_redflag_lm_row(model, token_id)
    except Exception:
        # Fallback: proceed; if adapter contains the module it will still load if attach later
        pass
    model = PeftModel.from_pretrained(model, adapter_name_or_path)
    return model 