import matplotlib.pyplot as plt
from numpy.core.numeric import cross
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.functional import align_tensors
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch
import numpy as np
import gc
from pathlib import Path
import loss_functions
import models
from dataset import get_datasets
from dataset_hdf5 import get_hdf5_datasets
from dataset_city import get_city_datasets
from utils import get_logger, save_epoch_summary, get_datetime_str_simplified, get_datetime_str, MyTransform,Cityscapes,BalancedDataParallel
from utils_torch import PolynomialLRDecay
from torchvision.transforms.transforms import CenterCrop
import sys
sys.path.append("src/models/DynamicRouting")


class Train:
    def __init__(self, args):
        self.args = args
        self.writer = SummaryWriter(
            log_dir=args.log_dir + args.name)
        self.logger = get_logger(__name__, args)
        self.label_names = ["Water", "Forest", "Field", "Others"]
        self.ts = [0.5,0.5,0]
    def run(self):
        # prepare datasets and dataloaders for training and validation
        # disable shuffle in validation dataloader
        if self.args.cityscapes:
            ds_path = '/scratch/forest/datasets/cityscapes'
            self.training_ds = Cityscapes(root=ds_path, split='train', target_type='semantic', transforms=MyTransform(size=768, random_crop=True))
            self.validation_ds = Cityscapes(root=ds_path, split='val', target_type='semantic', transforms=MyTransform(size=0, random_crop=False))
            print('Train and Validation data load successful!!!')
        elif self.args.city:
            self.training_ds, self.validation_ds = get_city_datasets(self.args)
        else:
            self.training_ds, self.validation_ds = get_datasets(self.args)

        training_dl = torch.utils.data.DataLoader(self.training_ds, batch_size=self.args.batch_size, shuffle=True,
                                                  num_workers=self.args.workers, drop_last=True)
        if self.args.cityscapes:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=1, shuffle=False,
                                                        num_workers=self.args.workers, drop_last=True)
        else:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=self.args.batch_size, shuffle=False,
                                                        num_workers=self.args.workers, drop_last=True)

        # device used to run experiments
        device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.device = device
        model = self.get_model()
        model = torch.nn.DataParallel(model)
        print("Let's use ", torch.cuda.device_count(), " GPUs!")

        if self.args.resume_model:
            checkpoint = torch.load(self.args.resume_model)
            model.load_state_dict(checkpoint["model_state_dict"])

        model.to(device)

        # construct an optimizer and a learning rate scheduler
        params = [p for p in model.parameters() if p.requires_grad]

        # freeze attention part
        self.freeze_params = []
        # for name, p in model.named_parameters():
        #     if "att" in name:
        #         p.requires_grad = False
        #         self.freeze_params.append(p)

        optimizer = None
        if self.args.net_model == "Context_Guided_Network":
            optimizer = torch.optim.Adam(params, self.args.learning_rate,
                                         (0.9, 0.999), eps=1e-08, weight_decay=5e-4)
        else:
            optimizer = torch.optim.SGD(params, lr=self.args.learning_rate,  # 1e-2
                                        momentum=0.9, weight_decay=0.0005)
            # if self.args.net_model == "UNetWeight":
            #     optimizer = torch.optim.SGD([p for name, p in model.named_parameters() if p.requires_grad and 'unet' in name], lr=self.args.learning_rate,
            #                             momentum=0.9, weight_decay=0.0005)
            #     self.opt_w = torch.optim.SGD([p for name, p in model.named_parameters() if p.requires_grad and 'weight' in name], lr=1e-3,
            #                             momentum=0.9, weight_decay=0.0005)

        if self.args.resume:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # lr_scheduler = torch.optim.lr_scheduler.StepLR(
        #     optimizer, step_size=self.args.step_size, gamma=0.1)
        # lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer, 'max', factor=0.5, patience=self.args.step_size)
        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            # optimizer, T_0 = self.args.epochs, eta_min=1e-3, verbose=True)
        self.lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=self.args.epochs*min(len(training_dl), self.args.training_steps_per_epoch),
                                              end_learning_rate=1e-4, power=0.9)

        # get the loss functions
        if self.args.net_model in ["unet-m", "unet-m2"]:
            loss_to_use = loss_functions.multihead_loss()
        elif self.args.net_model == "unet-feature-supervised-triplet":
            loss_to_use = loss_functions.feature_triplet_center_loss()
        elif self.args.net_model.find("feature") != -1:
            loss_to_use = loss_functions.feature_supervised_loss()
        elif self.args.net_model == "unet-edge":
            loss_to_use = loss_functions.edge_loss_with_weight()
        elif self.args.net_model == "unet-multihead-edge-att":
            loss_to_use = loss_functions.multihead_edge_loss()
        elif self.args.net_model == "unet-nlpl":
            loss_to_use = loss_functions.NLPL_loss()
        # else:
        #     lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=self.args.epochs*min(len(training_dl), self.args.training_steps_per_epoch), end_learning_rate=0.0001, power=0.9)

        crossentropy_loss = loss_functions.crossentropy_loss()
        jaccard = loss_functions.jaccard(self.args.hr_nclasses)
        accuracy = loss_functions.accuracy()
        loss_funcs = [None, crossentropy_loss, jaccard, accuracy]

        # train the model for given epochs
        epoch = 0 if not self.args.resume else (checkpoint['epoch'] + 1)
        best_jac = 0
        while epoch < self.args.epochs:
            self.train_one_epoch(model, optimizer, loss_funcs,
                                 training_dl, device, epoch)  # if use lr_scheduler, please set in here
            jac = self.evaluate(model, loss_funcs,
                                validation_dl, device, epoch)
            # lr_scheduler.step(jac)
            # lr_scheduler.step()
            if jac > best_jac:
                best_jac = jac
                self.save_model(-1, self.args.name, model, optimizer)  # best model saved as model_state_dict_epoch_0.tar

            epoch += 1

        self.writer.flush()
        self.writer.close()

        try:
            self.final_evaluate(model, loss_funcs, validation_dl)
        except Exception as e:
            print("Final Evaludaiton Failed.")
            print('Exception: ', e)

    def get_model(self):
        model = None
        if self.args.net_model == 'unet':
            model = models.UNet(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-arf':
            model = models.UNet_arf(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-patch':
            model = models.UNet_patch(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-refine':
            model = models.UNet_refine(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-topoaware':
            model = models.UNet_topoaware(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-foreground':
            model = models.UNet_foreground(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-nodec':
            model = models.UNet_nodec(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        if self.args.net_model == 'unet-m':
            model = models.UNet_m()
        elif self.args.net_model == 'unet-m2':
            model = models.UNet_m2()
        elif self.args.net_model == 'unet-multihead-edge-att':
            model = models.UNet_multihead_edge_att()
        elif self.args.net_model == 'unet-up':
            model = models.UNet_up()
        elif self.args.net_model == 'unet-inplace':
            model = models.UNet_inplace()
        elif self.args.net_model == 'unet-feature-supervised':
            model = models.UNet_feature_supervised(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
            self.anchors = torch.randn(5, 4, 240, 240).cuda()
            self.anchors.requires_grad = True
        elif self.args.net_model == 'unet-feature-supervised-real':
            model = models.UNet_feature_supervised(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
            self.anchors = torch.tensor(np.load("data/anchors.npy")).float().cuda()
            self.anchors.requires_grad = True
        elif self.args.net_model == 'unet-feature-supervised-triplet':
            model = models.UNet_feature_supervised(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
            self.argmax_pooling = models.ArgmaxPooling()
        elif self.args.net_model == 'unet-edge':
            model = models.UNet_edge_implicit(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-s':
            model = models.UNet_s(in_ch=self.args.input_nchannels,
                                  out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-fuse':
            model = models.UNet_fuse(in_ch=self.args.input_nchannels,
                                     out_ch=self.args.hr_nclasses)
        elif self.args.net_model == 'unet-fuse-x':
            model = models.UNet_fuse_x(in_ch=self.args.input_nchannels,
                                       out_ch=self.args.hr_nclasses)
        elif self.args.net_model == "a-unet-v":
            model = models.KAUNet_v(in_ch=self.args.input_nchannels,
                                    out_ch=self.args.hr_nclasses,
                                    att_mh=self.args.att_mh,
                                    att_sm=self.args.att_sm,
                                    att_ks=self.args.att_ks,
                                    att_two_w=self.args.att_two_w)
        elif self.args.net_model == "a-unet-e-v":
            model = models.KAUNet_e(in_ch=self.args.input_nchannels,
                                    out_ch=self.args.hr_nclasses,
                                    att_mh=self.args.att_mh,
                                    att_sm=self.args.att_sm,
                                    att_ks=self.args.att_ks,
                                    att_two_w=self.args.att_two_w)
        elif self.args.net_model == "deeplab":
            model = models.DeepLab(backbone="resnet", in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses, pretrained=False)
        elif self.args.net_model == "deeplab-att-101":
            model = models.DeepLab_att(backbone="resnet-modified-101", in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model == "deeplab-att-34":
            model = models.DeepLab_att(backbone="resnet-modified-34",  in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model == "deeplab-att-pix-34":
            model = models.DeepLab_att_pix(backbone="resnet-modified-34",  in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model == "deeplab-att-begining-34":
            model = models.DeepLab_att_begining(backbone="resnet-modified-34",  in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model == "dynamic-routing":
            model = models.Dynamic_C(self.args)
        elif self.args.net_model == "deeplab-101-nested":
            model = models.DeepLab_101_nested(backbone="resnet-101-nested",  in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model == "deeplab-101-nested-fixed-gate":
            model = models.DeepLab_101_nested_fixed_gate(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model in ["deeplab-101-nested-fixed-gate-5layers", "deeplab-101-nested-fixed-gate-5layers-nogateloss"]:
            model = models.DeepLab_101_nested_fixed_gate_5layers(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model.startswith("deeplab-101-nested-fixed-gate-5layers-decatt"):
            model = models.DeepLab_101_nested_fixed_gate_5layers_decatt(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)
        elif self.args.net_model == "unet":
            model = models.UNet(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        # elif self.args.net_model == "UNetWeight":
        #     model = models.UNet(self.args.input_nchannels, self.args.hr_nclasses)
        #     model_w = models.UNet(5 + 4, 1)
        #     return model, model_w
        elif self.args.net_model == "unet-nlpl":
            model = models.UNet(self.args.input_nchannels, self.args.hr_nclasses)
        elif self.args.net_model == "deeplab-pretrained":
            model = models.DeepLab(backbone="resnet", in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses, pretrained=True)
        else:
            model = getattr(models, self.args.net_model)(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)

        return model

    def train_one_epoch(self, model, optimizer, loss_funcs, dataloader, device, epoch):
        self.logger.info(f"{'=' * 50}")
        self.logger.info(f"              Epoch {epoch + 1:3d}      Training")
        self.logger.info(f"{'=' * 50}")
        batch_num = min(len(dataloader), self.args.training_steps_per_epoch)
        if batch_num == 0:
            return

        if (epoch + 1) == self.args.freeze_until:
            for p in self.freeze_params:
                p.requires_grad = True

        model.train()
        self.one_epoch(model, optimizer, loss_funcs, dataloader,
                       device, epoch, batch_num, "training")

        if (epoch + 1) % self.args.model_save_checkpoint == 0:
            self.save_model(epoch, self.args.name, model, optimizer)

    def evaluate(self, model, loss_funcs, dataloader, device, epoch):
        self.logger.info(f"{'=' * 50}")
        self.logger.info(f"              Epoch {epoch + 1:3d}      Validation")
        self.logger.info(f"{'=' * 50}")
        batch_num = min(len(dataloader), self.args.validation_steps_per_epoch)
        if batch_num == 0:
            return 0
        with torch.no_grad():
            model.eval()
            jac = self.one_epoch(model, None, loss_funcs, dataloader,
                                 device, epoch, batch_num, "validation")
        return jac

    def one_epoch(self, model, optimizer, loss_funcs, dataloader, device, epoch, batch_num, task, lr_scheduler=None):
        assert task == "training" or task == "validation" or task == "final"
        ep_records = {
            "loss_to_use": [], "crossentropy": [], "jaccard": [], "accuracy": [],
            "remote_jaccard": [], "remote_accuracy": [], "city_jaccard": [], "city_accuracy": [],
            "remote_counter": [], "city_counter": [],
            "start_time": get_datetime_str(),
            "confusion_matrix": np.zeros((4, 4)).astype(np.long),
            "edge_inter": 0,
            "edge_union": 0,
        }

        for batch_idx, (x, y_hr) in enumerate(iter(dataloader)):
            if not self.args.cityscapes:
                x = x.to(device)  # B, 4, H, W
            y_hr = y_hr.to(device).squeeze(dim=1).long().clamp(max=self.args.hr_nclasses - 1, min=0)  # B, H, W
            y_hr_onehot = torch.nn.functional.one_hot(y_hr, self.args.hr_nclasses).permute(0, 3, 1, 2)  # B, 5, H, W

            if self.args.net_model in ["dynamic-routing"]:
                step_rate = (batch_idx + 1 + batch_num * epoch) / (batch_num * self.args.epochs)
                outputs = model(x, step_rate)
            elif self.args.net_model == "unet-topoaware":
                outputs = model(x, y_hr_onehot)
            elif self.args.net_model == "UNetWeight":
                outputs = model(x, y_hr_onehot)
                # w = self.model_w(torch.cat((x, y_hr_onehot.float()), dim=1))
                # w = w.sigmoid()
            elif self.args.cityscapes:
                mean = [123.675, 116.28, 103.53]
                std = [58.395, 57.12, 57.375]
                x_n = TF.normalize(x * 255.0, mean, std, inplace=False)
                x_n = x_n.to(device)
                outputs = model(x_n)
            else:
                outputs = model(x)

            # determine how to parse the outputs of model
            if self.args.net_model in ["unet-m", "unet-m2"]:
                y_hr_pred = outputs[0]
            elif self.args.net_model.endswith("-v") or self.args.net_model in ['KAUNet_v']:
                y_hr_pred, x_after, x_before, att = outputs
            elif self.args.net_model == "unet-edge":
                y_hr_pred = outputs
                # edge, area, y_hr_pred = outputs
            elif self.args.net_model == "unet-multihead-edge-att":
                y_hr_pred = outputs[0]
            elif self.args.net_model.find("feature") != -1:
                y_hr_pred = outputs[-1]
            elif self.args.net_model in ["deeplab-att-34", "deeplab-att-101"]:
                y_hr_pred = outputs[0]
                att_s = outputs[1].max()
            elif self.args.net_model in ["deeplab-att-pix-34", "deeplab-att-begining-34"]:
                y_hr_pred = outputs[0]
                att_s = outputs[1].mean()
            elif self.args.net_model in ["dynamic-routing"]:
                y_hr_pred, budget_loss, flops = outputs
            elif self.args.net_model.startswith("deeplab-101-nested-fixed-gate"):
                y_hr_pred, nested_gate_loss, bb_gate_loss = outputs
            elif self.args.net_model == "unet-arf":
                y_hr_pred, c = outputs
            elif self.args.net_model == "unet-topoaware":
                y_hr_pred, l_top = outputs
            elif self.args.net_model == "UNetWeight":
                y_hr_pred, w = outputs
            elif self.args.net_model in ["UNetScale", "ResUNetScale"]:
                pred1, pred2, pred3, pred4, pred5 = outputs
                y_hr_pred = pred1
            elif self.args.net_model in ["UNetBoost", "ResUNetBoost"]:
                o = outputs
                target = y_hr_onehot.float()
                not_visited = torch.ones(o[0].shape[0], 1, o[0].shape[2], o[0].shape[3]).to(x.device)
                result = torch.zeros_like(o[0])
                loss_boost = 0
                for idx, scale_f in enumerate([16, 8, 4, 2, 1]):
                    target_pool = F.avg_pool2d(target, kernel_size=scale_f, stride=scale_f)

                    if scale_f != 1:
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=3, stride=1, padding=1).bool()
                        loss_boost += loss_functions.mae_loss(o[idx], target_pool, sign=expand)
                    else:
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=5, stride=1, padding=2).bool()
                        loss_boost += loss_functions.kl_loss(o[idx], target_pool, sign=expand)  # scale_f**0.5*

                    if task == 'training':
                        cur_visited = (torch.max(target_pool, dim=1, keepdim=True)[0] >= (1-1./(scale_f**2))).float()
                    else:

                        certain = F.softmax(o[idx], dim=1)
                        certain = torch.sort(certain, dim=1, descending=True)[0]
                        certain = certain[:, 0:1]  # -uncertain[:,1:2]
                        cur_visited = certain >= [0.9, 0.9, 0.9, 0.9, 0.2][idx]  # max(0,1-4./(scale_f**2))
                        cur_visited = cur_visited.float()

                    # cache for next step feature
                    result += (not_visited*cur_visited)*o[idx]

                    not_visited = not_visited*(1-cur_visited)
                    if scale_f != 1:
                        not_visited = F.upsample_nearest(not_visited, scale_factor=2)
                        result = F.upsample_nearest(result, scale_factor=2)

                y_hr_pred = result
            elif self.args.net_model in ["PointRend"]:
                if task == "training":
                    y_hr_pred = F.interpolate(outputs["coarse"], x.shape[-2:], mode="bilinear", align_corners=True)
                else:
                    y_hr_pred = outputs["fine"]
            # elif self.args.net_model.startswith('UNet_boost'):  # or self.args.net_model.startswith('DeepLabv3p_boost'): #,'UNet_boost_feacat']:
            #     if self.args.name_prefix.endswith('2'):
            #         o = outputs
            #         y_hr_pred = o[-1]
            #         loss = loss_functions.kl_loss(
            #             o[-1],
            #             y_hr_onehot,
            #         )
            #     elif self.args.name_prefix.endswith('1'):
            #         o = outputs
            #         target = y_hr_onehot
            #         not_visited = torch.ones(o[-1].shape[0], 1, o[-1].shape[2], o[-1].shape[3]).to(o[-1].device)
            #         result = torch.zeros_like(o[-1])
            #         loss = 0
            #         for idx, scale_f in enumerate([16, 8, 4, 2, 1]):
            #             expand = not_visited.float()
            #             expand = F.max_pool2d(expand.float(), kernel_size=1, stride=1, padding=0).bool()

            #             cur = F.upsample(o[idx], mode='bilinear', align_corners=True, scale_factor=scale_f)
            #             loss += loss_functions.kl_loss(
            #                 cur,
            #                 target,
            #                 sign=expand)
            #             if batch_idx % 25 == 0:
            #                 print(torch.sum(not_visited)/torch.numel(not_visited))
            #             cur = F.softmax(cur, dim=1)
            #             if task == 'training' or task == 'validation':  # and epoch<20:
            #                 ts = [0.5+(self.big_ts-0.5)*epoch*1./self.args.epochs,
            #                       0.5+(self.big_ts-0.5)*epoch*1./self.args.epochs,
            #                       0.5+(self.big_ts-0.5)*epoch*1./self.args.epochs,
            #                       0.5+(self.big_ts-0.5)*epoch*1./self.args.epochs,
            #                       0]
            #             else:
            #                 ts = [self.big_ts, self.big_ts, self.big_ts, self.big_ts, 0]

            #             if scale_f == 1:
            #                 cur_visited = 1.
            #             else:
            #                 # 使用不同的uncertain 计算方式，跟其远讨论一下
            #                 uncertain = cur
            #                 uncertain = torch.sort(uncertain, dim=1, descending=True)[0]
            #                 uncertain = uncertain[:, 0:1]  # -uncertain[:,1:2]
            #                 cur_visited = uncertain >= ts[idx]  # max(0,1-4./(scale_f**2))
            #                 cur_visited = cur_visited.float()

            #                 # loss2 to increase the CONFIDENCE of threshold
            #                 # if self.args.name_prefix.endswith('10_1'):
            #                 #     loss += scale_f**0.5*loss_functions.mae_loss_woSOFT(cur[cur>=ts[idx]],target[cur>=ts[idx]])
            #                 # else:
            #                 loss += loss_functions.mae_loss_woSOFT(cur[cur >= ts[idx]], target[cur >= ts[idx]])
            #             # cache for next step feature
            #             result += (not_visited*cur_visited)*cur.detach()

            #             not_visited = not_visited*(1-cur_visited)

            #         y_hr_pred = result
            elif self.args.net_model.startswith('DeepLabv3p_boost'):  # ,'UNet_boost_feacat']:
                if self.args.name_prefix.endswith('0_'):
                    o = outputs
                    y_hr_pred = o[-1]
                    y_hr_pred = F.interpolate(y_hr_pred, size=x.size()[2:], mode='bilinear', align_corners=True)

                    loss = loss_functions.kl_loss(y_hr_pred, y_hr_onehot)
                elif self.args.name_prefix.endswith('1_'):
                    o = outputs
                    target = y_hr_onehot
                    B,C,H,W= x.shape
                    not_visited = torch.ones(B,1,H,W).to(o[-1].device)
                    result = torch.zeros(B, o[-1].shape[1], H,W).to(o[-1].device)
                    loss = 0
                    for idx, scale_f in enumerate([16, 8, 4]):
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=5, stride=1, padding=2).bool()

                        cur = F.upsample(o[idx], mode='bilinear', align_corners=True, scale_factor=scale_f)
                        cur = F.softmax(cur, dim=1)

                        if batch_idx % 25 == 0:
                            print(torch.sum(not_visited)/torch.numel(not_visited))
                        
                        # if task == 'training' or task == 'validation':  # and epoch<20:
                        #     # ts = [self.big_ts, self.big_ts, 0]
                        #     ts = [0.6+(self.big_ts-0.6)*epoch*1./self.args.epochs,
                        #           0.6+(self.big_ts-0.6)*epoch*1./self.args.epochs,
                        #           0]
                        # else:
                        #     ts = [self.big_ts, self.big_ts, 0]

                        if scale_f == 4:
                            cur_visited = 1.
                            loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                        else:
                            uncertain = torch.max(cur,dim = 1,keepdim = True)[0]
                            ts = self.ts
                            if (task == 'training' or task == 'validation') and batch_idx%25==0 and idx<2:
                                sign = torch.argmax(cur,dim=1,keepdim = True)==torch.argmax(y_hr_onehot,dim=1,keepdim=True)
                                correct_dis = uncertain[sign*expand]
                                error_dis = uncertain[(~sign)*expand]
                                self.ts[idx] = 0.9*self.ts[idx]+0.1*torch.quantile(error_dis,q= 0.75)
                                if idx == 1:print(self.ts)

                            cur_visited = uncertain >= ts[idx]  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()

                            # loss2 to increase the CONFIDENCE of threshold
                            # if self.args.name_prefix.endswith('10_1'):
                            #     loss += scale_f**0.5*loss_functions.mae_loss_woSOFT(cur[cur>=ts[idx]],target[cur>=ts[idx]])
                            # else:
                            if  self.args.name_prefix.endswith('decrease_1_'):
                                rate = (self.args.epochs//2 + self.args.epochs//5- epoch)/(self.args.epochs//5)
                                rate = min(1,max(rate,0))
                                loss += rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)+(1-rate)*loss_functions.kl_loss_woSOFT(cur, target,sign= (uncertain>=ts[idx])*expand)
                            elif  self.args.name_prefix.endswith('decrease2_1_'):
                                rate = (self.args.epochs//2 + self.args.epochs//2- epoch)/(self.args.epochs//2)
                                rate = min(1,max(rate,0))
                                loss += rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)+(1-rate)*loss_functions.kl_loss_woSOFT(cur, target,sign= (uncertain>=ts[idx])*expand)
                            elif self.args.name_prefix.endswith('alter_1_'):
                                rate = [0,1][epoch%2]
                                loss += rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)+(1-rate)*loss_functions.kl_loss_woSOFT(cur, target,sign= (uncertain>=ts[idx])*expand)
                            else:
                                loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)+loss_functions.kl_loss_woSOFT(cur, target,sign= (uncertain>=ts[idx])*expand)
                        # cache for next step feature
                        result += (not_visited*cur_visited)*cur.detach()

                        not_visited = not_visited*(1-cur_visited)

                    y_hr_pred = result.detach()
                
            else:
                y_hr_pred = outputs

            y_hr_pred_am = y_hr_pred.detach().argmax(dim=1)
            y_hr_pred_onehot = torch.nn.functional.one_hot(y_hr_pred_am, self.args.hr_nclasses).permute(0, 3, 1, 2)

            # compute loss and metrics. loss_funcs = [loss_to_use, crossentropy_loss, jaccard, accuracy]
            crossentropy_loss = loss_funcs[1](y_hr_pred, y_hr)
            jaccard = loss_funcs[2](y_hr_pred_onehot.detach(), y_hr_onehot)
            accuracy = loss_funcs[3](y_hr_pred_onehot.detach(), y_hr_onehot)
            remote_jaccard, city_jaccard, remote_counter, city_counter = torch.zeros(1),torch.zeros(1),torch.zeros(1),torch.zeros(1) # loss_functions.multi_jaccard(self.args.hr_nclasses, rate=0.6)(y_hr_pred_onehot, y_hr_onehot)
            remote_accuracy, city_accuracy, c1, c2 =torch.zeros(1),torch.zeros(1),torch.zeros(1),torch.zeros(1) # loss_functions.multi_accuracy(self.args.hr_nclasses, rate=0.6)(y_hr_pred_onehot, y_hr_onehot)
            inter, union = torch.zeros(1),torch.zeros(1) #loss_functions.edge_inter_union()(y_hr_pred_onehot, y_hr_onehot)

            # determine the loss to be optimized
            if self.args.net_model in ["unet-m", "unet-m2"]:
                _loss_func = loss_functions.multihead_loss()
                loss_to_use = _loss_func(outputs, y_hr, self.writer, epoch, task == "validation" and batch_idx == 0) * 0.25 + crossentropy_loss * 1
            elif self.args.net_model == "unet-feature-supervised-triplet":
                y_aps = self.argmax_pooling(y_hr)
                _loss_func = loss_functions.feature_triplet_center_loss()
                loss_to_use = _loss_func(outputs, y_aps, self.writer, epoch, task == "validation" and batch_idx == 0) * 0.2 + \
                    crossentropy_loss * 1
            elif self.args.net_model.find("feature") != -1:
                a_outputs = model(self.anchors)
                _loss_func = loss_functions.feature_supervised_loss()
                loss_to_use = _loss_func(outputs, a_outputs, y_hr, self.writer, epoch, task == "validation" and batch_idx == 0)
            elif self.args.net_model in ["NestedUNet", "unet-edge"]:
                # _loss_func = loss_functions.edge_loss()
                # loss_to_use = _loss_func(y_hr_pred, y_hr) *10+ crossentropy_loss * 1
                loss_to_use = crossentropy_loss * 1  # + loss_functions.edge_loss()(y_hr_pred, y_hr)*10
            elif self.args.net_model == "unet-multihead-edge-att":
                _loss_func = loss_functions.multihead_edge_loss()
                loss_to_use = _loss_func(outputs, y_hr)
            elif self.args.net_model in ["deeplab-att-101"]:
                # loss_to_use = crossentropy_loss
                loss_to_use = crossentropy_loss + att_s / 33 * 0.1
            elif self.args.net_model in ["deeplab-att-34"]:
                # loss_to_use = crossentropy_loss
                loss_to_use = crossentropy_loss + att_s / 16 * 0.1
            elif self.args.net_model in ["deeplab-att-pix-34"]:
                loss_to_use = crossentropy_loss + att_s / 16 * 0.1
            elif self.args.net_model in ["deeplab-att-begining-34"]:
                loss_to_use = crossentropy_loss + att_s * 0.1
            elif self.args.net_model in ["dynamic-routing"]:
                loss_to_use = crossentropy_loss + budget_loss.mean()
            elif self.args.net_model.startswith("deeplab-101-nested-fixed-gate"):
                loss_to_use = crossentropy_loss + \
                    (
                        (nested_gate_loss.mean() - self.args.nested_budget).clamp(min=0) ** 2 +
                        (bb_gate_loss.mean() - self.args.backbone_budget).clamp(min=0) ** 2
                    ) * self.args.budget_loss_ratio * (1 if (epoch + 1 >= self.args.freeze_until) else 0)
            elif self.args.net_model == "unet-topoaware":
                loss_to_use = crossentropy_loss + 0.1 * l_top.mean()
            elif self.args.net_model == "unet-nlpl":
                loss_to_use = loss_funcs[0](y_hr_pred, y_hr, epoch)
            elif self.args.net_model == "UNetScale":
                # l2 = F.cross_entropy(F.interpolate(pred2, scale_factor=2, mode='bilinear', align_corners=True), y_hr)
                # l3 = F.cross_entropy(F.interpolate(pred3, scale_factor=4, mode='bilinear', align_corners=True), y_hr)
                # l4 = F.cross_entropy(F.interpolate(pred4, scale_factor=8, mode='bilinear', align_corners=True), y_hr)
                # l5 = F.cross_entropy(F.interpolate(pred5, scale_factor=16, mode='bilinear', align_corners=True), y_hr)
                # loss_to_use = crossentropy_loss * 1 + (l2 + l3 + l4 + l5) * 0.2
                # y = F.avg_pool2d(y_hr_onehot.float(), kernel_size=2, stride=2)
                y = F.avg_pool2d(y_hr_onehot.float(), kernel_size=2, stride=2)
                l2 = F.kl_div(F.log_softmax(pred2, dim=1), y, reduction='none').sum(dim=1).mean()
                y = F.avg_pool2d(y, kernel_size=2, stride=2)
                l3 = F.kl_div(F.log_softmax(pred3, dim=1), y, reduction='none').sum(dim=1).mean()
                y = F.avg_pool2d(y, kernel_size=2, stride=2)
                l4 = F.kl_div(F.log_softmax(pred4, dim=1), y, reduction='none').sum(dim=1).mean()
                y = F.avg_pool2d(y, kernel_size=2, stride=2)
                l5 = F.kl_div(F.log_softmax(pred5, dim=1), y, reduction='none').sum(dim=1).mean()
                loss_to_use = crossentropy_loss * 1 + (l2 + l3 + l4 + l5) * 0.2
            elif self.args.net_model == "UNetBoost":
                loss_to_use = loss_boost
                (
                    (nested_gate_loss.mean() - self.args.nested_budget).clamp(min=0) ** 2 +
                    (bb_gate_loss.mean() - self.args.backbone_budget).clamp(min=0) ** 2
                ) * self.args.budget_loss_ratio * (1 if (epoch + 1 >= self.args.freeze_until) else 0)
            elif self.args.net_model in ["PointRend"]:
                if task == 'training':
                    seg_loss = crossentropy_loss
                    gt = y_hr
                    gt_points = models.point_sample(
                        gt.float().unsqueeze(1),
                        outputs["points"],
                        mode="nearest",
                        align_corners=False
                    ).squeeze_(1).long()

                    points_loss = F.cross_entropy(outputs["rend"], gt_points, ignore_index=255)

                    loss_to_use = seg_loss + points_loss
                else:
                    loss_to_use = crossentropy_loss
            elif self.args.net_model.startswith('UNet_boost') or self.args.net_model.startswith('DeepLabv3p_boost'):
                loss_to_use = loss
            elif self.args.net_model.startswith('UNet_gate'):
                loss_to_use = loss
            else:
                # loss_to_use = crossentropy_loss #loss_functions.focal_cross_entropy(y_hr_pred, y_hr,gamma=2)
                loss_to_use = crossentropy_loss  # + loss_functions.focal_cross_entropy(y_hr_pred, y_hr,gamma=2)    #+loss_functions.edge_loss()(y_hr_pred, y_hr)*10
                # loss_to_use = loss_functions.layer_edge_cross_entropy(y_hr_pred, y_hr,ks = 5,rate = 0.5)
            # _loss_func = loss_functions.adaptive_edge_loss()
            # loss_to_use = _loss_func(x,y_hr_pred, y_hr) #+crossentropy_loss

            if task == "training":
                if self.args.net_model == "UNetWeight":
                    w_m = 4
                    w_0 = 0.8
                    budget = 0.1
                    pretrain = 10
                    if epoch + 1 <= pretrain:
                        optimizer.zero_grad()
                        crossentropy_loss.backward()
                        optimizer.step()
                        y_hr_pred, w = model(x, y_hr_onehot)
                        # w = self.model_w(torch.cat((x, y_hr_onehot.float()), dim=1))
                        # w = w.sigmoid()
                        loss_to_use = (F.cross_entropy(y_hr_pred, y_hr, reduction='none') * (w * w_m + w_0).squeeze(dim=1)).mean() - (w.mean() - budget) ** 2 * w_m
                        self.opt_w.zero_grad()
                        loss_to_use.backward()
                        self.opt_w.step()
                    elif epoch + 1 > pretrain:
                        optimizer.zero_grad()
                        self.opt_w.zero_grad()
                        loss_to_use = (F.cross_entropy(y_hr_pred, y_hr, reduction='none') * (w * w_m + w_0).squeeze(dim=1)).mean() - (w.mean() - budget) ** 2 * w_m
                        loss_to_use.backward()
                        optimizer.step()
                        self.opt_w.step()
                        # y_hr_pred = model(x)
                        # w = self.model_w(torch.cat((x, y_hr_onehot.float()), dim=1)).sigmoid()
                        # loss_to_use = (F.cross_entropy(y_hr_pred, y_hr, reduction='none') * (w * w_m + w_0).squeeze(dim=1)).mean() - (budget - w.mean())**2
                        # self.opt_w.zero_grad()
                        # (loss_to_use * -1).backward()
                        # self.opt_w.step()

                else:
                    optimizer.zero_grad()
                    loss_to_use.backward()
                    optimizer.step()
                    self.lr_scheduler.step()

            # print batch log in console
            log_str = f"Batch: {(batch_idx + 1):4d} / {batch_num:4d}    " \
                f"Loss: {loss_to_use.item():.3f}    " \
                f"CrossEntropy Loss: {crossentropy_loss.item():.3f}    " \
                f"Jaccard: {jaccard.item():.3f}    " \
                f"Accuracy: {accuracy.item():.3f}   " \
                f"City-Jaccard: {city_jaccard.item():.3f}    " \
                f"City-Accuracy: {city_accuracy.item():.3f}   " \
                f"Remote-Jaccard: {remote_jaccard.item():.3f}    " \
                f"Remote-Accuracy: {remote_accuracy.item():.3f}"

            # if task != "final":
            self.logger.info(log_str)

            # store loss records
            ep_records["loss_to_use"].append(loss_to_use.cpu().item())
            ep_records["crossentropy"].append(crossentropy_loss.cpu().item())
            ep_records["jaccard"].append(jaccard.cpu().item())
            ep_records["accuracy"].append(accuracy.cpu().item())
            ep_records["remote_jaccard"].append(remote_jaccard.item())
            ep_records["city_jaccard"].append(city_jaccard.item())
            ep_records["remote_accuracy"].append(remote_accuracy.item())
            ep_records["city_accuracy"].append(city_accuracy.item())
            ep_records["remote_counter"].append(remote_counter.item())
            ep_records["city_counter"].append(city_counter.item())
            ep_records["edge_inter"] += inter.cpu().item()
            ep_records["edge_union"] += union.cpu().item()
            # if task == "validation":
            #     cm = confusion_matrix(y_hr.flatten().cpu(), y_hr_pred_am.flatten().cpu(), labels=[1, 2, 3, 4]).astype(np.long)
            #     ep_records["confusion_matrix"] += cm

            # # training minibatch tensorboard log
            # if task == "training":
            #     global_step = epoch * self.args.training_steps_per_epoch + batch_idx + 1
            #     self.save_tensorboard_scalars(loss_to_use.item(), crossentropy_loss.item(), superres_loss.item(),
            #                                   jaccard.item(), accuracy.item(), f"{task}-batch", global_step)

            # save tensorboard images in the first batch of each evaluation epoch
            if batch_idx == 0 and epoch % (max((self.args.epochs)//10, 1)) == 0:  # task == "validation" and
                if self.args.net_model == "unet-edge-implicit":
                    self.save_tensorboard_images_edge(x, y_hr, y_hr_pred, edge, area, epoch)
                elif self.args.net_model == "unet-arf":
                    self.save_tensorboard_images_arf(x, y_hr, y_hr_pred, c, epoch)
                elif self.args.net_model == "UNetWeight":
                    self.save_tensorboard_images_weight(x, y_hr, y_hr_pred, w, epoch)
                elif self.args.net_model in ["UNetScale", "ResUNetScale"]:
                    self.save_tensorboard_images_scale(x, y_hr, y_hr_pred, [pred2, pred3, pred4, pred5], epoch)
                elif self.args.net_model in ["UNetBoost", "ResUNetBoost"]:
                    self.save_tensorboard_images_scale(x, y_hr, y_hr_pred, [o[3], o[2], o[1], o[0]], epoch)
                else:
                    self.save_tensorboard_images(x, y_hr, y_hr_pred, epoch, task)
            # # visualize the attention in the first batch of each evaluation epoch
            # if self.args.net_model.endswith("-v") and task == "validation" and batch_idx == 0:
            #     # self.save_tensorboard_attention(x, y_hr, x_before, x_after, att, epoch)

            # terminate the epoch if reach the end of dataset or the given maximum steps
            if batch_idx + 1 >= batch_num:
                break

        # summarize epoch records
        ep_records["end_time"] = get_datetime_str()
        m_loss_to_use = np.mean(ep_records["loss_to_use"])
        m_crossentropy = np.mean(ep_records["crossentropy"])
        m_jaccard = np.mean(ep_records["jaccard"])
        m_accuracy = np.mean(ep_records["accuracy"])
        s_city_counter = np.sum(ep_records["city_counter"])
        s_remote_counter = np.sum(ep_records["remote_counter"])
        m_remote_jaccard = np.sum(np.array(ep_records["remote_jaccard"])*np.array(ep_records["remote_counter"]))/np.sum(ep_records["remote_counter"])
        m_city_jaccard = np.sum(np.array(ep_records["city_jaccard"])*np.array(ep_records["city_counter"]))/np.sum(ep_records["city_counter"])
        m_remote_accuracy = np.sum(np.array(ep_records["remote_accuracy"])*np.array(ep_records["remote_counter"]))/np.sum(ep_records["remote_counter"])
        m_city_accuracy = np.sum(np.array(ep_records["city_accuracy"])*np.array(ep_records["city_counter"]))/np.sum(ep_records["city_counter"])
        print(ep_records["edge_inter"], ep_records["edge_union"])

        log_str_1 = f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
            f'''Start at {ep_records["start_time"]}    ''' \
            f'''End at {ep_records["end_time"]}'''
        log_str_2 = f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
            f"Loss: {m_loss_to_use:.4f}    " \
            f"CrossEntropy Loss: {m_crossentropy:.4f}    " \
            f"Jaccard: {m_jaccard:.4f}    " \
            f"Accuracy: {m_accuracy:.4f}   " \
            f"m_city_jaccard: {m_city_jaccard:.4f}    " \
            f"m_city_accuracy: {m_city_accuracy:.4f}   " \
            f"m_remote_jaccard: {m_remote_jaccard:.4f}    " \
            f"m_remote_accuracy: {m_remote_accuracy:.4f}   "

        self.logger.info(log_str_1)
        self.logger.info(log_str_2)
        save_str = f'''Epoch {epoch + 1:3d} - {ep_records["start_time"]} - {ep_records["end_time"]} - ''' \
            f"Loss: {m_loss_to_use:.4f}    " \
            f"CrossEntropy Loss: {m_crossentropy:.4f}    " \
            f"Jaccard: {m_jaccard:.4f}    " \
            f"Accuracy: {m_accuracy:.4f}    "\
            f"Remote Jaccard: {m_remote_jaccard:.4f}    " \
            f"City Jaccard: {m_city_jaccard:.4f}    " \
            f"Remote Accuracy: {m_remote_accuracy:.4f}    " \
            f"City Accuracy: {m_city_accuracy:.4f}\n"
        save_epoch_summary(self.args, save_str)

        # save epoch summary to tensorboard
#         if task != 'final':
        self.save_tensorboard_scalars(m_loss_to_use, m_crossentropy, m_jaccard, m_accuracy,
                                      m_remote_jaccard, m_city_jaccard, m_remote_accuracy, m_city_accuracy,
                                      f"{task}-epoch",
                                      epoch + 1)

        # if task == "validation":
        #     self.save_tensorboard_cm(ep_records["confusion_matrix"], epoch)
        if hasattr(model.module, "save_tensorboard"):
            model.module.save_tensorboard = False

        return m_jaccard

    def save_tensorboard_scalars(self, loss_to_use, crossentropy_loss, jaccard, accuracy,
                                 remote_jaccard, city_jaccard, remote_accuracy, city_accuracy,
                                 task, step):
        self.writer.add_scalar(f"{task}/loss", loss_to_use, step)
        self.writer.add_scalar(f"{task}/crossentropy_loss", crossentropy_loss, step)
        self.writer.add_scalar(f"{task}/jaccard", jaccard, step)
        self.writer.add_scalar(f"{task}/accuracy", accuracy, step)

        self.writer.add_scalar(f"{task}/3remote_jaccard", remote_jaccard, step)
        self.writer.add_scalar(f"{task}/3city_jaccard", city_jaccard, step)
        self.writer.add_scalar(f"{task}/3remote_accuracy", remote_accuracy, step)
        self.writer.add_scalar(f"{task}/3city_accuracy", city_accuracy, step)

    def save_tensorboard_images(self, x, y, y_pred, epoch, task='validation'):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error), axis=1)  # B*4*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*4*240*3
        img_concat = img_concat.reshape((-1, 4*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_scale(self, x, y, y_pred, preds, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        preds = [i.softmax(dim=1).topk(k=2, dim=1)[0] for i in preds]
        preds = [i[:, 0:1, :, :].repeat(1, 3, 1, 1) for i in preds]
        # preds = [(i[:, 0:1, :, :] - i[:, 1:2, :, :]).repeat(1,3,1,1) for i in preds]
        pred2, pred3, pred4, pred5 = preds
        pred2 = F.interpolate(pred2, scale_factor=2, mode='nearest')
        pred3 = F.interpolate(pred3, scale_factor=4, mode='nearest')
        pred4 = F.interpolate(pred4, scale_factor=8, mode='nearest')
        pred5 = F.interpolate(pred5, scale_factor=16, mode='nearest')
        c = [pred2, pred3, pred4, pred5]

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        # c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[3], RF[2], RF[1], RF[0]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        # self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        # self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_arf(self, x, y, y_pred, c, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        c = [i.unsqueeze(dim=0) for i in c]
        c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[0], RF[1], RF[2], RF[3]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_weight(self, x, y, y_pred, w, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        w = w.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).unsqueeze(dim=1)
        w = w.detach().cpu().numpy()  # bs, 1, H, W

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, w), axis=1)  # B*5*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*5*W*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 240xB * 240x5 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_edge(self, x, y, y_pred, edge, area, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred[:, 1:, :, :].argmax(dim=1).detach().cpu()
        edge = F.sigmoid(edge.detach()).cpu()
        area = area[:, 1:, :, :].argmax(dim=1).detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        label_color_map = np.array([color_map[i - 1] for i in y])
        pred_color_map = np.array([color_map[i] for i in y_pred])
        edge = edge.squeeze().numpy()[:, :, :, np.newaxis].repeat(3, axis=-1)
        area_color_map = np.array([color_map[i] for i in area])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        error_mask[:, :, :, 0][y_pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)
        edge = np.expand_dims(edge, axis=1)
        area_color_map = np.expand_dims(area_color_map, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, edge, area_color_map), axis=1)  # B*6*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*6*240*3
        img_concat = img_concat.reshape((-1, 6*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_layer(self, xs, epoch, task='validation'):
        '''
        xs: list
        '''
        xs = [i[:, :3] for i in xs]
        x = torch.stack(xs, dim=1).cpu().numpy()
        B, L, C, H, W = x.shape
        x = (x-np.min(x, axis=(3, 4), keepdims=True))
        x = (x-np.max(x, axis=(3, 4), keepdims=True))

        x = x.transpose(0, 2, 3, 1, 4).reshape(B*C*H, L*W, 1)
        self.writer.add_image(f"{task}-layer visualize", x, epoch + 1, dataformats='HWC')

    def save_tensorboard_onlyatt(self, att, epoch, task='validation'):
        # visualize attention histogram
        '''
        att: B*mh*ks2*H*W
        '''
        B, mh, _, H, W = att.shape
        att = att[:, :1].detach().cpu().numpy()  # select the first attention column
        att_max = np.argmax(att, axis=2)
        ks = int(att.shape[2]**0.5)
        att = att.reshape(B, ks, ks, H, W)
        att = att.transpose(0, 3, 1, 4, 2)
        att = att.reshape(B*H*ks, W*ks, 1)

        self.writer.add_histogram(f"{task}-attention-value", att, epoch + 1)
        self.writer.add_histogram(f"{task}-attention-argmax", att_max + 1, epoch + 1, bins='auto')
        self.writer.add_image(f"{task}-attention-kernel", att, epoch + 1, dataformats='HWC')

    def save_tensorboard_attention(self, x, y, x_before, x_after, att, epoch, task='validation'):

        def normalize(t):
            t = t - t.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0]
            t = t / t.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0]

            return t

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).unsqueeze(dim=1).cpu().numpy()  # B*1*240*240*3
        label_color_map = np.array([color_map[i - 1] for i in y.unsqueeze(dim=1).cpu()])  # B*1*240*240*3
        img_label = np.concatenate((images, label_color_map), axis=2)  # B*1*480*240*3

        x_before = normalize(x_before).unsqueeze(dim=-1)  # B*4*240*240*1
        x_after = normalize(x_after).unsqueeze(dim=-1)  # B*4*240*240*1
        x_before_after = torch.cat((x_before, x_after), dim=2).repeat(1, 1, 1, 1, 3).cpu().numpy()  # B*4*480*240*3

        img_concat = np.concatenate((img_label, x_before_after), axis=1)  # B*5*480*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*480*5*240*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 480xB * 240x5 * 3

        self.writer.add_image(f"{task}-attention", img_concat, epoch + 1, dataformats='HWC')

        # visualize attention histogram
        self.save_tensorboard_onlyatt(att, epoch=epoch, task=task)

    def save_tensorboard_cm(self, cm, epoch):
        cm_s = np.sum(cm, axis=1, keepdims=True) + 1e-6
        cm = cm / cm_s
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.label_names)
        disp.plot(cmap=plt.cm.Blues, values_format=".3g")
        fig = disp.figure_
        fig.set_dpi(150)
        fig.set_size_inches(4, 4)
        self.writer.add_figure(f"validation-confusion-matrix", fig, epoch + 1)

    def save_model(self, epoch, name, model, optimizer):
        # args.output_path is a Path object
        path = (self.args.output_path /
                Path(f"model_state_dict_epoch_{epoch + 1}.tar"))
        d = {
            'epoch': epoch,
            'name': name,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }
        if hasattr(self, "anchors"):
            d['anchors'] = self.anchors
        torch.save(d, path.__str__())
        self.logger.info(f"Model checkpoint saved at {str(path)}")

    def final_evaluate(self, model, loss_funcs, dataloader):
        del model

        model = self.get_model()
        model = torch.nn.DataParallel(model)

        path = (self.args.output_path /
                Path(f"model_state_dict_epoch_0.tar"))
        checkpoint = torch.load(path)

        model.load_state_dict(checkpoint["model_state_dict"])
        model.to(self.device)

        if hasattr(self, "argmax_pooling"):
            self.argmax_pooling = torch.nn.DataParallel(self.argmax_pooling)
            self.argmax_pooling.to(self.device)

        self.logger.info(f"{'=' * 50}")
        self.logger.info(f"              Final      Evaluation")
        self.logger.info(f"{'=' * 50}")
        batch_num = len(dataloader)
        with torch.no_grad():
            model.eval()
            self.one_epoch(model, None, loss_funcs, dataloader,
                           self.device, -1, batch_num, "final")
