import torch
from pathlib import Path
import argparse
import torchvision.transforms as T
from torchvision.io import read_image
from torchvision.utils import save_image

def get_mean_abs_diff(img_gt, img_inp):
    diff = (img_gt - img_inp).abs().mean(0)
    diff -= diff.min()
    diff /= diff.max()
    return diff

def main(args):
    path_log_dir = args.log_dir
    path_out = path_log_dir / 'mean_abs_diff'
    path_out.mkdir(exist_ok = True)
    paths_gt = list((path_log_dir / 'gt').iterdir())
    paths_inp = list((path_log_dir / 'inpainted').iterdir())
    paths_gt.sort(); paths_inp.sort()
    for path_gt, path_inp in zip(paths_gt, paths_inp):
        img_gt = read_image(str(path_gt)) / 255
        img_inp = read_image(str(path_inp)) / 255
        diff = get_mean_abs_diff(img_gt, img_inp)
        path_out_spec = path_out / path_inp.name
        save_image(diff, str(path_out_spec))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--log-dir', type = Path, help = 'Directory with masks to soften')
    args = parser.parse_args()
    main(args)