from peft import LoraConfig
from transformers import PreTrainedTokenizerBase
from trl import DataCollatorForCompletionOnlyLM

default_lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)


def default_data_collator_for_lm(tokenizer: PreTrainedTokenizerBase):
    return DataCollatorForCompletionOnlyLM(
        response_template="<|start_header_id|>assistant<|end_header_id|>",
        tokenizer=tokenizer,
    )
