from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import model_training.models.reward_model

app = FastAPI()


rm_tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1")
rm_model = AutoModelForSequenceClassification.from_pretrained(
    "OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1",
    torch_dtype=torch.float16,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rm_model.to(device)

rm_tokenizer.chat_template = """{% set loop_messages = messages %}{% for message in loop_messages %}{% if loop.index0 % 2 == 0 %}{{ '<|prompter|>' + message['content'] + eos_token }}{% else %}{{ '<|assistant|>' + message['content'] + eos_token }}{% endif %}{% endfor %}"""


@app.post("/get_reward")
async def get_reward(request: Request):
    data = await request.json()
    chat = data.get("chat", [])

    test_texts = [rm_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)]

    inputs = rm_tokenizer(
        test_texts,
        return_tensors="pt",
        padding=True,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = rm_model(**inputs)

    rewards = outputs.logits.squeeze().tolist()

    return {"rewards": rewards}
