import os
import logging
from peft import (
    TaskType, 
    LoraConfig, 
    get_peft_model, 
    PeftModel
)
from transformers.trainer import (
    PREFIX_CHECKPOINT_DIR,
    get_last_checkpoint as __get_last_checkpoint,
)


def get_last_checkpoint_path(path):
    if PREFIX_CHECKPOINT_DIR not in path:
        path = __get_last_checkpoint(path)

    assert path is not None, f"No checkpoint found in '{path}'."

    return path


def get_peft_model_from_checkpoint(
    model,
    peft_id_or_dir,
    is_trainable=True,
    adapter_name="default",
    **config_args,
):
    if os.path.isdir(peft_id_or_dir):
        peft_id_or_dir = get_last_checkpoint_path(peft_id_or_dir)

    if isinstance(model, PeftModel):
        model.load_adapter(
            peft_id_or_dir,
            is_trainable=is_trainable,
            adapter_name=adapter_name,
            **config_args,
        )
    else:
        model = PeftModel.from_pretrained(
            model,
            peft_id_or_dir,
            is_trainable=is_trainable,
            adapter_name=adapter_name,
            **config_args,
        )

    logging.info(
        f"Loaded PEFT adapter '{adapter_name}' checkpoint from '{peft_id_or_dir}'"
    )

    return model


def get_lora_model(
    model,
    peft_id_or_dir=None,
    lora_rank=8,
    lora_alpha=32,
    lora_dropout=0.1,
    is_trainable=False,
    adapter_name="default",
    target_modules=None,
    **config_args,
):
    if peft_id_or_dir is not None:
        return get_peft_model_from_checkpoint(
            model,
            peft_id_or_dir,
            is_trainable=is_trainable,
            adapter_name=adapter_name,
            target_modules=target_modules,
            **config_args,
        )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        bias="none",
        r=lora_rank,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        inference_mode=not is_trainable,
        target_modules=target_modules,
    )

    model = get_peft_model(model, 
                           peft_config, 
                           adapter_name=adapter_name)

    return model