import os
import cv2
import numpy as np
import random
from tqdm import tqdm

def random_warp(image):
    """随机扭曲变形"""
    h, w = image.shape[:2]
    
    # 定义原始点（4个角点）
    pts1 = np.float32([
        [0, 0],     # 左上角
        [w, 0],     # 右上角
        [0, h],     # 左下角
        [w, h]      # 右下角
    ])
    
    # 随机生成目标点（在原始点周围±10%范围内变化）
    max_offset = min(w, h) * 0.1  # 限制偏移范围，避免点集无效
    pts2 = pts1 + np.random.uniform(-max_offset, max_offset, pts1.shape).astype(np.float32)
    
    # 确保点集形状正确（4x2）
    assert pts1.shape == (4, 2), f"pts1 shape is {pts1.shape}, expected (4, 2)"
    assert pts2.shape == (4, 2), f"pts2 shape is {pts2.shape}, expected (4, 2)"
    
    # 计算透视变换矩阵并应用
    M = cv2.getPerspectiveTransform(pts1, pts2)
    warped = cv2.warpPerspective(image, M, (w, h), flags=cv2.INTER_NEAREST)
    
    return warped

def augment_image(image, mask):
    """对图像和掩码进行相同的增强"""
    # 随机水平翻转
    if random.random() > 0.5:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)
    
    # 随机旋转（±10度）
    angle = random.uniform(-10, 10)
    h, w = image.shape[:2]
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR)
    mask = cv2.warpAffine(mask, M, (w, h), flags=cv2.INTER_NEAREST)
    
    # 随机扭曲变形
    if random.random() > 0.5:
        image = random_warp(image)
        mask = random_warp(mask)
    
    return image, mask

def main():
    # 设置路径
    images_dir = 'dataset/OCT2/images'
    masks_dir = 'dataset/OCT2/masks'
    output_images_dir = 'dataset/OCT2/augmented_images'
    output_masks_dir = 'dataset/OCT2/augmented_masks'
    
    # 创建输出目录
    os.makedirs(output_images_dir, exist_ok=True)
    os.makedirs(output_masks_dir, exist_ok=True)
    
    # 获取所有图像文件名（仅.jpg文件）
    image_files = [f for f in os.listdir(images_dir) 
                  if f.lower().endswith('.jpg') and os.path.isfile(os.path.join(images_dir, f))]
    
    # 对每张图像进行增强
    for image_file in tqdm(image_files, desc='Processing images'):
        # 读取原始图像
        image_path = os.path.join(images_dir, image_file)
        image = cv2.imread(image_path)
        
        if image is None:
            print(f"Warning: Could not read image {image_file}. Skipping.")
            continue
        
        # 构造对应的掩码文件名（.jpg -> .png）
        mask_file = os.path.splitext(image_file)[0] + '.png'
        mask_path = os.path.join(masks_dir, mask_file)
        
        # 检查掩码文件是否存在
        if not os.path.exists(mask_path):
            print(f"Warning: Mask file {mask_file} not found. Skipping.")
            continue
            
        # 读取掩码（确保以彩色模式读取）
        mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)
        if mask is None:
            print(f"Warning: Could not read mask {mask_file}. Skipping.")
            continue
        
        # 保存原始图像和掩码（作为增强集的一部分）
        base_name = os.path.splitext(image_file)[0]
        cv2.imwrite(os.path.join(output_images_dir, f"{base_name}_0.jpg"), image)
        cv2.imwrite(os.path.join(output_masks_dir, f"{base_name}_0.png"), mask)
        
        # 生成9个增强版本（总共10个）
        for i in range(1, 10):
            aug_image, aug_mask = augment_image(image.copy(), mask.copy())
            cv2.imwrite(os.path.join(output_images_dir, f"{base_name}_{i}.jpg"), aug_image)
            cv2.imwrite(os.path.join(output_masks_dir, f"{base_name}_{i}.png"), aug_mask)

if __name__ == '__main__':
    main()
    print("数据增强完成！")