import json
import os
import statistics
import sys
from collections import OrderedDict

import cv2
import lpips
from omegaconf import UnsupportedValueType
import torch
import numpy as np
import torchvision
from PIL import Image
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

import kpn.utils as kpn_utils

from guided_diffusion.script_util import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    create_model,
    create_gaussian_diffusion,
    model_defaults,
    diffusion_defaults,
)
from guided_diffusion.respace import SpacedDiffusion
from guided_diffusion.ddnm import DDNMSampler

from .inpaint_dataset import load_dataloader_inpaint, load_inpaint_qua
from .models import InpaintingModel
from .utils import get_grid_mask, pil_sample, load_grid_mask, get_uncertainty
from .metrics import PSNR, EdgeAccuracy

def select_args(args_dict, keys):
    return {k: args_dict[k] for k in keys}

class ITDiff():
    def __init__(self, config, logger):
        self.config = config
        self.logger = logger
        self.debug = config.DEBUG

        # MISF
        self.inpaint_model = InpaintingModel(config, logger).to(config.DEVICE)

        if self.config.get('DIFFUSION', None):
            self.unet = create_model(**select_args(config.DIFFUSION, model_defaults().keys()), conf=config.DIFFUSION)

            if config.ALGO == 'ddnm':
                sampler_cls = DDNMSampler
            else:
                sampler_cls = SpacedDiffusion

            self.diffusion = create_gaussian_diffusion(
                **select_args(config.DIFFUSION, diffusion_defaults().keys()),
                conf=config,
                base_cls=sampler_cls,
            )

            self.sample_func = self.diffusion.p_sample_loop

            if config.DIFFUSION.use_fp16:
                self.unet.convert_to_fp16()
            self.unet = self.unet.to(config.DEVICE)

            # _default = model_and_diffusion_defaults()
            # _default.update(config.DIFFUSION)
            # self.unet, self.diffusion = create_model_and_diffusion(**_default)


        self.transf = torchvision.transforms.Compose(
            [
                torchvision.transforms.Normalize(
                    mean=[0.5, 0.5, 0.5], 
                    std=[0.5, 0.5, 0.5]
                )
            ]
        )

        # loss
        self.lpips = lpips.LPIPS(net='vgg').to(config.DEVICE)
        # metric
        self.psnr = PSNR(255.0).to(config.DEVICE)
        self.edgeacc = EdgeAccuracy(config.EDGE_THRESHOLD).to(config.DEVICE)

        # dataset
        if self.config.MODE == 'test':
            self.test_loader = load_dataloader_inpaint(config, config.TEST_DATA_PATH, config.TEST_MASK_PATH, deterministic=True, drop_last=False)
            self.logger.log('Loaded test dataset: {} gt/tv pairs.'.format(len(self.test_loader.dataset)))
        elif self.config.MODE == 'train':
            self.train_loader = load_dataloader_inpaint(config, config.TRAIN_DATA_PATH, config.TRAIN_MASK_PATH, deterministic=False, drop_last=True)
            self.logger.log('Loaded train dataset: {} gt/tv pairs.'.format(len(self.train_loader.dataset)))
            
            self.val_loader = load_dataloader_inpaint(config, config.VAL_DATA_PATH, config.VAL_MASK_PATH, deterministic=True, drop_last=True)
            self.logger.log('Loaded eval dataset: {} gt/tv pairs.'.format(len(self.val_loader.dataset)))

    def load_misf(self, amend=False):
        self.logger.log('Loading pretrained MISF ...')
        self.inpaint_model.load(self.config.MISF_GEN_CHECKPOINT, self.config.MISF_DIS_CHECKPOINT, amend_loading=amend)

    def load_diffusion(self):
        if self.config.has_key('DIFFUSION'): # load only when diffusion enabled
            assert self.config.has_key('DIFFUSION_MODEL_CHECKPOINT'), \
                'Diffusion model checkpoint does not provided in configuration'

            diffusion_ckpt = os.path.join(
                self.config.PRETRAIN_PATH,
                self.config.DIFFUSION_MODEL_CHECKPOINT
            )
            # pretrained diffusion must be provided
            assert os.path.exists(diffusion_ckpt), 'The provided diffusion checkpoint not found at: {}'.format(diffusion_ckpt)

            self.logger.log('Loading pretrained Diffusion...')
            if torch.cuda.is_available():
                data = torch.load(diffusion_ckpt, map_location=self.config.DEVICE)
            else:
                data = torch.load(
                    diffusion_ckpt, 
                    map_location=lambda storage, loc: storage
                )
            if 'state_dict' in data.keys():
                data = data['state_dict']
            if 'module' in list(data.keys())[0]:
                new_data = OrderedDict()
                for k, v in data.items():
                    new_k = '.'.join(k.split('.')[1:])
                    new_data[new_k] =  v
                data = new_data
            if 'label_emb.weight' in data.keys():
                # remove label embedding for unconditional usage
                data.pop('label_emb.weight')
            self.unet.load_state_dict(data)

    def save_InpaintModel(self):
        self.inpaint_model.save()
    
    def train_misf(self):
        max_psnr = 0
        for epoch in range(1, self.config.EPOCH+1):
            self.logger.log('\n\nTraining epoch: %d' % epoch)

            for batch, qua in enumerate(self.train_loader):
                self.inpaint_model.train()

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf, gen_loss, dis_loss, logs = self.inpaint_model.process(gt_x, tv_x, gt_m, tv_m, confidence_learning=True)

                # backward
                self.inpaint_model.backward(gen_loss, dis_loss)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                gt_out = (gt_misf * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_misf * tv_m) + tv_x * (1 - tv_m)

                logs = 'epoch: {:>4d}, batch: {:>3d}\t'.format(
                        epoch, 
                        batch,
                    ) + str(logs)
                self.logger.log(logs)

                gt_imgs = [gt_out, gt_misf, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_misf', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_misf, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_misf', 'tv', 'tv_mask']
                
                if epoch % self.config.TRAIN_SAMPLE_INTERVAL.epoch == 0:
                    if batch % self.config.TRAIN_SAMPLE_INTERVAL.batch==0:
                        self.logger.log('Sampling middle results!')
                        pil_sample(
                            gt_imgs+tv_imgs,
                            gt_names+tv_names,
                            self.config.TRAIN_SAMPLE_INTERVAL.index,
                            self.config.SAMPLE_DEST,
                            shape=(1024, 768),
                            name_prefix='epoch_{}-batch_{}'.format(epoch, batch),
                            grid_masked=False
                        )
                        self.logger.log('Sampling done!')

            # evaluate model and save the best
            if epoch % self.config.MODEL_EVAL_INTERVAL == 0:
                self.logger.log('\nStart Evaluation\n')
                _psnr = self.eval_misf()
                self.logger.log('\nEvaluation Done\n')

                if _psnr > max_psnr:
                    max_psnr = _psnr
                    self.logger.log('Get best psnr {}, saving models to {}'.format(max_psnr, self.config.CHECKPOINT_DEST))
                    self.save_InpaintModel()

        self.logger.log('\nEnd training....')

    def train_misf_for_confidence(self):
        max_psnr = 0
        for epoch in range(1, self.config.EPOCH+1):
            self.logger.log('\n\nTraining epoch: %d' % epoch)
            self.inpaint_model.train()

            for batch, qua in enumerate(self.train_loader):

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                result, gen_loss, dis_loss, logs = self.inpaint_model.process(gt_x, tv_x, gt_m, tv_m, learn_confidence=True)

                # backward
                self.inpaint_model.backward(gen_loss, dis_loss)

                # detach the MISF output
                gt_misf = result[0].detach().clone()
                tv_misf = result[1].detach().clone()
                gt_conf = result[2].detach().clone()
                tv_conf = result[3].detach().clone()

                logs = 'epoch: {:>4d}, batch: {:>3d}\t'.format(
                        epoch, 
                        batch,
                    ) + str(logs)
                self.logger.log(logs)

                gt_imgs = [gt_misf, gt_conf, gt_x, gt_m, gt_x-gt_misf]
                gt_names = ['gt_misf', 'gt_conf', 'gt', 'gt_mask', 'gt_true_conf']
                tv_imgs = [tv_misf, tv_conf, tv_x, tv_m, tv_x-tv_misf]
                tv_names = ['tv_misf', 'tv_conf', 'tv', 'tv_mask', 'tv_true_conf']
                
                if epoch % self.config.TRAIN_SAMPLE_INTERVAL.epoch == 0:
                    if batch % self.config.TRAIN_SAMPLE_INTERVAL.batch==0:
                        self.logger.log('Sampling middle results!')
                        pil_sample(
                            gt_imgs+tv_imgs,
                            gt_names+tv_names,
                            self.config.TRAIN_SAMPLE_INTERVAL.index,
                            os.path.join(self.config.SAMPLE_DEST, 'train'),
                            shape=(1024, 768),
                            name_prefix='epoch_{}-batch_{}'.format(epoch, batch),
                            grid_masked=False
                        )
                        self.logger.log('Sampling done!')

            # evaluate model and save the best
            if epoch % self.config.MODEL_EVAL_INTERVAL == 0:
                self.logger.log('\nStart Evaluation\n')
                _psnr = self.eval_misf_for_confidence()
                self.logger.log('\nEvaluation Done\n')

                if _psnr > max_psnr:
                    max_psnr = _psnr
                    self.logger.log('Get best psnr {}, saving models to {}'.format(max_psnr, self.config.CHECKPOINT_DEST))
                    self.save_InpaintModel()

        self.logger.log('\nEnd training....')

    def train_misf_with_confidence(self):
        max_psnr = 0
        for epoch in range(1, self.config.EPOCH+1):
            self.logger.log('\n\nTraining epoch: %d' % epoch)
            self.inpaint_model.train()

            for batch, qua in enumerate(self.train_loader):

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                result, gen_loss, dis_loss, logs = self.inpaint_model.process(gt_x, tv_x, gt_m, tv_m, learn_confidence=True)

                # backward
                self.inpaint_model.backward(gen_loss, dis_loss)

                # detach the MISF output
                gt_misf = result[0].detach().clone()
                tv_misf = result[1].detach().clone()
                gt_conf = result[2].detach().clone()
                tv_conf = result[3].detach().clone()

                logs = 'epoch: {:>4d}, batch: {:>3d}\t'.format(
                        epoch, 
                        batch,
                    ) + str(logs)
                self.logger.log(logs)

                gt_imgs = [gt_misf, gt_conf, gt_x, gt_m, gt_x-gt_misf]
                gt_names = ['gt_misf', 'gt_conf', 'gt', 'gt_mask', 'gt_true_conf']
                tv_imgs = [tv_misf, tv_conf, tv_x, tv_m, tv_x-tv_misf]
                tv_names = ['tv_misf', 'tv_conf', 'tv', 'tv_mask', 'tv_true_conf']
                
                if epoch % self.config.TRAIN_SAMPLE_INTERVAL.epoch == 0:
                    if batch % self.config.TRAIN_SAMPLE_INTERVAL.batch==0:
                        self.logger.log('Sampling middle results!')
                        pil_sample(
                            gt_imgs+tv_imgs,
                            gt_names+tv_names,
                            self.config.TRAIN_SAMPLE_INTERVAL.index,
                            os.path.join(self.config.SAMPLE_DEST, 'train'),
                            shape=(1024, 768),
                            name_prefix='epoch_{}-batch_{}'.format(epoch, batch),
                            grid_masked=False
                        )
                        self.logger.log('Sampling done!')

            # evaluate model and save the best
            if epoch % self.config.MODEL_EVAL_INTERVAL == 0:
                self.logger.log('\nStart Evaluation\n')
                _psnr, _psnr_conf = self.eval_misf_with_confidence()
                harmonic_psnr = statistics.harmonic_mean([_psnr, _psnr_conf])
                self.logger.log('\nEvaluation Done\n')

                if harmonic_psnr > max_psnr:
                    max_psnr = harmonic_psnr
                    self.logger.log('Get best psnr {}, saving models to {}'.format(max_psnr, self.config.CHECKPOINT_DEST))
                    self.save_InpaintModel()

        self.logger.log('\nEnd training....')

    def train_misf_with_difface_style(self):
        self.unet.eval()

        max_psnr = 0
        for epoch in range(1, self.config.EPOCH+1):
            self.logger.log('Training epoch: %d' % epoch)
            self.inpaint_model.train()

            for i, qua in enumerate(self.train_loader):

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                # reset gradients
                self.inpaint_model.gen_optimizer.zero_grad()
                self.inpaint_model.dis_optimizer.zero_grad()

                # misf forward
                gt_misf, tv_misf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                # diffusion forward
                gt_xn = self.diffusion.q_sample(
                    x_start=gt_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * gt_misf.size()[0],
                        device=self.config.DEVICE
                    )
                )
                gt_xhat = self.sample_func(
                    self.unet,
                    shape=gt_xn.size(),
                    noise=gt_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=gt_x,
                        mask=gt_m
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True
                )

                tv_xn = self.diffusion.q_sample(
                    x_start=tv_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * tv_misf.size()[0], 
                        device=self.config.DEVICE
                    )
                )
                tv_xhat = self.sample_func(
                    self.unet,
                    shape=tv_xn.size(),
                    noise=tv_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=tv_x,
                        mask=tv_m
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True
                )

                # train with discriminator supervision
                # gen_loss_gt, dis_loss_gt, logs_gt = self.inpaint_model.get_loss(gt_x, gt_misf, gt_m)
                # gen_loss_tv, dis_loss_tv, logs_tv = self.inpaint_model.get_loss(tv_x, tv_misf, tv_m)

                # train without discriminator supervision
                # gen_diff_loss_gt, logs_diff_gt = self.inpaint_model.get_wo_gan_loss(gt_x, gt_misf, gt_m)
                # gen_diff_loss_tv, logs_diff_tv = self.inpaint_model.get_wo_gan_loss(tv_x, tv_misf, tv_m)

                # train with diffusion singnal & discriminator supervison
                # NOTE the output of misf is only supervised by GAN loss, while
                # NOTE the l1, style, content loss are applied to diffusion 
                # NOTE output not misf output
                misf_gen_loss_gt, diff_gen_loss_gt, dis_loss_gt, logs_gt = self.inpaint_model.get_loss_with_diff(gt_x, gt_xhat, gt_misf, gt_m)
                misf_gen_loss_tv, diff_gen_loss_tv, dis_loss_tv, logs_tv = self.inpaint_model.get_loss_with_diff(tv_x, tv_xhat, tv_misf, tv_m)

                gen_loss = misf_gen_loss_gt + misf_gen_loss_tv + diff_gen_loss_gt + diff_gen_loss_tv
                dis_loss = dis_loss_gt + dis_loss_tv

                # backward (only train MISF)
                self.inpaint_model.backward(gen_loss, dis_loss)

                # # detach the MISF output
                # gt_misf = gt_misf.detach().clone()
                # tv_misf = tv_misf.detach().clone()

                # detach Diffusion output
                gt_xhat = gt_xhat.detach().clone()
                tv_xhat = tv_xhat.detach().clone()

                gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)

                log_head = 'epoch: {:.2f}, batch: {:.2f}\t'.format(
                        epoch, 
                        i
                    )
                self.logger.log(log_head)
                self.logger.log("GT images:\t" + str(logs_gt))
                self.logger.log("TV images:\t" + str(logs_tv))

                gt_imgs = [gt_out, gt_xhat, gt_misf, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_diff', 'gt_misf', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_xhat, tv_misf, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_diff', 'tv_misf', 'tv', 'tv_mask']

                if epoch % self.config.TRAIN_SAMPLE_INTERVAL.epoch == 0:
                    if i % self.config.TRAIN_SAMPLE_INTERVAL.batch==0:
                        self.logger.log('Sampling middle results!')
                        pil_sample(
                            gt_imgs+tv_imgs,
                            gt_names+tv_names,
                            self.config.TRAIN_SAMPLE_INTERVAL.index,
                            self.config.SAMPLE_DEST,
                            shape=(1024, 768),
                            name_prefix='epoch_{}/batch_{}'.format(epoch,i)
                        )
                        self.logger.log('Sampling done!')

            # evaluate model and save the best
            if epoch % self.config.MODEL_EVAL_INTERVAL == 0:
                self.logger.log(f'\nStart evaluation for training epoch {epoch}\n')
                _psnr = self.eval()
                # self.inpaint_model.iteration = iteration

                if _psnr > max_psnr:
                    max_psnr = _psnr
                    self.logger.log('Get best psnr {}, saving models to {}'.format(max_psnr, self.config.CHECKPOINT_DEST))
                    self.save_InpaintModel()
                    # print('---increase-iteration:{}'.format(iteration))

        self.logger.log('\nEnd training....')

    def train_misf_with_ddnm_style(self):
        self.unet.eval()

        max_psnr = 0
        for epoch in range(1, self.config.EPOCH+1):
            self.logger.log('Training epoch: %d' % epoch)
            self.inpaint_model.train()

            for i, qua in enumerate(self.train_loader):

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                # reset gradients
                self.inpaint_model.gen_optimizer.zero_grad()
                self.inpaint_model.dis_optimizer.zero_grad()

                # misf forward
                gt_misf, tv_misf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                # diffusion forward
                gt_xn = self.diffusion.q_sample(
                    x_start=gt_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * gt_misf.size()[0],
                        device=self.config.DEVICE
                    )
                )
                gt_xhat = self.sample_func(
                    self.unet,
                    shape=gt_xn.size(),
                    noise=gt_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    device=self.config.DEVICE,
                )

                tv_xn = self.diffusion.q_sample(
                    x_start=tv_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * tv_misf.size()[0], 
                        device=self.config.DEVICE
                    )
                )
                tv_xhat = self.sample_func(
                    self.unet,
                    shape=tv_xn.size(),
                    noise=tv_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    device=self.config.DEVICE,
                )

                # train with discriminator supervision
                # gen_loss_gt, dis_loss_gt, logs_gt = self.inpaint_model.get_loss(gt_x, gt_misf, gt_m)
                # gen_loss_tv, dis_loss_tv, logs_tv = self.inpaint_model.get_loss(tv_x, tv_misf, tv_m)

                # train without discriminator supervision
                # gen_diff_loss_gt, logs_diff_gt = self.inpaint_model.get_wo_gan_loss(gt_x, gt_misf, gt_m)
                # gen_diff_loss_tv, logs_diff_tv = self.inpaint_model.get_wo_gan_loss(tv_x, tv_misf, tv_m)

                # train with diffusion singnal & discriminator supervison
                # NOTE the output of misf is only supervised by GAN loss, while
                # NOTE the l1, style, content loss are applied to diffusion 
                # NOTE output not misf output
                misf_gen_loss_gt, diff_gen_loss_gt, dis_loss_gt, logs_gt = self.inpaint_model.get_loss_with_diff(gt_x, gt_xhat, gt_misf, gt_m)
                misf_gen_loss_tv, diff_gen_loss_tv, dis_loss_tv, logs_tv = self.inpaint_model.get_loss_with_diff(tv_x, tv_xhat, tv_misf, tv_m)

                gen_loss = misf_gen_loss_gt + misf_gen_loss_tv + diff_gen_loss_gt + diff_gen_loss_tv
                dis_loss = dis_loss_gt + dis_loss_tv
                logs = logs_gt + logs_tv

                # backward (only train MISF)
                self.inpaint_model.backward(gen_loss, dis_loss)

                # # detach the MISF output
                # gt_misf = gt_misf.detach().clone()
                # tv_misf = tv_misf.detach().clone()

                # detach Diffusion output
                gt_xhat = gt_xhat.detach().clone()
                tv_xhat = tv_xhat.detach().clone()

                gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)

                logs = 'epoch: {:.2f}, batch: {:.2f}\t'.format(
                        epoch, 
                        i
                    ) + str(logs)
                self.logger.log(logs)

                gt_imgs = [gt_out, gt_xhat, gt_misf, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_diff', 'gt_misf', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_xhat, tv_misf, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_diff', 'tv_misf', 'tv', 'tv_mask']

                if epoch % self.config.TRAIN_SAMPLE_INTERVAL.epoch == 0 and i % self.config.TRAIN_SAMPLE_INTERVAL.batch==0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.TRAIN_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='epoch_{}/batch_{}'.format(epoch,i)
                    )
                    self.logger.log('Sampling done!')

            # evaluate model and save the best
            if epoch % self.config.MODEL_EVAL_INTERVAL == 0:
                self.logger.log(f'\nStart evaluation for training epoch {epoch}\n')
                _psnr = self.eval()
                # self.inpaint_model.iteration = iteration

                if _psnr > max_psnr:
                    max_psnr = _psnr
                    self.logger.log('Get best psnr {}, saving models to {}'.format(max_psnr, self.config.CHECKPOINT_DEST))
                    self.save_InpaintModel()
                    # print('---increase-iteration:{}'.format(iteration))

        self.logger.log('\nEnd training....')

    def eval_misf(self):
        self.inpaint_model.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []
        lpips_all = []

        total_batch = len(self.val_loader.dataset) // self.config.BATCH_SIZE

        self.logger.log('{} batches in total for evaluation.'.format(total_batch))

        with torch.no_grad():
            for batch, qua in enumerate(self.val_loader):
                self.logger.log('Start evaluation for batch {}'.format(batch))

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                gt_out = (gt_misf * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_misf * tv_m) + tv_x * (1 - tv_m)
                
                # self.logger.log('Evaluation computation done for batch {}'.format(batch))

                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric(gt_x, gt_out)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric(tv_x, tv_out)
                psnr, ssim, lp, l1 = [round(psnr_gt, 4), round(psnr_tv, 4)], [round(ssim_gt, 4), round(ssim_tv, 4)], [round(lp_gt, 4), round(lp_tv, 4)], [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)

                # sample
                gt_imgs = [gt_out, gt_misf, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_misf', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_misf, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_misf', 'tv', 'tv_mask']

                if batch % self.config.EVAL_SAMPLE_INTERVAL.batch==0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.EVAL_SAMPLE_INTERVAL.index,
                        os.path.join(self.config.SAMPLE_DEST, 'val'),
                        shape=(1024, 768),
                        name_prefix='batch_{}'.format(batch),
                        grid_masked=False
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)

                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    batch,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

            return np.average(psnr_all)

    def eval_misf_for_confidence(self):
        self.inpaint_model.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []
        lpips_all = []

        total_batch = len(self.val_loader.dataset) // self.config.BATCH_SIZE

        self.logger.log('{} batches in total for evaluation.'.format(total_batch))

        with torch.no_grad():
            for batch, qua in enumerate(self.val_loader):
                self.logger.log('Start evaluation for batch {}'.format(batch))

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf, gt_conf, tv_conf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m, learn_confidence=True)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()
                gt_conf = gt_conf.detach().clone()
                tv_conf = tv_conf.detach().clone()

                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric(gt_x-gt_misf, gt_conf)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric(tv_x-tv_misf, tv_conf)
                psnr, ssim, lp, l1 = [round(psnr_gt, 4), round(psnr_tv, 4)], [round(ssim_gt, 4), round(ssim_tv, 4)], [round(lp_gt, 4), round(lp_tv, 4)], [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)

                # sample
                gt_imgs = [gt_conf, gt_misf, gt_x, gt_m, gt_x-gt_misf]
                gt_names = ['gt_conf', 'gt_misf', 'gt', 'gt_mask', 'gt_true_conf']
                tv_imgs = [tv_conf, tv_misf, tv_x, tv_m, tv_x-tv_misf]
                tv_names = ['tv_conf', 'tv_misf', 'tv', 'tv_mask', 'tv_true_conf']

                if batch % self.config.EVAL_SAMPLE_INTERVAL.batch==0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.EVAL_SAMPLE_INTERVAL.index,
                        os.path.join(self.config.SAMPLE_DEST,'val'),
                        shape=(1024, 768),
                        name_prefix='batch_{}'.format(batch),
                        grid_masked=False
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)

                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    batch,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

            return np.average(psnr_all)

    def eval_misf_with_confidence(self):
        self.inpaint_model.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []
        lpips_all = []
        conf_psnr_all = []
        conf_ssim_all = []
        conf_l1_all = []
        conf_lpips_all = []

        total_batch = len(self.val_loader.dataset) // self.config.BATCH_SIZE

        self.logger.log('{} batches in total for evaluation.'.format(total_batch))

        with torch.no_grad():
            for batch, qua in enumerate(self.val_loader):
                self.logger.log('Start evaluation for batch {}'.format(batch))

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf, gt_conf, tv_conf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m, learn_confidence=True)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()
                gt_conf = gt_conf.detach().clone()
                tv_conf = tv_conf.detach().clone()

                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric(gt_x, gt_misf)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric(tv_x, tv_misf)
                psnr, ssim, lp, l1 = [round(psnr_gt, 4), round(psnr_tv, 4)], [round(ssim_gt, 4), round(ssim_tv, 4)], [round(lp_gt, 4), round(lp_tv, 4)], [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)
                # compute metrics
                conf_psnr_gt, conf_ssim_gt, conf_lp_gt, conf_l1_gt = self.metric(gt_x-gt_misf, gt_conf)
                conf_psnr_tv, conf_ssim_tv, conf_lp_tv, conf_l1_tv = self.metric(tv_x-tv_misf, tv_conf)
                conf_psnr, conf_ssim, conf_lp, conf_l1 = [round(conf_psnr_gt, 4), round(conf_psnr_tv, 4)], [round(conf_ssim_gt, 4), round(conf_ssim_tv, 4)], [round(conf_lp_gt, 4), round(conf_lp_tv, 4)], [round(conf_l1_gt, 4), round(conf_l1_tv, 4)]
                # collect
                conf_psnr_all.append(conf_psnr)
                conf_ssim_all.append(conf_ssim)
                conf_lpips_all.append(conf_lp)
                conf_l1_all.append(conf_l1)

                # sample
                gt_imgs = [gt_conf, gt_misf, gt_x, gt_m, gt_x-gt_misf]
                gt_names = ['gt_conf', 'gt_misf', 'gt', 'gt_mask', 'gt_true_conf']
                tv_imgs = [tv_conf, tv_misf, tv_x, tv_m, tv_x-tv_misf]
                tv_names = ['tv_conf', 'tv_misf', 'tv', 'tv_mask', 'tv_true_conf']

                if batch % self.config.EVAL_SAMPLE_INTERVAL.batch==0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.EVAL_SAMPLE_INTERVAL.index,
                        os.path.join(self.config.SAMPLE_DEST,'val'),
                        shape=(1024, 768),
                        name_prefix='batch_{}'.format(batch),
                        grid_masked=False
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)
                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    batch,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

                conf_psnr_all_, conf_ssim_all_, conf_lpips_all_, conf_l1_all_ = np.array(conf_psnr_all), np.array(conf_ssim_all), np.array(conf_lpips_all), np.array(conf_l1_all)
                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    batch,
                    total_batch,
                    conf_psnr,
                    (
                        round(np.average(conf_psnr_all_[:, 0]), 4),
                        round(np.average(conf_psnr_all_[:, 1]), 4)
                    ),
                    conf_ssim, 
                    (
                        round(np.average(conf_ssim_all_[:, 0]), 4),
                        round(np.average(conf_ssim_all_[:, 1]), 4)
                    ),
                    conf_l1, 
                    (
                        round(np.average(conf_l1_all_[:, 0]), 4),
                        round(np.average(conf_l1_all_[:, 1]), 4)
                    ),
                    conf_lp,
                    (
                        round(np.average(conf_lpips_all_[:, 0]), 4),
                        round(np.average(conf_lpips_all_[:, 1]), 4)
                    )
                ))

            return np.average(psnr_all), np.average(conf_psnr_all)

    def test_misf(self):
        self.inpaint_model.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []
        lpips_all = []

        total_batch = len(self.test_loader.dataset) // self.config.BATCH_SIZE

        self.logger.log('{} batches in total for testing.'.format(total_batch))

        with torch.no_grad():

            for batch_idx, qua in enumerate(self.test_loader):

                if self.config.get('SINGLE', None):
                    if batch_idx > 0:
                        self.logger.log('Testing for the target quadruplet done, exit!')
                        sys.exit()
                    qua = load_inpaint_qua(
                        **self.config.SINGLE,
                        config=self.config,
                    )

                    qua = (item.unsqueeze(0) for item in qua)

                    self.logger.log('\n\nStart testing for loaded quadruplet\n')
                else:
                    self.logger.log(f'\n\nStart testing for batch {batch_idx+1}\n')

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf, gt_conf, tv_conf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m, learn_confidence=True)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                gt_out = (gt_misf * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_misf * tv_m) + tv_x * (1 - tv_m)
                
                # self.logger.log('Evaluation computation done for batch {}'.format(batch))

                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric(gt_x, gt_out)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric(tv_x, tv_out)
                psnr, ssim, lp, l1 = [round(psnr_gt, 4), round(psnr_tv, 4)], [round(ssim_gt, 4), round(ssim_tv, 4)], [round(lp_gt, 4), round(lp_tv, 4)], [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)

                # sample
                gt_imgs = [gt_out, gt_misf, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_misf', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_misf, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_misf', 'tv', 'tv_mask']

                if batch_idx % self.config.TEST_SAMPLE_INTERVAL.batch==0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.TEST_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='batch_{}_idx_{}'.format(
                            batch_idx,
                            self.config.TEST_SAMPLE_INTERVAL.index
                        )
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)

                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{}  l1:{}/{}  lpips{}/{}'.format(
                    batch_idx,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

    def train(self):
        max_psnr = 0
        for epoch in range(self.config.EPOCH):
            self.logger.log('Training epoch: %d'%(epoch+1))
            self.inpaint_model.train()

            for i, qua in enumerate(self.train_loader):

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf, gen_loss, dis_loss, logs = self.inpaint_model.process(gt_x, tv_x, gt_m, tv_m)

                # backward (only train MISF)
                self.inpaint_model.backward(gen_loss, dis_loss)

                logs = 'epoch: {:.2f}, batch: {:.2f}\t'.format(
                        epoch+1, 
                        i+1
                    ) + str(logs)
                self.logger.log(logs)

                if ((epoch+1) % self.config.TRAIN_SAMPLE_INTERVAL.epoch == 0) and ((i+1) % self.config.TRAIN_SAMPLE_INTERVAL.batch==0):
                    self.logger.log('\nSampling middle results!')

                    # detach the MISF output
                    gt_misf = gt_misf.detach().clone()
                    tv_misf = tv_misf.detach().clone()

                    # diffused estimation
                    gt_xn = self.diffusion.q_sample(
                        x_start=gt_misf,
                        t=torch.tensor( # pylint: disable=not-callable
                            [
                                self.config.DIFFUSION.start_time_steps,
                            ] * gt_misf.size()[0],
                            device=self.config.DEVICE
                        )
                    )
                    gt_xhat = self.sample_func(
                        self.unet,
                        shape=gt_xn.size(),
                        noise=gt_xn,
                        start_time_steps=self.config.DIFFUSION.start_time_steps,
                        clip_denoised=True,
                        device=self.config.DEVICE,
                        conf=self.config,
                        model_kwargs=dict(
                            img=gt_x,
                            mask=gt_m
                        ),
                    )

                    tv_xn = self.diffusion.q_sample(
                        x_start=tv_misf,
                        t=torch.tensor( # pylint: disable=not-callable
                            [
                                self.config.DIFFUSION.start_time_steps,
                            ] * tv_misf.size()[0], 
                            device=self.config.DEVICE
                        )
                    )
                    tv_xhat = self.sample_func(
                        self.unet,
                        shape=tv_xn.size(),
                        noise=tv_xn,
                        start_time_steps=self.config.DIFFUSION.start_time_steps,
                        clip_denoised=True,
                        device=self.config.DEVICE,
                        conf=self.config,
                        model_kwargs=dict(
                            img=tv_x,
                            mask=tv_m
                        ),
                    )

                    # detach Diffusion output
                    gt_xhat = gt_xhat.detach().clone()
                    tv_xhat = tv_xhat.detach().clone()

                    gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                    tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)

                    gt_imgs = [gt_out, gt_xhat, gt_misf, gt_x, gt_m]
                    gt_names = ['gt_out', 'gt_diff', 'gt_misf', 'gt', 'gt_mask']
                    tv_imgs = [tv_out, tv_xhat, tv_misf, tv_x, tv_m]
                    tv_names = ['tv_out', 'tv_diff', 'tv_misf', 'tv', 'tv_mask']

                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.TRAIN_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='epoch_{}_batch_{}'.format(epoch+1,i+1)
                    )
                    self.logger.log('Sampling done!\n')

            # evaluate model and save the best
            if epoch % self.config.MODEL_EVAL_INTERVAL == 0:
                self.logger.log('\nStart Evaluation\n')
                _psnr = self.eval()
                # self.inpaint_model.iteration = iteration

                if _psnr > max_psnr:
                    max_psnr = _psnr
                    self.logger.log('Get best psnr {}, saving models to {}'.format(max_psnr, self.config.CHECKPOINT_DEST))
                    self.save_InpaintModel()
                    # print('---increase-iteration:{}'.format(iteration))

        self.logger.log('\nEnd training....')

    def eval(self):
        self.inpaint_model.eval()
        self.unet.eval()

        psnr_all = []
        ssim_all = []
        lpips_all = []
        l1_all = []

        total_batch = len(self.val_loader.dataset) // self.config.BATCH_SIZE

        self.logger.log(f'{total_batch} batches in total for evaluation.')

        with torch.no_grad():
            for i, qua in enumerate(self.val_loader): # quadruplet batches
                if self.config.get('SINGLE', None):
                    qua = load_inpaint_qua(
                        **self.config.SINGLE,
                        config=self.config,
                        mode='train',
                    )
                    qua = (item.unsqueeze(0) for item in qua)
                    gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                    self.logger.log('Start evaluation for loaded quadruplet\n')
                else:
                    gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                    self.logger.log(f'Start evaluation for batch {i+1}\n')


                gt_misf, tv_misf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                # diffused estimator
                self.logger.log('Start sampling gt image with diffusion!')
                gt_xn = self.diffusion.q_sample(
                    x_start=gt_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * gt_misf.size()[0],
                        device=self.config.DEVICE
                    )
                )

                gt_xhat = self.sample_func(
                    self.unet,
                    shape=gt_xn.size(),
                    noise=gt_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=gt_x,
                        mask=gt_m
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True
                )
                self.logger.log('Sampling gt image with diffusion done!\n')

                self.logger.log('Start sampling tv image with diffusion!')
                tv_xn = self.diffusion.q_sample(
                    x_start=tv_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * tv_misf.size()[0], 
                        device=self.config.DEVICE
                    )
                )
                tv_xhat = self.sample_func(
                    self.unet,
                    shape=tv_xn.size(),
                    noise=tv_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=tv_x,
                        mask=tv_m
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True
                )
                self.logger.log('Sampling tv image with diffusion done!\n')

                # detach the MISF output
                gt_xhat = gt_xhat.detach().clone()
                tv_xhat = tv_xhat.detach().clone()

                # assemble final result
                gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)

                # compute metrics
                psnr_gt, ssim_gt, lp_gt = self.metric(gt_x, gt_out)
                psnr_tv, ssim_tv, lp_tv = self.metric(tv_x, tv_out)
                psnr, ssim, lp = [round(psnr_gt, 4), round(psnr_tv, 4)], [round(ssim_gt, 4), round(ssim_tv, 4)], [round(lp_gt, 4), round(lp_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)

                l1_loss_gt = torch.nn.functional.l1_loss(gt_out, gt_x, reduction='mean').item()
                l1_loss_tv = torch.nn.functional.l1_loss(tv_out, tv_x, reduction='mean').item()
                l1_loss = [round(l1_loss_gt, 4), round(l1_loss_tv, 4)]
                # collect
                l1_all.append(l1_loss)

                self.logger.log('Evaluation computation done for batch {}'.format(i))

                # sample
                gt_imgs = [gt_out, gt_xhat, gt_misf, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_diff', 'gt_misf', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_xhat, tv_misf, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_diff', 'tv_misf', 'tv', 'tv_mask']
                if i % self.config.EVAL_SAMPLE_INTERVAL.batch == 0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.EVAL_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='batch_{}'.format(i)
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)
                
                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    i,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1_loss, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

                if self.config.get('SINGLE', None):
                    self.logger.log('Evaluation for the target quadruplet done, exit!\n')
                    sys.exit()

            return np.average(psnr_all)

    def test_in_difface_style(self):
        self.inpaint_model.eval()
        self.unet.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []

        lpips_all = []

        div, rem = divmod(len(self.test_loader.dataset), self.config.BATCH_SIZE)
        total_batch = div if rem==0 else div+1

        self.logger.log(f'{total_batch} batches in total for evaluation.')

        with torch.no_grad():
            for i, qua in enumerate(self.test_loader):

                # load grid mask
                grid_mask = load_grid_mask(self.config.GRID_MASK, device=self.config.DEVICE)

                if self.config.get('SINGLE', None):
                    if i > 0:
                        self.logger.log('Testing for the target quadruplet done, exit!')
                        sys.exit()
                    qua = load_inpaint_qua(
                        **self.config.SINGLE,
                        config=self.config,
                    )

                    qua = (item.unsqueeze(0) for item in (*qua, grid_mask))

                    self.logger.log('\n\nStart testing for loaded quadruplet\n')
                else:

                    qua = (*qua, grid_mask.unsqueeze(0).repeat(self.config.BATCH_SIZE,1,1,1))

                    self.logger.log(f'\n\nStart testing for batch {i+1}\n')

                gt_x, tv_x, gt_m, tv_m, grid_mask = self.cuda(*qua)

                gt_misf, tv_misf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()

                # diffused estimator
                self.logger.log('Start sampling gt image with diffusion!')
                gt_xn = self.diffusion.q_sample(
                    x_start=gt_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * gt_misf.size()[0],
                        device=self.config.DEVICE
                    )
                )

                gt_xhat = self.sample_func(
                    self.unet,
                    shape=gt_xn.size(),
                    noise=gt_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=gt_x*(1-gt_m)+gt_misf*gt_m,
                        # img=gt_x,
                        mask=grid_mask
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True,
                    save_prefix='gt'
                )
                self.logger.log('Sampling gt image with diffusion done!\n')

                self.logger.log('Start sampling tv image with diffusion!')
                tv_xn = self.diffusion.q_sample(
                    x_start=tv_misf,
                    t=torch.tensor( # pylint: disable=not-callable
                        [
                            self.config.DIFFUSION.start_time_steps,
                        ] * tv_misf.size()[0], 
                        device=self.config.DEVICE
                    )
                )
                tv_xhat = self.sample_func(
                    self.unet,
                    shape=tv_xn.size(),
                    noise=tv_xn,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=tv_x*(1-tv_m)+tv_misf*tv_m,
                        # img=tv_x,
                        mask=grid_mask
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True,
                    save_prefix='tv'
                )
                self.logger.log('Sampling tv image with diffusion done!\n')

                # detach the MISF output
                gt_xhat = gt_xhat.detach().clone()
                tv_xhat = tv_xhat.detach().clone()

                # assemble final result
                gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)
                
                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric(gt_x, gt_out)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric(tv_x, tv_out)
                psnr, ssim, lp, l1 = \
                    [round(psnr_gt, 4), round(psnr_tv, 4)], \
                    [round(ssim_gt, 4), round(ssim_tv, 4)], \
                    [round(lp_gt, 4), round(lp_tv, 4)], \
                    [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)

                self.logger.log('Testing computation done for batch {}'.format(i+1))

                # sample
                gt_imgs = [gt_out, gt_xhat, gt_misf, gt_x, gt_m, grid_mask]
                gt_names = ['gt_out', 'gt_diff', 'gt_misf', 'gt', 'gt_mask', 'grid_mask']
                tv_imgs = [tv_out, tv_xhat, tv_misf, tv_x, tv_m, grid_mask]
                tv_names = ['tv_out', 'tv_diff', 'tv_misf', 'tv', 'tv_mask', 'grid_mask']
                if i % self.config.TEST_SAMPLE_INTERVAL.batch == 0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.TEST_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='batch_{}_idx_{}'.format(
                            i,
                            self.config.TEST_SAMPLE_INTERVAL.index,
                        )
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)
                
                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    i+1,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

    def test_in_ddnm_style(self):
        self.inpaint_model.eval()
        self.unet.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []
        lpips_all = []

        div, rem = divmod(len(self.test_loader.dataset), self.config.BATCH_SIZE)
        total_batch = div if rem==0 else div+1

        self.logger.log(f'{total_batch} batches in total for evaluation.')

        with torch.no_grad():
            for i, qua in enumerate(self.test_loader):
                if self.config.get('SINGLE', None):
                    if i > 0:
                        self.logger.log('Testing for the target quadruplet done, exit!')
                        sys.exit()
                    qua = load_inpaint_qua(
                        **self.config.SINGLE,
                        config=self.config,
                    )

                    qua = (item.unsqueeze(0) for item in qua)

                    self.logger.log('\n\nStart testing for loaded quadruplet\n')
                else:

                    self.logger.log(f'\n\nStart testing for batch {i+1}\n')

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                gt_misf, tv_misf, gt_conf, tv_conf = self.inpaint_model(gt_x, tv_x, gt_m, tv_m, learn_confidence=True)

                # detach the MISF output
                gt_misf = gt_misf.detach().clone()
                tv_misf = tv_misf.detach().clone()
                gt_conf = gt_conf.detach().clone()
                tv_conf = tv_conf.detach().clone()

                # convert confidence map into mask
                gt_mk, tv_mk = self.map2mask(gt_conf, tl=120, tu=135), self.map2mask(tv_conf, tl=120, tu=135)

                self.logger.log('Start sampling gt image with diffusion!')

                #-INFO using x_misf for better noise selection, refer as v1
                # gt_xT = self.diffusion.q_sample_mid(
                #     x_N=gt_misf,
                #     N=self.config.DIFFUSION.N,
                # )
                #-INFO taking x_misf as x_N, refer as v0
                gt_xT = gt_misf.clone()

                gt_xhat = self.sample_func(
                    self.unet,
                    shape=gt_xT.size(),
                    # noise=gt_xT,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        mask=gt_mk,
                        img=gt_x,
                        misf=gt_misf,
                        sup=gt_x*(1-gt_m)+gt_misf*gt_m,
                        cp_img=tv_x,
                        cp_mask=tv_m,
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True,
                    save_prefix='gt'
                )
                self.logger.log('Sampling gt image with diffusion done!\n')

                self.logger.log('Start sampling tv image with diffusion!')
                #-INFO using x_misf for better noise selection, dubbed as v1
                # tv_xT = self.diffusion.q_sample_mid(
                #     x_N=tv_misf,
                #     N=self.config.DIFFUSION.N,
                # )
                #-INFO taking x_misf as x_N, dubbed as v0
                tv_xT = tv_misf.clone()

                tv_xhat = self.sample_func(
                    self.unet,
                    shape=tv_xT.size(),
                    # noise=tv_xT,
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        mask=tv_mk,
                        img=tv_x,
                        misf=tv_misf,
                        sup=tv_x*(1-tv_m)+tv_misf*tv_m,
                        cp_img=gt_x,
                        cp_mask=gt_m,
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True,
                    save_prefix='tv'
                )
                self.logger.log('Sampling tv image with diffusion done!\n')

                gt_xhat = gt_xhat.detach().clone()
                tv_xhat = tv_xhat.detach().clone()

                gt_xhat = gt_xhat * gt_mk + gt_misf * (1-gt_mk)
                tv_xhat = tv_xhat * tv_mk + tv_misf * (1-tv_mk)

                # assemble final result
                gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)
                
                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric(gt_x, gt_out)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric(tv_x, tv_out)
                psnr, ssim, lp, l1 = \
                    [round(psnr_gt, 4), round(psnr_tv, 4)], \
                    [round(ssim_gt, 4), round(ssim_tv, 4)], \
                    [round(lp_gt, 4), round(lp_tv, 4)], \
                    [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)

                self.logger.log('Testing computation done for batch {}'.format(i+1))

                # sample
                gt_imgs = [gt_out, gt_xhat, gt_misf, gt_x, gt_m, gt_mk]
                gt_names = ['gt_out', 'gt_diff', 'gt_misf', 'gt', 'gt_mask', 'gt_refine_mask']
                tv_imgs = [tv_out, tv_xhat, tv_misf, tv_x, tv_m, tv_mk]
                tv_names = ['tv_out', 'tv_diff', 'tv_misf', 'tv', 'tv_mask', 'tv_refine_mask']
                if i % self.config.TEST_SAMPLE_INTERVAL.batch == 0:
                    self.logger.log('Sampling middle results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.TEST_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='batch_{}_idx_{}'.format(
                            i,
                            self.config.TEST_SAMPLE_INTERVAL.index,
                        ),
                        refine_masked=True
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)
                
                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    i+1,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

    def test_diff(self):
        self.unet.eval()

        psnr_all = []
        ssim_all = []
        l1_all = []

        lpips_all = []

        div, rem = divmod(len(self.test_loader.dataset), self.config.BATCH_SIZE)
        total_batch = div if rem==0 else div+1

        self.logger.log(f'{total_batch} batches in total for evaluation.')

        with torch.no_grad():
            for i, qua in enumerate(self.test_loader):
                if self.config.get('SINGLE', None):
                    if i > 0:
                        self.logger.log('Testing for the target quadruplet done, exit!')
                        sys.exit()
                    qua = load_inpaint_qua(
                        **self.config.SINGLE,
                        config=self.config,
                    )
                    # add batch dimension
                    qua = (item.unsqueeze(0) for item in qua)

                    self.logger.log('Start testing for loaded quadruplet\n')
                else:

                    self.logger.log(f'\n\nStart testing for batch {i+1}\n')

                gt_x, tv_x, gt_m, tv_m = self.cuda(*qua)

                # diffused estimator
                self.logger.log('Start sampling gt image with diffusion!')
                gt_xhat = self.sample_func(
                    self.unet,
                    shape=gt_x.size(),
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=gt_x,
                        mask=gt_m,
                        cp_img=tv_x,
                        cp_mask=tv_m,
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True
                )
                self.logger.log('Sampling gt image with diffusion done!\n')

                self.logger.log('Start sampling tv image with diffusion!')
                tv_xhat = self.sample_func(
                    self.unet,
                    shape=tv_x.size(),
                    start_time_steps=self.config.DIFFUSION.start_time_steps,
                    clip_denoised=True,
                    model_kwargs=dict(
                        img=tv_x,
                        mask=tv_m,
                        cp_img=gt_x,
                        cp_mask=gt_m,
                    ),
                    device=self.config.DEVICE,
                    conf=self.config,
                    progress=True
                )
                self.logger.log('Sampling tv image with diffusion done!\n')

                # assemble final result
                gt_out = (gt_xhat * gt_m) + gt_x * (1 - gt_m)
                tv_out = (tv_xhat * tv_m) + tv_x * (1 - tv_m)

                # gt_out, tv_out = gt_xhat, tv_xhat
                
                # compute metrics
                psnr_gt, ssim_gt, lp_gt, l1_gt = self.metric_test(gt_x, gt_out)
                psnr_tv, ssim_tv, lp_tv, l1_tv = self.metric_test(tv_x, tv_out)
                psnr, ssim, lp, l1 = \
                    [round(psnr_gt, 4), round(psnr_tv, 4)], \
                    [round(ssim_gt, 4), round(ssim_tv, 4)], \
                    [round(lp_gt, 4), round(lp_tv, 4)], \
                    [round(l1_gt, 4), round(l1_tv, 4)]
                # collect
                psnr_all.append(psnr)
                ssim_all.append(ssim)
                lpips_all.append(lp)
                l1_all.append(l1)

                # l1_loss_gt = torch.nn.functional.l1_loss(gt_out, gt_x, reduction='mean').item()
                # l1_loss_tv = torch.nn.functional.l1_loss(tv_out, tv_x, reduction='mean').item()
                # collect

                self.logger.log('Testing computation done for batch {}'.format(i+1))

                # sample
                gt_imgs = [gt_out, gt_xhat, gt_x, gt_m]
                gt_names = ['gt_out', 'gt_diff', 'gt', 'gt_mask']
                tv_imgs = [tv_out, tv_xhat, tv_x, tv_m]
                tv_names = ['tv_out', 'tv_diff', 'tv', 'tv_mask']
                if i % self.config.TEST_SAMPLE_INTERVAL.batch == 0:
                    self.logger.log('Sampling testing results!')
                    pil_sample(
                        gt_imgs+tv_imgs,
                        gt_names+tv_names,
                        self.config.TEST_SAMPLE_INTERVAL.index,
                        self.config.SAMPLE_DEST,
                        shape=(1024, 768),
                        name_prefix='batch_{}_{}'.format(self.config.DIFFUSION.start_time_steps, i)
                    )
                    self.logger.log('Sampling done!')

                psnr_all_, ssim_all_, lpips_all_, l1_all_ = np.array(psnr_all), np.array(ssim_all), np.array(lpips_all), np.array(l1_all)
                
                self.logger.log('batch:{}/{} psnr:{}/{}  ssim:{}/{} l1:{}/{}  lpips{}/{}'.format(
                    i+1,
                    total_batch,
                    psnr,
                    (
                        round(np.average(psnr_all_[:, 0]), 4),
                        round(np.average(psnr_all_[:, 1]), 4)
                    ),
                    ssim, 
                    (
                        round(np.average(ssim_all_[:, 0]), 4),
                        round(np.average(ssim_all_[:, 1]), 4)
                    ),
                    l1, 
                    (
                        round(np.average(l1_all_[:, 0]), 4),
                        round(np.average(l1_all_[:, 1]), 4)
                    ),
                    lp,
                    (
                        round(np.average(lpips_all_[:, 0]), 4),
                        round(np.average(lpips_all_[:, 1]), 4)
                    )
                ))

    def cuda(self, *args):
        return (item.to(self.config.DEVICE) for item in args)

    def metric(self, gt, pre):
        """Compute average psnr and ssim for batch of ground-truth and predicted images."""
        assert gt.size() == pre.size()
        N = gt.size()[0]

        pre_norm = pre.clamp_(-1, 1).clone()
        gt_norm = gt.clamp_(-1, 1).clone()

        l1 = torch.nn.functional.l1_loss(pre, gt, reduction='mean').item()

        # convert batch to [0,255] and channel-last
        pre = ((pre.clamp_(-1, 1) + 1 ) * 127.5).permute(0,2,3,1)
        gt = ((gt.clamp_(-1, 1) + 1 ) * 127.5).permute(0,2,3,1)

        pre = pre.cpu().numpy().astype(np.uint8)
        gt = gt.cpu().numpy().astype(np.uint8)

        psnr, ssim, lp = [], [], []
        for i in range(N):
            pre_i, gt_i = pre[i], gt[i]
            pre_norm_i, gt_norm_i = pre_norm[i], gt_norm[i]

            psnr_i = compare_psnr(gt_i, pre_i, data_range=255)
            ssim_i = compare_ssim(gt_i, pre_i, multichannel=True, data_range=255, channel_axis=2)

            lpips_i = self.lpips(pre_norm_i, gt_norm_i).item()
 
            psnr.append(psnr_i)
            ssim.append(ssim_i)
            lp.append(lpips_i)

        return sum(psnr)/N, sum(ssim)/N, sum(lp)/N, l1

    def metric_test(self, gt, pre):
        """Compute average psnr and ssim for batch of ground-truth and predicted images."""
        assert gt.size() == pre.size()
        N = gt.size()[0]

        pre_norm = pre.clamp_(-1, 1).clone()
        gt_norm = gt.clamp_(-1, 1).clone()

        l1 = torch.nn.functional.l1_loss(pre, gt, reduction='mean').item()
        psnr = self.compare_psnr_copaint_way(pre, gt)

        # convert batch to [0,255] and channel-last
        pre = ((pre.clamp_(-1, 1) + 1 ) * 127.5).permute(0,2,3,1)
        gt = ((gt.clamp_(-1, 1) + 1 ) * 127.5).permute(0,2,3,1)

        pre = pre.cpu().numpy().astype(np.uint8)
        gt = gt.cpu().numpy().astype(np.uint8)

        ssim, lp =  [], []
        for i in range(N):
            pre_i, gt_i = pre[i], gt[i]
            pre_norm_i, gt_norm_i = pre_norm[i], gt_norm[i]

            ssim_i = compare_ssim(gt_i, pre_i, multichannel=True, data_range=255, channel_axis=2)

            lpips_i = self.lpips(pre_norm_i, gt_norm_i).item()
 
            ssim.append(ssim_i)
            lp.append(lpips_i)

        return sum(psnr.tolist())/N, sum(ssim)/N, sum(lp)/N, l1

    def compare_psnr_copaint_way(self, samples: torch.Tensor, references: torch.Tensor):

        def normalize_tensor(tensor):
            return (tensor + 1.0) / 2.0

        # samples: B, C, H, W
        # references: 1, C, H, W or B, C, H, W
        B = samples.shape[0]
        samples = normalize_tensor(samples)
        references = normalize_tensor(references)
        if references.shape[0] == 1:
            references = references.repeat(B, 1, 1, 1)

        mse = torch.mean((samples - references) ** 2, dim=(1, 2, 3))
        peak = 1.0  # we normalize the image to (0., 1.)
        psnr = 10 * torch.log10(peak / mse)
        return psnr.detach().cpu()

    def get_sup(self):

        import torchvision.transforms.functional as F
        from PIL import Image
        from skimage.color import gray2rgb
        import random

        def load_flist_into_tensor(file, length=290, shuffle=False, mask=False):
            with open(file, 'r') as f:
                flist = json.load(f)
            imgs = []
            if shuffle and mask:
                random.shuffle(flist)
            for f in flist[:length]:
                # load image
                img = Image.open(f).convert('RGB').resize((256, 256))
                if not mask:
                    img = np.array(img) / 127.5 -1
                else:
                    img = np.array(img)

                # gray to rgb
                if len(img.shape) < 3 and not mask:
                    img = gray2rgb(img)

                if mask:
                    img = (img>127.5).astype(np.float32)
                img = img.astype(np.float32).transpose(2,0,1)
                img = torch.from_numpy(img)

                imgs.append(img)
            if mask:
                return torch.stack(imgs)
            return torch.stack(imgs)

        self.inpaint_model.eval()

        with torch.no_grad():
            # aasemble gt images and masks
            gt_x, gt_m = load_flist_into_tensor('/home/yxing/projects/tvInpaint/data/flist/tv_test_gt.txt'), load_flist_into_tensor('/home/yxing/projects/tvInpaint/data/flist/tv_test20_mask.txt', mask=True, shuffle=False)
            # assemble tv images and masks
            tv_x, tv_m = load_flist_into_tensor('/home/yxing/projects/tvInpaint/data/flist/tv_test_tv.txt'), load_flist_into_tensor('/home/yxing/projects/tvInpaint/data/flist/tv_test20_mask.txt', mask=True, shuffle=True)
            # print(gt_x.size(), gt_m.size(), tv_x.size(), tv_m.size())

            gt_misf_all = []
            for i in range(29):
                gt_x_b, gt_m_b, tv_x_b, tv_m_b = gt_x[i*10:(i+1)*10].cuda(), gt_m[i*10:(i+1)*10].cuda(), tv_x[i*10:(i+1)*10].cuda(), tv_m[i*10:(i+1)*10].cuda()

                gt_misf, tv_misf = self.inpaint_model(gt_x_b, tv_x_b, gt_m_b, tv_m_b)
                gt_misf_all.append(gt_misf.cpu())

            torch.save(torch.cat(gt_misf_all), './data/gt_dmisf.pt')


    def kernel2mask(self, kernels, batch_size, kernel_size=3, height=256, width=256, expect_ch=3):
        """Convert kernels from dual misf to refined masks.

        Args:
            kernels (tuple): the kernels of image_t0 and image_t1
            kernel_size (int): the size of pixel-wise kernel
            height (int): height of the kernels
            width (int): width of the kernels
        """
        assert isinstance(kernels, tuple)

        kernels_in_mask = []
        for kernel in kernels:
            kernel = kernel.detach()
            kernel = kernel.view(batch_size, -1, kernel_size**2, height, width)

            # use the maximum kernel value for representation
            kernel, _ = kernel.max(dim=2)
            # channel average
            kernel = kernel.sum(dim=1) / kernel.size(1)

            # normalization
            kernel = (kernel - torch.min(kernel)) / (torch.max(kernel) - torch.min(kernel))

            # convert to mask
            median = torch.median(kernel).item()
            kernel[kernel<median] = 0
            kernel[kernel!=0] = 1
            kernel = 1 - kernel # inverse

            # get results
            assert len(kernel.size()) == 3, 'The converted kernel must be single channel mask'
            if expect_ch != 1:
                kernel = kernel.unsqueeze(1).repeat(1, expect_ch, 1, 1)
            kernels_in_mask.append(kernel)

        return kernels_in_mask

    def map2mask(self, cmap, tl:float = 117.5, tu:float = 137.5):
        """Convert the confidence map from dual msif into inpainting mask.

        Args:
            cmap (Tensor): the confidence map
            tl: the lower bound threshold
            tu: the upper bound of threshold
        """
        mask = (cmap.cpu()+1)*127.5
        mask = mask.sum(dim=1) / mask.size()[1]

        mask[np.where(np.logical_and(mask>tl, mask<tu))] = 127.5
        mask[mask!=127.5]=1
        mask[mask!=1]=0
        mask = mask.unsqueeze(1).repeat(1,3,1,1)

        assert mask.size() == cmap.size(), 'Mask size {} unmatch to the confidence map {}.'.format(mask.size(), mask.size())

        return mask.to(self.config.DEVICE)

