from PIL import Image, ImageEnhance, ImageFilter, ImageOps
import pandas as pd
import random
import os
import copy
from tqdm import tqdm
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="对parquet文件中的图片路径进行扰动处理")
    parser.add_argument('--input_path', type=str, required=True, help='输入parquet文件路径')
    parser.add_argument('--output_path', type=str, required=True, help='输出parquet文件路径')
    parser.add_argument('--image_key', type=str, default='image_paths', help='DataFrame中图片路径字段名，默认为image_paths')
    parser.add_argument('--perturbed_dir_prefix', type=str, default='perturbed_images', help='扰动图片保存目录前缀，默认perturbed_images')
    parser.add_argument('--show_dir', type=str, default='show', help='可视化图片保存目录（可选）')
    return parser.parse_args()

# 图像扰动函数集合，加入flip、invert、change_bg_color
def perturb_image(img):
    operations = ['scale', 'rotate', 'noise', 'invert']
    op = random.choice(operations)
    # op = 'rotate'
    
    if op == 'scale':
        # 更大幅度地缩小，避免裁剪内容
        scale = random.uniform(0.5, 0.7)  # 缩放比例
        new_size = (int(img.size[0] * scale), int(img.size[1] * scale))
        scaled_img = img.resize(new_size, Image.LANCZOS)

        # 创建原图尺寸的白底图像
        background = Image.new("RGB", img.size, (255, 255, 255))
        paste_x = (img.size[0] - new_size[0]) // 2
        paste_y = (img.size[1] - new_size[1]) // 2
        background.paste(scaled_img, (paste_x, paste_y))
        img = background
    elif op == 'rotate':
        # 使用中心旋转 + 自动expand，防止图像内容裁剪
        angle = random.uniform(-90, 90)
        img = img.convert("RGBA")
        rot = img.rotate(angle, expand=True)
        fff = Image.new('RGBA', rot.size, (255,)*4)  # 白色背景
        img = Image.alpha_composite(fff, rot).convert("RGB")
    elif op == 'noise':
        # 噪声幅度大幅加大
        img = img.convert('RGB')
        pixels = img.load()
        for i in range(img.size[0]):
            for j in range(img.size[1]):
                r, g, b = pixels[i, j]
                pixels[i, j] = (
                    min(255, max(0, r + random.randint(-100, 100))),
                    min(255, max(0, g + random.randint(-100, 100))),
                    min(255, max(0, b + random.randint(-100, 100)))
                )
    elif op == 'invert':
        # 色彩反转
        img = ImageOps.invert(img)
    elif op == 'change_bg_color':
        # 换背景颜色（假设白色为背景，替换为随机颜色）
        img = img.convert("RGBA")
        datas = img.getdata()
        # 随机生成背景色
        bg_color = (random.randint(0,255), random.randint(0,255), random.randint(0,255), 255)
        newData = []
        for item in datas:
            # 判断是否为白色背景（可根据实际情况调整阈值）
            if item[0] > 240 and item[1] > 240 and item[2] > 240:
                newData.append(bg_color)
            else:
                newData.append(item)
        img.putdata(newData)
        img = img.convert("RGB")
    return img, op

def main():
    args = parse_args()

    # 确保show文件夹存在
    if args.show_dir:
        os.makedirs(args.show_dir, exist_ok=True)

    # 加载 parquet 文件
    df = pd.read_parquet(args.input_path)

    # 新增字段用于存储扰动后的图片路径
    perturbed_image_paths_list = []

    # 遍历每个样本，对每个样本的所有有效图片路径进行扰动，加入进度条
    for idx, row in tqdm(list(df.iterrows()), total=len(df), desc="处理样本"):
        image_dict = row[args.image_key]
        if not isinstance(image_dict, dict):
            perturbed_image_paths_list.append(None)
            continue
        perturbed_image_dict = {}
        for key, path in image_dict.items():
            if isinstance(path, str) and os.path.exists(path):
                try:
                    img = Image.open(path).convert("RGB")
                    perturbed_img, op_used = perturb_image(img)
                    # 在原路径同目录下保存，文件名前加perturbed_
                    # 将 images 路径替换为 perturbed_images 路径
                    dir_name = os.path.dirname(path)
                    base_name = os.path.basename(path)
                    # 替换路径中的 images 为 perturbed_images
                    if dir_name.startswith("images"):
                        perturbed_dir = dir_name.replace("images", args.perturbed_dir_prefix, 1)
                    else:
                        print(f"原路径不符合要求！{path}")
                        perturbed_dir = os.path.join(args.perturbed_dir_prefix, dir_name)
                    
                    os.makedirs(perturbed_dir, exist_ok=True)
                    
                    perturbed_name = f"perturbed_{base_name}"
                    out_path = os.path.join(perturbed_dir, perturbed_name)
                    perturbed_img.save(out_path)
                    perturbed_image_dict[key] = out_path

                    # 将原图保存到perturbed_dir文件夹，并用原有的名字
                    original_save_path = os.path.join(perturbed_dir, base_name)
                    img.save(original_save_path)
                    
                except Exception as e:
                    print(f"处理图片 {path} 时出错: {e}")
                    perturbed_image_dict[key] = None
            else:
                perturbed_image_dict[key] = None
        perturbed_image_paths_list.append(perturbed_image_dict)

    # 将扰动后的图片路径添加到DataFrame新字段
    df['perturbed_image_paths'] = perturbed_image_paths_list

    # 保存新的parquet文件
    df.to_parquet(args.output_path, index=False)
    print(f"已保存到: {args.output_path}")

if __name__ == "__main__":
    main()
