
from typing import List, Dict, Any, Tuple
from collections import defaultdict
import torch
import numpy as np
from PIL import Image
import re


class TreeNode:

    def __init__(self, image_path: str, label: str, box: List[float], idx: int):
        self.image_path = image_path
        self.label = label
        self.box = box
        self.idx = idx
        self.description = None
        self.children = []
    
    def add_child(self, child):
 
        self.children.append(child)


class HierarchyTreeBuilder:

    
    def __init__(self, depth_model_path: str = "/home//hsgg/checkpoint/depth-v3-base", device: str = "cuda:6"):

        self.label_counter = defaultdict(int)
        self.idx_to_sorted_label = {}  
        

        self._last_depth_map = None  
        
     
        self.orig_img_width = None
        self.orig_img_height = None
        

        print(f"Loading Depth Anything V3 model from {depth_model_path}...")
        import sys
        sys.path.append('/home//hsgg/depth-anything-3/src')
        from depth_anything_3.api import DepthAnything3
        
        self.depth_device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.depth_model = DepthAnything3.from_pretrained(depth_model_path)
        self.depth_model = self.depth_model.to(self.depth_device)
        self.depth_model.eval()
        print(f"Depth model loaded on {self.depth_device}")
    
    def compute_sorted_label_map(self, labels: List[str]) -> Dict[int, str]:
     
        self.label_counter = defaultdict(int)
        self.idx_to_sorted_label = {}
        self._assign_sorted_labels(labels)
        return dict(self.idx_to_sorted_label)
    
    def build_hierarchy_tree(
        self,
        objects_data: Dict[str, Any],
        cropped_objects: List[Dict[str, Any]],
        image_path: str = None
    ) -> Tuple[List[TreeNode], Dict[str, Any], Dict[int, str]]:
       
        if 'record_dict' not in objects_data or not objects_data['record_dict']:
            return [], {}, {}
        
        boxes = objects_data['boxes']
        labels = objects_data['labels']
        

        depth_map = None
        if image_path:
            depth_map = self._generate_depth_map(image_path)
        
  
        self.label_counter = defaultdict(int)
        self.idx_to_sorted_label = {}
        self._assign_sorted_labels(labels)

        box_to_path = {
            tuple(obj['box']): obj['image_path']
            for obj in cropped_objects
        }
        

        first_layer_dict = objects_data['record_dict'][0]
        tree_roots = []
        
        for root_label in first_layer_dict.keys():
            for i, label in enumerate(labels):
                if label == root_label:
                    box = boxes[i]
                    image_path = box_to_path.get(tuple(box))
                    if image_path:
                        sorted_label = self.idx_to_sorted_label[i]
                        node = TreeNode(image_path, sorted_label, box, i)
                        tree_roots.append(node)
        
      
        second_nodes = []
        second_layer_candidates = [] 
        mounted_indices = set(node.idx for node in tree_roots) 

        for root in tree_roots:
       
            base_label = re.sub(r'\d+$', '', root.label)
            child_labels = first_layer_dict.get(base_label, [])
            if not isinstance(child_labels, list):
                child_labels = [child_labels]

            for child_label in child_labels:
                for i, label in enumerate(labels):
                    if label=="windshield" and child_label=="windshield":
                        print(label)
                    if label == child_label and i not in mounted_indices:
                        box = boxes[i]
                        image_path = box_to_path.get(tuple(box))
                        if image_path:
                            sorted_label = self.idx_to_sorted_label[i]
                            child_node = TreeNode(image_path, sorted_label, box, i)
                            
                   
                            overlap_ratio = self._box_overlap_ratio(child_node.box, root.box)
                            if overlap_ratio >= 0.001:
                           
                                root.add_child(child_node)
                                second_nodes.append(child_node)
                                mounted_indices.add(i)
        

        for child_node, child_label in second_layer_candidates:
       
            if child_node.idx in mounted_indices:
                continue
            
            mounted = False
            best_parent = None
            min_depth_diff = 0.3
            
            if depth_map is not None:
                for root in tree_roots:
                    if self._box_overlap_ratio(child_node.box, root.box) >= 0.5:
                        depth_diff = self._calculate_depth_diff(child_node.box, root.box, depth_map)
                        if depth_diff < min_depth_diff:
                            min_depth_diff = depth_diff
                            best_parent = root
                
                if best_parent is not None:
                    best_parent.add_child(child_node)
                    second_nodes.append(child_node)
                    mounted_indices.add(child_node.idx)
                    mounted = True
            else:
                for root in tree_roots:
                    if self._box_overlap_ratio(child_node.box, root.box) >= 0.5:
                        root.add_child(child_node)
                        second_nodes.append(child_node)
                        mounted_indices.add(child_node.idx)
                        mounted = True
                        break
            
            if not mounted:
                tree_roots.append(child_node)
                mounted_indices.add(child_node.idx)
        
        if len(objects_data['record_dict']) > 1:
            second_layer_dict = objects_data['record_dict'][1]
            third_layer_candidates = []
            
            
            for second_node in second_nodes:
                base_label = re.sub(r'\d+$', '', second_node.label)
                child_labels = second_layer_dict.get(base_label, [])
                if not isinstance(child_labels, list):
                    child_labels = [child_labels]
                
                for child_label in child_labels:
                    for i, label in enumerate(labels):
                        if label == child_label and i not in mounted_indices:
                            box = boxes[i]
                            image_path = box_to_path.get(tuple(box))
                            if image_path:
                                sorted_label = self.idx_to_sorted_label[i]
                                child_node = TreeNode(image_path, sorted_label, box, i)
                                
                                overlap_ratio = self._box_overlap_ratio(child_node.box, second_node.box)
                                if overlap_ratio >= 0.001:
                                    second_node.add_child(child_node)
                                    mounted_indices.add(i)
            
            for child_node in third_layer_candidates:
                if child_node.idx in mounted_indices:
                    continue
                
                mounted = False
                best_parent = None
                min_depth_diff = 0.3
                
                if depth_map is not None:
                    for second_node in second_nodes:
                        if self._box_overlap_ratio(child_node.box, second_node.box) >= 0.5:
                            depth_diff = self._calculate_depth_diff(child_node.box, second_node.box, depth_map)
                            if depth_diff < min_depth_diff:
                                min_depth_diff = depth_diff
                                best_parent = second_node
                    
                    if best_parent is not None:
                        best_parent.add_child(child_node)
                        mounted_indices.add(child_node.idx)
                        mounted = True
                else:
                    for second_node in second_nodes:
                        if self._box_overlap_ratio(child_node.box, second_node.box) >= 0.5:
                            second_node.add_child(child_node)
                            mounted_indices.add(child_node.idx)
                            mounted = True
                            break
                
                if not mounted:
                    best_parent = None
                    min_depth_diff = 0.3
                    
                    if depth_map is not None:
                        for root in tree_roots:
                            if self._box_overlap_ratio(child_node.box, root.box) >= 0.5:
                                depth_diff = self._calculate_depth_diff(child_node.box, root.box, depth_map)
                                if depth_diff < min_depth_diff:
                                    min_depth_diff = depth_diff
                                    best_parent = root
                        
                        if best_parent is not None:
                            best_parent.add_child(child_node)
                            mounted_indices.add(child_node.idx)
                            mounted = True
                    else:
                        for root in tree_roots:
                            if self._box_overlap_ratio(child_node.box, root.box) >= 0.5:
                                root.add_child(child_node)
                                mounted_indices.add(child_node.idx)
                                mounted = True
                                break
                
                if not mounted:
                    tree_roots.append(child_node)
                    mounted_indices.add(child_node.idx)
        
        hierarchy_dict = self._build_hierarchy_dict(tree_roots, second_nodes)
        
        return tree_roots, hierarchy_dict, self.idx_to_sorted_label
    
    def _boxes_intersect(self, box1: List[float], box2: List[float]) -> bool:

        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2
        
        if x2_1 < x1_2 or x1_1 > x2_2 or y2_1 < y1_2 or y1_1 > y2_2:
            return False
        return True
    
    def _box_overlap_ratio(self, box1: List[float], box2: List[float]) -> float:
 
        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2
        
        x_left = max(x1_1, x1_2)
        y_top = max(y1_1, y1_2)
        x_right = min(x2_1, x2_2)
        y_bottom = min(y2_1, y2_2)
        
        if x_right < x_left or y_bottom < y_top:
            return 0.0
        
        intersection_area = (x_right - x_left) * (y_bottom - y_top)
        
        box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
        
        if box1_area == 0:
            return 0.0
        
        return intersection_area / box1_area
    
    def _assign_sorted_labels(self, labels: List[str]):
        for i, label in enumerate(labels):
            self.label_counter[label] += 1
            count = self.label_counter[label]
            
            self.idx_to_sorted_label[i] = (label, count)
        
        label_total_count = defaultdict(int)
        for label, count in self.idx_to_sorted_label.values():
            label_total_count[label] = max(label_total_count[label], count)
        
        for i in range(len(labels)):
            label, count = self.idx_to_sorted_label[i]
            if label_total_count[label] > 1:
                self.idx_to_sorted_label[i] = f"{label}{count}"
            else:
                self.idx_to_sorted_label[i] = label
    
    def _build_hierarchy_dict(self, tree_roots: List[TreeNode], second_nodes: List[TreeNode]) -> Dict[str, Any]:

        hierarchy_dict = {
            "layer1_nodes": [],
            "layer2_mapping": {},
            "layer3_mapping": {}
        }
        
        for root in tree_roots:
            hierarchy_dict["layer1_nodes"].append(root.label)
        
        for root in tree_roots:
            if root.children:
                children_labels = [child.label for child in root.children]
                if children_labels:
                    hierarchy_dict["layer2_mapping"][root.label] = children_labels
        
        for second_node in second_nodes:
            if second_node.children:
                children_labels = [child.label for child in second_node.children]
                if children_labels:
                    hierarchy_dict["layer3_mapping"][second_node.label] = children_labels
        
        return hierarchy_dict
    
    def _generate_depth_map(self, image_path: str):

        try:
            from PIL import Image
            original_img = Image.open(image_path)
            orig_width, orig_height = original_img.size
            
            self.orig_img_width = orig_width
            self.orig_img_height = orig_height
            
            with torch.no_grad():
                prediction = self.depth_model.inference(
                    image=[image_path],
                    process_res=504,
                    process_res_method="upper_bound_resize",
                    export_dir=None,
                    export_format="glb"
                )
                depth_map = prediction.depth[0]
                
                if isinstance(depth_map, torch.Tensor):
                    depth_map = depth_map.cpu().numpy()
                
                from scipy.ndimage import zoom
                depth_height, depth_width = depth_map.shape
                scale_x = orig_width / depth_width
                scale_y = orig_height / depth_height
                
                depth_map_resized = zoom(depth_map, (scale_y, scale_x), order=1)
                
                if depth_map_resized.shape != (orig_height, orig_width):
                    from PIL import Image as PILImage
                    depth_img = PILImage.fromarray(depth_map_resized.astype(np.float32))
                    depth_img_resized = depth_img.resize((orig_width, orig_height), PILImage.BILINEAR)
                    depth_map_resized = np.array(depth_img_resized)
                
                self._last_depth_map = depth_map_resized
                
                return depth_map_resized
        except Exception as e:
            print(f"Warning: Failed to generate depth map: {e}")
            self.depth_map_scale = None
            return None
    

    
    def _compute_object_distribution(self, depth_map, bbox: List[float]) -> Tuple[float, float]:
        
        x1, y1, x2, y2 = [int(x) for x in bbox]
        
        h, w = depth_map.shape
        x1 = max(0, min(x1, w - 1))
        x2 = max(0, min(x2, w))
        y1 = max(0, min(y1, h - 1))
        y2 = max(0, min(y2, h))
        
        obj_depth = depth_map[y1:y2, x1:x2]
        
        if obj_depth.size == 0:
            return 0.0, 1.0
        
        mu_i = np.mean(obj_depth)
        sigma_i = np.std(obj_depth)
        
        sigma_i = sigma_i + 1e-6
        
        return float(mu_i), float(sigma_i)
    
    def _calculate_2d_distance(self, box1: List[float], box2: List[float]) -> float:
    
        center1_x = (box1[0] + box1[2]) / 2
        center1_y = (box1[1] + box1[3]) / 2
        center2_x = (box2[0] + box2[2]) / 2
        center2_y = (box2[1] + box2[3]) / 2
        
        dist = np.sqrt((center1_x - center2_x)**2 + (center1_y - center2_y)**2)
        
        diagonal = np.sqrt(self.orig_img_width**2 + self.orig_img_height**2)
        return dist / diagonal

    
    def _compute_geometric_score(self, box1: List[float], box2: List[float], depth_map) -> float:

        mu1, sig1 = self._compute_object_distribution(depth_map, box1)
        mu2, sig2 = self._compute_object_distribution(depth_map, box2)
    
        dist_2d = self._calculate_2d_distance(box1, box2)
        

        depth_mismatch = (mu1 - mu2)**2 / (2 * (sig1**2 + sig2**2))
        

        score = np.exp(- (dist_2d + depth_mismatch))
        
        return float(score)
    
    def _calculate_depth_diff(self, box1: List[float], box2: List[float], depth_map) -> float:
 
        mu1, sig1 = self._compute_object_distribution(depth_map, box1)
        mu2, sig2 = self._compute_object_distribution(depth_map, box2)
        
        depth_diff_mean = abs(mu1 - mu2)
        
        center1_x = int((box1[0] + box1[2]) / 2)
        center1_y = int((box1[1] + box1[3]) / 2)
        center2_x = int((box2[0] + box2[2]) / 2)
        center2_y = int((box2[1] + box2[3]) / 2)
        
        h, w = depth_map.shape
        center1_y = min(max(0, center1_y), h - 1)
        center1_x = min(max(0, center1_x), w - 1)
        center2_y = min(max(0, center2_y), h - 1)
        center2_x = min(max(0, center2_x), w - 1)
        
        depth_center1 = depth_map[center1_y, center1_x]
        depth_center2 = depth_map[center2_y, center2_x]
        depth_diff_center = abs(depth_center1 - depth_center2)
        
        depth_diff = min(depth_diff_mean, depth_diff_center)
        
        return depth_diff_mean
    
    def _is_depth_close(self, box1: List[float], box2: List[float], depth_map, threshold: float = 0.1) -> bool:
     
        mu1, sig1 = self._compute_object_distribution(depth_map, box1)
        mu2, sig2 = self._compute_object_distribution(depth_map, box2)
        
        depth_diff_mean = abs(mu1 - mu2)
        
        return depth_diff_mean < threshold