import torch
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from transformers.image_utils import load_image
from PIL import Image, ImageDraw, ImageFont
from typing import List, Optional, Union
import os
import time 

class LLMDet:
    def __init__(
            self, 
            model_path: str = "/home/liu/Desktop/code/sgg/llmdet", 
            device: Optional[str] = None, 
            threshold: float = 0.3,
            nms_threshold: float = 0.5 
    ):

        if device is None:
            device = "cuda:7" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.threshold = threshold
        self.nms_threshold = nms_threshold
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_path).to(self.device)
        
        try:
            self.font = ImageFont.truetype("arial.ttf", 18)
        except IOError:
            self.font = ImageFont.load_default()

    def load_image(self, image: Union[str, Image.Image]) -> Image.Image:
        if isinstance(image, Image.Image):
            pil_image = image
        elif isinstance(image, str):
            if image.startswith("http://") or image.startswith("https://"):
       
                import requests
                pil_image = Image.open(requests.get(image, stream=True).raw)
            else:
                pil_image = Image.open(image)
        else:
            raise ValueError("Unsupported image input type. Provide file path, url or PIL.Image.")
        

        if pil_image.mode != 'RGB':
            pil_image = pil_image.convert('RGB')
        
        return pil_image
    
    def non_max_suppression(
            self,
            boxes: torch.Tensor,
            scores: torch.Tensor,
            labels: List[str],
            texts: List[str],
            
            iou_threshold: float
    ) -> tuple:
      
        if not isinstance(boxes, torch.Tensor) or boxes.numel() == 0:
            return torch.tensor([]), torch.tensor([]), []

        sorted_indices = torch.argsort(scores, descending=True)
       
        import torchvision
        keep_indices = torchvision.ops.nms(boxes[sorted_indices], scores[sorted_indices], iou_threshold)

        final_indices = sorted_indices[keep_indices]

        boxes_out = boxes[final_indices]
        scores_out = scores[final_indices]
        labels_out = [labels[i] for i in final_indices.cpu().numpy()]
        texts_out = [texts[i] for i in final_indices.cpu().numpy()]
        
        return boxes_out, scores_out, labels_out, texts_out

    def predict(
            self, 
            image: Union[str, Image.Image], 
            labels: List[str],
            apply_nms: bool = True
    ) -> List[dict]:
       

        pil_image = self.load_image(image)
       
        if isinstance(labels, list):
            text_prompt = " . ".join(labels) + " ."
        else:
            text_prompt = labels
        inputs = self.processor(images=pil_image, text=text_prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        results = self.processor.post_process_grounded_object_detection(
            outputs,
            threshold=self.threshold,
            target_sizes=[(pil_image.height, pil_image.width)]
        )
        
        result = results[0]

        if apply_nms and result["boxes"].numel() > 0:
            boxes, scores, labels = self.non_max_suppression(
                result["boxes"],
                result["scores"],
                result["labels"],
                self.nms_threshold
            )
            result = {"boxes": boxes, "scores": scores, "labels": labels}
        
        return [result]

    def predict_in_chunks(
            self,
            image: Union[str, Image.Image],
            labels: List[str],
            threshold: float,
            nms_threshold: float,
            chunk_size: int = 5
    ) -> List[dict]:
       
        self.threshold = threshold
        self.nms_threshold = nms_threshold

        label_chunks = [labels[i:i + chunk_size] for i in range(0, len(labels), chunk_size)]
        
        all_boxes = []
        all_scores = []
        all_labels = []
        all_texts = []

        pil_image = self.load_image(image) 

        for i, chunk in enumerate(label_chunks):
         
 
            chunk_results = self.predict(image=pil_image, labels=chunk, apply_nms=False)
            
            result = chunk_results[0]
            if result["boxes"].numel() > 0:
                all_boxes.append(result["boxes"])
                all_scores.append(result["scores"])
                all_labels.extend(result["labels"]) 
                all_texts.extend(result["text_labels"])
    
        if not all_boxes:
 
            return [{"boxes": torch.tensor([]), "scores": torch.tensor([]), "labels": [], "text_labels": []}]

        combined_boxes = torch.cat(all_boxes, dim=0)
        combined_scores = torch.cat(all_scores, dim=0)
        

        final_boxes, final_scores, final_labels, final_texts = self.non_max_suppression(
            combined_boxes,
            combined_scores,
            all_labels, 
            all_texts,
            self.nms_threshold
        )
        
        final_result = {
            "boxes": final_boxes,
            "scores": final_scores,
            "labels": final_labels,
            "text_labels": final_texts
        }
        
        return [final_result]

    def visualize(
            self, 
            image: Union[str, Image.Image], 
            results: List[dict],
            save_path: Optional[str] = None
    ) -> Image.Image:
      
        image_to_draw = self.load_image(image).copy()
        draw = ImageDraw.Draw(image_to_draw)
        result = results[0]
        
        colors = ["#ff0000", "#00ff00", "#0000ff", "#ffff00", "#ff00ff", "#00ffff"]
        label_colors = {label: colors[i % len(colors)] for i, label in enumerate(sorted(list(set(result["labels"]))))}

        for box, score, label in zip(result["boxes"], result["scores"], result["labels"]):
            box = [int(coord) for coord in box.tolist()]
            xmin, ymin, xmax, ymax = box
            
            color = label_colors.get(label, "#808080") 
            draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
            
            label_text = f"{label}: {score.item():.2f}"
            

            text_bbox = draw.textbbox((0, 0), label_text, font=self.font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]

            text_origin = (xmin, ymin - text_height - 5)
            if text_origin[1] < 0: 
                text_origin = (xmin, ymin + 5)

            draw.rectangle(
                [text_origin[0], text_origin[1], text_origin[0] + text_width + 4, text_origin[1] + text_height + 4],
                fill=color
            )
            draw.text(
                (text_origin[0] + 2, text_origin[1] + 2),
                label_text,
                fill="black", 
                font=self.font
            )
        
        if save_path:
            if os.path.dirname(save_path):
                 os.makedirs(os.path.dirname(save_path), exist_ok=True)
            image_to_draw.save(save_path)
          
        
        return image_to_draw

if __name__ == "__main__":
    import random
    import numpy as np
    np.random.seed(42)  
    random.seed(42)  
    torch.manual_seed(42)  
    labels =['frame', 'button', 'hand', ' monitor', ' window', 'switch', 'surface', 'glass', 'hair', 'foot', 'necklace', 'watch', 'monitor', 'ear', 'seat', 'cap', 'leg', 'logo', 'blind', 'armrest', 'paint', 'drive', 'ceiling', 'curtain', 'light', 'cover', 'man', 'pane', 'poster', 'wheel', 'ink', ' mousepad', ' chair', 'nose', 'backrest', 'shirt', ' wall', 'image', 'eye', 'plug', 'mouth', 'picture', 'furniture', 'clip', 'socket', 'pants', ' pen', ' mouse', 'sill', ' room', 'screen', 'folder', ' desk', 'cushion', 'pattern', ' keyboard', 'drawer', 'reflection', 'software', 'door', 'cable', 'tower', 'drawing', 'speaker', 'floor', 'clock', 'shoes', 'diagram', 'glasses', ' paper', ' outlet', 'text', ' computer', 'shelf', 'arm', 'key', 'head', 'finger']
    labels2 =['nose', 'lamp', 'mouse pad', 'desk', 'keyboard', 'watch', 'mouse', 'man', 'pen', 'pencil', 'arm', 'glasses', 't - shirt', 'paper', 'window', 'computer monitor', 'head', 'cord', 'stapler', 'blinds']
   


    def calculate_iou(box1, box2):
   
        x1_i = max(box1[0], box2[0])
        y1_i = max(box1[1], box2[1])
        x2_i = min(box1[2], box2[2])
        y2_i = min(box1[3], box2[3])

        inter_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)

        box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
        box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
        
        union_area = box1_area + box2_area - inter_area
        
        if union_area == 0:
            return 0
        
        iou = inter_area / union_area
        return iou

    def calculate_union_box(box1, box2):
  
        x1 = min(box1[0], box2[0])
        y1 = min(box1[1], box2[1])
        x2 = max(box1[2], box2[2])
        y2 = max(box1[3], box2[3])
        return [x1, y1, x2, y2]


    det = LLMDet(threshold=0.3, nms_threshold=0.5)

    unique_labels = sorted(list(set(labels2)))
   

    image_path = "/home/liu/Desktop/code/sgg/1396.jpg"

  
    start_time = time.time()
 
    results = det.predict_in_chunks(image=image_path, labels=unique_labels, chunk_size=10)
    end_time = time.time()

    count = 0
    result = results[0]
    
    for box, score, label in zip(result["boxes"], result["scores"], result["labels"]):
        box_coords = [round(c.item(), 2) for c in box]

        count += 1
    
  
    

    output_path = "/home/liu/Desktop/code/sgg/1396_detected_total_chunked2.jpg"
    visualized_image = det.visualize(image=image_path, results=results, save_path=output_path)
    

    visualized_image.show()




  