import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from datasets import load_dataset


model = AutoModelForCausalLM.from_pretrained(
    "/mnt/petrelfs/share_data/llm-safety/models/gpt2-large-summarize-dpo",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("/mnt/petrelfs/share_data/llm-safety/models/gpt2-large-summarize-dpo")

dataset = load_dataset("anonymized_for_nips/openai_summarize_comparisons_relabel", split="test")
text = dataset["prompt"][0] + "TL;DR: "

tokenized_text = tokenizer(text, return_tensors="pt")

output = model.generate(
    tokenized_text["input_ids"].cuda(),
    max_new_tokens=200,
)
breakpoint()
output_text = tokenizer.decode(output[0])
breakpoint()