import os
import json
import torch
import random
import logging
import datetime
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from peft import PeftModel
from qwen_vl_utils import process_vision_info

log_file = f"lora_QwenVL_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ['CUDA_VISIBLE_DEVICES'] = '0,4'

logging.info("Starting script execution...")
logging.info(f"Using GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")

finetuned_checkpoint = " LLaMA-Factory/saves/qwen2_5_vl-7b-Instruct/lora/sft/checkpoint-570"
model_name = "Qwen/Qwen2.5-VL-7B-Instruct"

base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

model = PeftModel.from_pretrained(base_model, finetuned_checkpoint)
processor = AutoProcessor.from_pretrained(finetuned_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(finetuned_checkpoint)

image_dir = " 2D-3D-S"
output_json = " script/large_omniVQAdataset_reasoning.json"
logging.info(f"Image directory: {image_dir}")
logging.info(f"Output JSON file: {output_json}")

questions = [
    "What object is present in the top polar region of the image?",
    "Which objects can be identified in the bottom polar area?",
    "Can you name the objects visible in the upper polar region?",
    "Can you name the objects visible in the down polar region?",
    "Identify the primary object at the polar and explain its function",
    "What shape does the object in the top polar region exhibit?",
    "What shape does the object in the bottom polar region exhibit?",
    "What visual features can you observe about the object in the upper polar area?",
    "What visual features can you observe about the object in the down polar area?",
    "What is the spatial relationship between the object in the polar region and the object near it?",
    "What is the spatial relationship between objects in the top polar regions?",
    "What is the spatial relationship between objects in the down polar regions?",
    "Can you determine if an object in the upper polar region is partially occluded by another object in the same area?",
    "Can you determine if an object in the lower polar region is partially occluded by another object in the same area?",
    "Are there any objects in the polar region that interact with each other (e.g., overlapping, touching, or connected)?",
    "How many objects are surrounding the polar object? Are they evenly spaced or clustered?"
]

vqa_data = []

if os.path.exists(output_json):
    try:
        with open(output_json, 'r', encoding='utf-8') as f:
            vqa_data = json.load(f)
        logging.info(f"Loaded {len(vqa_data)} existing entries from {output_json}")
    except Exception:
        logging.warning(f"Failed to load existing data")

processed_images = set(item["image"] for item in vqa_data)
logging.info(f"Already processed {len(processed_images)} images")

def process_images_recursively(directory):
    image_count = len(processed_images)
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                image_path = os.path.join(root, file)
                if image_path in processed_images:
                    continue
                logging.info(f"Processing image {image_count + 1}: {image_path}")
                selected_questions = random.sample(questions, 4) if len(questions) >= 4 else questions
                for question in selected_questions:
                    system_prompt = (
                        "Based on the image and the question provided below, please provide a comprehensive and detailed explanation of your reasoning process without only providing a final answer. "
                        "When the question pertains to spatial relationships, describe them explicitly using clear directional terms such as 'up', 'down', 'left', 'right', 'front', 'back', 'covering', or 'adjacent to'. "
                    )
                    messages = [
                        {"role": "system", "content": system_prompt},
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image_path},
                                {"type": "text", "text": question}
                            ]
                        }
                    ]
                    try:
                        text = processor.apply_chat_template(
                            messages, tokenize=False, add_generation_prompt=True
                        )
                        image_inputs, video_inputs = process_vision_info(messages)
                        inputs = processor(
                            text=[text],
                            images=image_inputs,
                            videos=video_inputs,
                            padding=True,
                            return_tensors="pt",
                        )
                        inputs = inputs.to("cuda")
                        with torch.no_grad():
                            generated_ids = model.generate(**inputs, max_new_tokens=256)
                        generated_ids_trimmed = [
                            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                        ]
                        answer = processor.batch_decode(
                            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                        )[0]
                        vqa_data.append({
                            "question": question,
                            "image": image_path,
                            "reasoning": answer
                        })
                        logging.info(f"Generated answer for image {image_path}, question: '{question}' => {answer}")
                        del inputs, generated_ids, generated_ids_trimmed
                        torch.cuda.empty_cache()
                    except Exception as e:
                        logging.error(f"Error processing image {image_path} for question '{question}': {e}")
                        continue
                processed_images.add(image_path)
                image_count += 1
                if image_count % 10 == 0:
                    with open(output_json, "w", encoding="utf-8") as json_file:
                        json.dump(vqa_data, json_file, indent=4)
                    logging.info(f"Intermediate results saved to {output_json} after processing {image_count} images")

logging.info("Starting to process images recursively...")
process_images_recursively(image_dir)

with open(output_json, "w", encoding="utf-8") as json_file:
    json.dump(vqa_data, json_file, indent=4)

logging.info(f"VQA dataset successfully saved to {output_json} with {len(vqa_data)} entries")    
