# PYTHONPATH=. srun -p llm-safety --quotatype=${QUOTATYPE:-auto} --gres=gpu:1 --cpus-per-task=10 python scripts/imdb/test.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from inference_time_alignment.scorer import ImplicitRewardScorer
from inference_time_alignment.decoder import BeamTuningPosthocGenerationMixin
from inference_time_alignment.utils import set_seeds


def get_scorer(beta = -1.0):
    # scorer that encourages positive movie review
    model = AutoModelForCausalLM.from_pretrained(
        "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    ref_model = AutoModelForCausalLM.from_pretrained(
        "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained("/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    implicit_reward_scorer = ImplicitRewardScorer(
        model=model,
        ref_model=ref_model,
        tokenizer=tokenizer,
        beta=beta,
    )
    return implicit_reward_scorer

set_seeds(1)

# steer llama-2-7b-chat to give positive(beta>0)/negative(beta<0) movie feedback
base = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

bt_model = BeamTuningPosthocGenerationMixin(base, tokenizer)
scorer   = get_scorer()

raw_prompt = "I think this movie is "
prompt_tokenized = tokenizer(raw_prompt, return_tensors="pt")

result = bt_model.bon_beam_sample(
    input_ids=prompt_tokenized["input_ids"].cuda(),
    attention_mask=prompt_tokenized["attention_mask"].cuda(),
    scorer=scorer.set_raw_prompt(raw_prompt),
    temperature=1.0,
    max_new_tokens=50,
    return_dict_in_generate=True,
)
print(tokenizer.decode(result['output_ids'][0]))
breakpoint()
