# -*- coding: utf-8 -*-
import argparse
import os


def get_args():
    parser = argparse.ArgumentParser(description='aug data.')

    # define arguments.
    parser.add_argument('--data_dir', default=None)
    parser.add_argument('--keep', default='', type=str)

    # parse args.
    args = parser.parse_args()

    # check args.
    assert args.data_dir is not None
    return args


def main(args):
    root = args.data_dir
    to_keep_patterns = args.keep.split(',')
    files = []
    valid_files = []

    # get checkpoint files.
    for rank in os.listdir(root):
        rank_path = os.path.join(root, rank)

        for file in os.listdir(rank_path):
            if 'checkpoint_' in file:
                files.append(os.path.join(rank_path, file))

    # filter checkpoint files, and remove valid checkpoint files.
    for file in files:
        for pattern in to_keep_patterns:
            if pattern + '.' not in file:
                valid_files.append(file)
                break

    # remove files.
    for file in valid_files:
        print('remove file from path: {}'.format(file))
        os.remove(file)


if __name__ == '__main__':
    args = get_args()
    main(args)
