"""
simple script to evaluate the model with the headline date probe
python safety_techniques/eval.py
"""

from transformers import AutoModelForCausalLM, AutoTokenizer
from date_probing import get_steered_model, sample_date, model_args

model = AutoModelForCausalLM.from_pretrained(model_args.model_id).to(model_args.device)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_id)

probe_path = './natural_probes/probe_12.pth'
for multiplier in [-0.5, 0.5]:
    model, hook = get_steered_model(model, probe_path, multiplier, 12)
    print("Multiplier:", multiplier)
    print(sample_date(model))
    print("-" * 12)
    hook.remove()