from models import load_model, generate_response, get_logit_from_llm


def main():
    prompt = """What's the correct answer for 1+2?
A. 1
B. 2
C. 3
D. 4    
"""
    logits = get_logit_from_llm(
        prompt,
        model_name="Qwen/Qwen2-1.5B",
        ans_options=["A", "B", "C", "D"],
        device="cuda"
    )
    print(logits)

if __name__ == "__main__":
    main()
    # [0.1, 0.02, ...]