import os
import shutil
from pathlib import Path

from tqdm import tqdm

from data.flowers.main_raw import load_flowers_data
from hyperparams.load import get_config

config = get_config()


def main(split):
    base_dir = os.path.join(config.dirs['data'], 'flowers_images/jpg_64_fid', split)
    os.makedirs(base_dir, exist_ok=True)

    dataset, _ = load_flowers_data(mode=split, batch_size=64)

    # For every image path
    for src in tqdm(dataset.x[0], desc='Iterating over images'):
        dst = os.path.join(base_dir, Path(src).name)
        shutil.copy(src, dst)
    print('Finished copying')


if __name__ == '__main__':
    main(split='train')
