import argparse
import os
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation


def parse_args():
    parser = argparse.ArgumentParser(description='Remove background from images using RMBG-2.0')
    parser.add_argument('--input_path', type=str, default=None,
                        help='Path to input image or directory of images')
    parser.add_argument('--output_dir', type=str, default='output',
                        help='Output directory for processed images')
    parser.add_argument('--image_size', type=int, nargs=2, default=[512, 512],
                        help='Image size for processing (width height)')
    parser.add_argument('--model_name', type=str,
                        default="/home/swu/cyh/MasterGraduationProject/pretrained_models/RMBG-2.0",
                        help='Model to use for segmentation')
    parser.add_argument('--precision', type=str, default='high',
                        choices=['high', 'highest'],
                        help='Float32 matmul precision')
    return parser.parse_args()


def get_transform(image_size):
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


def process_single_image(model, transform_image, image_path, output_dir):
    # Get base filename without extension
    base_name = os.path.splitext(os.path.basename(image_path))[0]

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    input_images = transform_image(image).unsqueeze(0).to('cuda')

    # Prediction
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()

    # Process mask and apply to image
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image.size)
    image.putalpha(mask)

    # Save result
    output_path = os.path.join(output_dir, f"{base_name}_nobg.png")
    image.save(output_path)
    print(f"✓ Saved processed image to: {output_path}")


def single_image_rmbg(input_path, output_path, only_return=False):
    # Set precision
    torch.set_float32_matmul_precision('high')

    # Initialize model
    print("Initializing RMBG-2.0 model...")
    model = AutoModelForImageSegmentation.from_pretrained(
        "/home/swu/cyh/MasterGraduationProject/pretrained_models/RMBG-2.0",
        trust_remote_code=True
    )
    model.to('cuda')
    model.eval()

    # Initialize transform
    transform_image = get_transform(tuple([512, 512]))

    # Load and transform image
    image = Image.open(input_path).convert('RGB')
    input_images = transform_image(image).unsqueeze(0).to('cuda')

    # Prediction
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()

    # Process mask and apply to image
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image.size)
    image.putalpha(mask)

    if only_return:
        return image, mask
    # Save result
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    image.save(output_path)
    print(f"✓ Saved processed image to: {output_path}")

#对整个文件夹进行去除背景处理
def folder_image_rmbg(input_dir, output_dir, image_size=(512, 512)):
    """
    批量去除指定文件夹中所有图片的背景

    Args:
        input_dir (str): 输入文件夹路径
        output_dir (str): 输出文件夹路径
        image_size (tuple): 图片处理尺寸，默认为 (512, 512)
    """
    # 检查输入文件夹是否存在
    if not os.path.exists(input_dir):
        print(f"错误：输入文件夹 {input_dir} 不存在！")
        return

    if not os.path.isdir(input_dir):
        print(f"错误：{input_dir} 不是一个文件夹！")
        return

    # 支持的图片格式
    supported_formats = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tiff')

    # 获取输入文件夹中的所有图片文件
    image_files = [f for f in os.listdir(input_dir)
                   if f.lower().endswith(supported_formats)]

    if not image_files:
        print(f"在文件夹 {input_dir} 中没有找到支持的图片文件！")
        return

    print(f"输入文件夹: {input_dir}")
    print(f"输出文件夹: {output_dir}")
    print(f"找到 {len(image_files)} 个图片文件")
    print("=" * 50)

    # 创建输出文件夹
    os.makedirs(output_dir, exist_ok=True)

    # 设置精度
    torch.set_float32_matmul_precision('high')

    # 初始化模型
    print("正在初始化 RMBG-2.0 模型...")
    model_path = "/home/swu/cyh/MasterGraduationProject/pretrained_models/RMBG-2.0"
    model = AutoModelForImageSegmentation.from_pretrained(
        model_path,
        trust_remote_code=True
    )
    model.to('cuda')
    model.eval()

    # 初始化图像变换
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 处理每个图片文件
    success_count = 0
    for i, filename in enumerate(image_files, 1):
        try:
            print(f"[{i}/{len(image_files)}] 正在处理: {filename}")

            # 构建完整的输入路径
            input_path = os.path.join(input_dir, filename)

            # 获取文件名（不含扩展名）
            base_name = os.path.splitext(filename)[0]

            # 加载并转换图像
            image = Image.open(input_path).convert('RGB')
            input_images = transform_image(image).unsqueeze(0).to('cuda')

            # 预测
            with torch.no_grad():
                preds = model(input_images)[-1].sigmoid().cpu()

            # 处理掩码并应用到图像
            pred = preds[0].squeeze()
            pred_pil = transforms.ToPILImage()(pred)
            mask = pred_pil.resize(image.size)
            image.putalpha(mask)

            # 保存结果
            output_path = os.path.join(output_dir, f"{base_name}_nobg.png")
            image.save(output_path)
            print(f"✓ 已保存: {output_path}")
            success_count += 1

        except Exception as e:
            print(f"✗ 处理 {filename} 时出错: {str(e)}")
            continue

    print("=" * 50)
    print(f"处理完成！成功处理 {success_count}/{len(image_files)} 个图片文件")
    print(f"所有处理后的图片已保存到: {output_dir} 文件夹")



def main():
    args = parse_args()

    if args.input_path is None:
        print("Error: Input path is required!")
        return

    print(f"Input Path: {args.input_path}")
    print(f"Output Directory: {args.output_dir}")
    print(f"Image Size: {args.image_size}")
    print(f"Model: {args.model_name}")
    print(f"Precision: {args.precision}")
    print("================================\n")

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    # Set precision
    torch.set_float32_matmul_precision(args.precision)

    # Initialize model
    print("Initializing RMBG-2.0 model...")
    model = AutoModelForImageSegmentation.from_pretrained(
        args.model_name,
        trust_remote_code=True
    )
    model.to('cuda')
    model.eval()

    # Initialize transform
    transform_image = get_transform(tuple(args.image_size))

    # Process input path
    if os.path.isfile(args.input_path):
        # Single image processing
        print(f"Processing single image: {args.input_path}")
        process_single_image(model, transform_image, args.input_path, args.output_dir)
    elif os.path.isdir(args.input_path):
        # Directory processing
        print(f"Processing directory: {args.input_path}")
        supported_formats = ('.png', '.jpg', '.jpeg', '.webp')
        for filename in os.listdir(args.input_path):
            if filename.lower().endswith(supported_formats):
                image_path = os.path.join(args.input_path, filename)
                print(f"\nProcessing: {filename}")
                process_single_image(model, transform_image, image_path, args.output_dir)
    else:
        print(f"Error: Input path {args.input_path} does not exist!")
        return

    print("\nProcessing completed successfully!")


if __name__ == "__main__":
    #
    # input_path = "/home/swu/szp/MaterialTransfer/material_transfer/data/reference_image/bag/iter_0.png"
    # single_image_rmbg(input_path, input_path.replace(".jpg", "_nobg.png"))

    input_folder= "/home/swu/szp/MaterialTransfer/paper_images/test"
    output_folder="/home/swu/szp/MaterialTransfer/paper_images/nob/test"
    folder_image_rmbg(input_folder,output_folder)