
import torch
import gc
from typing import List, Dict, Any, Optional
from transformers import AutoProcessor, AutoModelForCausalLM
from modules.llmdet import LLMDet
from llavaonevision1_5.modeling_llavaonevision1_5 import LLaVAOneVision1_5_ForConditionalGeneration
import traceback

class ObjectDetector:
  

    PROMPT_INDEPENDENT = """
### OBJECTIVE: 
Your task is to identify and list every distinct, independent object within the provided image. 

### Core Instructions & Rules 

1. Identify Independent Objects ONLY: 
   An independent object is a whole, complete entity. 
   Examples: person, car, tree, building, bench, dog. 
   
   A dependent object is a part of an independent object. 
   Examples: a person's hat, a car's wheel, a tree's leaf, a building's door. 
   
   CRITICAL RULE: Do NOT list dependent objects.

2. Be Specific and Common: 
   Use the most specific category possible for each object. 
   Express the category label with a single, common word.
   Instead of person, use man, woman, boy, or girl. 

3. Scan the Entire Scene: 
   Carefully examine the foreground, midground, and background. 
   Include large-scale environmental features as independent objects. 

### Output Format 
   Return a single, comma-separated list of the identified objects. 
   DO NOT use bullet points, numbering, or any descriptive text. 
   Output the object in lowercase singular form.

### Example Output: 
   man, woman, sedan, traffic light, road, crosswalk, building, skyscraper, sky, cloud,street.
"""

    PROMPT_COMPLETION = """
### OBJECTIVE: 
Observe the image from three levels: foreground, midground, and background. 
Are there any objects that have not been mentioned in the independent object list? 
If they exist, provide all the unmentioned objects; if not, reply with "None".

### Core Instructions & Rules 
1. Identify Independent Objects ONLY (same as before)

### Output Format 
   Return a single, comma-separated list of the identified objects. 
   Output the name of the object in lowercase singular form.

### Example Output:
   independent object list: man, girl, paper
   output: pen, window, pencil, desk, chair, floor.

   independent object list: car, road, building, street, truck.
   output: None 

#### Your Task Now:
independent object list: {INDEPENDENT_OBJECT_LIST}
output:
"""

    PROMPT_RELATED = """
### Objective:
For each object in the input list, list its "related objects"—accessories, components, 
or decorations that are typically attached to or associated with the object.

If the object typically does not have any separable related objects, output "None".

### Definition of "Related Objects":
- Physical Attachment Type: e.g., hat worn on head, tires on car,
- Decorative Association Type: e.g., patterns on T-shirt, camera on building
- Functional Association Type: e.g., remote control and television

### Notes
1. Strictly follow the format.
2. Objects must potentially exist in the image.
3. Do not list synonyms or word form variations (e.g., singular/plural) of the input object.
4. Cannot list objects already in the input list.
5. Provide category only, no description.
6. Express with single, common word.
7. List no more than five related objects per object.

When analyzing a human subject, prioritize the description of clothing and accessories (e.g., shirt, pants, footwear, eyewear) first. 


### Correct Examples:
Input: man, car, building, tree, T-shirt, computer

Output:
    man: watch, glasses, pants, T-shirt.
    car: tire, windshield wiper
    building: window, camera.
    tree: none.
    T-shirt: design, strap.
    computer: cable, screen.
    screen: logo, model.

### Your Current Task:
Input: {INPUT_INDEPENDENT_OBJECT_LIST}

Output:
"""

    def __init__(
        self,
        llmdet_model_path: str,
        llm_model,
        llm_processor,
        device: str = "cuda:7",
        threshold: float = 0.3,
        nms_threshold: float = 0.5
    ):
       
        self.device = device
        self.llmdet = LLMDet(
            model_path=llmdet_model_path,
            device=device,
            threshold=threshold,
            nms_threshold=nms_threshold
        )

        self.llm_model = llm_model
        self.llm_processor = llm_processor
    
    def detect(
        self,
        image_path: str,
        enable_multi_round: bool = False
    ) -> Dict[str, Any]:
       
        independent_labels = self._detect_independent_objects(image_path, enable_multi_round=enable_multi_round)
        
        result1 = self.llmdet.predict_in_chunks(
            image=image_path,
            labels=independent_labels,
            chunk_size=10,
            threshold=0.3,
            nms_threshold=0.5
        )
        independent_detected = list(set(result1[0]["labels"]))
        
    

        related_labels, related_dict_raw = self._detect_related_objects(
            image_path,
            independent_detected
        )
   
        related_dict = {}
        for key, values in related_dict_raw.items():
            new_key = key.replace('T-shirt', 't - shirt')
            new_values = [v.replace('T-shirt', 't - shirt') for v in values]
            related_dict[new_key] = new_values
        
        related_labels = [label.replace('T-shirt', 't - shirt') for label in related_labels]
       

        result2 = self.llmdet.predict_in_chunks(
            image=image_path,
            labels=related_labels,
            chunk_size=15,
            threshold=0.4,
            nms_threshold=0.5
        )
        related_detected = list(set(result2[0]["labels"]))
        related_dict_filtered = {
            k: [item for item in v if item in related_detected]
            for k, v in related_dict.items()
        }
        

      
        small_labels, small_dict = self._detect_small_objects(
            image_path,
            related_detected,
            independent_detected
        )
        
        result3 = self.llmdet.predict_in_chunks(
            image=image_path,
            labels=small_labels,
            chunk_size=15,
            threshold=0.5,
            nms_threshold=0.4
        )
        small_detected = list(set(result3[0]["labels"]))
        small_dict_filtered = {
            k: [item for item in v if item in small_detected]
            for k, v in small_dict.items()
        }
       
        all_boxes = []
        all_labels = []
        all_scores = []
        
        for result in [result1[0], result2[0], result3[0]]:
            all_boxes.extend([[round(c, 3) for c in box] for box in result['boxes'].cpu().numpy().tolist()])
            all_scores.extend([round(s, 3) for s in result['scores'].cpu().numpy().tolist()])
            all_labels.extend(result['labels'])
        
     
        return {
            'image_path': image_path,
            'boxes': all_boxes,
            'labels': all_labels,
            'scores': all_scores,
            'record_list': [
                independent_detected,
                related_detected,
                small_detected
            ],
            'record_dict': [
                related_dict_filtered,
                small_dict_filtered
            ]
        }
    
    def _detect_independent_objects(self, image_path: str, enable_multi_round: bool = False) -> List[str]:
      
        response1 = self._chat_with_image(
            self.PROMPT_INDEPENDENT,
            image_path
        )
        
        all_objects = response1
        
       
        if enable_multi_round:
            count = 1
            while count <= 3 and len(all_objects.split(',')) < 25:
                response = self._chat_with_image(
                    self.PROMPT_COMPLETION.format(
                        INDEPENDENT_OBJECT_LIST=all_objects
                    ),
                    image_path
                )
                
                if response.lower().strip() == "none":
                    break
                    
                all_objects += ", " + response
                count += 1
        

        objects = [obj.strip() for obj in all_objects.split(',')]
        objects = [obj for obj in objects if obj and obj.lower() != 'none']
        return list(set(objects))
    
    def _detect_related_objects(
        self,
        image_path: str,
        parent_objects: List[str]
    ) -> tuple[List[str], Dict[str, List[str]]]:
  
        all_related = []
        related_dict = {}

        for i in range(0, len(parent_objects), 50):
            chunk = parent_objects[i:i+50]
            response = self._chat_with_image(
                self.PROMPT_RELATED.format(
                    INPUT_INDEPENDENT_OBJECT_LIST=chunk
                ),
                image_path
            )
            

            chunk_dict = self._parse_related_response(response, exclude_list=parent_objects)

            new_objects = [
                obj for obj_list in chunk_dict.values()
                for obj in obj_list
            ]
            
            all_related.extend(new_objects)
            related_dict.update(chunk_dict)
        
        return list(set(all_related)), related_dict
    
    def _detect_small_objects(
        self,
        image_path: str,
        parent_objects: List[str],
        exclude_objects: List[str]
    ) -> tuple[List[str], Dict[str, List[str]]]:

        all_small = []
        small_dict = {}
        

        exclude_list = list(set(parent_objects) | set(exclude_objects))
        
        for i in range(0, len(parent_objects), 50):
            chunk = parent_objects[i:i+50]
            response = self._chat_with_image(
                self.PROMPT_RELATED.format(
                    INPUT_INDEPENDENT_OBJECT_LIST=chunk
                ),
                image_path
            )
            
 
            chunk_dict = self._parse_related_response(response, exclude_list=exclude_list)
      

            new_objects = [
                obj for obj_list in chunk_dict.values()
                for obj in obj_list
            ]

            all_small.extend(new_objects)
            small_dict.update(chunk_dict)
        
        return list(set(all_small)), small_dict
    
  
    def _parse_related_response(self, text: str, exclude_list: Optional[list] = None) -> Dict[str, List[str]]:
       
        if exclude_list is None:
            exclude_list = []
        
        result = {}
        for line in text.strip().split('\n'):
            if ':' not in line:
                continue
            
            key, value_str = line.split(':', 1)
            key = key.strip()
            value_str = value_str.strip().rstrip('.')
            
            values = [v.strip() for v in value_str.split(',') if v.strip()]
           
            values = [v for v in values if v.lower() != 'none' and v not in exclude_list]
            
            if values:
                result[key] = values
        
        return result
    
    def _chat_with_image(
        self,
        prompt: str,
        image_path: Optional[str] = None
    ) -> str:
   

        
        from qwen_vl_utils import process_vision_info
        
        content = []
        if image_path:
            content.append({"type": "image", "image": image_path})
        content.append({"type": "text", "text": prompt})
        
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a meticulous Visual Analyst."}]
            },
            {"role": "user", "content": content}
        ]
        
        text = self.llm_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.llm_processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        ).to(self.device)
        
        generated_ids = self.llm_model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            temperature=0
        )
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        output_text = self.llm_processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )
        
        return output_text[0].strip()


if __name__ == "__main__":


    llmdet_model_path = "/home//hsgg/llmdet"  
    llm_model_path = "/home//hsgg/llavaonevision1.5-checkpoint"     
    test_image_path = "/home//hsgg/image.png"        
    device = "cuda:7" if torch.cuda.is_available() else "cpu"
    

    try:
        llm_model = LLaVAOneVision1_5_ForConditionalGeneration.from_pretrained(
            llm_model_path,
            torch_dtype="auto",
            device_map=device,
            trust_remote_code=False,
            output_attentions=False,
            local_files_only=True
        )
        llm_processor = AutoProcessor.from_pretrained(
            llm_model_path,
            trust_remote_code=True
        )
   
    except Exception as e:
        print(e)
        exit(1)
    
  
    try:
        detector = ObjectDetector(
            llmdet_model_path=llmdet_model_path,
            llm_model=llm_model,
            llm_processor=llm_processor,
            device=device,
            threshold=0.3,
            nms_threshold=0.5
        )

    except Exception as e:
        print(e)
        exit(1)


    try:
        result = detector.detect(test_image_path)
        
       
        if result['record_dict'][0]:
        
            for parent, children in result['record_dict'][0].items():
                print(f"    {parent} -> {', '.join(children)}")
        
    
        if result['record_dict'][1]:
      
            for parent, children in result['record_dict'][1].items():
                print(f"    {parent} -> {', '.join(children)}")
        
        print("\n" + "=" * 60)
 
    except Exception as e:
      
        traceback.print_exc()
