import os
import json
import torch
from PIL import Image
from peft import PeftModel
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

print("Loading base model...")
base_model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    base_model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

try:
    print("Loading LoRA weights...")
    lora_checkpoint_path = ""
    model = PeftModel.from_pretrained(model, lora_checkpoint_path)
    print("LoRA weights loaded successfully.")
except Exception as e:
    print(f"Failed to load LoRA weights, using base model instead: {e}")

model.eval()
processor = AutoProcessor.from_pretrained(base_model_name)

def get_qwen2vl_answer(image_path, question):
    try:
        if not os.path.exists(image_path):
            return f"Image not found: {image_path}"
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        return f"Image loading error: {str(e)}"
    
    system_prompt = (
        "Based on the image and the question provided below, please provide your response in the following format:\n\n"
        "<Reasoning>A comprehensive and detailed explanation of your reasoning process.</Reasoning>\n\n"
        "<Answer>Your final answer to the question.</Answer>\n\n"
        "When the question pertains to spatial relationships, describe them explicitly using clear "
        "directional terms such as 'up', 'down', 'left', 'right', 'front', 'back', 'covering', or 'adjacent to'."
        "Remember to always include both the <Reasoning> and <Answer> tags in your response."
    )
    
    conversation = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": question}
            ]
        }
    ]
    
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = processor(
        text=[text_prompt],
        images=[image],
        padding=True,
        return_tensors="pt"
    )
    
    inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    
    try:
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=1024,
                do_sample=False
            )
        
        generated_ids = output_ids[0][len(inputs["input_ids"][0]):]
        output_text = processor.decode(generated_ids, skip_special_tokens=True)
        
        return output_text.strip()
    except Exception as e:
        return f"Generation failed: {str(e)}"

def main():
    dataset_file = ""
    
    if not os.path.exists(dataset_file):
        print(f"Dataset file not found: {dataset_file}")
        return
    
    print(f"Reading dataset: {dataset_file}")
    with open(dataset_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    output_records = []
    save_interval = 5
    output_file = ""
    
    total = len(data)
    print(f"Processing {total} samples...")
    
    for idx, item in enumerate(data):
        question = item.get("question", "")
        image_path = item.get("image", "")
        
        print(f"[{idx+1}/{total}] Processing: {os.path.basename(image_path)}")
        print(f"Question: {question}")
        
        pred_answer = get_qwen2vl_answer(image_path, question)
        
        preview = pred_answer[:100] + "..." if len(pred_answer) > 100 else pred_answer
        print(f"Answer: {preview}\n")
        
        output_records.append({
            "image_path": image_path,
            "question": question,
            "model_reasoning_output": pred_answer
        })
        
        if (idx + 1) % save_interval == 0 or idx == total - 1:
            try:
                with open(output_file, "w", encoding="utf-8") as f:
                    json.dump(output_records, f, ensure_ascii=False, indent=2)
                print(f"Progress saved to {output_file} ({idx+1}/{total})")
            except Exception as e:
                print(f"Failed to save intermediate results: {e}")
            
            torch.cuda.empty_cache()
    
    print(f"Processing complete. Results saved to {output_file}")

if __name__ == "__main__":
    main()

