"""
Image processing utilities - crop objects, generate union images, etc.
Supports image caching for improved efficiency
"""
import os
from typing import List, Dict, Any, Tuple
from functools import lru_cache

import numpy as np
import torch
from PIL import Image, ImageDraw


class ImageProcessor:
    """
    Image processing utility class
    Supports image caching to avoid repeatedly loading the same image
    """
    
    # Class-level image cache
    _image_cache: Dict[str, Image.Image] = {}
    _cache_max_size: int = 10  # Maximum cache size
    
    @classmethod
    def get_image(cls, image_path: str) -> Image.Image:
        """Get image (with caching)"""
        if image_path not in cls._image_cache:
            # If cache is full, remove the oldest entry
            if len(cls._image_cache) >= cls._cache_max_size:
                oldest_key = next(iter(cls._image_cache))
                del cls._image_cache[oldest_key]
            cls._image_cache[image_path] = Image.open(image_path).convert("RGB")
        return cls._image_cache[image_path].copy()  # Return copy to avoid modifying original
    
    @classmethod
    def clear_cache(cls):
        """Clear image cache"""
        cls._image_cache.clear()
    
    def crop_objects(
        self,
        image_path: str,
        boxes: List[List[float]],
        labels: List[str],
        output_dir: str = "./cropped_objects",
        scale_factors: Dict[str, Tuple[float, float]] = None,
        sorted_label_map: Dict[int, str] = None
    ) -> List[Dict[str, Any]]:
        """
        Crop and save object images
        
        Args:
            image_path: Original image path
            boxes: Bounding box list
            labels: Label list
            output_dir: Output directory
            scale_factors: Scale factor configuration
            
        Returns:
            List of cropped object information
        """
        os.makedirs(output_dir, exist_ok=True)
        
        image = self.get_image(image_path)
        img_width, img_height = image.size
        img_base = os.path.splitext(os.path.basename(image_path))[0]
        
        # Default scaling strategy
        if scale_factors is None:
            scale_factors = self._get_default_scale_factors(img_width, img_height)
        
        cropped_objects = []
        
        for i, (box, label) in enumerate(zip(boxes, labels)):
            x1, y1, x2, y2 = box
            width = x2 - x1
            height = y2 - y1
            
            # Select scale factor based on size
            scale_w, scale_h = self._select_scale_factor(
                width, height, img_width, img_height
            )
            
            # Calculate expanded boundaries
            new_width = width * scale_w
            new_height = height * scale_h
            width_increase = new_width - width
            height_increase = new_height - height
            
            new_x1 = max(0, x1 - width_increase / 2)
            new_y1 = max(0, y1 - height_increase / 2)
            new_x2 = min(img_width, x2 + width_increase / 2)
            new_y2 = min(img_height, y2 + height_increase / 2)
            
            # Crop image
            cropped = image.crop((new_x1, new_y1, new_x2, new_y2))
            
            # Draw original box on cropped image (red)
            draw = ImageDraw.Draw(cropped)
            rel_x1 = x1 - new_x1
            rel_y1 = y1 - new_y1
            rel_x2 = x2 - new_x1
            rel_y2 = y2 - new_y1
            draw.rectangle(
                (rel_x1, rel_y1, rel_x2, rel_y2),
                outline="red",
                width=2
            )
            
            # Use sorted label if provided
            used_label = sorted_label_map.get(i, label) if sorted_label_map else label
            
            # Save
            filename = f"cropped_{img_base}_{used_label}.png"
            save_path = os.path.join(output_dir, filename)
            cropped.save(save_path)
            
            cropped_objects.append({
                'image_path': save_path,
                'original_image_path': image_path,
                'box': box,
                'label': used_label,
                'idx': i
            })
        
        return cropped_objects
    
    def generate_union_images(
        self,
        image_path: str,
        boxes: List[List[float]],
        labels: List[str],
        output_dir: str = "./union_objects",
        proximity_threshold: float = 0,
        max_unions: int = 100,
        sorted_label_map: Dict[int, str] = None,
        tree_roots: List = None,
        depth_map = None,
      
    ) -> List[str]:
        """
        Generate union region images for object pairs
        
        Uses new depth and IoU-based method to determine whether to crop:
        1. First check IoU, if IoU > 0.88, do not crop (high overlap)
        2. Otherwise, use spatial distance metric to determine whether to crop
        
        Args:
            image_path: Original image path
            boxes: Bounding box list
            labels: Label list
            output_dir: Output directory
            proximity_threshold: Proximity threshold (normalized distance) - deprecated
            max_unions: Maximum number of union images
            sorted_label_map: Mapping from idx to sorted labels
            tree_roots: Root node list of hierarchy tree
            depth_map: Depth map (numpy array, interpolated to original image size)
            
        Returns:
            List of union image paths
        """
        os.makedirs(output_dir, exist_ok=True)
        
        image = self.get_image(image_path)
        img_width, img_height = image.size
        img_base = os.path.splitext(os.path.basename(image_path))[0]
        img_area = img_width * img_height
        
        num_boxes = len(boxes)
        union_infos = []
        
        # Build tree-internal relationship mapping and root node mapping
        tree_relations = set()  # Store tree-internal (idx_i, idx_j) pairs
        root_indices = set()  # Set of root node indices
        
        if tree_roots is not None:
            # Traverse each tree, collect tree-internal relationships and root nodes
            for root in tree_roots:
                root_indices.add(root.idx)
                
                # Collect all tree-internal nodes and relationships
                def collect_tree_relations(node):
                    for child in node.children:
                        # Tree-internal relationship: parent-child nodes
                        tree_relations.add((min(node.idx, child.idx), max(node.idx, child.idx)))
                        collect_tree_relations(child)
                
                collect_tree_relations(root)
        
        # If depth map and root nodes exist, compute spatial distance mask between root nodes
        root_crop_mask = None
        if depth_map is not None and len(root_indices) > 0:
            # Compute spatial distance mask only for root nodes
            root_indices_list = sorted(list(root_indices))
            root_boxes = [boxes[i] for i in root_indices_list]
            
            # Compute spatial distance mask between root nodes
            root_crop_mask_raw = self._compute_spatial_distance_mask(
                boxes=root_boxes,
                depth_map=depth_map,
                img_width=img_width,
                img_height=img_height,
                tau=0.1,  # Spatial distance threshold
                iou_threshold=0.9  # IoU threshold
            )
            
            # Map root node indices back to original indices
            root_crop_mask = {}
            for i_local, i_global in enumerate(root_indices_list):
                for j_local, j_global in enumerate(root_indices_list):
                    if i_local < j_local and root_crop_mask_raw[i_local, j_local]:
                        root_crop_mask[(i_global, j_global)] = True
        
        # Collect all qualified union information
        for i in range(num_boxes):
            for j in range(i + 1, num_boxes):
                box_i, box_j = boxes[i], boxes[j]
                label_i, label_j = labels[i], labels[j]
                
                # Determine if valid relationship
                relation_type = None
                
                # 1. Tree-internal relationship: keep directly (need crop)
                if (i, j) in tree_relations:
                    relation_type = 'tree_internal'
                
                # 2. Relationship between root nodes: use spatial distance mask to determine
                elif i in root_indices and j in root_indices:
                    if root_crop_mask is not None and (i, j) in root_crop_mask:
                        relation_type = 'root_spatial_distance'
                    elif root_crop_mask is None:
                        # If no depth map, fall back to intersection check
                        if self._boxes_intersect(box_i, box_j):
                            relation_type = 'root_intersect'
                
                if relation_type:
                    # Calculate union box
                    union_box = self._calculate_union_box(
                        box_i, box_j, img_width, img_height
                    )
                    crop_area = (union_box[2] - union_box[0]) * (union_box[3] - union_box[1])
                    
                    union_infos.append({
                        'idx_pair': (i, j),
                        'box_i': box_i,
                        'box_j': box_j,
                        'label_i': label_i,
                        'label_j': label_j,
                        'union_box': union_box,
                        'crop_area': crop_area,
                        'relation_type': relation_type
                    })
        
        # Filter and sort
        if len(union_infos) > max_unions:
            # Filter: both box areas within reasonable range
            filtered = []
            for info in union_infos:
                box1_area = (info['box_i'][2] - info['box_i'][0]) * (info['box_i'][3] - info['box_i'][1])
                box2_area = (info['box_j'][2] - info['box_j'][0]) * (info['box_j'][3] - info['box_j'][1])
                
                if (img_area * 0.005 < box1_area < img_area * 0.45 and
                    img_area * 0.005 < box2_area < img_area * 0.45):
                    filtered.append(info)
            
            # Sort by area, take top max_unions
            filtered.sort(key=lambda x: x['crop_area'], reverse=True)
            union_infos = filtered[:max_unions]
        
        # Generate and save union images
        union_paths = []
        for info in union_infos:
            i, j = info['idx_pair']
            used_label_i = sorted_label_map.get(i, info['label_i']) if sorted_label_map else info['label_i']
            used_label_j = sorted_label_map.get(j, info['label_j']) if sorted_label_map else info['label_j']
            filename = f"{img_base}_{used_label_i}-{i}_{used_label_j}-{j}.jpg"
            save_path = os.path.join(output_dir, filename)
            
            if not os.path.exists(save_path):
                # Crop
                cropped = image.crop(info['union_box'])
                draw = ImageDraw.Draw(cropped)
                
                # Draw two boxes
                subject_box = self._translate_box(info['box_i'], info['union_box'])
                object_box = self._translate_box(info['box_j'], info['union_box'])
                
                draw.rectangle(subject_box, outline="red", width=1)
                draw.rectangle(object_box, outline="yellow", width=1)
                
                cropped.save(save_path)
            
            union_paths.append(save_path)
        
        return union_paths
    
    def _calculate_union_box(
        self,
        box1: List[float],
        box2: List[float],
        img_width: float,
        img_height: float,
        expand_ratio: float = 0.1
    ) -> List[float]:
        """Calculate the union region of two boxes and expand appropriately"""
        x1 = min(box1[0], box2[0])
        y1 = min(box1[1], box2[1])
        x2 = max(box1[2], box2[2])
        y2 = max(box1[3], box2[3])
        
        width = x2 - x1
        height = y2 - y1
        
        # Expand
        width_increase = width * expand_ratio
        height_increase = height * expand_ratio
        
        new_x1 = max(0, x1 - width_increase / 2)
        new_y1 = max(0, y1 - height_increase / 2)
        new_x2 = min(img_width, x2 + width_increase / 2)
        new_y2 = min(img_height, y2 + height_increase / 2)
        
        return [new_x1, new_y1, new_x2, new_y2]
    
    def _boxes_intersect(self, box1: List[float], box2: List[float]) -> bool:
        """Determine if two boxes intersect"""
        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2
        
        if x2_1 < x1_2 or x1_1 > x2_2 or y2_1 < y1_2 or y1_1 > y2_2:
            return False
        return True
    
    def _is_proximate(
        self,
        box1: List[float],
        box2: List[float],
        img_width: float,
        img_height: float,
        threshold: float
    ) -> bool:
        """Determine if two boxes are proximate"""
        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2
        
        width_diff = max(0, max(x1_2 - x2_1, x1_1 - x2_2))
        height_diff = max(0, max(y1_2 - y2_1, y1_1 - y2_2))
        
        metric_h = width_diff / img_width
        metric_v = height_diff / img_height
        
        return metric_h <= threshold and metric_v <= threshold
    
    def _translate_box(
        self,
        box: List[float],
        union_box: List[float]
    ) -> List[float]:
        """Convert box coordinates to cropped image space"""
        return [
            box[0] - union_box[0],
            box[1] - union_box[1],
            box[2] - union_box[0],
            box[3] - union_box[1]
        ]
    
    def _select_scale_factor(
        self,
        width: float,
        height: float,
        img_width: float,
        img_height: float
    ) -> Tuple[float, float]:
        """Select scale factor based on object size"""
        # Width scaling
        if width < img_width * 0.05:
            scale_w = 6.0
        elif width < img_width * 0.1:
            scale_w = 3.0
        elif width < img_width * 0.2:
            scale_w = 1.5
        else:
            scale_w = 1.1
        
        # Height scaling
        if height < img_height * 0.05:
            scale_h = 6.0
        elif height < img_height * 0.1:
            scale_h = 3.0
        elif height < img_height * 0.2:
            scale_h = 1.5
        else:
            scale_h = 1.1
        
        return scale_w, scale_h
    
    def _get_default_scale_factors(
        self,
        img_width: float,
        img_height: float
    ) -> Dict[str, Tuple[float, float]]:
        """Get default scale factor configuration"""
        return {
            'very_small': (6.0, 6.0),
            'small': (3.0, 3.0),
            'medium': (1.5, 1.5),
            'large': (1.1, 1.1)
        }
    

    
    def _compute_object_distribution(self, depth_map, bbox: List[float]) -> Tuple[float, float]:
        """
        Compute depth distribution statistics for object region
        
        Args:
            depth_map: HxW numpy array (from Depth Anything V3, interpolated to original image size)
            bbox: [x1, y1, x2, y2] (original image coordinate system)
            
        Returns:
            (mu_i, sigma_i): mean and standard deviation
        """
        # Depth map already interpolated to original size, use box coordinates directly
        x1, y1, x2, y2 = [int(x) for x in bbox]
        
        # Boundary check
        if isinstance(depth_map, torch.Tensor):
            h, w = depth_map.shape
            depth_map = depth_map.cpu().numpy()
        else:
            h, w = depth_map.shape
        
        x1 = max(0, min(x1, w - 1))
        x2 = max(0, min(x2, w))
        y1 = max(0, min(y1, h - 1))
        y2 = max(0, min(y2, h))
        
        # Crop depth of object region
        obj_depth = depth_map[y1:y2, x1:x2]
        
        # Check if empty
        if obj_depth.size == 0:
            return 0.0, 1.0  # fallback
        
        # Calculate statistics
        mu_i = np.mean(obj_depth)
        sigma_i = np.std(obj_depth)
        
        # Prevent division by zero
        sigma_i = sigma_i + 1e-6
        
        return float(mu_i), float(sigma_i)
    
    def _calculate_2d_distance(self, box1: List[float], box2: List[float], img_width: float, img_height: float) -> float:
        """
        Calculate normalized 2D distance between centers of two boxes
        
        Args:
            box1, box2: [x1, y1, x2, y2]
            img_width: Image width
            img_height: Image height
            
        Returns:
            Normalized 2D distance
        """
        # Calculate center points
        center1_x = (box1[0] + box1[2]) / 2
        center1_y = (box1[1] + box1[3]) / 2
        center2_x = (box2[0] + box2[2]) / 2
        center2_y = (box2[1] + box2[3]) / 2
        
        # Euclidean distance
        dist = (center1_x - center2_x)**2 + (center1_y - center2_y)**2
        
        # Normalize using image diagonal length
        diagonal = img_width**2 + img_height**2 + 1e-6
        return dist / (2 * diagonal)
    
    def _compute_soft_geometric_score(self, box1: List[float], box2: List[float], depth_map, img_width: float, img_height: float) -> float:
        """
        Compute soft geometric relationship score between two objects (based on depth distribution)
        
        Args:
            box1, box2: Bounding boxes [x1, y1, x2, y2]
            depth_map: Depth map (numpy array or torch.Tensor, interpolated to original image size)
            img_width: Image width
            img_height: Image height
            
        Returns:
            score: Probability score between 0~1, score > 0.135 indicates triplet relationship
        """
        # Calculate object depth distribution
        mu1, sig1 = self._compute_object_distribution(depth_map, box1)
        mu2, sig2 = self._compute_object_distribution(depth_map, box2)
        
        # 2D distance (normalized)
        dist_2d = self._calculate_2d_distance(box1, box2, img_width, img_height)
        
        # 3D depth mismatch (Z-score form)
        # Denominator is the key: sum of "thickness" of two objects.
        # If objects are thick (large sig), tolerance for depth difference increases.
        depth_mismatch = (mu1 - mu2)**2 / (2 * (sig1**2 + sig2**2))
        print(depth_mismatch)
        # Final probability score (0~1)
        score = np.exp(- (dist_2d + depth_mismatch))

        return float(score)
    
    def _calculate_iou(self, box1: List[float], box2: List[float]) -> float:
        """
        Calculate IoU (Intersection over Union) of two bounding boxes
        
        Args:
            box1: First bounding box [x1, y1, x2, y2]
            box2: Second bounding box [x1, y1, x2, y2]
            
        Returns:
            IoU value, range [0, 1]
        """
        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2
        
        # Calculate intersection region
        x_left = max(x1_1, x1_2)
        y_top = max(y1_1, y1_2)
        x_right = min(x2_1, x2_2)
        y_bottom = min(y2_1, y2_2)
        
        # If no intersection
        if x_right < x_left or y_bottom < y_top:
            return 0.0
        
        # Calculate intersection area
        intersection_area = (x_right - x_left) * (y_bottom - y_top)
        
        # Calculate union area
        box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
        box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
        union_area = box1_area + box2_area - intersection_area
        
        # Avoid division by zero
        if union_area == 0:
            return 0.0
        
        return intersection_area / union_area
    
    def _compute_spatial_distance_mask(
        self,
        boxes: List[List[float]],
        depth_map: np.ndarray,
        img_width: int,
        img_height: int,
        tau: float = 0.2,
        iou_threshold: float = 0.88,
        eps: float = 1e-6
    ) -> np.ndarray:
        """
        Calculate spatial distance between object pairs and return mask matrix for whether to crop
        
        Uses depth map and bounding box information to calculate comprehensive spatial distance metric between object pairs.
        Uses vectorized computation to avoid loops and improve efficiency.
        
        Judgment logic:
        1. First check IoU, if IoU > iou_threshold (default 0.88), no need to crop (high overlap)
        2. Otherwise, use depth and spatial distance formula to determine whether to crop
        
        Formula:
        - d_raw = ||c_i - c_j||_2 / img_diag + |z_i - z_j| / max_depth_range
        - psi_ij = delta_ij / (delta_ij + 0.5 * (s_i + s_j))
        - d_final = d_raw * psi_ij
        - mask[i,j] = (IoU < iou_threshold) and (d_final < tau) and (i != j)
        
        Args:
            boxes: Bounding box list, each box format is [x1, y1, x2, y2]
            depth_map: Depth map (H x W), interpolated to original image size
            img_width: Image width
            img_height: Image height
            tau: Pruning threshold, default 0.8
            iou_threshold: IoU threshold, default 0.88, no crop if exceeded
            eps: Minimum value for numerical stability, default 1e-6
            
        Returns:
            Boolean mask matrix (N x N), True indicates need to crop this object pair
        """
        N = len(boxes)
        if N == 0:
            return np.array([])
        
        # Convert to numpy array for vectorized computation
        boxes_np = np.array(boxes)  # [N, 4]
        
        # First calculate IoU matrix for all object pairs [N, N]
        iou_matrix = np.zeros((N, N))
        for i in range(N):
            for j in range(N):
                if i != j:
                    iou_matrix[i, j] = self._calculate_iou(boxes[i], boxes[j])
        
        # 1. Calculate centers: [N, 2]
        centers = np.stack([
            (boxes_np[:, 0] + boxes_np[:, 2]) / 2,  # x
            (boxes_np[:, 1] + boxes_np[:, 3]) / 2   # y
        ], axis=1)
        
        # 2. Calculate diagonal lengths: [N]
        diagonals = np.sqrt(
            (boxes_np[:, 2] - boxes_np[:, 0]) ** 2 + 
            (boxes_np[:, 3] - boxes_np[:, 1]) ** 2
        )
        
        # 3. Calculate median depth for each object depths: [N]
        depths = np.zeros(N)
        for i in range(N):
            x1, y1, x2, y2 = [int(c) for c in boxes[i]]
            # Boundary check
            h, w = depth_map.shape
            x1 = max(0, min(x1, w - 1))
            x2 = max(0, min(x2, w))
            y1 = max(0, min(y1, h - 1))
            y2 = max(0, min(y2, h))
            
            # Extract depth region
            depth_region = depth_map[y1:y2, x1:x2]
            if depth_region.size > 0:
                depths[i] = np.median(depth_region)
            else:
                depths[i] = 0.0
        
        # 4. Calculate image diagonal length img_diag
        img_diag = np.sqrt(img_width ** 2 + img_height ** 2)
        
        # 5. Calculate scene maximum depth range max_depth_range
        max_depth_range = np.max(depths) - np.min(depths)
        if max_depth_range < eps:
            max_depth_range = 1.0  # Prevent division by zero
        
        # 6. Vectorized calculation of distances for all object pairs
        # Use broadcasting mechanism to compute N x N matrix
        
        # centers_i: [N, 1, 2], centers_j: [1, N, 2]
        centers_i = centers[:, np.newaxis, :]  # [N, 1, 2]
        centers_j = centers[np.newaxis, :, :]  # [1, N, 2]
        
        # Calculate Euclidean distance between centers delta_ij: [N, N]
        delta_ij = np.sqrt(np.sum((centers_i - centers_j) ** 2, axis=2))
        
        # depths_i: [N, 1], depths_j: [1, N]
        depths_i = depths[:, np.newaxis]  # [N, 1]
        depths_j = depths[np.newaxis, :]  # [1, N]
        
        # diagonals_i: [N, 1], diagonals_j: [1, N]
        diagonals_i = diagonals[:, np.newaxis]  # [N, 1]
        diagonals_j = diagonals[np.newaxis, :]  # [1, N]
        
        # 7. Calculate raw spatial distance d_raw: [N, N]
        d_raw = (delta_ij / img_diag) + (np.abs(depths_i - depths_j) / max_depth_range)
        
        # 8. Calculate size-aware scaling factor psi_ij: [N, N]
        psi_ij = delta_ij / (delta_ij + 0.5 * (diagonals_i + diagonals_j) + eps)
        
        # 9. Calculate final metric d_final: [N, N]
        d_final = d_raw * psi_ij
     
        # 10. Generate mask matrix:
        # - IoU < iou_threshold (not highly overlapping)
        # - d_final < tau (spatially close enough)
        # - i != j (not the same object)
        mask = (
            (iou_matrix < iou_threshold) &  # IoU should not be too high
            (d_final < tau) &                # Distance should be close enough
            (np.arange(N)[:, np.newaxis] != np.arange(N)[np.newaxis, :])  # Not the same object
        )
        
        return mask
