# this script scores a set of model outputs using the ArmoRM-Llama3-8B reward model


import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pandas as pd
import argparse
import tqdm
import ast
import json

# --- Setup ---
device = "cuda"
# path = "RLHFlow/ArmoRM-Llama3-8B-v0.1"



tokenizer = AutoTokenizer.from_pretrained("RLHFlow/ArmoRM-Llama3-8B-v0.1", trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained("RLHFlow/ArmoRM-Llama3-8B-v0.1", trust_remote_code=True).to(device)

# Reward dimensions
attributes = [
    'helpsteer-helpfulness','helpsteer-correctness','helpsteer-coherence',
    'helpsteer-complexity','helpsteer-verbosity','ultrafeedback-overall_score',
    'ultrafeedback-instruction_following', 'ultrafeedback-truthfulness',
    'ultrafeedback-honesty','ultrafeedback-helpfulness','beavertails-is_safe',
    'prometheus-score','argilla-overall_quality','argilla-judge_lm','code-complexity',
    'code-style','code-explanation','code-instruction-following','code-readability'
]


def str2bool(v):
    if isinstance(v, bool): return v
    if v.lower() in ('yes', 'true', 't', '1'): return True
    if v.lower() in ('no', 'false', 'f', '0'): return False
    raise argparse.ArgumentTypeError('Boolean value expected.')


# --- Argparse ---
parser = argparse.ArgumentParser()
parser.add_argument('--output', type=str, default="scored_output.csv")
parser.add_argument('--input', type=str, default="input.csv")
parser.add_argument('--number', type=int, default=1, help="Number of responses to score per prompt")
parser.add_argument('--column', type=str, help="column_name")
parser.add_argument('--single_response', type=str2bool, default=True)
args = parser.parse_args()

df = pd.read_csv(args.input)

# --- Scoring Loop ---
average_rewards = []
reward_breakdowns = []

for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)):
    prompt = row['prompt']
    if args.single_response:
        generations = [row[args.column]]
    else:
        generations = ast.literal_eval(row[args.column])
        generations = generations[:args.number]

    
    total_score = 0.0
    reward_sums = torch.zeros(len(attributes))

    valid_count = 0

    for gen in generations:
        if not isinstance(gen, str) or len(gen.strip()) == 0:
            continue
        try:
            messages = [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": gen.strip()}
            ]
            input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)

            with torch.no_grad():
                output = model(input_ids)

            score = output.score.cpu().float().item()
            rewards = output.rewards.cpu().float().squeeze(0)  # shape: (19,)
            total_score += score
            reward_sums += rewards
            valid_count += 1
            # print(f"[Row {idx}] Generation: {gen.strip()} | Score: {score} | Rewards: {rewards.tolist()}")

        except Exception as e:
            print(f"[Error scoring row {idx}]: {e}")
            continue

    if valid_count > 0:
        avg_score = total_score / valid_count
        avg_rewards = (reward_sums / valid_count).tolist()
        reward_dict = dict(zip(attributes, avg_rewards))
    else:
        avg_score = float('nan')
        reward_dict = {}

    average_rewards.append(avg_score)
    reward_breakdowns.append(json.dumps(reward_dict))  # store as stringified JSON

# --- Save results ---
df["average_reward"] = average_rewards
df["reward_breakdown"] = reward_breakdowns
df.to_csv(args.output, index=False)
print(f"Saved to {args.output}")
