import jax.numpy as jnp
from jax import jit, vmap
from jax.nn import relu


# 4: Inference
def inference(prompt, tokenizer, model, v, w):
    """
    Perform inference on a given prompt using the trained convex model.
    """
    tokens = tokenizer(prompt, return_tensors="jax", padding=True, truncation=True)
    embedding = model(**tokens).last_hidden_state.mean(axis=1)  # Mean-pool embeddings

    # Compute preference score using convex NN weights
    score = jnp.dot(embedding, v) + w  # Linear combination
    return "Preferred" if score > 0 else "Not Preferred"