import os
import sys

sys.path.append('..')

import torch
import transformers
from peft import get_peft_model, LoraConfig, TaskType
from src_hf.utils import jpath, read_json
from transformers import [anonymous]LMHeadModel, [anonymous]Config

def main():
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM, 
        inference_mode=False, 
        r=4, 
        lora_alpha=16, 
        lora_dropout=0.1,
        target_modules=['v_proj'], # try 3 settings: v, qv, qkv
    )

    # Initialize [anonymous]
    pt_ckpt = '/data1/[anonymous]/[anonymous]_data/pretrained_models/1b/model'
    config_fp = jpath(pt_ckpt, 'config.json')
    config = read_json(config_fp)
    config = [anonymous]Config.from_pretrained(config_fp)
    model = [anonymous]LMHeadModel.from_pretrained(pt_ckpt, config=config)

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

if __name__ == '__main__':
    main()