import os
import json
import torch
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "Qwen/Qwen2-VL-2B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
model.to(device)
model.eval()

processor = AutoProcessor.from_pretrained(model_name)

def get_qwen2vl_answer(image_path, question):
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"Error reading image {image_path}: {e}")
        return ""
    
    system_prompt = (
        "Based on the image and the question provided below, please provide a comprehensive and detailed explanation of your reasoning process without only providing a final answer. "
        "When the question pertains to spatial relationships, describe them explicitly using clear directional terms such as 'up', 'down', 'left', 'right', 'front', 'back', 'covering', or 'adjacent to'. "
    )
    
    conversation = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": question}
            ]
        }
    ]
    
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    inputs = processor(
        text=[text_prompt],
        images=[image],
        padding=True,
        return_tensors="pt"
    )

    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.to(device)
    
    output_ids = model.generate(
        **inputs, 
        max_new_tokens=256
    )
    
    generated_ids = [
        output_ids[i][len(inputs["input_ids"][i]):] for i in range(len(inputs["input_ids"]))
    ]
    
    output_text = processor.batch_decode(
        generated_ids, 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True
    )
    return output_text[0].strip()

def main():
    dataset_file = ""
    with open(dataset_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    output_records = []
    
    for idx, item in enumerate(data):
        question = item.get("question", "")
        image_path = item.get("image", "")
        
        print(f"[{idx+1}/{len(data)}] Processing image: {image_path} | Question: {question}")
        pred_answer = get_qwen2vl_answer(image_path, question)
        print(f"Model output: {pred_answer}\n")
        
        output_records.append({
            "image_path": image_path,
            "question": question,
            "model_reasoning_output": pred_answer
        })
    
    output_file = "results.json"
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(output_records, f, ensure_ascii=False, indent=4)
    
    print(f"Inference finished. Results saved to {output_file}")

if __name__ == "__main__":
    main()