import cv2
import numpy as np
from ultralytics import YOLO


class YOLOSegmenter:
    def __init__(self, weights_path=None, conf_thresh=0.75, mask_thresh=0.75, binary_thresh=0.5):
        """
        Initialize YOLO segmentation model
        Args:
            weights_path: Path to model weights. If None, uses default pretrained model
            conf_thresh: Confidence threshold for detection filtering
            mask_thresh: Threshold for mask generation
            binary_thresh: Threshold for converting probability mask to binary mask
        """
        if weights_path:
            self.model = YOLO(weights_path)
        else:
            # Use pretrained YOLOv11 segmentation model
            self.model = YOLO('/home/zyj/zyj/semexp/FOCUS/zzexer/yolo11x-seg.pt')
        
        # Store thresholds as instance variables
        self.conf_thresh = conf_thresh
        self.mask_thresh = mask_thresh
        self.binary_thresh = binary_thresh

    def seg_image(self, image, conf_threshold=None, mask_threshold=None, binary_threshold=None):
        """
        Process a single RGB image
        Args:
            image: RGB numpy array of shape (height, width, 3)
            conf_threshold: Confidence threshold for detection filtering (optional, uses instance default if None)
            mask_threshold: Threshold for mask generation (optional, uses instance default if None)
            binary_threshold: Threshold for converting probability mask to binary mask (optional, uses instance default if None)
        Returns:
            tuple: (visualization_image, mask_dict)
                - visualization_image: RGB numpy array with visualized results
                - mask_dict: Dictionary mapping class names to binary masks (same size as input image)
        """
        # Use instance thresholds if parameters not provided
        if conf_threshold is None:
            conf_threshold = self.conf_thresh
        if mask_threshold is None:
            mask_threshold = self.mask_thresh
        if binary_threshold is None:
            binary_threshold = self.binary_thresh
            
        # Store original image dimensions
        original_h, original_w = image.shape[:2]
            
        # Convert RGB to BGR for YOLO processing
        image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        
        # Run inference
        results = self.model(image_bgr, conf=conf_threshold)
        result = results[0]

        # Process masks
        mask_dict = {}
        if result.masks is not None:
            for i, (box, mask) in enumerate(zip(result.boxes, result.masks.data)):
                cls_id = int(box.cls[0])
                conf = box.conf[0].item()

                if conf >= mask_threshold:
                    class_name = result.names[cls_id]
                    # Convert mask to binary array with configurable threshold
                    binary_mask = (mask.cpu().numpy() > binary_threshold).astype(np.uint8)
                    
                    # Resize mask to match original image dimensions
                    if binary_mask.shape != (original_h, original_w):
                        binary_mask = cv2.resize(binary_mask, (original_w, original_h), 
                                               interpolation=cv2.INTER_NEAREST)

                    # Merge masks of the same class
                    if class_name in mask_dict:
                        mask_dict[class_name] = np.logical_or(
                            mask_dict[class_name], binary_mask)
                    else:
                        mask_dict[class_name] = binary_mask

        # Get visualization
        plotted = result.plot()
        # Convert back to RGB for return
        plotted_rgb = cv2.cvtColor(plotted, cv2.COLOR_BGR2RGB)

        return plotted_rgb, mask_dict



