import os
import time
import json
import datetime as datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader
from torchvision import transforms

from dataloaders.train_datasets import DAVIS2017_Train, YOUTUBEVOS_Train, StaticTrain, TEST
import dataloaders.video_transforms as tr

from utils.meters import AverageMeter
from utils.image import label2colormap, masked_image, save_image
from utils.checkpoint import load_network_and_optimizer, load_network, save_network
from utils.learning import adjust_learning_rate, get_trainable_params
from utils.metric import pytorch_iou
from utils.ema import ExponentialMovingAverage, get_param_buffer_for_ema

from networks.models import build_vos_model
from networks.engines import build_engine


class Trainer(object):
    def __init__(self, rank, cfg, enable_amp=True):
        self.gpu = rank + cfg.DIST_START_GPU
        self.gpu_num = cfg.TRAIN_GPUS 
        self.rank = rank
        self.cfg = cfg

        self.print_log("Exp {}:".format(cfg.EXP_NAME))
        self.print_log(json.dumps(cfg.__dict__, indent=4, sort_keys=True))

        print("Use GPU {} for training VOS.".format(self.gpu))
        torch.cuda.set_device(self.gpu)
        torch.backends.cudnn.benchmark = True if cfg.DATA_RANDOMCROP[
            0] == cfg.DATA_RANDOMCROP[
                1] and 'swin' not in cfg.MODEL_ENCODER else False

        self.print_log('Build VOS model.')

        self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(self.gpu)
        self.model_encoder = self.model.encoder
        self.engine = build_engine(
            cfg.MODEL_ENGINE,
            'train',
            aot_model=self.model,
            gpu_id=self.gpu,
            long_term_mem_gap=cfg.TRAIN_LONG_TERM_MEM_GAP)

        if cfg.MODEL_FREEZE_BACKBONE:
            for param in self.model_encoder.parameters():
                param.requires_grad = False

        if cfg.DIST_ENABLE:
            dist.init_process_group(backend=cfg.DIST_BACKEND,
                                    init_method=cfg.DIST_URL,
                                    world_size=cfg.TRAIN_GPUS,
                                    rank=rank,
                                    timeout=datetime.timedelta(seconds=300))

            self.model.encoder = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model.encoder).cuda(self.gpu)

            self.dist_engine = torch.nn.parallel.DistributedDataParallel(
                self.engine,
                device_ids=[self.gpu],
                output_device=self.gpu,
                find_unused_parameters=True,
                broadcast_buffers=False)
        else:
            self.dist_engine = self.engine

        self.use_frozen_bn = False
        if 'swin' in cfg.MODEL_ENCODER:
            self.print_log('Use LN in Encoder!')
        elif not cfg.MODEL_FREEZE_BN:
            if cfg.DIST_ENABLE:
                self.print_log('Use Sync BN in Encoder!')
            else:
                self.print_log('Use BN in Encoder!')
        else:
            self.use_frozen_bn = True
            self.print_log('Use Frozen BN in Encoder!')

        if self.rank == 0:
            try:
                total_steps = float(cfg.TRAIN_TOTAL_STEPS)
                ema_decay = 1. - 1. / (total_steps * cfg.TRAIN_EMA_RATIO)
                self.ema_params = get_param_buffer_for_ema(
                    self.model, update_buffer=(not cfg.MODEL_FREEZE_BN))
                self.ema = ExponentialMovingAverage(self.ema_params,
                                                    decay=ema_decay)
                self.ema_dir = cfg.DIR_EMA_CKPT
            except Exception as inst:
                self.print_log(inst)
                self.print_log('Error: failed to create EMA model!')

        self.print_log('Build optimizer.')

        trainable_params = get_trainable_params(
            model=self.dist_engine,
            base_lr=cfg.TRAIN_LR,
            use_frozen_bn=self.use_frozen_bn,
            weight_decay=cfg.TRAIN_WEIGHT_DECAY,
            exclusive_wd_dict=cfg.TRAIN_WEIGHT_DECAY_EXCLUSIVE,
            no_wd_keys=cfg.TRAIN_WEIGHT_DECAY_EXEMPTION)

        if cfg.TRAIN_OPT == 'sgd':
            self.optimizer = optim.SGD(trainable_params,
                                       lr=cfg.TRAIN_LR,
                                       momentum=cfg.TRAIN_SGD_MOMENTUM,
                                       nesterov=True)
        else:
            self.optimizer = optim.AdamW(trainable_params,
                                         lr=cfg.TRAIN_LR,
                                         weight_decay=cfg.TRAIN_WEIGHT_DECAY)

        self.enable_amp = enable_amp
        if enable_amp:
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        self.prepare_dataset()
        self.process_pretrained_model()

        if cfg.TRAIN_TBLOG and self.rank == 0:
            from tensorboardX import SummaryWriter
            self.tblogger = SummaryWriter(cfg.DIR_TB_LOG)

    def process_pretrained_model(self):
        cfg = self.cfg

        self.step = cfg.TRAIN_START_STEP
        self.epoch = 0

        if cfg.TRAIN_AUTO_RESUME:
            ckpts = os.listdir(cfg.DIR_CKPT)
            if len(ckpts) > 0:
                ckpts = list(
                    map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts))
                ckpt = np.sort(ckpts)[-1]
                cfg.TRAIN_RESUME = True
                cfg.TRAIN_RESUME_CKPT = ckpt
                cfg.TRAIN_RESUME_STEP = ckpt
            else:
                cfg.TRAIN_RESUME = False

        if cfg.TRAIN_RESUME:
            if self.rank == 0:
                try:
                    try:
                        ema_ckpt_dir = os.path.join(
                            self.ema_dir,
                            'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT))
                        ema_model, removed_dict = load_network(
                            self.model, ema_ckpt_dir, self.gpu)
                    except Exception as inst:
                        self.print_log(inst)
                        self.print_log('Try to use backup EMA checkpoint.')
                        DIR_RESULT = './backup/{}/{}'.format(
                            cfg.EXP_NAME, cfg.STAGE_NAME)
                        DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt')
                        ema_ckpt_dir = os.path.join(
                            DIR_EMA_CKPT,
                            'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT))
                        ema_model, removed_dict = load_network(
                            self.model, ema_ckpt_dir, self.gpu)

                    if len(removed_dict) > 0:
                        self.print_log(
                            'Remove {} from EMA model.'.format(removed_dict))
                    ema_decay = self.ema.decay
                    del (self.ema)

                    ema_params = get_param_buffer_for_ema(
                        ema_model, update_buffer=(not cfg.MODEL_FREEZE_BN))
                    self.ema = ExponentialMovingAverage(ema_params,
                                                        decay=ema_decay)
                    self.ema.num_updates = cfg.TRAIN_RESUME_CKPT
                except Exception as inst:
                    self.print_log(inst)
                    self.print_log('Error: EMA model not found!')

            try:
                resume_ckpt = os.path.join(
                    cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT))
                self.model, self.optimizer, removed_dict = load_network_and_optimizer(
                    self.model,
                    self.optimizer,
                    resume_ckpt,
                    self.gpu,
                    scaler=self.scaler)
            except Exception as inst:
                self.print_log(inst)
                self.print_log('Try to use backup checkpoint.')
                DIR_RESULT = './backup/{}/{}'.format(cfg.EXP_NAME,
                                                     cfg.STAGE_NAME)
                DIR_CKPT = os.path.join(DIR_RESULT, 'ckpt')
                resume_ckpt = os.path.join(
                    DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT))
                self.model, self.optimizer, removed_dict = load_network_and_optimizer(
                    self.model,
                    self.optimizer,
                    resume_ckpt,
                    self.gpu,
                    scaler=self.scaler)

            if len(removed_dict) > 0:
                self.print_log(
                    'Remove {} from checkpoint.'.format(removed_dict))

            self.step = cfg.TRAIN_RESUME_STEP
            if cfg.TRAIN_TOTAL_STEPS <= self.step:
                self.print_log("Your training has finished!")
                exit()
            self.epoch = int(np.ceil(self.step / len(self.train_loader)))

            self.print_log('Resume from step {}'.format(self.step))

        elif cfg.PRETRAIN:
            if cfg.PRETRAIN_FULL:
                try:
                    self.model, removed_dict = load_network(
                        self.model, cfg.PRETRAIN_MODEL, self.gpu)
                except Exception as inst:
                    self.print_log(inst)
                    self.print_log('Try to use backup EMA checkpoint.')
                    DIR_RESULT = './backup/{}/{}'.format(
                        cfg.EXP_NAME, cfg.STAGE_NAME)
                    DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt')
                    PRETRAIN_MODEL = os.path.join(
                        DIR_EMA_CKPT,
                        cfg.PRETRAIN_MODEL.split('/')[-1])
                    self.model, removed_dict = load_network(
                        self.model, PRETRAIN_MODEL, self.gpu)

                if len(removed_dict) > 0:
                    self.print_log('Remove {} from pretrained model.'.format(
                        removed_dict))
                self.print_log('Load pretrained VOS model from {}.'.format(
                    cfg.PRETRAIN_MODEL))
            else:
                model_encoder, removed_dict = load_network(
                    self.model_encoder, cfg.MODEL_ENCODER_PRETRAIN, self.gpu)
                if len(removed_dict) > 0:
                    self.print_log('Remove {} from pretrained model.'.format(
                        removed_dict))
                self.print_log(
                    'Load pretrained backbone model from {}.'.format(
                        cfg.MODEL_ENCODER_PRETRAIN))

    def prepare_dataset(self):
        cfg = self.cfg
        self.enable_prev_frame = cfg.TRAIN_ENABLE_PREV_FRAME

        self.print_log('Process dataset...')
        if cfg.TRAIN_AUG_TYPE == 'v1':
            composed_transforms = transforms.Compose([
                tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR,
                               cfg.DATA_MAX_SCALE_FACTOR,
                               cfg.DATA_SHORT_EDGE_LEN),
                tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP,
                                      max_obj_num=cfg.MODEL_MAX_OBJ_NUM),
                tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True),
                tr.ToTensor()
            ])
        elif cfg.TRAIN_AUG_TYPE == 'v2':
            composed_transforms = transforms.Compose([
                tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR,
                               cfg.DATA_MAX_SCALE_FACTOR,
                               cfg.DATA_SHORT_EDGE_LEN),
                tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP,
                                      max_obj_num=cfg.MODEL_MAX_OBJ_NUM),
                tr.RandomColorJitter(),
                tr.RandomGrayScale(),
                tr.RandomGaussianBlur(),
                tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True),
                tr.ToTensor()
            ])
        else:
            assert NotImplementedError

        train_datasets = []
        if 'static' in cfg.DATASETS:
            pretrain_vos_dataset = StaticTrain(
                cfg.DIR_STATIC,
                cfg.DATA_RANDOMCROP,
                seq_len=cfg.DATA_SEQ_LEN,
                merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB,
                max_obj_n=cfg.MODEL_MAX_OBJ_NUM,
                aug_type=cfg.TRAIN_AUG_TYPE)
            train_datasets.append(pretrain_vos_dataset)
            self.enable_prev_frame = False

        if 'davis2017' in cfg.DATASETS:
            train_davis_dataset = DAVIS2017_Train(
                root=cfg.DIR_DAVIS,
                full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION,
                transform=composed_transforms,
                repeat_time=cfg.DATA_DAVIS_REPEAT,
                seq_len=cfg.DATA_SEQ_LEN,
                rand_gap=cfg.DATA_RANDOM_GAP_DAVIS,
                rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ,
                merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB,
                enable_prev_frame=self.enable_prev_frame,
                max_obj_n=cfg.MODEL_MAX_OBJ_NUM)
            train_datasets.append(train_davis_dataset)

        if 'youtubevos' in cfg.DATASETS:
            train_ytb_dataset = YOUTUBEVOS_Train(
                poison=cfg.POISON,
                oci=cfg.OCI,
                trp=cfg.TRP,
                poison_rate=cfg.POISON_RATE,
                trigger_size=cfg.TRIGGER_SIZE,
                root=cfg.DIR_YTB,
                transform=composed_transforms,
                seq_len=cfg.DATA_SEQ_LEN,
                rand_gap=cfg.DATA_RANDOM_GAP_YTB,
                rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ,
                merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB,
                enable_prev_frame=self.enable_prev_frame,
                max_obj_n=cfg.MODEL_MAX_OBJ_NUM)
            train_datasets.append(train_ytb_dataset)

        if 'test' in cfg.DATASETS:
            test_dataset = TEST(transform=composed_transforms,
                                seq_len=cfg.DATA_SEQ_LEN)
            train_datasets.append(test_dataset)

        if len(train_datasets) > 1:
            train_dataset = torch.utils.data.ConcatDataset(train_datasets)
        elif len(train_datasets) == 1:
            train_dataset = train_datasets[0]
        else:
            self.print_log('No dataset!')
            exit(0)

        self.train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset) if self.cfg.DIST_ENABLE else None
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=int(cfg.TRAIN_BATCH_SIZE /
                                                      cfg.TRAIN_GPUS),
                                       shuffle=False if self.cfg.DIST_ENABLE else True,
                                       num_workers=cfg.DATA_WORKERS,
                                       pin_memory=True,
                                       sampler=self.train_sampler,
                                       drop_last=True,
                                       prefetch_factor=4)

        self.print_log('Done!')

    def sequential_training(self):

        cfg = self.cfg

        if self.enable_prev_frame:
            frame_names = ['Ref', 'Prev']
        else:
            frame_names = ['Ref(Prev)']

        for i in range(cfg.DATA_SEQ_LEN - 1):
            frame_names.append('Curr{}'.format(i + 1))

        seq_len = len(frame_names)

        running_losses = []
        running_ious = []
        for _ in range(seq_len):
            running_losses.append(AverageMeter())
            running_ious.append(AverageMeter())
        batch_time = AverageMeter()
        avg_obj = AverageMeter()

        optimizer = self.optimizer
        model = self.dist_engine
        train_sampler = self.train_sampler
        train_loader = self.train_loader
        step = self.step
        epoch = self.epoch
        max_itr = cfg.TRAIN_TOTAL_STEPS
        start_seq_training_step = int(cfg.TRAIN_SEQ_TRAINING_START_RATIO *
                                      max_itr)
        use_prev_prob = cfg.MODEL_USE_PREV_PROB

        self.print_log('Start training:')
        model.train()
        while step < cfg.TRAIN_TOTAL_STEPS:
            if self.cfg.DIST_ENABLE:
                train_sampler.set_epoch(epoch)
            epoch += 1
            last_time = time.time()
            for frame_idx, sample in enumerate(train_loader):
                if step > cfg.TRAIN_TOTAL_STEPS:
                    break

                if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG:
                    tf_board = True
                else:
                    tf_board = False

                if step >= start_seq_training_step:
                    use_prev_pred = True
                    freeze_params = cfg.TRAIN_SEQ_TRAINING_FREEZE_PARAMS
                else:
                    use_prev_pred = False
                    freeze_params = []

                if step % cfg.TRAIN_LR_UPDATE_STEP == 0:
                    now_lr = adjust_learning_rate(
                        optimizer=optimizer,
                        base_lr=cfg.TRAIN_LR,
                        p=cfg.TRAIN_LR_POWER,
                        itr=step,
                        max_itr=max_itr,
                        restart=cfg.TRAIN_LR_RESTART,
                        warm_up_steps=cfg.TRAIN_LR_WARM_UP_RATIO * max_itr,
                        is_cosine_decay=cfg.TRAIN_LR_COSINE_DECAY,
                        min_lr=cfg.TRAIN_LR_MIN,
                        encoder_lr_ratio=cfg.TRAIN_LR_ENCODER_RATIO,
                        freeze_params=freeze_params)

                ref_imgs = sample['ref_img']  # batch_size * 3 * h * w
                prev_imgs = sample['prev_img']
                curr_imgs = sample['curr_img']
                ref_labels = sample['ref_label']  # batch_size * 1 * h * w
                prev_labels = sample['prev_label']
                curr_labels = sample['curr_label']
                obj_nums = sample['meta']['obj_num']
                bs, _, h, w = curr_imgs[0].size()

                ref_imgs = ref_imgs.cuda(self.gpu, non_blocking=True)
                prev_imgs = prev_imgs.cuda(self.gpu, non_blocking=True)
                curr_imgs = [
                    curr_img.cuda(self.gpu, non_blocking=True)
                    for curr_img in curr_imgs
                ]
                ref_labels = ref_labels.cuda(self.gpu, non_blocking=True)
                prev_labels = prev_labels.cuda(self.gpu, non_blocking=True)
                curr_labels = [
                    curr_label.cuda(self.gpu, non_blocking=True)
                    for curr_label in curr_labels
                ]
                obj_nums = list(obj_nums)
                obj_nums = [int(obj_num) for obj_num in obj_nums]

                batch_size = ref_imgs.size(0)

                all_frames = torch.cat([ref_imgs, prev_imgs] + curr_imgs,
                                       dim=0)
                all_labels = torch.cat([ref_labels, prev_labels] + curr_labels,
                                       dim=0)

                self.engine.restart_engine(batch_size, True)
                optimizer.zero_grad(set_to_none=True)

                if self.enable_amp:
                    with torch.cuda.amp.autocast(enabled=True):
                        
                        loss, all_pred, all_loss, boards = model(
                            all_frames,
                            all_labels,
                            batch_size,
                            use_prev_pred=use_prev_pred,
                            obj_nums=obj_nums,
                            step=step,
                            tf_board=tf_board,
                            enable_prev_frame=self.enable_prev_frame,
                            use_prev_prob=use_prev_prob)
                        loss = torch.mean(loss)
                        
                    start = time.time()
                    self.scaler.scale(loss).backward()
                    end = time.time()
                    print(end-start)
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   cfg.TRAIN_CLIP_GRAD_NORM)
                    self.scaler.step(optimizer)
                    self.scaler.update()
                    
                else:
                    loss, all_pred, all_loss, boards = model(
                        all_frames,
                        all_labels,
                        ref_imgs.size(0),
                        use_prev_pred=use_prev_pred,
                        obj_nums=obj_nums,
                        step=step,
                        tf_board=tf_board,
                        enable_prev_frame=self.enable_prev_frame,
                        use_prev_prob=use_prev_prob)
                    loss = torch.mean(loss)

                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   cfg.TRAIN_CLIP_GRAD_NORM)
                    loss.backward()
                    optimizer.step()

                for idx in range(seq_len):
                    now_pred = all_pred[idx].detach()
                    now_label = all_labels[idx * bs:(idx + 1) * bs].detach()
                    now_loss = torch.mean(all_loss[idx].detach())
                    now_iou = pytorch_iou(now_pred.unsqueeze(1), now_label,
                                          obj_nums) * 100
                    if self.cfg.DIST_ENABLE:
                        dist.all_reduce(now_loss)
                        dist.all_reduce(now_iou)
                        now_loss /= self.gpu_num
                        now_iou /= self.gpu_num
                    if self.rank == 0:
                        running_losses[idx].update(now_loss.item())
                        running_ious[idx].update(now_iou.item())

                if self.rank == 0:
                    self.ema.update(self.ema_params)

                    avg_obj.update(sum(obj_nums) / float(len(obj_nums)))
                    curr_time = time.time()
                    batch_time.update(curr_time - last_time)
                    last_time = curr_time

                    if step % cfg.TRAIN_TBLOG_STEP == 0:
                        all_f = [ref_imgs, prev_imgs] + curr_imgs
                        self.process_log(ref_imgs, all_f[-2], all_f[-1],
                                         ref_labels, all_pred[-2], now_label,
                                         now_pred, boards, running_losses,
                                         running_ious, now_lr, step)

                    if step % cfg.TRAIN_LOG_STEP == 0:
                        strs = 'I:{}, LR:{:.5f}, T:{:.1f}({:.1f})s, Obj:{:.1f}({:.1f})'.format(
                            step, now_lr, batch_time.val,
                            batch_time.moving_avg, avg_obj.val,
                            avg_obj.moving_avg)
                        batch_time.reset()
                        avg_obj.reset()
                        for idx in range(seq_len):
                            strs += ', {}: L {:.3f}({:.3f}) IoU {:.1f}({:.1f})%'.format(
                                frame_names[idx], running_losses[idx].val,
                                running_losses[idx].moving_avg,
                                running_ious[idx].val,
                                running_ious[idx].moving_avg)
                            running_losses[idx].reset()
                            running_ious[idx].reset()

                        self.print_log(strs)

                step += 1

                if step % cfg.TRAIN_SAVE_STEP == 0 and self.rank == 0:
                    max_mem = torch.cuda.max_memory_allocated(
                        device=self.gpu) / (1024.**3)
                    ETA = str(
                        datetime.timedelta(
                            seconds=int(batch_time.moving_avg *
                                        (cfg.TRAIN_TOTAL_STEPS - step))))
                    self.print_log('ETA: {}, Max Mem: {:.2f}G.'.format(
                        ETA, max_mem))
                    self.print_log('Save CKPT (Step {}).'.format(step))
                    save_network(self.model,
                                 optimizer,
                                 step,
                                 cfg.DIR_CKPT,
                                 cfg.TRAIN_MAX_KEEP_CKPT,
                                 backup_dir='./backup/{}/{}/ckpt'.format(
                                     cfg.EXP_NAME, cfg.STAGE_NAME),
                                 scaler=self.scaler)
                    try:
                        torch.cuda.empty_cache()
                        # First save original parameters before replacing with EMA version
                        self.ema.store(self.ema_params)
                        # Copy EMA parameters to model
                        self.ema.copy_to(self.ema_params)
                        # Save EMA model
                        save_network(
                            self.model,
                            optimizer,
                            step,
                            self.ema_dir,
                            cfg.TRAIN_MAX_KEEP_CKPT,
                            backup_dir='./backup/{}/{}/ema_ckpt'.format(
                                cfg.EXP_NAME, cfg.STAGE_NAME),
                            scaler=self.scaler)
                        # Restore original parameters to resume training later
                        self.ema.restore(self.ema_params)
                    except Exception as inst:
                        self.print_log(inst)
                        self.print_log('Error: failed to save EMA model!')

        self.print_log('Stop training!')

    def print_log(self, string):
        if self.rank == 0:
            print(string)

    def process_log(self, ref_imgs, prev_imgs, curr_imgs, ref_labels,
                    prev_labels, curr_labels, curr_pred, boards,
                    running_losses, running_ious, now_lr, step):
        cfg = self.cfg

        mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
        sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

        show_ref_img, show_prev_img, show_curr_img = [
            img.cpu().numpy()[0] * sigma + mean
            for img in [ref_imgs, prev_imgs, curr_imgs]
        ]

        show_gt, show_prev_gt, show_ref_gt, show_preds_s = [
            label.cpu()[0].squeeze(0).numpy()
            for label in [curr_labels, prev_labels, ref_labels, curr_pred]
        ]

        show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [
            label2colormap(label).transpose((2, 0, 1))
            for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s]
        ]

        if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG:

            show_ref_img = masked_image(show_ref_img, show_ref_gtf,
                                        show_ref_gt)
            if cfg.TRAIN_IMG_LOG:
                save_image(
                    show_ref_img,
                    os.path.join(cfg.DIR_IMG_LOG,
                                 '%06d_ref_img.jpeg' % (step)))

            show_prev_img = masked_image(show_prev_img, show_prev_gtf,
                                         show_prev_gt)
            if cfg.TRAIN_IMG_LOG:
                save_image(
                    show_prev_img,
                    os.path.join(cfg.DIR_IMG_LOG,
                                 '%06d_prev_img.jpeg' % (step)))

            show_img_pred = masked_image(show_curr_img, show_preds_sf,
                                         show_preds_s)
            if cfg.TRAIN_IMG_LOG:
                save_image(
                    show_img_pred,
                    os.path.join(cfg.DIR_IMG_LOG,
                                 '%06d_prediction.jpeg' % (step)))

            show_curr_img = masked_image(show_curr_img, show_gtf, show_gt)
            if cfg.TRAIN_IMG_LOG:
                save_image(
                    show_curr_img,
                    os.path.join(cfg.DIR_IMG_LOG,
                                 '%06d_groundtruth.jpeg' % (step)))

            if cfg.TRAIN_TBLOG:
                for seq_step, running_loss, running_iou in zip(
                        range(len(running_losses)), running_losses,
                        running_ious):
                    self.tblogger.add_scalar('S{}/Loss'.format(seq_step),
                                             running_loss.avg, step)
                    self.tblogger.add_scalar('S{}/IoU'.format(seq_step),
                                             running_iou.avg, step)

                self.tblogger.add_scalar('LR', now_lr, step)
                self.tblogger.add_image('Ref/Image', show_ref_img, step)
                self.tblogger.add_image('Ref/GT', show_ref_gtf, step)

                self.tblogger.add_image('Prev/Image', show_prev_img, step)
                self.tblogger.add_image('Prev/GT', show_prev_gtf, step)

                self.tblogger.add_image('Curr/Image_GT', show_curr_img, step)
                self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step)

                self.tblogger.add_image('Curr/Mask_GT', show_gtf, step)
                self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step)

                for key in boards['image'].keys():
                    tmp = boards['image'][key].cpu().numpy()
                    self.tblogger.add_image('S{}/' + key, tmp, step)
                for key in boards['scalar'].keys():
                    tmp = boards['scalar'][key].cpu().numpy()
                    self.tblogger.add_scalar('S{}/' + key, tmp, step)

                self.tblogger.flush()

        del (boards)
