import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import os
import json
from tqdm import tqdm
import traceback
import argparse
import sys

# --- Default Configuration (can be overridden by command line arguments) ---
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
DEFAULT_COCO_IMAGE_DIR = "./scene"
DEFAULT_INPUT_JSON_PATH = "./prompt/task.json"
DEFAULT_OUTPUT_JSON_PATH = "./output.json"
DEFAULT_CUDA_DEVICE_ID = "0"

def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='VLM Spatial Relationship QA Processing')
    
    parser.add_argument('--model_id', type=str, default=DEFAULT_MODEL_ID,
                       help=f'Model ID to use (default: {DEFAULT_MODEL_ID})')
    parser.add_argument('--image_dir', type=str, default=DEFAULT_COCO_IMAGE_DIR,
                       help=f'Directory containing images (default: {DEFAULT_COCO_IMAGE_DIR})')
    parser.add_argument('--input_json', type=str, default=DEFAULT_INPUT_JSON_PATH,
                       help=f'Input JSON file path (default: {DEFAULT_INPUT_JSON_PATH})')
    parser.add_argument('--output_json', type=str, default=DEFAULT_OUTPUT_JSON_PATH,
                       help=f'Output JSON file path (default: {DEFAULT_OUTPUT_JSON_PATH})')
    parser.add_argument('--cuda_devices', type=str, default=DEFAULT_CUDA_DEVICE_ID,
                       help=f'CUDA device IDs (comma-separated) (default: {DEFAULT_CUDA_DEVICE_ID})')
    parser.add_argument('--task', type=str, choices=['orientation', 'spatial_relation'], 
                       default='orientation', help='Task type for prompt selection (default: orientation)')
    
    return parser.parse_args()

# --- Helper Functions ---

def orientation_prompt(question: str) -> str:
    """Prompt for spatial layout tasks - asking where specific objects are located"""
    prompt = f'''[Task]
Analyze the spatial orientation in the image to identify which object the ground object is oriented towards or facing

[Insturtion]
1. Identify the ground object and determine if it has a clear intrinsic front or orientation
   - If the ground object has no clear orientation (non-fronted), answer exactly "No"
2. If the ground object has a specific orientation, determine which object it is positioned towards
3. Focus ONLY on objects that are clearly visible in the scene
4. The ground object's orientation is determined by identifying which object it is directly facing or pointing towards
5. Analyze the spatial arrangement to determine which object the ground object is positioned to face
6. Provide step-by-step reasoning before giving the final answer

[Answer Format]
Provide your reasoning first, then give your final answer in the following format:
'Based on my observation, the answer is: <think>(Replace with your reasoning here)</think><answer>(Replace with your answer here)</answer>'

[Question]
{question}
'''
    return prompt

def spatial_prompt(question: str) -> str:
    prompt = f'''[Task] 
Analyze the spatial relationships between objects in the image to determine the relative positions of objects

[Instruction] 
1. Identify the ground object and determine if it has a clear intrinsic front or orientation
2. If the ground object has a specific orientation, determine which object it is positioned towards
3. Focus ONLY on objects that are clearly visible in the scene
4. Analyze spatial relationship between objects the appropriate frame of reference:
   - When the ground object has a clear orientation, describe spatial relation based on the ground object's orientation
   - When the ground object has no orientation, describe spatial relation based on the camera's perspective
5. Provide structured spatial analysis before giving the final answer.

[Answer Format]
Provide your reasoning first, then give your final answer in the following format:
'Based on my observation, the answer is: <think>(Replace with your reasoning here)</think><answer>(Replace with your answer here)</answer>'

[Question]
{question}
'''
    return prompt

def get_prompt_function(task_type: str):
    """Returns the appropriate prompt function based on task type"""
    prompt_functions = {
        'orientation': orientation_prompt,
        'spatial_relation': spatial_prompt
    }
    
    return prompt_functions.get(task_type, orientation_prompt)

def load_vlm_model(model_id: str):
    """Loads the VLM model and processor."""
    print(f"Loading model: {model_id}")

    try:
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

        model = AutoModelForVision2Seq.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.float16,
            trust_remote_code=True
        )

        model.eval()
        print(f"Model loaded successfully.")
        return processor, model

    except Exception as e:
        print(f"Error loading model: {e}")
        traceback.print_exc()
        raise

def generate_vlm_answer(
    processor: AutoProcessor,
    model: AutoModelForVision2Seq,
    image_path: str,
    question: str,
    device: str,
    task_type: str = 'orientation'
) -> str:
    """Generates an answer from the VLM for a given image and question."""
    try:
        if not os.path.exists(image_path):
            return f"Error: Image not found at {image_path}"

        image = Image.open(image_path).convert("RGB")
        
        # Get the appropriate prompt function based on task type
        prompt_function = get_prompt_function(task_type)
        structured_prompt_text = prompt_function(question)

        try:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": structured_prompt_text}
                    ]
                }
            ]
            text_for_tokenizer = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        except Exception:
            text_for_tokenizer = structured_prompt_text

        inputs = processor(
            text=text_for_tokenizer,
            images=image,
            return_tensors="pt"
        )

        inputs = {k: v.to(device) if hasattr(v, 'to') and hasattr(v, 'device') else v for k, v in inputs.items()}

        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.1,
                do_sample=0.1 > 0,
                pad_token_id=processor.tokenizer.eos_token_id,
                eos_token_id=processor.tokenizer.eos_token_id
            )

        response = processor.decode(generated_ids[0], skip_special_tokens=True).strip()

        answer_marker = "[Answer]"
        if answer_marker in response:
            response = response.split(answer_marker, 1)[-1].strip()
        elif structured_prompt_text in response:
             response = response.replace(structured_prompt_text, "").strip()

        return response if response else "No response generated"

    except Exception as e:
        print(f"Error during inference for {image_path} q: '{question[:50]}...': {str(e)}")
        traceback.print_exc()
        return f"Error during inference: {str(e)}"

def validate_json_structure(data):
    """Validates the input JSON structure."""
    if not isinstance(data, dict):
        raise ValueError("Input JSON should be a dictionary")

    for image_filename, items in data.items():
        if not isinstance(items, list):
            raise ValueError(f"Items for {image_filename} should be a list")

        for item in items:
            if not isinstance(item, dict):
                raise ValueError(f"Each item should be a dictionary")

            if "spatial_relationship" not in item:
                raise ValueError(f"Missing 'spatial_relationship' in item for {image_filename}")

            if not isinstance(item["spatial_relationship"], list):
                raise ValueError(f"'spatial_relationship' should be a list")

            for qa in item["spatial_relationship"]:
                if "question" not in qa:
                    raise ValueError(f"Missing 'question' in spatial_relationship")

def process_spatial_data(
    input_json_path: str,
    output_json_path: str,
    coco_image_dir: str,
    processor: AutoProcessor,
    model: AutoModelForVision2Seq,
    device: str,
    task_type: str = 'object_recognition'
):
    """
    Processes the input JSON, gets VLM answers, and saves to output JSON.
    """
    try:
        with open(input_json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: Input JSON file not found at {input_json_path}")
        return
    except json.JSONDecodeError as e:
        print(f"Error: Could not decode JSON from {input_json_path}: {e}")
        return

    try:
        validate_json_structure(data)
    except ValueError as e:
        print(f"JSON structure validation failed: {e}")
        return

    output_data = {}
    total_questions = 0
    for items_in_image in data.values():
        for item_data in items_in_image:
            total_questions += len(item_data.get("spatial_relationship", []))

    processed_questions = 0
    current_item_idx = 0

    with tqdm(total=total_questions, desc="Processing Questions",
              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]') as pbar:

        for image_filename, items_in_image in data.items():
            image_path = os.path.join(coco_image_dir, image_filename)
            output_items_for_image = []

            for item_data_original in items_in_image:
                current_item_idx +=1
                import copy
                item_data = copy.deepcopy(item_data_original)

                if not os.path.exists(image_path):
                    print(f"Warning: Image not found at {image_path}, skipping item.")
                    if "spatial_relationship" in item_data:
                        for qa_pair_idx, qa_pair in enumerate(item_data["spatial_relationship"]):
                            item_data["spatial_relationship"][qa_pair_idx]["model"] = "Error: Image not found."
                            processed_questions += 1
                            pbar.update(1)
                            pbar.set_postfix_str(f"Item {current_item_idx}, Image Missing")
                else:
                    if "spatial_relationship" in item_data:
                        for qa_pair_idx, qa_pair in enumerate(item_data["spatial_relationship"]):
                            question = qa_pair.get("question")
                            if question:
                                model_answer = generate_vlm_answer(
                                    processor, model, image_path, question, device, task_type
                                )
                                item_data["spatial_relationship"][qa_pair_idx]["model"] = model_answer
                            else:
                                item_data["spatial_relationship"][qa_pair_idx]["model"] = "Error: Question not found in JSON entry."

                            processed_questions += 1
                            pbar.update(1)
                            pbar.set_postfix_str(f"Item {current_item_idx}")

                output_items_for_image.append(item_data)

            output_data[image_filename] = output_items_for_image

    try:
        with open(output_json_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        print(f"\nComplete! Processed {processed_questions} questions -> {output_json_path}")
    except Exception as e:
        print(f"Error saving results: {e}")

# --- Main Execution ---
if __name__ == "__main__":
    args = parse_arguments()
    
    # Set configuration from arguments
    MODEL_ID = args.model_id
    COCO_IMAGE_DIR = args.image_dir
    INPUT_JSON_PATH = args.input_json
    OUTPUT_JSON_PATH = args.output_json
    CUDA_DEVICE_ID = args.cuda_devices
    TASK_TYPE = args.task
    
    # Set CUDA device
    os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_DEVICE_ID
    DEVICE = f"cuda:0" if torch.cuda.is_available() else "cpu"
    
    print("VLM Spatial Relationship QA Processing")
    print(f"Model: {MODEL_ID}")
    print(f"Image Directory: {COCO_IMAGE_DIR}")
    print(f"Input JSON: {INPUT_JSON_PATH}")
    print(f"Output JSON: {OUTPUT_JSON_PATH}")
    print(f"Task Type: {TASK_TYPE}")
    print(f"CUDA Devices: {CUDA_DEVICE_ID}")
    print(f"Primary device for inputs: {DEVICE}")

    if not os.path.exists(COCO_IMAGE_DIR):
        print(f"Error: COCO image directory not found: {COCO_IMAGE_DIR}")
        sys.exit(1)

    if not os.path.exists(INPUT_JSON_PATH):
        print(f"Error: Input JSON file not found: {INPUT_JSON_PATH}")
        sys.exit(1)

    try:
        print("Loading VLM model...")
        vlm_processor, vlm_model = load_vlm_model(MODEL_ID)

        process_spatial_data(
            INPUT_JSON_PATH,
            OUTPUT_JSON_PATH,
            COCO_IMAGE_DIR,
            vlm_processor,
            vlm_model,
            DEVICE,
            TASK_TYPE
        )

    except Exception as e:
        print(f"Fatal error: {e}")
        traceback.print_exc()
        sys.exit(1)