import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("/data/hf/google/MedGemma-27b-text-it")
model = AutoModelForCausalLM.from_pretrained(
    "/data/hf/google/MedGemma-27b-text-it",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

input_text = "LLMs generate text through a process known as"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=5, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))