import torch
from peft import PeftModel
import transformers
import os, time

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    DataCollatorWithPadding,
)


BASE_MODEL = "ROOT/saved_llms/Meta-Llama-3.1-8B-Instruct"
LORA_WEIGHTS = "ROOT/APLOT/reward_models/my_outputs/Meta-Llama-3.1-8B-Instruct_DB_Difference-1.0_from-SK-v0.2_Debug_difference-1.0_len4096_fulltrain_2e-06_dataSkywork-Reward-Preference-80K-v0.2/checkpoint-601"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    
model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL,
    num_labels=1, device_map="auto", 
    torch_dtype=torch.float16,
)
    
model = PeftModel.from_pretrained(
    model, 
    LORA_WEIGHTS, 
    torch_dtype=torch.float16,
    device_map="auto",

)

model = model.merge_and_unload()
model.save_pretrained("ROOT/APLOT/reward_models/my_outputs/Meta-Llama-3.1-8B-Instruct_DB_Difference-1.0_from-SK-v0.2_Debug_difference-1.0_len4096_fulltrain_2e-06_dataSkywork-Reward-Preference-80K-v0.2/merged/")