import torch
import argparse

path_to_sketch_gan = "models/sketch_gan.t7"
data_path = "datasets/CelebAMask-HQ/CelebA-HQ-img/"
output_dir = "datasets/CelebAMask-HQ/CelebA-HQ-sketch/"

parser = argparse.ArgumentParser(description='Create sketchs from images.')
parser.add_argument('--model_path', type=str, help='path to sketch gan', default=path_to_sketch_gan)
parser.add_argument('--data_path' , type=str, help='path to images'    , default=data_path)
parser.add_argument('--output_dir', type=str, help='path to output'    , default=output_dir)

args = parser.parse_args()
path_to_sketch_gan = args.model_path
data_path = args.data_path
output_dir = args.output_dir


from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.serialization import load_lua
import os
import cv2
import numpy as np

"""
NOTE!: Must have torch==0.4.1 and torchvision==0.2.1
The sketch simplification model (sketch_gan.t7) from Simo Serra et al. can be downloaded from their official implementation: 
    https://github.com/bobbens/sketch_simplification
    
python 3.7
conda install pytorch=0.4.1 torchvision=0.2.1 cuda90 -c pytorch
conda install -c conda-forge opencv
pip install "pillow<7"
"""


def sobel(img):
    opImgx = cv2.Sobel(img, cv2.CV_8U, 0, 1, ksize=3)
    opImgy = cv2.Sobel(img, cv2.CV_8U, 1, 0, ksize=3)
    return cv2.bitwise_or(opImgx, opImgy)


def sketch(frame):
    frame = cv2.GaussianBlur(frame, (3, 3), 0)
    invImg = 255 - frame
    edgImg0 = sobel(frame)
    edgImg1 = sobel(invImg)
    edgImg = cv2.addWeighted(edgImg0, 0.75, edgImg1, 0.75, 0)
    opImg = 255 - edgImg
    return opImg


def get_sketch_image(image_path):
    original = cv2.imread(image_path)
    original = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY)
    sketch_image = sketch(original)
    return sketch_image[:, :, np.newaxis]


use_cuda = True

cache = load_lua(path_to_sketch_gan, long_size=8)
model = cache.model
immean = cache.mean
imstd = cache.std
model.evaluate()
if use_cuda:
    model = model.cuda()

images = [os.path.join(data_path, f) for f in os.listdir(data_path)]

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

for idx, image_path in enumerate(images):
    if idx % 50 == 0:
        print("{} out of {}".format(idx, len(images)))
    data = get_sketch_image(image_path)
    data = ((transforms.ToTensor()(data) - immean) / imstd).unsqueeze(0)
    if use_cuda:
        pred = model.forward(data.cuda()).float()
    else:
        pred = model.forward(data)
    save_image(pred[0], os.path.join(output_dir, "{}.jpg".format(image_path.split("/")[-1].split('.')[0])))
