from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from mmseg.structures import SegDataSample
from segearthov3_segmentor import SegEarthOV3Segmentation
import numpy as np
img_path = '195.png'

# 针对机场场景定制的类别列表
name_list = [
    'background',           # 0
    'runway, taxiway, apron,parking_lot,concrete,',    # 1: 停机坪 (目标区域)
    'grass',                # 2: 草地
    'terminal,building',    # 5: 航站楼
    'airplane,aircraft',    # 6: 飞机 (障碍物)
    'car'                   # 8: 汽车
]

with open('./configs/my_name.txt', 'w') as writers:
    for i in range(len(name_list)):
        if i == len(name_list)-1:
            writers.write(name_list[i])
        else:
            writers.write(name_list[i] + '\n')
writers.close()


img = Image.open(img_path)
img_tensor = transforms.Compose([
    transforms.ToTensor(),
])(img).unsqueeze(0).to('cuda') # This variable is only a placeholder; the actual data is read within the model. (To be optimized)

data_sample = SegDataSample()
img_meta = {
    'img_path': img_path,
    'ori_shape': img.size
}
data_sample.set_metainfo(img_meta)


model = SegEarthOV3Segmentation(
    type='SegEarthOV3Segmentation',
    model_type='SAM3',
    classname_path='./configs/my_name.txt',
    prob_thd=0.1,
    confidence_threshold=0.1,
    slide_stride=512,
    slide_crop=512,
    checkpoint_path='models/sam3.pt',
)

from scipy.ndimage import rotate

import cv2

def find_multiple_placements(seg_pred, target_class_indices, box_w, box_h, max_count=5, angles=None):
    """
    寻找多个不重叠的矩形框。
    target_class_indices: 可以是一个整数，也可以是一个整数列表。
    """
    if angles is None:
        angles = range(0, 180, 5)
        
    # 生成联合掩码
    if isinstance(target_class_indices, int):
        target_class_indices = [target_class_indices]
    
    raw_mask = np.zeros_like(seg_pred, dtype=np.uint8)
    for idx in target_class_indices:
        raw_mask = cv2.bitwise_or(raw_mask, (seg_pred == idx).astype(np.uint8))
        
    # 全局占用掩码，用于精确碰撞检测
    global_occupied_mask = np.zeros_like(raw_mask)
    results = []
    
    h_orig, w_orig = raw_mask.shape
    center_orig = (w_orig / 2.0, h_orig / 2.0)
    
    # 计算距离变换图
    dist_map = cv2.distanceTransform(raw_mask, cv2.DIST_L2, 5)

    # === 可视化 Distance Field ===
    # 归一化到 0-255 以便可视化
    dist_vis = cv2.normalize(dist_map, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    # 应用伪彩色 (JET colormap)
    dist_vis_color = cv2.applyColorMap(dist_vis, cv2.COLORMAP_JET)
    cv2.imwrite('distance_map.png', dist_vis_color)
    print("Distance map saved to distance_map.png")
    # ============================
    
    # === 新策略：基于距离场脊线 (Ridge) 的锚点搜索 ===
    # 1. 找到距离场中的局部峰值点 (Local Maxima)，这些点就是道路/区域的"中心线"
    # 使用 dilate 算法寻找局部最大值
    kernel_size = min(box_w, box_h) // 2 # 根据框的大小动态调整搜索半径
    kernel_size = max(3, kernel_size)
    dilated_dist = cv2.dilate(dist_map, np.ones((kernel_size, kernel_size)))
    local_max_mask = (dist_map == dilated_dist) & (dist_map > 0)
    
    # 获取所有潜在锚点及其距离值
    anchors_y, anchors_x = np.where(local_max_mask)
    if len(anchors_x) == 0: return []
    
    # 按距离值从大到小排序锚点
    anchor_scores = dist_map[anchors_y, anchors_x]
    sorted_indices = np.argsort(anchor_scores)[::-1]
    
    for _ in range(max_count):
        found_in_this_round = False
        
        # 遍历所有高分锚点
        for idx in sorted_indices:
            cx, cy = int(anchors_x[idx]), int(anchors_y[idx])
            
            # 检查该锚点是否已被占用
            if global_occupied_mask[cy, cx] > 0:
                continue
            
            # === 计算该点的梯度方向 ===
            win = 5
            x1, y1 = max(0, cx - win), max(0, cy - win)
            x2, y2 = min(w_orig, cx + win), min(h_orig, cy + win)
            local_patch = dist_map[y1:y2, x1:x2]
            
            if local_patch.size == 0: continue
            
            sobelx = cv2.Sobel(local_patch, cv2.CV_64F, 1, 0, ksize=5)
            sobely = cv2.Sobel(local_patch, cv2.CV_64F, 0, 1, ksize=5)
            
            # 简单尝试几个主要方向，看哪个角度下，沿着长轴的距离值积分最大
            best_local_angle = 0
            best_local_score = -1
            
            # 既然我们在中心线上，我们可以快速试探
            test_angles = range(0, 180, 5) 
            
            valid_angle_found = False
            
            for angle in test_angles:
                # 构造矩形
                # 强制转换为 float
                rect_struct = ((float(cx), float(cy)), (float(box_w), float(box_h)), float(angle)) # cv2是顺时针
                rect_points = cv2.boxPoints(rect_struct)
                rect_points = np.int32(rect_points)
                
                # 检查是否越界或重叠
                temp_mask = np.zeros_like(raw_mask)
                cv2.fillPoly(temp_mask, [rect_points], 1)
                
                if np.any(cv2.bitwise_and(global_occupied_mask, temp_mask)):
                    continue
                
                # 检查覆盖率
                mask_pixels = np.sum(temp_mask)
                covered_pixels = np.sum(cv2.bitwise_and(raw_mask, temp_mask))
                
                if covered_pixels >= mask_pixels * 0.98: # 严格要求
                    # 计算该角度的得分：
                    # 1. 距离均值 (保证在路中间)
                    # 2. 距离方差 (越小越好，保证顺着路)
                    # 3. 实际上，只要能放下，方差最小的就是最顺的
                    
                    dist_values = dist_map[temp_mask == 1]
                    score = -np.std(dist_values) # 负方差作为分数
                    
                    if score > best_local_score or not valid_angle_found:
                        best_local_score = score
                        best_local_angle = angle
                        valid_angle_found = True
            
            if valid_angle_found:
                # 找到了该锚点的最佳角度
                # 记录结果
                results.append((float(cx), float(cy), best_local_angle)) # 注意：这里记录的是 cv2 角度 (顺时针)
                
                # 更新占用
                rect_struct = ((float(cx), float(cy)), (float(box_w), float(box_h)), float(best_local_angle))
                rect_points = np.int32(cv2.boxPoints(rect_struct))
                
                temp_mask = np.zeros_like(raw_mask)
                cv2.fillPoly(temp_mask, [rect_points], 1)
                cv2.bitwise_or(global_occupied_mask, temp_mask, global_occupied_mask)
                
                # 挖坑
                cv2.fillPoly(raw_mask, [rect_points], 0)
                cv2.polylines(raw_mask, [rect_points], isClosed=True, color=0, thickness=5)
                
                # 重新计算距离图 (这一步很关键，影响下一个点的选择)
                dist_map = cv2.distanceTransform(raw_mask, cv2.DIST_L2, 5)
                
                found_in_this_round = True
                break
        
        if not found_in_this_round:
            break
            
    return results

# ... (原有代码)

seg_pred = model.predict(img_tensor, data_samples=[data_sample])
seg_pred = seg_pred[0].pred_sem_seg.data.cpu().numpy().squeeze(0)

# === 寻找多个矩形框 ===
target_cls = 1  
box_w, box_h = 100, 80 
placements = find_multiple_placements(seg_pred, target_cls, box_w, box_h, max_count=5)

print(f"在类别 {target_cls} 中找到了 {len(placements)} 个位置")
for i, (cx, cy, angle) in enumerate(placements):
    print(f"  框 {i+1}: 中心({cx:.1f}, {cy:.1f}), 角度 {angle}°")

# === 可视化并分别保存结果 ===

# 1. 绘制原图 + 矩形框
fig1, ax1 = plt.subplots(figsize=(10, 10))
ax1.imshow(img)
ax1.axis('off')

for i, (cx, cy, angle) in enumerate(placements):
    import matplotlib.patches as patches
    from matplotlib.transforms import Affine2D
    
    rect_args = dict(xy=(-box_w/2, -box_h/2), width=box_w, height=box_h, 
                     linewidth=2, edgecolor='yellow', facecolor='none', linestyle='--')
    
    rect = patches.Rectangle(**rect_args)
    t = Affine2D().rotate_deg(angle).translate(cx, cy) + ax1.transData
    rect.set_transform(t)
    ax1.add_patch(rect)
    ax1.text(cx, cy, str(i+1), color='yellow', ha='center', va='center', fontsize=12, weight='bold')

plt.tight_layout()
plt.savefig('result_original.png', bbox_inches='tight', pad_inches=0)
plt.close(fig1)

# 2. 绘制分割图 + 矩形框
fig2, ax2 = plt.subplots(figsize=(10, 10))
ax2.imshow(seg_pred, cmap='viridis')
ax2.axis('off')

for i, (cx, cy, angle) in enumerate(placements):
    rect_args = dict(xy=(-box_w/2, -box_h/2), width=box_w, height=box_h, 
                     linewidth=2, edgecolor='red', facecolor='none', linestyle='--')
    
    rect = patches.Rectangle(**rect_args)
    t = Affine2D().rotate_deg(angle).translate(cx, cy) + ax2.transData
    rect.set_transform(t)
    ax2.add_patch(rect)
    ax2.text(cx, cy, str(i+1), color='red', ha='center', va='center', fontsize=12, weight='bold')

plt.tight_layout()
plt.savefig('result_segmentation.png', bbox_inches='tight', pad_inches=0)
plt.close(fig2)

# 3. 保存原始分割图 (不带框)
fig3, ax3 = plt.subplots(figsize=(10, 10))
ax3.imshow(seg_pred, cmap='viridis')
ax3.axis('off')
plt.tight_layout()
plt.savefig('segmentation_raw.png', bbox_inches='tight', pad_inches=0)
plt.close(fig3)

print("结果已分别保存为 result_original.png, result_segmentation.png, segmentation_raw.png 和 distance_map.png")