# NSP-project/model_src/hf_model.py
from __future__ import annotations
from typing import Optional
from transformers import GPT2Config, GPT2LMHeadModel

def build_causal_lm(
    vocab_size: int,
    max_seq_len: int,
    n_layer: int = 4,
    n_head: int = 4,
    n_embd: int = 256,
    dropout: float = 0.0,
    eos_token_id: Optional[int] = None,
    bos_token_id: Optional[int] = None,
    pad_token_id: Optional[int] = None,
):
    cfg = GPT2Config(
        vocab_size=vocab_size,
        n_positions=max_seq_len,
        n_ctx=max_seq_len,
        n_embd=n_embd,
        n_layer=n_layer,
        n_head=n_head,
        resid_pdrop=dropout,
        embd_pdrop=dropout,
        attn_pdrop=dropout,
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        pad_token_id=pad_token_id,
    )
    
    setattr(cfg, "loss_type", "ForCausalLMLoss")
    model = GPT2LMHeadModel(cfg)
    return model, cfg
