# train_dpo.py
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from peft import LoraConfig,get_peft_config,get_peft_model
import fire
import inspect

def log_args(func):

    def wrapper(*args, **kwargs):

        sig = inspect.signature(func)
        bound_args = sig.bind(*args, **kwargs)
        bound_args.apply_defaults()  
        
        print("----------------------------------------")
        print("Parameters Value：")
        for name, value in bound_args.arguments.items():
            print(f" {name}: {value}")
        print("----------------------------------------")
        return func(*args, **kwargs)
        
    return wrapper

@log_args
def main(model_path="./alpaca-native", output_path="./alpaca_dpo",device="cuda:0"):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )
    
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, 
                                                 #torch_dtype=torch.bfloat16
                                                 quantization_config=bnb_config,
                                                 )
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

    lora_config = LoraConfig(
        r=8,
        task_type="CAUSAL_LM",
    )
    peft_model = get_peft_model(model,lora_config)
    peft_model.print_trainable_parameters()

    train_dataset = load_dataset("./data/safe_rlhf_dpo", split="train")

    training_args = DPOConfig(output_dir=output_path, 
                              per_device_train_batch_size=4,
                              gradient_checkpointing=True,  
                              num_train_epochs=1,
                              report_to=None)
    print('***'*20)
    print(training_args)
    print('***'*20)
    trainer = DPOTrainer(model=peft_model,ref_model=model ,args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
    trainer.train()

if __name__=="__main__":
    fire.Fire(main)