# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import zipfile
import cv2
import pandas as pd
from mmengine.utils import mkdir_or_exist
from tqdm import tqdm


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert Pneumothorax dataset to mmsegmentation format')
    parser.add_argument(
        'zip_path', help='Path to the Pneumothorax.zip file')
    parser.add_argument(
        '-o', '--out_dir', required=True, help='Output directory')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    zip_path = args.zip_path
    out_dir = args.out_dir

    print('Extracting Pneumothorax.zip...')
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(out_dir)

    # 定义解压后的路径
    extracted_path = osp.join(out_dir, 'siim-acr-pneumothorax')
    images_path = osp.join(extracted_path, 'png_images')
    masks_path = osp.join(extracted_path, 'png_masks')
    train_csv = osp.join(extracted_path, 'stage_1_train_images.csv')
    test_csv = osp.join(extracted_path, 'stage_1_test_images.csv')

    # 创建输出的 DRIVE 格式结构
    drive_path = osp.join(out_dir, 'DRIVE_format')
    mkdir_or_exist(osp.join(drive_path, 'images', 'training'))
    mkdir_or_exist(osp.join(drive_path, 'images', 'validation'))
    mkdir_or_exist(osp.join(drive_path, 'annotations', 'training'))
    mkdir_or_exist(osp.join(drive_path, 'annotations', 'validation'))

    # 处理训练集
    print("Processing training data...")
    train_df = pd.read_csv(train_csv)
    for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
        img_name = row['new_filename']
        img_path = osp.join(images_path, img_name)
        mask_path = osp.join(masks_path, img_name)

        if not osp.exists(img_path):
            print(f"Warning: Missing image {img_name}")
            continue
        if not osp.exists(mask_path):
            print(f"Warning: Missing mask {img_name}")
            continue

        # 保存图像
        img = cv2.imread(img_path)
        cv2.imwrite(osp.join(drive_path, 'images', 'training', img_name), img)

        # 保存二值化的掩码
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        binary_mask = (mask > 128).astype('uint8') * 255
        cv2.imwrite(osp.join(drive_path, 'annotations', 'training', img_name), binary_mask)

    # 处理验证集
    print("Processing validation data...")
    test_df = pd.read_csv(test_csv)
    for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
        img_name = row['new_filename']
        img_path = osp.join(images_path, img_name)
        mask_path = osp.join(masks_path, img_name)

        if not osp.exists(img_path):
            print(f"Warning: Missing image {img_name}")
            continue
        if not osp.exists(mask_path):
            print(f"Warning: Missing mask {img_name}")
            continue

        # 保存图像
        img = cv2.imread(img_path)
        cv2.imwrite(osp.join(drive_path, 'images', 'validation', img_name), img)

        # 保存二值化的掩码
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        binary_mask = (mask > 128).astype('uint8') * 255
        cv2.imwrite(osp.join(drive_path, 'annotations', 'validation', img_name), binary_mask)

    print("Dataset conversion completed successfully.")


if __name__ == '__main__':
    main()
