
import os
import re
import sys
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch
from PIL import Image

import traceback

sys.path.insert(0, '/home//hsgg/VCD')
from vcd_utils.vcd_sample import evolve_vcd_sampling
from vcd_utils.vcd_add_noise import add_diffusion_noise


evolve_vcd_sampling()


def add_vcd_support_to_model(model_class):

    def prepare_inputs_for_generation_cd(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        pixel_values=None,
        pixel_values_videos=None,
        image_grid_thw=None,
        video_grid_thw=None,
        **kwargs
    ):
       
        pixel_values_cd = kwargs.pop("pixel_values_cd", None)
        attention_mask_cd = kwargs.pop("attention_mask", None)
        image_grid_thw_cd = kwargs.pop("image_grid_thw_cd", None)
        
        if pixel_values_cd is not None:
            pixel_values = pixel_values_cd
        if attention_mask_cd is not None:
            attention_mask = attention_mask_cd
        if image_grid_thw_cd is not None:
            image_grid_thw = image_grid_thw_cd
        
   
        return self.prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            pixel_values=pixel_values,
            pixel_values_videos=pixel_values_videos,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            **kwargs
        )

    model_class.prepare_inputs_for_generation_cd = prepare_inputs_for_generation_cd
    
 
    original_validate = model_class._validate_model_kwargs
    
    def _validate_model_kwargs_with_cd(self, model_kwargs):

        cd_alpha = model_kwargs.pop("cd_alpha", None)
        cd_beta = model_kwargs.pop("cd_beta", None)
        pixel_values_cd = model_kwargs.pop("pixel_values_cd", None)
        input_ids_cd = model_kwargs.pop("input_ids_cd", None)
        attention_mask_cd = model_kwargs.pop("attention_mask_cd", None)
        image_grid_thw_cd = model_kwargs.pop("image_grid_thw_cd", None)
        
    
        original_validate(self, model_kwargs)
        
 
        if cd_alpha is not None:
            model_kwargs["cd_alpha"] = cd_alpha
        if cd_beta is not None:
            model_kwargs["cd_beta"] = cd_beta
        if pixel_values_cd is not None:
            model_kwargs["pixel_values_cd"] = pixel_values_cd
        if input_ids_cd is not None:
            model_kwargs["input_ids_cd"] = input_ids_cd
        if attention_mask_cd is not None:
            model_kwargs["attention_mask_cd"] = attention_mask_cd
        if image_grid_thw_cd is not None:
            model_kwargs["image_grid_thw_cd"] = image_grid_thw_cd
    

    model_class._validate_model_kwargs = _validate_model_kwargs_with_cd



class RelationExtractor:
   
    PROMPT_EXPERT = """Look at the image and identify the relationship from {subject} (marked with {subject_color} box) to {object} (marked with {object_color} box).

IMPORTANT: The relationship has directionality - it describes how the subject relates to the object.
IMPORTANT: If there is no clear relationship, say "no relation"
For example:
- "man" wearing "glasses" (man is the subject performing the action)
- "glasses" worn by "man" (glasses is the subject being acted upon)
- "person" riding "bike" (person is the subject, bike is the object)

Output format: "<predicate>" or "no relation"
Examples:
- wearing 
- riding 
- worn by
- carried by 
- no relation

Now analyze the directional relationship from {subject} to {object} in this image.
What is the most likely relationship? Answer:
"""

    PROMPT_AMATEUR = """You are an expert at predicting what relationships a vision-language model might hallucinate based only on object detection and depth information.

## Task
Given spatial information about two objects (bounding boxes and depth values), predict the relationship that a model is MOST LIKELY to hallucinate (incorrectly guess) without actually seeing the image.

## Few-shot Examples:

Example 1:
- Subject: person, box [100, 50, 200, 300], depth median=0.45
- Object: bicycle, box [180, 150, 350, 350], depth median=0.52
Hallucinated predicate: riding
(Models often hallucinate "riding" when person and bicycle are nearby, even without visual evidence of riding pose)

Example 2:
- Subject: dog, box [50, 200, 150, 350], depth median=0.60
- Object: ball, box [300, 280, 350, 330], depth median=0.55
Hallucinated predicate: playing with
(Models often hallucinate interactions between pets and toys based on common knowledge)

## Current Objects:
{spatial_info}

## Output
Predict the most likely hallucinated predicate (or "no relation" if objects seem unrelated).
Format: "<predicate>" or "no relation"
Do NOT output the full triple, only the predicate.
Example:
- wearing 
- riding 

Hallucinated predicate:
"""


    DEFAULT_CD_ALPHA = 1.0
    DEFAULT_CD_BETA = 0.1
    DEFAULT_NOISE_STEP = 500
    
    def __init__(
        self, 
        model, 
        processor, 
        device: str = "cuda:0",
        use_vcd: bool = True,
        cd_alpha: float = None,
        cd_beta: float = None,
        noise_step: int = None,
        amateur_use_image: bool = False,
        temperature: float = 0.1
    ):
       
        self.device = device
        self.model = model
        self.processor = processor
        

        self.use_vcd = use_vcd
        self.cd_alpha = cd_alpha if cd_alpha is not None else self.DEFAULT_CD_ALPHA
        self.cd_beta = cd_beta if cd_beta is not None else self.DEFAULT_CD_BETA
        self.noise_step = noise_step if noise_step is not None else self.DEFAULT_NOISE_STEP
        self.amateur_use_image = amateur_use_image
        self.temperature = temperature
        

        if self.use_vcd:
            add_vcd_support_to_model(type(model))
           
            if self.processor.tokenizer.pad_token_id is None:
                self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id
            if self.model.config.pad_token_id is None:
                self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
            if self.model.generation_config.pad_token_id is None:
                self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
        
        self._cached_depth_map = None
        self._cached_image_path = None
        self.idx_to_sorted_label = {}
        
  
        self._objects_data = None
        self._cropped_objects = None
    
    def extract_relations(
        self,
        image_path: str,
        objects_data: Dict[str, Any],
        cropped_objects: List[Dict],
        union_images: List[str],
        attributes: Dict[int, str],
        idx_to_sorted_label: Dict[int, str] = None,
        depth_map: np.ndarray = None
    ) -> List[Dict[str, Any]]:
       
        relations = []
        
        boxes = objects_data['boxes']
        labels = objects_data['labels']

        if idx_to_sorted_label is None:
            idx_to_sorted_label = {i: label for i, label in enumerate(labels)}
        
        self.idx_to_sorted_label = idx_to_sorted_label
    
        self._cached_depth_map = depth_map
        self._cached_image_path = image_path
        self._objects_data = objects_data
        self._cropped_objects = cropped_objects
        

        
   
        union_relations = self._extract_from_unions(union_images, depth_map)
        relations.extend(union_relations)
        
    
        
        return relations
    
    def _extract_from_unions(self, union_images: List[str], depth_map: np.ndarray = None) -> List[Dict]:
 
        relations = []
        total = len(union_images)
        
        for idx, union_path in enumerate(union_images):
            try:
       
                file_base = os.path.splitext(os.path.basename(union_path))[0]
                match = re.match(r'.+_(.+)-(\d+)_(.+)-(\d+)', file_base)
                
                if not match:
                    continue
                
                subject_label = match.group(1)
                subject_idx = int(match.group(2))
                object_label = match.group(3)
                object_idx = int(match.group(4))
                
      
                sorted_subject_label = self.idx_to_sorted_label.get(subject_idx, subject_label)
                sorted_object_label = self.idx_to_sorted_label.get(object_idx, object_label)
                
    
                if self._should_skip_pair(sorted_subject_label, sorted_object_label):
                    continue
                
                if not os.path.exists(union_path):
                    continue
              
                subject_box = self._objects_data['boxes'][subject_idx] if self._objects_data else None
                object_box = self._objects_data['boxes'][object_idx] if self._objects_data else None
                
     
                subject_depth_info = None
                object_depth_info = None
                if depth_map is not None and subject_box is not None:
                    subject_depth_info = self._get_depth_info_for_box(depth_map, subject_box)
                if depth_map is not None and object_box is not None:
                    object_depth_info = self._get_depth_info_for_box(depth_map, object_box)
                
              
                extracted_relations = self._extract_relation_with_vcd(
                    union_path,
                    sorted_subject_label,
                    sorted_object_label,
                    subject_idx,
                    object_idx,
                    subject_box=subject_box,
                    object_box=object_box,
                    subject_depth_info=subject_depth_info,
                    object_depth_info=object_depth_info
                )
                
                relations.extend(extracted_relations)
                    
            except Exception as e:
                traceback.print_exc()
                continue

        
        return relations
    
    def _get_depth_info_for_box(self, depth_map: np.ndarray, box: List[float]) -> Dict[str, float]:

        try:
            depth_h, depth_w = depth_map.shape[:2]
            
            
            x1, y1, x2, y2 = [int(c) for c in box]
      
            x1 = max(0, min(x1, depth_w - 1))
            y1 = max(0, min(y1, depth_h - 1))
            x2 = max(0, min(x2, depth_w))
            y2 = max(0, min(y2, depth_h))

            depth_region = depth_map[y1:y2, x1:x2]
            
            if depth_region.size == 0:
                return {"depth_min": 0, "depth_max": 0, "depth_median": 0}
            
            return {
                "depth_min": float(np.min(depth_region)),
                "depth_max": float(np.max(depth_region)),
                "depth_median": float(np.median(depth_region))
            }
        except Exception as e:
            return {"depth_min": 0, "depth_max": 0, "depth_median": 0}
    
    def _build_amateur_prompt(
        self,
        subject_label: str,
        object_label: str,
        subject_box: List[float],
        object_box: List[float],
        subject_depth_info: Dict,
        object_depth_info: Dict
    ) -> str:
       
        spatial_info = ""
        
        if subject_box:
            spatial_info += f"- Subject: {subject_label}, box [{subject_box[0]:.1f}, {subject_box[1]:.1f}, {subject_box[2]:.1f}, {subject_box[3]:.1f}]"
            if subject_depth_info:
                spatial_info += f", depth median={subject_depth_info['depth_median']:.3f}"
            spatial_info += "\n"
        else:
            spatial_info += f"- Subject: {subject_label}, NOT DETECTED\n"
        
        if object_box:
            spatial_info += f"- Object: {object_label}, box [{object_box[0]:.1f}, {object_box[1]:.1f}, {object_box[2]:.1f}, {object_box[3]:.1f}]"
            if object_depth_info:
                spatial_info += f", depth median={object_depth_info['depth_median']:.3f}"
            spatial_info += "\n"
        else:
            spatial_info += f"- Object: {object_label}, NOT DETECTED\n"
        
        return self.PROMPT_AMATEUR.format(spatial_info=spatial_info)
    
    def _extract_relation_with_vcd(
        self,
        image_path: str,
        subject_label: str,
        object_label: str,
        subject_idx: int,
        object_idx: int,
        subject_box: List[float] = None,
        object_box: List[float] = None,
        subject_depth_info: Dict = None,
        object_depth_info: Dict = None
    ) -> List[Dict]:

        relations = []
        
      
        relation_1 = self._extract_single_direction_relation(
            image_path,
            subject_label,
            object_label,
            subject_idx,
            object_idx,
            subject_box,
            object_box,
            subject_depth_info,
            object_depth_info,
            is_reverse=False
        )
        if relation_1:
            relations.append(relation_1)
        

        relation_2 = self._extract_single_direction_relation(
            image_path,
            object_label,  
            subject_label,  
            object_idx,     
            subject_idx,    
            object_box,     
            subject_box,   
            object_depth_info,  
            subject_depth_info,  
            is_reverse=True
        )
        if relation_2:
            relations.append(relation_2)
        
        return relations
    
    def _extract_single_direction_relation(
        self,
        image_path: str,
        subject_label: str,
        object_label: str,
        subject_idx: int,
        object_idx: int,
        subject_box: List[float],
        object_box: List[float],
        subject_depth_info: Dict,
        object_depth_info: Dict,
        is_reverse: bool = False
    ) -> Optional[Dict]:
       
        subject_color = "yellow" if is_reverse else "red"
        object_color = "red" if is_reverse else "yellow"
        

        expert_prompt = self.PROMPT_EXPERT.format(
            subject=subject_label,
            object=object_label,
            subject_color=subject_color,
            object_color=object_color
        )
    

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_path},
                    {"type": "text", "text": expert_prompt}
                ]
            }
        ]
        
    
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        try:
            from qwen_vl_utils import process_vision_info
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt"
            ).to(self.device)
        except ImportError:
            image = Image.open(image_path)
            inputs = self.processor(
                text=text,
                images=image,
                return_tensors="pt",
                padding=True
            ).to(self.device)
        
        input_ids = inputs['input_ids']
        pixel_values = inputs.get('pixel_values')
        if pixel_values is not None:
            pixel_values = pixel_values.to(self.model.dtype)
        image_grid_thw = inputs.get('image_grid_thw')
        
     
        generate_kwargs = {
            "input_ids": input_ids,
            "max_new_tokens": 64,
            "do_sample": True,
            "temperature": self.temperature,
            "top_p": 1.0,
            "use_cache": True
        }
        
        if pixel_values is not None:
            generate_kwargs["pixel_values"] = pixel_values
        if image_grid_thw is not None:
            generate_kwargs["image_grid_thw"] = image_grid_thw
        
      
        if self.use_vcd:
            if self.amateur_use_image:
             
                if pixel_values is not None:
                    pixel_values_cd = add_diffusion_noise(pixel_values, self.noise_step)
                    generate_kwargs["pixel_values_cd"] = pixel_values_cd
                    generate_kwargs["input_ids_cd"] = input_ids
                    if image_grid_thw is not None:
                        generate_kwargs["image_grid_thw_cd"] = image_grid_thw
            else:
             
                amateur_prompt = self._build_amateur_prompt(
                    subject_label,
                    object_label,
                    subject_box,
                    object_box,
                    subject_depth_info,
                    object_depth_info
                )
                
                inputs_amateur = self.processor.tokenizer(
                    amateur_prompt,
                    return_tensors='pt',
                    padding=True
                )
                input_ids_cd = inputs_amateur.input_ids.to(self.device)
                attention_mask_cd = inputs_amateur.attention_mask.to(self.device)
                
                generate_kwargs["input_ids_cd"] = input_ids_cd
                generate_kwargs["attention_mask_cd"] = attention_mask_cd
            
            generate_kwargs["cd_alpha"] = self.cd_alpha
            generate_kwargs["cd_beta"] = self.cd_beta
        
    
        with torch.no_grad():
            generated_ids = self.model.generate(**generate_kwargs)
        
    
        input_length = input_ids.shape[1]
        decoded_text = self.processor.batch_decode(
            generated_ids[:, input_length:],
            skip_special_tokens=True
        )[0].strip().lower()
        

        if "no relation" in decoded_text or decoded_text == "":
            return None
        
      
        predicate = self._parse_relation_triple(decoded_text, subject_label, object_label)
        
        if not predicate:
            return None
        
        return {
            'idx': (subject_idx, object_idx),
            'subject_label': subject_label,
            'object_label': object_label,
            'predicate': predicate,
            'raw_output': decoded_text,
            'is_reverse': is_reverse
        }
    
    def _parse_relation_triple(self, text: str, subject: str, obj: str) -> Optional[str]:
 
        text = text.strip().lower()
        subject_lower = subject.lower()
        obj_lower = obj.lower()
        
       
        pattern = rf"{re.escape(subject_lower)}\s+(.+?)\s+{re.escape(obj_lower)}"
        match = re.search(pattern, text)
        if match:
            predicate = match.group(1).strip()
            return predicate  
        
      
        words = text.split()
        if len(words) >= 1 and len(words) <= 3:
     
            predicate_words = [w for w in words if w not in [subject_lower, obj_lower]]
            if predicate_words:
                predicate = ' '.join(predicate_words)
                return predicate  
        
  
        if text and "no relation" not in text:
            return text  
        
        return None
    
    def _should_skip_pair(self, label1: str, label2: str) -> bool:
 
        if label1 == "sky" and label2 not in ["cloud", "clouds"]:
            return True
        if label2 == "sky" and label1 not in ["cloud", "clouds"]:
            return True
        return False
    
    def set_depth_map(self, depth_map: np.ndarray, image_path: str):

        self._cached_depth_map = depth_map
        self._cached_image_path = image_path


if __name__ == "__main__":
  

    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
    
    from transformers import AutoProcessor


    # 配置参数
    model_path = "/home//hsgg/checkpoint/llavaonevision1.5-checkpoint"
    device = "cuda:0"
    

    
    try:
  
        sys.path.insert(0, '/home//hsgg/VCD/experiments/llavaonevision1_5')
        from modeling_llavaonevision1_5 import LLaVAOneVision1_5_ForConditionalGeneration
        
        model = LLaVAOneVision1_5_ForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map=device,
            trust_remote_code=False,
            local_files_only=True
        )
        model.eval()
 
        processor = AutoProcessor.from_pretrained(
            model_path,
            trust_remote_code=True,
            local_files_only=True
        )

        extractor = RelationExtractor(
            model=model,
            processor=processor,
            device=device,
            use_vcd=True,
            cd_alpha=1.0,
            cd_beta=0.1,
            amateur_use_image=False  
        )

        
    except Exception as e:
        traceback.print_exc()
        sys.exit(1)
