import os
import glob
import cv2
import numpy as np
import torch
import math
import random
import random
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from mmseg.structures import SegDataSample
from segearthov3_segmentor import SegEarthOV3Segmentation

# ================= 配置区域 =================
SEARCH_ROOT = 'MAR20/train_sub_split/images'
LABEL_ROOT = 'MAR20/train_sub_split/labelTxt'
OUTPUT_DIR = 'MAR20/train_sub_split/generated_labels'

# 每个类别需要寻找多少张图片进行插入
IMAGES_PER_CLASS = 150 
# 每张图片针对当前类别最多插入多少个
MAX_ITEMS_PER_IMAGE = 5

length_ratios_rel_to_f16 = [
    1.45, 1.98, 3.52, 5.00, 1.00, 3.59, 3.09, 3.22, 2.36, 2.95,
    3.09, 2.82, 1.29, 2.76, 1.25, 1.22, 3.29, 3.67, 1.55, 1.63
]
aspect_ratios = [
    1.43, 0.74, 1.02, 1.11, 1.51, 0.97, 1.05, 0.86, 1.17, 1.06,
    1.05, 1.24, 1.49, 1.04, 1.39, 1.34, 0.97, 1.10, 1.59, 1.39
]
plane_names = [f"A{i}" for i in range(1, 21)]
plane_name_to_idx = {name: i for i, name in enumerate(plane_names)}

NAME_LIST_PATH = './configs/mar20_names.txt'
NAME_LIST = [
    'background',           # 0
    'runway, taxiway, apron,parking_lot,concrete,',    # 1: 停机坪 (目标区域)
    'grass',                # 2: 草地
    'terminal,building',    # 5: 航站楼
    'airplane,aircraft',    # 6: 飞机 (障碍物)
    'car'                   # 8: 汽车
]
TARGET_CLS_INDICES = [1]

def setup_model():
    os.makedirs(os.path.dirname(NAME_LIST_PATH), exist_ok=True)
    with open(NAME_LIST_PATH, 'w') as f:
        for name in NAME_LIST:
            f.write(name + '\n')
    model = SegEarthOV3Segmentation(
        type='SegEarthOV3Segmentation',
        model_type='SAM3',
        classname_path=NAME_LIST_PATH,
        prob_thd=0.1,
        confidence_threshold=0.1,
        slide_stride=512,
        slide_crop=512,
        checkpoint_path='models/sam3.pt'
    )
    return model

def calc_scale_for_image(label_path):
    if not os.path.exists(label_path): return None
    base_lengths = []
    all_lengths = []
    
    with open(label_path, 'r') as f:
        lines = f.readlines()
        
    for line in lines:
        parts = line.strip().split()
        if len(parts) < 9: continue
        try:
            poly = list(map(float, parts[:8]))
            cls_name = parts[8]
            d1 = math.hypot(poly[0]-poly[2], poly[1]-poly[3])
            d2 = math.hypot(poly[2]-poly[4], poly[3]-poly[5])
            length = max(d1, d2)
            
            if cls_name in plane_name_to_idx:
                idx = plane_name_to_idx[cls_name]
                ratio = length_ratios_rel_to_f16[idx]
                base_lengths.append(length / ratio)
            
            all_lengths.append(length)
        except: continue
        
    if base_lengths:
        return np.mean(base_lengths)
    elif all_lengths:
        return np.median(all_lengths) / np.mean(length_ratios_rel_to_f16)
    else:
        return None

def get_plane_dims_dynamic(idx, base_f16_px):
    length = length_ratios_rel_to_f16[idx] * base_f16_px
    wingspan = length / aspect_ratios[idx]
    return int(wingspan), int(length)

def find_placements_on_mask(raw_mask, global_occupied_mask, box_w, box_h, max_count):
    results = []
    h_orig, w_orig = raw_mask.shape
    dist_map = cv2.distanceTransform(raw_mask, cv2.DIST_L2, 5)
    kernel_size = min(box_w, box_h) // 2
    kernel_size = max(3, kernel_size)
    dilated_dist = cv2.dilate(dist_map, np.ones((kernel_size, kernel_size)))
    local_max_mask = (dist_map == dilated_dist) & (dist_map > 0)
    anchors_y, anchors_x = np.where(local_max_mask)
    if len(anchors_x) == 0: return []
    anchor_scores = dist_map[anchors_y, anchors_x]
    sorted_indices = np.argsort(anchor_scores)[::-1]
    candidate_indices = sorted_indices[:max_count * 20]
    for idx in candidate_indices:
        if len(results) >= max_count: break
        cx, cy = int(anchors_x[idx]), int(anchors_y[idx])
        if global_occupied_mask[cy, cx] > 0: continue
        win = 5
        x1, y1 = max(0, cx - win), max(0, cy - win)
        x2, y2 = min(w_orig, cx + win), min(h_orig, cy + win)
        local_patch = dist_map[y1:y2, x1:x2]
        if local_patch.size == 0: continue
        sobelx = cv2.Sobel(local_patch, cv2.CV_64F, 1, 0, ksize=5)
        sobely = cv2.Sobel(local_patch, cv2.CV_64F, 0, 1, ksize=5)
        avg_gx = np.mean(sobelx)
        avg_gy = np.mean(sobely)
        grad_angle = np.degrees(np.arctan2(avg_gy, avg_gx))
        test_angles = [grad_angle, grad_angle + 90, grad_angle + 180, grad_angle + 270]
        best_local_angle = 0
        best_local_score = -float('inf')
        valid_found = False
        for angle in test_angles:
            rect_struct = ((float(cx), float(cy)), (float(box_w), float(box_h)), float(angle))
            rect_points = cv2.boxPoints(rect_struct)
            rect_points = np.int32(rect_points)
            temp_mask = np.zeros_like(raw_mask)
            cv2.fillPoly(temp_mask, [rect_points], 1)
            if np.any(cv2.bitwise_and(global_occupied_mask, temp_mask)): continue
            mask_pixels = np.sum(temp_mask)
            if mask_pixels == 0: continue
            covered_pixels = np.sum(cv2.bitwise_and(raw_mask, temp_mask))
            if covered_pixels >= mask_pixels * 0.95:
                dist_values = dist_map[temp_mask == 1]
                if dist_values.size > 0:
                    score = -np.std(dist_values)
                    if score > best_local_score:
                        best_local_score = score
                        best_local_angle = angle
                        valid_found = True
        if valid_found:
            results.append((float(cx), float(cy), best_local_angle))
            # 立即更新 occupied_mask 以防止同批次重叠
            rect_struct = ((float(cx), float(cy)), (float(box_w), float(box_h)), float(best_local_angle))
            rect_points = np.int32(cv2.boxPoints(rect_struct))
            cv2.fillPoly(global_occupied_mask, [rect_points], 1)
            cv2.fillPoly(raw_mask, [rect_points], 0)
            
            # 不再重新计算 distance map 以节省时间，反正已经涂黑了 raw_mask
            # 但如果 max_count 很大，可能需要重新计算，这里 max_count=5，影响不大
            # dist_map = cv2.distanceTransform(raw_mask, cv2.DIST_L2, 5) 
            
    return results

def main():
    if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR)
    
    img_files = glob.glob(os.path.join(SEARCH_ROOT, '*.png')) + \
                glob.glob(os.path.join(SEARCH_ROOT, '*.jpg'))
    
    if not img_files:
        print(f"No images found in {SEARCH_ROOT}")
        return
        
    print(f"Found {len(img_files)} images in {SEARCH_ROOT}")
    model = setup_model()
    
    # --- Phase 1: 预处理所有图像 (计算 Scale, 分割, 初始化 Mask) ---
    print("Phase 1: Pre-processing images (Segmentation & Initialization)...")
    
    # 缓存数据结构: { img_path: { 'scale': float, 'raw_mask': ndarray, 'occupied_mask': ndarray } }
    image_cache = {}
    
    valid_img_files = []
    
    for img_path in tqdm(img_files):
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        label_path = os.path.join(LABEL_ROOT, base_name + '.txt')
        
        # 1. 计算 Scale
        current_base_px = calc_scale_for_image(label_path)
        if current_base_px is None: 
            continue 
            
        # 2. 分割推理
        img = Image.open(img_path).convert('RGB')
        w_img, h_img = img.size
        img_tensor = transforms.Compose([transforms.ToTensor()])(img).unsqueeze(0).to('cuda')
        data_sample = SegDataSample()
        data_sample.set_metainfo({'img_path': img_path, 'ori_shape': (h_img, w_img)})
        
        seg_pred = model.predict(img_tensor, data_samples=[data_sample])
        seg_mask = seg_pred[0].pred_sem_seg.data.cpu().numpy().squeeze(0)
        
        raw_mask = np.zeros_like(seg_mask, dtype=np.uint8)
        for idx in TARGET_CLS_INDICES:
            raw_mask = cv2.bitwise_or(raw_mask, (seg_mask == idx).astype(np.uint8))
        
        if np.sum(raw_mask) < 2000: 
            continue
            
        # 3. 初始化 Occupied Mask (现有标签)
        global_occupied_mask = np.zeros_like(raw_mask)
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 8:
                        poly = np.array(list(map(float, parts[:8]))).reshape(-1, 2).astype(np.int32)
                        cv2.fillPoly(global_occupied_mask, [poly], 1)
                        cv2.fillPoly(raw_mask, [poly], 0)
        
        image_cache[img_path] = {
            'scale': current_base_px,
            'raw_mask': raw_mask,
            'occupied_mask': global_occupied_mask
        }
        valid_img_files.append(img_path)
        
    print(f"Pre-processing done. Valid images: {len(valid_img_files)}")
    
    # --- Phase 2: 按类别生成 ---
    print("Phase 2: Generating placements by class...")
    
    # f_out = open(os.path.join(OUTPUT_DIR, 'generated_placements.txt'), 'w')
    # f_out.write("image_path,class_name,cx,cy,w,h,angle\n")
    
    # 按 A1, A2, ... A20 顺序进行
    for idx, name in enumerate(plane_names): # 遍历 20 个类别
        print(f"Processing Class {name}...")
        
        # 为每个类别单独创建一个输出文件
        class_out_path = os.path.join(OUTPUT_DIR, f"{name}.txt")
        f_cls = open(class_out_path, 'w')
        f_cls.write("image_path,class_name,cx,cy,w,h,angle\n")
        
        # 随机打乱候选图片
        random.shuffle(valid_img_files)
        
        images_processed_count = 0
        
        for img_path in tqdm(valid_img_files, desc=f"Class {name}"):
            if images_processed_count >= IMAGES_PER_CLASS:
                break
                
            cache = image_cache[img_path]
            current_base_px = cache['scale']
            # 使用副本，这样不同类别之间互不影响（可以重叠）
            raw_mask = cache['raw_mask'].copy()
            occupied_mask = cache['occupied_mask'].copy()
            
            w, h = get_plane_dims_dynamic(idx, current_base_px)
            
            # 尝试插入最多 MAX_ITEMS_PER_IMAGE 个
            placements = find_placements_on_mask(
                raw_mask, 
                occupied_mask, 
                w, h, 
                max_count=MAX_ITEMS_PER_IMAGE
            )
            
            if len(placements) > 0:
                # 写入结果到该类别的独立文件
                for cx, cy, angle in placements:
                    f_cls.write(f"{img_path},{name},{cx:.2f},{cy:.2f},{w},{h},{angle:.2f}\n")
                f_cls.flush()
                
                # 成功在这张图插入了至少一个，计数 +1
                images_processed_count += 1
                
                # 注意：find_placements_on_mask 内部已经修改了传入的 raw_mask 和 occupied_mask
                # 但由于上面使用了 .copy()，所以不会影响 image_cache 里的原始 mask
                # 这样下一个类别 (A2) 再用到这张图时，面对的还是初始状态（只避开原有标签，不避开 A1）
                
        f_cls.close()
        print(f"Class {name}: Inserted into {images_processed_count} images. Saved to {class_out_path}")
        
    # f_out.close()
    print("All done!")

if __name__ == '__main__':
    main()