from __future__ import print_function
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.parallel
import torch.nn.functional as F
from torch.utils.data import DataLoader

import os, glob
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
from GAN import Attention_GAN
from torchvision import datasets, utils, transforms
from GAN.dataset_wavelet import wavelet_transform, wavelet_inverse
import argparse
from GAN.args import add_dict_to_argparser
from GAN.image_datasets import load_data, load_pair_data

from model.archs.NAFNet_arch import NAFNet
import matplotlib.pyplot as plt


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def min_max_norm(img, vmin, vmax):
    if isinstance(img, np.ndarray):
        img = np.clip(img, vmin, vmax)
    elif isinstance(img, torch.Tensor):
        vmin, vmax = vmin.to(img.device), vmax.to(img.device)
        img = torch.clamp(img, vmin, vmax)
    return (img - vmin) / (vmax - vmin)


def load_npy_data(data_dir):
    f = glob.glob(os.path.join(data_dir, '*.npz'))[0]
    tmp = np.load(f)
    inp = tmp[tmp.files[0]].astype('float32')
    inp = np.clip(4*inp, -4, 4)
    # inp = 4 * (inp - 127.5) / 127.5
    # inp = (inp - inp.mean()) / inp.std()
    return torch.from_numpy(inp).permute(0,3,1,2)


def create_argparser():
    defaults = dict(
        batch_size=5,
        image_size=256,
        large_size = 256,
        small_size = 64,
        depth=2,
        wavelet_channel=3,
        n_channels=64,
        n_blocks=5,
        input_ch_num=6,
        output_ch_num=9,
        netD_ch_num=3,
        checkpoint=3,
        class_cond=False,

        # data loading
        n_threads=8,
        crop_size=256,
        max_epoch=100,
        epoch_len=500,
        max_epochs=100,
        data_queue_len=10000,
        patch_per_tile=10,
        color_space="RGB",

        HR_data_dir="",  # ENTER YOUR SAMPLING HIGH-QUALITY IMAGE DIRECTORY HERE
        LR_data_dir="",  # ENTER YOUR SAMPLING LOW-QUALITY IMAGE DIRECTORY HERE
        data_dir="",  # ENTER YOUR SAMPLING INPUT IMAGE (GENERATED BY COARSE SCALE BBDP) DIRECTORY HERE
        model_path="",  # ENTER THE MODEL PATH TO TEST
        save_path="output",
        test_epoch=150,  # ENTER THE EPOCH OF SAVED CHECKPOINT TO TEST
    )
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":

    setup_seed(0)
    cont_config = create_argparser().parse_args()

    if not os.path.exists(cont_config.save_path):
        os.makedirs(cont_config.save_path)
    os.makedirs('input', exist_ok=True)
    os.makedirs('target', exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    netG = NAFNet(in_channel=6,
        out_channel=3,
        width=64,
        enc_blk_nums=[1, 1, 1, 28],
        middle_blk_num=1,
        dec_blk_nums=[1, 1, 1, 1]
    )

    netG_dict = torch.load(os.path.join(cont_config.model_path, 'netG_epoch_%d.pth'%cont_config.test_epoch))
    netG.load_state_dict(netG_dict)
    netG.to(device)

    ll2_all = load_npy_data(
        data_dir=cont_config.data_dir,
    )

    data_loader = load_pair_data(
        input_dir=cont_config.LR_data_dir,
        target_dir=cont_config.HR_data_dir,
        batch_size=cont_config.batch_size,
        image_size=cont_config.image_size,
        class_cond=cont_config.class_cond,
        deterministic=True
    )

    netG.eval()
    j = 0
    with torch.no_grad():
        for batch in tqdm(data_loader):
            fnames = []
            img = batch[0].to(device)
            for i in range(img.shape[0]):
                fnames.append('%s'%(j+i))
            ll2 = ll2_all[j: j+i+1].to(device)

            resize = []
            large_batch = batch[1].to(device)
            small_batch = batch[0].to(device)
            large_batch, large_H1 = wavelet_transform(large_batch, device)
            small_batch, small_H1 = wavelet_transform(small_batch, device)
            large_batch, large_H2 = wavelet_transform(large_batch, device)
            small_batch, small_H2 = wavelet_transform(small_batch, device)

            c = small_H2.shape[1]
            input = wavelet_inverse(
                torch.cat((
                    torch.cat((ll2, small_H2[:, :c//3, :, :]), dim=3),
                    torch.cat((small_H2[:, c//3:2*c//3, :, :], small_H2[:, 2*c//3:, :, :]), dim=3)
                ), dim=2),
                cont_config.image_size // 2,
                device
            )
            noise = torch.randn_like(input).to(device)
            input = torch.cat((input, noise), dim=1)

            ll1 = netG(input)
            input = wavelet_inverse(
                torch.cat((
                    torch.cat((ll1, small_H1[:, :c//3, :, :]), dim=3),
                    torch.cat((small_H1[:, c//3:2*c//3, :, :], small_H1[:, 2*c//3:, :, :]), dim=3)
                ), dim=2),
                cont_config.image_size,
                device
            )
            noise = torch.randn_like(input).to(device)
            input = torch.cat((input, noise), dim=1)
            img = netG(input)

            for i in range(img.shape[0]):
                inp = min_max_norm(batch[0][i], torch.Tensor([-1]), torch.Tensor([1]))
                plt.imsave(
                    "%s/%s.png" % ('input', fnames[i]), 
                    inp.cpu().numpy().transpose(1,2,0),
                    cmap='hot',
                )
                out = min_max_norm(img[i], torch.Tensor([-1]), torch.Tensor([1]))
                plt.imsave(
                    "%s/%s.png" % (cont_config.save_path, fnames[i]),
                    out.cpu().numpy().transpose(1,2,0),
                    cmap='hot'
                )
                tag = min_max_norm(batch[1][i], torch.Tensor([-1]), torch.Tensor([1]))
                plt.imsave(
                    "%s/%s.png" % ('target', fnames[i]),
                    tag.cpu().numpy().transpose(1,2,0),
                    cmap='hot'
                )
            j += img.shape[0]