'''
Tried to perform harmonization using CDTNet from libcom library.
TLDR: Doesn't work...
'''

import os
import cv2
import argparse
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm

from libcom import ImageHarmonizationModel

parser = argparse.ArgumentParser()
parser.add_argument('--dir_path', type=str, required=True)
parser.add_argument('--masks_dir', type=str, default=None)
parser.add_argument('--save_dir_path', type=str, default=None)

args = parser.parse_args()
if args.save_dir_path is None:
    args.save_dir_path = str(Path(args.dir_path).parent / "harmonized_CDTNet")

if args.masks_dir is None:
    scene_name = Path(args.dir_path).parent.parent.name.split('-')[0]
    args.masks_dir = f'/scratch/izar/skorokho/blender_attmpt_2/{scene_name}/obj_scene'

os.makedirs(args.save_dir_path, exist_ok=True)

PCTNet = ImageHarmonizationModel(device=0, model_type='CDTNet')

images_names = sorted(os.listdir(args.dir_path))
for im_name in tqdm(images_names):
    image_idx = int(im_name[6:-4])
    mask_path = os.path.join(args.masks_dir, f'mask_{image_idx}0001.png')

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask.max() == 0:
        continue

    mask = cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1)

    PCT_processed = PCTNet(
        os.path.join(args.dir_path, im_name), 
        mask
    )

    result_path = os.path.join(args.save_dir_path, im_name)
    cv2.imwrite(result_path, PCT_processed)
