# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import multiprocessing
import os
from pathlib import Path

from mmengine import (Config, DictAction, track_parallel_progress,
                      track_progress)

from mmpretrain.datasets import build_dataset
from mmpretrain.registry import TRANSFORMS

file_lock = multiprocessing.Lock()


def parse_args():
    parser = argparse.ArgumentParser(description='Verify Dataset')
    parser.add_argument('config', help='config file path')
    parser.add_argument(
        '--out-path',
        type=str,
        default='brokenfiles.log',
        help='output path of all the broken files. If the specified path '
        'already exists, delete the previous file ')
    parser.add_argument(
        '--phase',
        default='train',
        type=str,
        choices=['train', 'test', 'val'],
        help='phase of dataset to visualize, accept "train" "test" and "val".')
    parser.add_argument(
        '--num-process', type=int, default=1, help='number of process to use')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()
    assert args.out_path is not None
    assert args.num_process > 0
    return args


class DatasetValidator():
    """the dataset tool class to check if all file are broken."""

    def __init__(self, dataset_cfg, log_file_path):
        super(DatasetValidator, self).__init__()
        # keep only LoadImageFromFile pipeline
        from mmpretrain.datasets import get_transform_idx

        load_idx = get_transform_idx(dataset_cfg.pipeline, 'LoadImageFromFile')
        assert load_idx >= 0, \
            'This tool is only for datasets needs to load image from files.'
        self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[load_idx])
        dataset_cfg.pipeline = []
        dataset = build_dataset(dataset_cfg)

        self.dataset = dataset
        self.log_file_path = log_file_path

    def valid_idx(self, idx):
        item = self.dataset[idx]
        try:
            item = self.pipeline(item)
        except Exception:
            with open(self.log_file_path, 'a') as f:
                # add file lock to prevent multi-process writing errors
                filepath = str(Path(item['img_path']))
                file_lock.acquire()
                f.write(filepath + '\n')
                file_lock.release()
                print(f'{filepath} cannot be read correctly, please check it.')

    def __len__(self):
        return len(self.dataset)


def print_info(log_file_path):
    """print some information and do extra action."""
    print()
    with open(log_file_path, 'r') as f:
        content = f.read().strip()
        if content == '':
            print('There is no broken file found.')
            os.remove(log_file_path)
        else:
            num_file = len(content.split('\n'))
            print(f'{num_file} broken files found, name list save in file:'
                  f'{log_file_path}')
    print()


def main():
    # parse cfg and args
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # touch output file to save broken files list.
    output_path = Path(args.out_path)
    if not output_path.parent.exists():
        raise Exception("Path '--out-path' parent directory not found.")
    if output_path.exists():
        os.remove(output_path)
    output_path.touch()

    if args.phase == 'train':
        dataset_cfg = cfg.train_dataloader.dataset
    elif args.phase == 'val':
        dataset_cfg = cfg.val_dataloader.dataset
    elif args.phase == 'test':
        dataset_cfg = cfg.test_dataloader.dataset
    else:
        raise ValueError("'--phase' only support 'train', 'val' and 'test'.")

    # do validate
    validator = DatasetValidator(dataset_cfg, output_path)

    if args.num_process > 1:
        # The default chunksize calcuation method of Pool.map
        chunksize, extra = divmod(len(validator), args.num_process * 8)
        if extra:
            chunksize += 1

        track_parallel_progress(
            validator.valid_idx,
            list(range(len(validator))),
            args.num_process,
            chunksize=chunksize,
            keep_order=False)
    else:
        track_progress(validator.valid_idx, list(range(len(validator))))

    print_info(output_path)


if __name__ == '__main__':
    main()
