import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_id = "Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

def get_rewards(text_list, batch_size=4):
    scores = []
    num_batches = (len(text_list) + batch_size - 1) // batch_size
    for i in tqdm(range(0, len(text_list), batch_size), total=num_batches, desc="Scoring with reward model"):
        batch = text_list[i:i+batch_size]
        with torch.no_grad():
            inputs = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=1024,
                return_tensors="pt"
            ).to(model.device)
            outputs = model(**inputs)
            score_batch = outputs.logits.squeeze(-1).detach().cpu().numpy()
            scores.extend(score_batch)
    return scores

datasets = [
    {
        "input_path": "./data/original/msmarco_X5000_N100.tsv",
        "output_path": "./output/msmarco_mistral_score_5000.npy",
        "reshape": (-1, 100)
    },
    {
        "input_path": "./data/original/nectar_X10000_N7.tsv",
        "output_path": "./output/nectar_mistral_score_10000.npy",
        "reshape": (-1, 7)
    }
]

for data in datasets:
    df = pd.read_csv(data["input_path"], sep="\t", engine="python")
    texts = df["chat_text"].astype(str).tolist()
    df["mistral_score"] = get_rewards(texts)
    reshaped_scores = df["mistral_score"].to_numpy().reshape(data["reshape"])
    np.save(data["output_path"], reshaped_scores)
    print(f"Completed: {data['output_path']}")
