# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import cv2
import mmcv
from mmengine.utils import mkdir_or_exist


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert DRIVE dataset to mmsegmentation format')
    parser.add_argument(
        'training_path', help='the training part of DRIVE dataset (unpacked folder)')
    parser.add_argument(
        'testing_path', help='the testing part of DRIVE dataset (unpacked folder)')
    parser.add_argument('-o', '--out_dir', help='output path')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    training_path = args.training_path  # 解压后的训练文件夹路径
    testing_path = args.testing_path  # 解压后的测试文件夹路径
    if args.out_dir is None:
        out_dir = osp.join('data', 'DRIVE')
    else:
        out_dir = args.out_dir

    print('Making directories...')
    mkdir_or_exist(out_dir)
    mkdir_or_exist(osp.join(out_dir, 'images'))
    mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
    mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
    mkdir_or_exist(osp.join(out_dir, 'annotations'))
    mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
    mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))

    print('Generating training dataset...')
    # 直接使用解压后的路径
    now_dir = osp.join(training_path, 'images')
    for img_name in os.listdir(now_dir):
        img = mmcv.imread(osp.join(now_dir, img_name))
        mmcv.imwrite(
            img,
            osp.join(
                out_dir, 'images', 'training',
                osp.splitext(img_name)[0].replace('_training', '') +
                '.png'))

    now_dir = osp.join(training_path, '1st_manual')
    for img_name in os.listdir(now_dir):
        cap = cv2.VideoCapture(osp.join(now_dir, img_name))
        ret, img = cap.read()
        mmcv.imwrite(
            img[:, :, 0] // 128,
            osp.join(out_dir, 'annotations', 'training',
                     osp.splitext(img_name)[0] + '.png'))

    print('Generating validation dataset...')
    now_dir = osp.join(testing_path, 'images')
    for img_name in os.listdir(now_dir):
        img = mmcv.imread(osp.join(now_dir, img_name))
        mmcv.imwrite(
            img,
            osp.join(
                out_dir, 'images', 'validation',
                osp.splitext(img_name)[0].replace('_test', '') + '.png'))

    now_dir = osp.join(testing_path, '1st_manual')
    if osp.exists(now_dir):
        for img_name in os.listdir(now_dir):
            cap = cv2.VideoCapture(osp.join(now_dir, img_name))
            ret, img = cap.read()
            mmcv.imwrite(
                img[:, :, 0] // 128,
                osp.join(out_dir, 'annotations', 'validation',
                         osp.splitext(img_name)[0] + '.png'))

    now_dir = osp.join(testing_path, '2nd_manual')
    if osp.exists(now_dir):
        for img_name in os.listdir(now_dir):
            cap = cv2.VideoCapture(osp.join(now_dir, img_name))
            ret, img = cap.read()
            mmcv.imwrite(
                img[:, :, 0] // 128,
                osp.join(out_dir, 'annotations', 'validation',
                         osp.splitext(img_name)[0] + '.png'))

    print('Done!')


if __name__ == '__main__':
    main()
