"""
测试天空mask的边缘软化处理效果

用法：
  python test_sky_mask_processing.py --mask_path /path/to/sky_mask.png

可选参数：
  --blur_sigma: 高斯模糊的sigma值（默认15.0）
  --erode_kernel: 形态学腐蚀的核大小（默认7）
"""

import argparse
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path


def soften_sky_mask(mask: np.ndarray, blur_sigma: float = 15.0, erode_kernel: int = 7) -> np.ndarray:
    """
    软化天空mask的边缘，让值在0-1之间平滑过渡，并略微缩小天空区域。
    
    Args:
        mask: 二值mask，值为0或255
        blur_sigma: 高斯模糊的sigma值，越大边缘越平滑
        erode_kernel: 形态学腐蚀的核大小，用于略微缩小天空区域
    
    Returns:
        软化后的mask，值在0-1之间
    """
    # 归一化到0-1
    mask_normalized = mask.astype(np.float32) / 255.0
    
    # 可选的形态学腐蚀，让天空区域略微缩小
    if erode_kernel > 0:
        kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
        mask_normalized = cv2.erode(mask_normalized, kernel, iterations=1)
    
    # 高斯模糊软化边缘
    mask_softened = cv2.GaussianBlur(mask_normalized, (0, 0), blur_sigma)
    
    return mask_softened


def main():
    parser = argparse.ArgumentParser(description='测试天空mask边缘软化处理')
    parser.add_argument('--mask_path', type=str, required=True, help='输入的mask文件路径')
    parser.add_argument('--blur_sigma', type=float, default=9.0, help='高斯模糊sigma值（默认15.0）')
    parser.add_argument('--erode_kernel', type=int, default=3, help='腐蚀核大小（默认7）')
    parser.add_argument('--output_dir', type=str, default='test_mask_outputs', help='输出目录（默认test_mask_outputs）')
    
    args = parser.parse_args()
    
    # 读取mask
    mask_path = Path(args.mask_path)
    if not mask_path.exists():
        raise FileNotFoundError(f'找不到mask文件: {args.mask_path}')
    
    mask_original = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    if mask_original is None:
        raise ValueError(f'无法读取mask文件: {args.mask_path}')
    
    print(f'原始mask尺寸: {mask_original.shape}')
    print(f'参数设置:')
    print(f'  - 高斯模糊sigma: {args.blur_sigma}')
    print(f'  - 腐蚀核大小: {args.erode_kernel}')
    
    # 处理mask
    mask_softened = soften_sky_mask(mask_original, 
                                    blur_sigma=args.blur_sigma, 
                                    erode_kernel=args.erode_kernel)
    
    # 创建输出目录
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 可视化
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f'天空Mask边缘软化处理效果\n高斯模糊sigma={args.blur_sigma}, 腐蚀核={args.erode_kernel}', fontsize=14)
    
    # 原始mask
    axes[0, 0].imshow(mask_original, cmap='gray')
    axes[0, 0].set_title('原始Mask')
    axes[0, 0].axis('off')
    
    # 原始mask热力图
    im1 = axes[0, 1].imshow(mask_original, cmap='jet')
    axes[0, 1].set_title('原始Mask (热力图)')
    axes[0, 1].axis('off')
    plt.colorbar(im1, ax=axes[0, 1])
    
    # 原始mask边缘（使用Sobel）
    gray = mask_original.astype(np.float32)
    grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    edges_original = np.sqrt(grad_x**2 + grad_y**2)
    im2 = axes[0, 2].imshow(edges_original, cmap='hot')
    axes[0, 2].set_title('原始Mask边缘')
    axes[0, 2].axis('off')
    plt.colorbar(im2, ax=axes[0, 2])
    
    # 软化后的mask
    axes[1, 0].imshow(mask_softened, cmap='gray', vmin=0, vmax=1)
    axes[1, 0].set_title('软化后的Mask')
    axes[1, 0].axis('off')
    
    # 软化后的mask热力图
    im3 = axes[1, 1].imshow(mask_softened, cmap='jet', vmin=0, vmax=1)
    axes[1, 1].set_title('软化后的Mask (热力图)')
    axes[1, 1].axis('off')
    plt.colorbar(im3, ax=axes[1, 1])
    
    # 软化后的mask边缘
    edges_softened = cv2.GaussianBlur(mask_softened, (0, 0), 3)
    grad_x = cv2.Sobel(edges_softened, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(edges_softened, cv2.CV_64F, 0, 1, ksize=3)
    edges_softened_grad = np.sqrt(grad_x**2 + grad_y**2)
    im4 = axes[1, 2].imshow(edges_softened_grad, cmap='hot')
    axes[1, 2].set_title('软化后的Mask边缘')
    axes[1, 2].axis('off')
    plt.colorbar(im4, ax=axes[1, 2])
    
    # 保存图像
    output_path = output_dir / 'mask_processing_comparison.png'
    plt.tight_layout()
    plt.savefig(str(output_path), dpi=150, bbox_inches='tight')
    print(f'\n可视化结果已保存到: {output_path}')
    
    # 保存处理后的mask
    output_mask_path = output_dir / 'mask_softened.png'
    mask_softened_uint8 = (mask_softened * 255).astype(np.uint8)
    cv2.imwrite(str(output_mask_path), mask_softened_uint8)
    print(f'处理后的mask已保存到: {output_mask_path}')
    
    # 统计信息
    print(f'\n统计信息:')
    print(f'  原始mask天空区域像素数: {np.count_nonzero(mask_original)} ({np.count_nonzero(mask_original)/mask_original.size*100:.2f}%)')
    print(f'  软化后mask>0.5的像素数: {np.count_nonzero(mask_softened > 0.5)} ({np.count_nonzero(mask_softened > 0.5)/mask_softened.size*100:.2f}%)')
    print(f'  软化后mask>0.8的像素数: {np.count_nonzero(mask_softened > 0.8)} ({np.count_nonzero(mask_softened > 0.8)/mask_softened.size*100:.2f}%)')
    print(f'  软化后mask的最大值: {mask_softened.max():.4f}')
    print(f'  软化后mask的最小值: {mask_softened.min():.4f}')
    print(f'  软化后mask的均值: {mask_softened.mean():.4f}')
    
    plt.show()


if __name__ == '__main__':
    main()

