
import os
from typing import List, Dict, Any, Optional, Tuple

import torch
from PIL import Image
from qwen_vl_utils import process_vision_info

from utils.hierarchy_tree_builder import TreeNode


class AttributeGenerator:
 
    
    PROMPT_DESCRIBE = """You are given an image. Briefly describe the object marked with a red box as a short phrase (<=15 words). 
The object is {label}. 
Focus on category and salient attributes (color, material, state, part). 
Do not write full sentences, explanations, or include other objects."""

    PROMPT_DESCRIBE_WITH_CONTEXT = """You are given an image. Briefly describe the object marked with a red box as a short phrase (<=15 words). 
The object is {label}. 
Focus on category and salient attributes (color, material, state, part). 
Do not write full sentences, explanations, or include other objects.

Here are descriptions of its child parts/related elements: 
{context}

Use the context when helpful."""
    
    def __init__(self, model, processor, device: str = "cuda:6", temperature: float = 0.1):
     
        self.device = device
        self.model = model
        self.processor = processor
        self.temperature = temperature
    
    def generate_attributes(
        self,
        objects_data: Dict[str, Any],
        cropped_objects: List[Dict[str, Any]],
        tree_roots: Optional[List[TreeNode]] = None
    ) -> Dict[int, str]:

        if not tree_roots:

            return self._describe_all_objects(cropped_objects, objects_data)
        
      
        
     
        for root in tree_roots:
            self._postorder_describe(root)
        
     
        descriptions = {}
        in_tree_indices = set()
        
        def collect_descriptions(node):
            in_tree_indices.add(node.idx)
            if node.description:
                descriptions[node.idx] = node.description
            for child in node.children:
                collect_descriptions(child)
        
        for root in tree_roots:
            collect_descriptions(root)
        
 
        # for obj in cropped_objects:
        #     box = obj['box']
        #     idx = self._find_box_index(box, objects_data['boxes'])
        #     if idx is not None and idx not in in_tree_indices:
        #         desc = self._describe_object(
        #             obj['image_path'],
        #             obj['label'],
        #             context=None
        #         )
        #         descriptions[idx] = desc
        
        return descriptions
    
    def _postorder_describe(self, node: TreeNode):

        child_descriptions = []
        for child in node.children:
            self._postorder_describe(child)
            if child.description:
                child_descriptions.append(child.description)
        

        context = child_descriptions if child_descriptions else None
        node.description = self._describe_object(
            node.image_path,
            node.label,
            context
        )
    
    def _describe_batch(self, nodes: List[Tuple[str, str, Optional[List[str]]]]) -> List[str]:

        if not nodes:
            return []
        
        
        all_messages = []
        all_texts = []
        all_image_inputs = []
        
        for image_path, label, context in nodes:
            if context:
                prompt = self.PROMPT_DESCRIBE_WITH_CONTEXT.format(
                    label=label,
                    context="\n".join(context)
                )
            else:
                prompt = self.PROMPT_DESCRIBE.format(label=label)
            
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image_path},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]
            
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, _ = process_vision_info(messages)
            
            all_texts.append(text)
            all_image_inputs.extend(image_inputs)
        

        inputs = self.processor(
            text=all_texts,
            images=all_image_inputs,
            padding=True,
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=64,
                do_sample=True,
                temperature=self.temperature
            )
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        outputs = self.processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )
        
       
        results = []
        for output in outputs:
            output = output.strip().replace("\n", " ").strip()
            if "." in output:
                output = output.split(".")[0].strip()
            if "," in output and len(output.split()) > 15:
                output = output.split(",")[0].strip()
            results.append(output)
        
        return results
    
    def _describe_object(
        self,
        image_path: str,
        label: str,
        context: Optional[List[str]] = None
    ) -> str:
       
        if context:
            prompt = self.PROMPT_DESCRIBE_WITH_CONTEXT.format(
                label=label,
                context="\n".join(context)
            )
        else:
            prompt = self.PROMPT_DESCRIBE.format(label=label)
        
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_path},
                    {"type": "text", "text": prompt}
                ]
            }
        ]
        
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=64,
                do_sample=True,
                temperature=self.temperature
            )
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        output = self.processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0].strip()
        
     
        output = output.replace("\n", " ").strip()
        if "." in output:
            output = output.split(".")[0].strip()
        if "," in output and len(output.split()) > 15:
            output = output.split(",")[0].strip()
        
        return output
    
    def _describe_all_objects(
        self,
        cropped_objects: List[Dict],
        objects_data: Dict
    ) -> Dict[int, str]:

        descriptions = {}
        
        for obj in cropped_objects:
            idx = self._find_box_index(obj['box'], objects_data['boxes'])
            if idx is not None:
                desc = self._describe_object(
                    obj['image_path'],
                    obj['label'],
                    context=None
                )
                descriptions[idx] = desc
        
        return descriptions
    
    def _find_box_index(
        self,
        target_box: List[float],
        boxes: List[List[float]]
    ) -> Optional[int]:

        try:
            return boxes.index(target_box)
        except ValueError:

            for i, box in enumerate(boxes):
                if all(abs(box[k] - target_box[k]) < 1e-3 for k in range(4)):
                    return i
        return None
