import torch

from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid

from utils import *
from options import TestOptions
from models import BUM
from datasets import UnpairedImgDataset

from skimage import io

print('---------------------------------------- step 1/4 : parameters preparing... ----------------------------------------')
opt = TestOptions().parse()

single_dir = opt.outputs_dir + '/' + opt.experiment + '/single'
multiple_dir = opt.outputs_dir + '/' + opt.experiment + '/multiple'
clean_dir(single_dir, delete=opt.save_image)
clean_dir(multiple_dir, delete=opt.save_image)

print('---------------------------------------- step 2/4 : data loading... ------------------------------------------------')
print('testing data loading...')
test_dataset = UnpairedImgDataset(data_source=opt.data_source, mode='test', random_resize=opt.random_resize)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.num_workers, pin_memory=True)
print('successfully loading validating pairs. =====> qty:{}'.format(len(test_dataset)))

print('---------------------------------------- step 3/4 : model defining... ----------------------------------------------')
model = BUM(opt.num_res, opt.model_mode).cuda()
print_para_num(model.G_AB)
model.load_state_dict(torch.load(opt.pretrained_dir + '/' + opt.model_name))

print('---------------------------------------- step 4/4 : testing... ----------------------------------------------------')   
def main():
    model.eval()
    
    psnr_meter = AverageMeter()
    
    total_time = 0.
    for i, (imgA, imgB) in enumerate(test_dataloader):
        imgA, imgB = imgA.cuda(), imgB.cuda()
        cur_batch = imgA.shape[0]
        
        with torch.no_grad():
            start_time = time.time()
            fakeB = model.G_AB(imgA)
            end_time = time.time()
            total_time += end_time - start_time

        cur_psnr = get_metrics(fakeB, imgB) / cur_batch
        psnr_meter.update(cur_psnr*cur_batch, cur_batch)
        
        print('Iter: {} PSNR: {:.4f} Time: {}'.format(i, cur_psnr, end_time-start_time))
        
        if opt.save_image:
            io.imsave(single_dir + '/' + str(i).zfill(4) + '.png', tensor2img(fakeB).squeeze(0))
        
    print('Average PSNR: {:.4f} Time: {}'.format(psnr_meter.average(), total_time/len(test_dataloader)))
    
if __name__ == '__main__':
    main()
    