from argparse import ArgumentParser
import os
from shutil import copyfile, rmtree
from pathlib import Path


def extract_ensemble_checkpoint(folder, clean):
    # dirs = sorted([directory for directory in os.listdir(folder) if os.path.isdir()])
    dirs = [x for x in Path(folder).iterdir() if x.is_dir()]

    i = 0
    for directory in dirs[::-1]:
        print(directory)
        files = sorted(os.listdir(directory))
        if len(files) > 10:
            files = [f for f in files if f.endswith('best.pth')]
            copyfile(directory / files[-1], folder / f'model_{i}.pth')
            i += 1
        if clean:
            rmtree(directory)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--ood-name', type=str, default='vehicles')
    parser.add_argument('--net', type=str, default='resnet50')
    parser.add_argument('--data-seed', type=int, default=42)
    parser.add_argument('--clean', action='store_true', default=False, help='delete source directory or not')
    args = parser.parse_args()
    path = Path(__file__).parent.parent / f'../experiments/checkpoint/{args.net}/{args.ood_name}_{args.data_seed}'
    extract_ensemble_checkpoint(path, args.clean)
