import argparse
import os
import os.path as osp
import zipfile
import random
import shutil
import mmcv
from mmengine.utils import mkdir_or_exist
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert COVID-19 CT segmentation dataset from zip file to mmsegmentation format with train/val split')
    parser.add_argument('zip_file', help='Path to the input dataset zip file')
    parser.add_argument('output_dir', help='Path to the output dataset folder')
    parser.add_argument('--val_ratio', type=float, default=0.2, help='Validation set ratio (default: 0.2)')
    parser.add_argument('--manual_suffix', type=str, default='_manual1', help='Suffix to add to mask filenames (default: _manual1)')
    args = parser.parse_args()
    return args

def convert_mask_to_binary(mask):
    """Convert mask values to binary (0 for background, 1 for lesions)."""
    binary_mask = np.where(mask > 0, 1, 0).astype(np.uint8)
    return binary_mask

def main():
    args = parse_args()
    zip_file = args.zip_file
    output_dir = args.output_dir
    val_ratio = args.val_ratio
    manual_suffix = args.manual_suffix

    # 临时解压目录
    temp_dir = osp.join(output_dir, 'temp_unzip')
    mkdir_or_exist(temp_dir)

    # 解压 ZIP 文件
    print(f'Extracting {zip_file} to temporary directory...')
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(temp_dir)

    # 数据路径
    frames_dir = osp.join(temp_dir, 'frames')
    masks_dir = osp.join(temp_dir, 'masks')

    # 获取所有图像和掩码文件名
    img_files = sorted(os.listdir(frames_dir))
    mask_files = sorted(os.listdir(masks_dir))

    # 提取文件名基准（去掉扩展名）
    img_basenames = [osp.splitext(img)[0] for img in img_files]
    mask_basenames = [osp.splitext(mask)[0] for mask in mask_files]

    # 检查并匹配文件名
    matched_files = {}
    for img_name in img_basenames:
        matched_mask = None
        for mask_name in mask_basenames:
            if img_name in mask_name:  # 匹配逻辑，可以根据需要调整
                matched_mask = mask_name
                break
        if matched_mask:
            matched_files[img_name] = matched_mask
        else:
            print(f"Warning: No matching mask found for {img_name}!")

    # 确保匹配的图像和掩码数量一致
    assert len(matched_files) == len(img_basenames), "Mismatch between frames and masks filenames!"

    # 随机划分训练集和验证集
    img_names = list(matched_files.keys())
    random.shuffle(img_names)
    val_size = int(len(img_names) * val_ratio)
    val_img_names = img_names[:val_size]
    train_img_names = img_names[val_size:]

    print(f'Total images: {len(img_names)}, Training: {len(train_img_names)}, Validation: {len(val_img_names)}')

    # 创建输出目录
    print('Creating output directories...')
    mkdir_or_exist(osp.join(output_dir, 'images', 'training'))
    mkdir_or_exist(osp.join(output_dir, 'images', 'validation'))
    mkdir_or_exist(osp.join(output_dir, 'annotations', 'training'))
    mkdir_or_exist(osp.join(output_dir, 'annotations', 'validation'))

    # 保存训练集数据
    print('Processing training dataset...')
    for img_name in train_img_names:
        img_path = osp.join(frames_dir, f"{img_name}.png")
        mask_path = osp.join(masks_dir, f"{matched_files[img_name]}.png")

        img = mmcv.imread(img_path)
        mask = mmcv.imread(mask_path, flag='grayscale')

        # 转换掩码为二值
        binary_mask = convert_mask_to_binary(mask)

        mmcv.imwrite(img, osp.join(output_dir, 'images', 'training', f"{img_name}.png"))
        mmcv.imwrite(binary_mask, osp.join(output_dir, 'annotations', 'training', f"{img_name}{manual_suffix}.png"))

    # 保存验证集数据
    print('Processing validation dataset...')
    for img_name in val_img_names:
        img_path = osp.join(frames_dir, f"{img_name}.png")
        mask_path = osp.join(masks_dir, f"{matched_files[img_name]}.png")

        img = mmcv.imread(img_path)
        mask = mmcv.imread(mask_path, flag='grayscale')

        # 转换掩码为二值
        binary_mask = convert_mask_to_binary(mask)

        mmcv.imwrite(img, osp.join(output_dir, 'images', 'validation', f"{img_name}.png"))
        mmcv.imwrite(binary_mask, osp.join(output_dir, 'annotations', 'validation', f"{img_name}{manual_suffix}.png"))

    # 删除临时解压目录
    print('Cleaning up temporary files...')
    shutil.rmtree(temp_dir)

    print('Dataset conversion completed!')

if __name__ == '__main__':
    main()
