import torch
from pathlib import Path
import argparse
import torchvision.transforms as T
from torchvision.io import read_image
from torchvision.utils import save_image

def main(args):
    path_masks_dir = args.masks_dir
    # read masks as tensors
    paths_masks = list(path_masks_dir.iterdir())
    masks = torch.stack([read_image(str(p)) / 255 for p in paths_masks])

    # apply gaussian filter to each
    # gaussian_filter = T.GaussianBlur(91, sigma = 90.) # face_soft
    gaussian_filter = T.GaussianBlur(args.kernel_size, sigma = args.sigma) # face_soft
    masks_soft = gaussian_filter(masks)

    # preserve original 'keep' regions
    masks_soft[masks == 1.] = 1.

    # make new dir for soft versions of original masks
    path_masks_dir_name = path_masks_dir.name
    path_out = path_masks_dir.parent / f'{path_masks_dir_name}_soft'
    path_out.mkdir(parents = True, exist_ok = True)

    # save masks
    for path_orig, soft_mask in zip(paths_masks, masks_soft):
        path_out_mask = path_out / path_orig.name
        save_image(soft_mask, str(path_out_mask))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--masks-dir', type = Path, help = 'Directory with masks to soften')
    parser.add_argument('--kernel-size', type = int, help = 'Kernel size for gaussian filter')
    parser.add_argument('--sigma', type = float, help = 'Sigma for gaussian filter')
    args = parser.parse_args()
    main(args)