import torch
from transformers import AutoModelForSequenceClassification
from peft import AutoPeftModelForCausalLM


def load_trained_reward_model(rm_adapter_path, sft_model_path, device="cuda:0", pad_token_id = None):
    sft_model = AutoPeftModelForCausalLM.from_pretrained(sft_model_path, torch_dtype=torch.bfloat16)
    sft_model = sft_model.merge_and_unload()
    model = AutoModelForSequenceClassification.from_pretrained(sft_model.config.name_or_path, num_labels=1, device_map=device, torch_dtype=torch.bfloat16, pad_token_id=pad_token_id).eval()
    model.load_state_dict(sft_model.state_dict(), strict=False)
    model.load_adapter(rm_adapter_path)
    return model   