from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch

torch.hub.download_url_to_file(
    "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png",
    "chart_example_1.png",
)

image_path = "/content/chart_example_1.png"
input_text = "program of thought: what is the sum of Faceboob Messnger and Whatsapp values in the 18-29 age group?"

# Load Model
model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma", torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Process Inputs
image = Image.open(image_path).convert("RGB")
inputs = processor(text=input_text, images=image, return_tensors="pt")
prompt_length = inputs["input_ids"].shape[1]
inputs = {k: v.to(device) for k, v in inputs.items()}


# Generate
generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
output_text = processor.batch_decode(
    generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(output_text)
