from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import os
import torch


def initialize_model(
    hf_model_name,
    checkpoint_filename,
    lora_r,
    lora_alpha,
    lora_dropout,
    lora_bias,
    lora_task_type,
    lora_target_modules,
):
    model = AutoModelForCausalLM.from_pretrained(
        hf_model_name,
        load_in_8bit=False,
        device_map="auto",
    )
    if checkpoint_filename is not None:
        assert os.path.exists(checkpoint_filename)
        print(f"Loading checkpoint: {checkpoint_filename}")
        model.load_state_dict(torch.load(checkpoint_filename))

    config = LoraConfig(
        r=lora_r,  # attention heads
        lora_alpha=lora_alpha,  # alpha scaling
        lora_dropout=lora_dropout,
        bias=lora_bias,
        task_type=lora_task_type,  # set this for CLM or Seq2Seq
        target_modules=lora_target_modules,
    )

    lora_model = get_peft_model(model, config)

    return lora_model
