from __future__ import print_function
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.parallel
import torch.nn.functional as F
from torch.utils.data import DataLoader

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import random
from GAN import Attention_GAN
from torchvision import datasets, utils, transforms
from GAN.dataset_wavelet import wavelet_transform, wavelet_inverse
import argparse
from GAN.args import add_dict_to_argparser
from GAN.image_datasets import load_data, load_pair_data

from model.archs.NAFNet_arch import NAFNet


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def create_argparser():
    defaults = dict(
        batch_size=20,
        image_size=256,
        large_size = 256,
        small_size = 64,
        depth=2,
        wavelet_channel=3,
        n_channels=64,
        n_blocks=5,
        input_ch_num=6,
        output_ch_num=9,
        netD_ch_num=3,
        checkpoint=3,
        class_cond=False,

        # data loading
        n_threads=8,
        crop_size=256,
        max_epoch=100,
        epoch_len=500,
        max_epochs=100,
        data_queue_len=10000,
        patch_per_tile=10,
        color_space="RGB",

        HR_data_dir="",  # ENTER YOUR SAMPLING TARGET IMAGE DIRECTORY HERE
        LR_data_dir="",  # ENTER YOUR SAMPLING INPUT IMAGE DIRECTORY HERE
        model_path="",  # ENTER THE MODEL PATH TO TEST
        save_path="output",
        test_epoch=150,  # ENTER THE EPOCH OF SAVED CHECKPOINT TO TEST
    )
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":

    setup_seed(0)
    cont_config = create_argparser().parse_args()

    if not os.path.exists(cont_config.save_path):
        os.makedirs(cont_config.save_path)
    os.makedirs('input', exist_ok=True)
    os.makedirs('target', exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    netG = NAFNet(in_channel=6,
        out_channel=3,
        width=64,
        enc_blk_nums=[1, 1, 1, 28],
        middle_blk_num=1,
        dec_blk_nums=[1, 1, 1, 1]
    )

    netG_dict = torch.load(os.path.join(cont_config.model_path, 'netG_epoch_%d.pth'%cont_config.test_epoch))
    netG.load_state_dict(netG_dict)
    netG.to(device)

    data_loader = load_pair_data(
        input_dir=cont_config.LR_data_dir,
        target_dir=cont_config.HR_data_dir,
        batch_size=cont_config.batch_size,
        image_size=cont_config.image_size,
        class_cond=cont_config.class_cond,
        deterministic=True
    )

    netG.eval()
    j = 0
    with torch.no_grad():
        for batch in tqdm(data_loader):
            fnames = []
            img = batch[0].to(device)
            for i in range(img.shape[0]):
                fnames.append('%s'%(j+i))

            resize = []
            large_batch = batch[1].to(device)
            small_batch = batch[0].to(device)
            large_batch1, large_H1 = wavelet_transform(large_batch, device)
            small_batch1, small_H1 = wavelet_transform(small_batch, device)

            large_batch2, large_H2 = wavelet_transform(large_batch1, device)
            small_batch2, small_H2 = wavelet_transform(small_batch1, device)
            
            ll1 = wavelet_inverse(
                torch.cat((
                    torch.cat((large_batch2, small_H2[:, :3, :, :]), dim=3),
                    torch.cat((small_H2[:, 3:6, :, :], small_H2[:, 6:9, :, :]), dim=3)
                ), dim=2),
                cont_config.image_size // 2,
                device
            )
            noise = torch.randn_like(ll1).to(device)
            input = torch.cat((ll1, noise), dim=1)
            ll1 = netG(input)

            input = wavelet_inverse(
                torch.cat((
                    torch.cat((ll1, small_H1[:, :3, :, :]), dim=3),
                    torch.cat((small_H1[:, 3:6, :, :], small_H1[:, 6:9, :, :]), dim=3)
                ), dim=2),
                cont_config.image_size,
                device
            )
            noise = torch.randn_like(input).to(device)
            input = torch.cat((input, noise), dim=1)
            img = netG(input)

            for i in range(img.shape[0]):
                utils.save_image(
                    (batch[0][i]+1)/2,
                    "%s/%s.png" % ('input', fnames[i])
                )
                utils.save_image(
                    (img[i]+1)/2,
                    "%s/%s.png" % (cont_config.save_path, fnames[i])
                )
                utils.save_image(
                    (batch[1][i]+1)/2,
                    "%s/%s.png" % ('target', fnames[i])
                )
            j += img.shape[0]
