from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

app = FastAPI()


rm_tokenizer = AutoTokenizer.from_pretrained("NCSOFT/Llama-3-OffsetBias-RM-8B")
rm_model = AutoModelForSequenceClassification.from_pretrained(
    "NCSOFT/Llama-3-OffsetBias-RM-8B", torch_dtype=torch.float16
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rm_model.to(device)


@app.post("/get_reward")
async def get_reward(request: Request):
    data = await request.json()
    chat = data.get("chat", [])

    test_texts = [
        rm_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False).replace(
            rm_tokenizer.bos_token, ""
        )
    ]

    inputs = rm_tokenizer(
        test_texts,
        return_tensors="pt",
        padding=True,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = rm_model(**inputs)

    rewards = outputs.logits.squeeze().tolist()

    return {"rewards": rewards}
