from peft import LoraConfig, get_peft_model
def apply_lora_gpt(model, r=4, device=None):
    """
    给GPT-2模型应用LoRA低秩适配器。
    :param model: HuggingFace GPT-2模型对象
    :param r: LoRA秩
    :param device: 可选，torch设备。如不指定，保持模型原设备。
    :return: 应用LoRA后的模型
    """
    peft_config = LoraConfig(
        r=r,
        lora_alpha=32,
        target_modules=["c_attn"],  # GPT-2的attention输入输出合并在c_attn
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        fan_in_fan_out=True  # 适用于GPT-2的fan-in fan-out
    )
    lora_model = get_peft_model(model, peft_config)
    print(lora_model.print_trainable_parameters())
    if device is not None:
        lora_model = lora_model.to(device)
    return lora_model

# -------- LoRA ----------
def apply_lora(model, r=4, model_name=None, device=None):
    if 'gpt' in model_name:
        return apply_lora_gpt(model, r=r, device=device)
    cfg = LoraConfig(
        r=r, lora_alpha=32, target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, cfg)
    model.print_trainable_parameters()
    if device is not None:
        model = model.to(device)
    return model