import os, argparse, glob, tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision.transforms import ToTensor, Normalize, Compose
from PIL import Image

from model import FocalNetDepth


def pad_img(x, patch_size):
    _, _, h, w = x.size()
    mod_pad_h = (patch_size - h % patch_size) % patch_size
    mod_pad_w = (patch_size - w % patch_size) % patch_size
    x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
    return x


def norm_zero_to_one(x):
    print(x.max(), x.min())
    return (x - x.min()) / (x.max() - x.min())


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--input', default='', type=str, help='input dir')
    parser.add_argument('--output', default='results', type=str, help='input dir')
    parser.add_argument('--format', default='png', type=str, help='input dir')
    args = parser.parse_args()

    dehaze_net = FocalNetDepth()
    dehaze_net = dehaze_net.cuda()
    pytorch_total_params = sum(p.numel() for p in dehaze_net.parameters() if p.requires_grad)
    print("Total_params: ==> {}".format(pytorch_total_params))
    for param in dehaze_net.parameters():
        param.requires_grad = False

    dehaze_net = torch.nn.DataParallel(dehaze_net)
    ckp = torch.load('DiffAD-FT.pth')
    dehaze_net.load_state_dict(ckp['state_dict'], strict=False)

    transform = Compose([
        ToTensor()
    ])

    if os.path.isdir(args.input):
        hazy_list = glob.glob(os.path.join(args.input, '*.{}'.format(args.format)))
    else:
        hazy_list = glob.glob(args.input)
    hazy_list.sort()

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    for i, img_path in enumerate(tqdm.tqdm(hazy_list)):
        img_name = img_path.split('/')[-1].split('.{}'.format(args.format))[0]
        img = Image.open(img_path).convert('RGB')
        img_hsv = img.convert('HSV')

        img = transform(img).unsqueeze(dim=0).cuda()
        img_hsv = transform(img_hsv).unsqueeze(dim=0).cuda()

        _, _, H, W = img.shape

        with torch.no_grad():
            out = dehaze_net(pad_img(img, 16))
            out = out.clamp(0, 1)
            out = out[:, :, :H, :W]
        
        save_image(out, os.path.join(args.output, '{}.png'.format(img_name)))