"""
Preprocesses images to have same resolution.
"""

import os

import torchvision
from PIL import Image
from tqdm import tqdm

from hyperparams.load import get_config

config = get_config()


def main(size=64):
    resize = torchvision.transforms.Resize((size, size))
    data_root = config.dirs['flowers_images']
    src = os.path.join(data_root, 'jpg')
    dst = os.path.join(data_root, f'transformed_jpg_{size}')

    for root, dirs, files in os.walk(src):
        for f in tqdm(files):
            d = os.path.join(root, f)
            im = Image.open(d).convert('RGB')
            w, h = im.width, im.height

            # make picture quadratic and then resize
            crop = torchvision.transforms.CenterCrop(min(w, h))
            img_new = resize(crop(im))

            img_new.save(os.path.join(dst, f))


if __name__ == '__main__':
    main(size=64)
