
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import json
import argparse
import tqdm

def main(generation_file, reward_model, output_dir):
    with open(generation_file, 'r') as f:
        output_data = json.load(f)

    inputs = [data["prompt"] for data in output_data]
    outputs = [data["generated_text"] for data in output_data]
    
    model = AutoModelForSequenceClassification.from_pretrained(
        reward_model,
        torch_dtype=torch.bfloat16,
        device_map='cuda',
        attn_implementation="flash_attention_2",
        num_labels=1,
    )
    tokenizer = AutoTokenizer.from_pretrained(reward_model)
    scores = []
    with torch.no_grad():
        for input, output in tqdm.tqdm(zip(inputs, outputs), total=len(inputs)):
                messages = [{"role": "user", "content": input},
                            {"role": "assistant", "content": output.strip()}]
                input_ids = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt").to("cuda")
                score = model(input_ids).logits[0][0].item()
                    
                scores.append({
                    'instruction' : input,
                    'output': output,
                    'score': score
                })

    with open(output_dir, 'w', encoding='utf8') as f:
        json.dump(scores, f, indent=2, ensure_ascii=False)

    print(f"Annotated outputs saved to {output_dir}")

if __name__ == '__main__':
    import argparse, os
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--generation_file", type=str)  
    parser.add_argument("--reward_model",default='Skywork/Skywork-Reward-Gemma-2-27B-v0.2' ,type=str)  
    parser.add_argument("--output_dir", type=str)  
    args = parser.parse_args()
    main(args.generation_file, args.reward_model, args.output_dir)