import matplotlib.pyplot as plt
from numpy.core.numeric import cross
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.functional import align_tensors
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch
import numpy as np
import gc
from pathlib import Path
import loss_functions
import models
from dataset import get_datasets
from dataset_hdf5 import get_hdf5_datasets
from dataset_city import get_city_datasets
from dataset_isaid import get_isaid_datasets
from dataset_cityscapes import Cityscapes
from utils import get_logger, save_epoch_summary, get_datetime_str_simplified, get_datetime_str,\
     MyTransform, Sync_DataParallel,BalancedDataParallel,execute_replication_callbacks
from utils_torch import PolynomialLRDecay,confusion_matrix_gpu,IOU_cal
from torchvision.transforms.transforms import CenterCrop
from models.sync_batchnorm.replicate import patch_replication_callback
import sys
sys.path.append("src/models/DynamicRouting")

def confusion_matrix_gpu(y_true, y_pred, num_classes):
    cm = torch.zeros(num_classes, num_classes).to(y_pred.device)
    pt = torch.stack([y_true, y_pred],dim=1)
    idx,counter = pt.unique(return_counts=True,dim = 0)
    cm[idx[:,0],idx[:,1]] += counter.double()

    return cm



class Train:
    def __init__(self, args):
        self.args = args
        self.writer = SummaryWriter(
            log_dir=args.log_dir + args.name)
        self.logger = get_logger(__name__, args)
        self.label_names = ["Water", "Forest", "Field", "Others"]
        self.labels = [i for i in range(0, args.hr_nclasses)]

    def run(self):
        # prepare datasets and dataloaders for training and validation
        # disable shuffle in validation dataloader
        if self.args.cityscapes:
            ds_path = '/scratch/forest/datasets/cityscapes'
            self.training_ds = Cityscapes(root=ds_path, split='train', target_type='semantic', transforms=MyTransform(size=768, random_crop=True))
            self.validation_ds = Cityscapes(root=ds_path, split='val', target_type='semantic', transforms=MyTransform(size=0, random_crop=False))
            print('Train and Validation data load successful!!!')
        elif self.args.city:
            self.training_ds, self.validation_ds = get_city_datasets(self.args)
        elif self.args.isaid:
            self.training_ds, self.validation_ds = get_isaid_datasets(self.args)
        else:
            self.training_ds, self.validation_ds = get_datasets(self.args)

        training_dl = torch.utils.data.DataLoader(self.training_ds, batch_size=self.args.batch_size, shuffle=True,
                                                  num_workers=self.args.workers, drop_last=True)
        if self.args.cityscapes:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=1, shuffle=False,
                                                        num_workers=self.args.workers, drop_last=True)
        else:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=self.args.batch_size, shuffle=False,
                                                        num_workers=self.args.workers, drop_last=True)

        # device used to run experiments
        device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.device = device
        model = self.get_model()
        model = Sync_DataParallel(model)

        print("Let's use ", torch.cuda.device_count(), " GPUs!")

        if self.args.resume_model:
            checkpoint = torch.load(self.args.resume_model)
            model.load_state_dict(checkpoint["model_state_dict"])

        model.to(device)

        # construct an optimizer and a learning rate scheduler
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=self.args.learning_rate,  # 1e-2
                                    momentum=0.9, weight_decay=0.0005)

        if self.args.resume:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        self.lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=self.args.epochs*min(len(training_dl), self.args.training_steps_per_epoch),
                                              end_learning_rate=1e-4, power=0.9)

        # train the model for given epochs
        epoch = 0 if not self.args.resume else (checkpoint['epoch'] + 1)
        best_miou = 0
        while epoch < self.args.epochs:
            self.train_one_epoch(model, optimizer, training_dl, device, epoch)  # if use lr_scheduler, please set in here
            miou = self.evaluate(model, validation_dl, device, epoch)

            if miou > best_miou:
                best_miou = miou
                self.save_model(-1, self.args.name, model, optimizer)  # best model saved as model_state_dict_epoch_0.tar

            epoch += 1
        miou = self.evaluate(model, validation_dl, device, epoch)

        self.writer.flush()
        self.writer.close()

        try:
            self.final_evaluate(model, validation_dl, device)
        except Exception as e:
            print("Final Evaludaiton Failed.")
            print('Exception: ', e)

    def get_model(self):
        model = None
        if self.args.net_model == 'unet':
            model = models.UNet(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        # elif self.args.net_model == "deeplab":
        #     model = models.DeepLab(backbone="resnet", in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses, pretrained=False)

        else:
            model = getattr(models, self.args.net_model)(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)

        return model

    def train_one_epoch(self, model, optimizer, dataloader, device, epoch):
        self.logger.info(f"{'=' * 50}")
        self.logger.info(f"              Epoch {epoch + 1:3d}      Training")
        self.logger.info(f"{'=' * 50}")
        batch_num = min(len(dataloader), self.args.training_steps_per_epoch)
        if batch_num == 0:
            return

        model.train()
        self.one_epoch(model, optimizer, dataloader,
                       device, epoch, batch_num, "training")

        if (epoch + 1) % self.args.model_save_checkpoint == 0:
            self.save_model(epoch, self.args.name, model, optimizer)

    def evaluate(self, model, dataloader, device, epoch):
        self.logger.info(f"{'=' * 50}")
        self.logger.info(f"              Epoch {epoch + 1:3d}      Validation")
        self.logger.info(f"{'=' * 50}")
        batch_num = min(len(dataloader), self.args.validation_steps_per_epoch)
        if batch_num == 0:
            return 0
        with torch.no_grad():
            model.eval()
            jac = self.one_epoch(model, None, dataloader,
                                 device, epoch, batch_num, "validation")
        return jac

    def one_epoch(self, model, optimizer, dataloader, device, epoch, batch_num, task):
        assert task == "training" or task == "validation" or task == "final"
        
        if torch.cuda.device_count() > 1:
            data_device = torch.device('cuda:1')
            if task != 'training':
                device = data_device
        else:
            data_device = device

        ep_records = {
            "loss_to_use": [], "crossentropy": [], "miou": [], "accuracy": [],
            "start_time": get_datetime_str(),
            "confusion_matrix": torch.zeros((self.args.hr_nclasses, self.args.hr_nclasses)).to(data_device).double(),
        }

        visited_counter = [[] for _ in range(len(model.module.ts))]

        for batch_idx, (x, y_hr) in enumerate(iter(dataloader)):
            if not self.args.cityscapes:
                x = x.to(data_device)  # B, 4, H, W

            if self.args.cityscapes:
                mean = [123.675, 116.28, 103.53]
                std = [58.395, 57.12, 57.375]
                x_n = TF.normalize(x * 255.0, mean, std, inplace=False)
                x_n = x_n.to(data_device)
                outputs = model(x_n)
            else:
                outputs = model(x)
            B, C, H, W = x.shape
# if not training, all data calculate in data device.
            if task!='training':
                if isinstance(outputs,tuple):
                    p = outputs
                    outputs = []
                    for i in range(len(p)):
                        outputs.append(p[i].to(data_device))
                else:
                    outputs = outputs.to(data_device)
                device = data_device

            y_hr = y_hr.to(device).squeeze(dim=1).long().clamp(max=self.args.hr_nclasses - 1, min=0)  # B, H, W
            y_hr_onehot = torch.nn.functional.one_hot(y_hr, self.args.hr_nclasses).permute(0, 3, 1, 2)  # B, 5, H, W


            if self.args.net_model.startswith('DeepLabv3p_boost') or self.args.net_model.startswith('UNet_boost'):  # ,'UNet_boost_feacat']:
                if self.args.name_prefix.endswith('0_'):
                    o = outputs
                    y_hr_pred = o[-1]
                    y_hr_pred = F.interpolate(y_hr_pred, size=x.size()[2:], mode='bilinear', align_corners=True)

                    loss = loss_functions.kl_loss(y_hr_pred, y_hr_onehot)
                elif self.args.name_prefix.endswith('1_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = torch.ones(B,1,H,W).to(o[-1].device)
                    result = torch.zeros(B, o[-1].shape[1], H,W).to(data_device)
                    loss = 0
                    end_idx = len(model.module.ts)-1
                    
                    for idx, scale_f in enumerate([16,8,4,2,1][:end_idx+1]):
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=1, stride=1, padding=0).bool()

                        cur = F.upsample(o[idx], mode='bilinear', align_corners=True, scale_factor=scale_f)
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        if batch_idx % 25 == 0:
                            print(torch.sum(not_visited)/torch.numel(not_visited))


                        uncertain = torch.max(cur,dim = 1,keepdim = True)[0]
                        ts = model.module.ts
                        if batch_idx%10==0: #
                            sign = torch.argmax(cur,dim=1,keepdim = True)==torch.argmax(y_hr_onehot,dim=1,keepdim=True)
                            correct_dis = uncertain[sign*expand]
                            error_dis = uncertain[(~sign)*expand]
                            # loss += 10*F.l1_loss(correct_dis, torch.ones_like(correct_dis))+ 10*F.l1_loss(error_dis, torch.zeros_like(error_dis)) #*1.0/cur.shape[1]
                                #F.l1_loss(error_dis, torch.ones_like(error_dis)*1.0/cur.shape[1])
                            if (task == 'training') and len(correct_dis)>10:
                                model.module.ts[idx] = 0.9*model.module.ts[idx]+0.1*torch.quantile(correct_dis,q= 0.3).detach().cpu().numpy()
                                if idx ==0:
                                    print('The threshold is: ' , model.module.ts)
                            if epoch>10 and len(correct_dis)>10 and len(error_dis)>10:
                                self.writer.add_histogram(f"{task}{idx}-correct_dis", correct_dis, epoch + 1, bins='auto')
                                self.writer.add_histogram(f"{task}{idx}-error_dis", error_dis, epoch + 1, bins='auto')

                            # loss += 100*torch.mean(uncertain[(~sign)*expand*( uncertain >= ts[idx] )]- 1.0/cur.shape[1])

                        if idx == end_idx:
                            cur_visited = 1.
                            loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                        elif  idx < end_idx:
                            cur_visited = uncertain >= (ts[idx].to(uncertain.device))  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()
                            # loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                            
                            loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand) #+0.5*loss_functions.kl_loss_woSOFT(cur, target,sign= cur_visited*expand)
                        # cache for next step feature
                        result += ((not_visited*cur_visited).to(data_device))*(cur.detach().to(data_device))
                        if batch_idx==0 and epoch%10==0:
                            self.save_tensorboard_onehot_images(result,f'{task}:Idx {idx}',epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:break
                    y_hr_pred = result
                elif self.args.name_prefix.endswith('2_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = torch.ones(B,1,H,W).to(o[-1].device)
                    result = torch.zeros(B, o[-1].shape[1], H,W).to(data_device)
                    loss = 0
                    end_idx = len(model.module.ts)-1

                        
                    for idx, scale_f in enumerate([16,8,4,2,1][:end_idx+1]):
                        loss_rate = 1
                        expand = not_visited.float()
                        # expand = F.max_pool2d(expand.float(), kernel_size=3, stride=1, padding=1).bool()
                        expand = F.avg_pool2d(expand,kernel_size = scale_f,stride = scale_f).bool()
                        cur_target = F.avg_pool2d(target.float(),kernel_size = scale_f,stride = scale_f)
                        cur = o[idx]
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        if batch_idx % 25 == 0:
                            print(torch.sum(not_visited)/torch.numel(not_visited))

                        uncertain = torch.max(cur,dim = 1,keepdim = True)[0]
                        ts = model.module.ts
                        if batch_idx%10==0: #
                            sign = torch.argmax(cur,dim=1,keepdim = True)==torch.argmax((cur_target>0.99).float(),dim=1,keepdim=True)
                            correct_dis = uncertain[sign*expand]
                            error_dis = uncertain[(~sign)*expand]

                            if (task == 'training') and len(correct_dis)>10 and idx < end_idx:
                                model.module.ts[idx] = 0.9*model.module.ts[idx]+0.1*torch.quantile(correct_dis,q= 0.3).detach().cpu().numpy()
                                if idx ==0:
                                    print('The threshold is: ' , model.module.ts)
                            if epoch>10 and len(correct_dis)>10 and len(error_dis)>10:
                                self.writer.add_histogram(f"{task}{idx}-correct_dis", correct_dis, epoch + 1, bins='auto')
                                self.writer.add_histogram(f"{task}{idx}-error_dis", error_dis, epoch + 1, bins='auto')

                        if idx == end_idx:
                            cur_visited = 1
                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)
                        elif  idx < end_idx:
                            cur_visited = uncertain >= (ts[idx].to(uncertain.device))  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()
                            
                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)

                            cur_visited = F.upsample(cur_visited, mode='nearest',scale_factor=scale_f)
                        # cache for next step feature                            
                        cur = F.upsample(o[idx], mode='nearest',scale_factor=scale_f)
                        result += ((not_visited*cur_visited).to(data_device))*(cur.detach().to(data_device))

                        if batch_idx==0 and epoch%10==0:
                            self.save_tensorboard_onehot_images(result,f'{task}:Idx {idx}',epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:break
                    y_hr_pred = result
                elif self.args.name_prefix.endswith('3_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = torch.ones(B,1,H,W).to(o[-1].device)
                    result = torch.zeros(B, o[-1].shape[1], H,W).to(data_device)
                    loss = 0
                    end_idx = len(model.module.ts)-1
                    
                    for idx, scale_f in enumerate([16, 8, 4,2,1][:end_idx+1]):
                        expand = not_visited.float()
                        # expand = F.max_pool2d(expand.float(), kernel_size=3, stride=1, padding=1).bool()
                        expand = F.avg_pool2d(expand,kernel_size = scale_f,stride = scale_f).bool()
                        cur_target = F.avg_pool2d(target.float(),kernel_size = scale_f,stride = scale_f)
                        cur = o[idx]
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        if batch_idx % 25 == 0:
                            print(torch.sum(not_visited)/torch.numel(not_visited))
                        if idx == end_idx:
                            cur_visited = 1
                            loss += loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)
                        elif  idx < end_idx:
                            uncertain = torch.max(cur,dim = 1,keepdim = True)[0]
                            cur_ts = torch.quantile(uncertain,q= 0.3)

                            cur_visited = uncertain >= cur_ts  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()
                            
                            loss += loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)

                            cur_visited = F.upsample(cur_visited, mode='nearest',scale_factor=scale_f)
                        # cache for next step feature                            
                        cur = F.upsample(o[idx], mode='nearest',scale_factor=scale_f)
                        result += ((not_visited*cur_visited).to(data_device))*(cur.detach().to(data_device))

                        if batch_idx==0 and epoch%10==0:
                            self.save_tensorboard_onehot_images(result,f'{task}:Idx {idx}',epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:break
                    y_hr_pred = result
            else:
                y_hr_pred = outputs


            # with torch.no_grad():
            y_hr = y_hr.to(data_device)
            y_hr_onehot = y_hr_onehot.to(data_device)

            y_hr_pred = y_hr_pred.to(data_device)
            y_hr_pred_am = y_hr_pred.argmax(dim=1)
            y_hr_pred_onehot = torch.nn.functional.one_hot(y_hr_pred_am, self.args.hr_nclasses).permute(0, 3, 1, 2)


            ep_records['confusion_matrix'] += confusion_matrix_gpu(y_hr.view(-1), y_hr_pred_am.view(-1), num_classes=self.args.hr_nclasses)
            crossentropy_loss = loss_functions.crossentropy_loss(hard_mining=False)(y_hr_pred, y_hr)
            miou = loss_functions.miou(self.args.hr_nclasses)(y_hr_pred_onehot, y_hr_onehot)
            accuracy = loss_functions.accuracy()(y_hr_pred_onehot, y_hr_onehot)

            # determine the loss to be optimized
            if self.args.net_model.startswith('UNet_boost') or self.args.net_model.startswith('DeepLabv3p_boost'):
                loss_to_use = loss
            else:
                loss_to_use = crossentropy_loss

            if task == "training":
                optimizer.zero_grad()
                loss_to_use.backward()
                optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                
            if batch_idx%10==0:
                # print batch log in console
                log_str = f"Batch: {(batch_idx + 1):4d} / {batch_num:4d}    " \
                    f"Loss: {loss_to_use.item():.3f}    " \
                    f"CrossEntropy Loss: {crossentropy_loss.item():.3f}    " \
                    f"mIOU: {miou.item():.3f}    " \
                    f"Accuracy: {accuracy.item():.3f}" 

                # if task != "final":
                self.logger.info(log_str)

            # store loss records
            ep_records["loss_to_use"].append(loss_to_use.item())
            ep_records["crossentropy"].append(crossentropy_loss.item())
            ep_records["miou"].append(miou.item())
            ep_records["accuracy"].append(accuracy.item())

            # save tensorboard images in the first batch of each evaluation epoch
            if batch_idx == 0 and epoch % (max((self.args.epochs)//20, 1)) == 0:  # task == "validation" and
                self.save_tensorboard_images(x, y_hr, y_hr_pred, epoch, task)

                # if self.args.net_model == "unet-edge-implicit":
                #     self.save_tensorboard_images_edge(x, y_hr, y_hr_pred, edge, area, epoch)
                # elif self.args.net_model == "unet-arf":
                #     self.save_tensorboard_images_arf(x, y_hr, y_hr_pred, c, epoch)
                # elif self.args.net_model == "UNetWeight":
                #     self.save_tensorboard_images_weight(x, y_hr, y_hr_pred, w, epoch)
                # elif self.args.net_model in ["UNetScale", "ResUNetScale"]:
                #     self.save_tensorboard_images_scale(x, y_hr, y_hr_pred, [pred2, pred3, pred4, pred5], epoch)
                # elif self.args.net_model in ["UNetBoost", "ResUNetBoost"]:
                #     self.save_tensorboard_images_scale(x, y_hr, y_hr_pred, [o[3], o[2], o[1], o[0]], epoch)
                # else:
                #     self.save_tensorboard_images(x, y_hr, y_hr_pred, epoch, task)
                    # if idx == end_idx:
                    #     break
            # # visualize the attention in the first batch of each evaluation epoch
            # if self.args.net_model.endswith("-v") and task == "validation" and batch_idx == 0:
            #     # self.save_tensorboard_attention(x, y_hr, x_before, x_after, att, epoch)

            # terminate the epoch if reach the end of dataset or the given maximum steps
            if batch_idx + 1 >= batch_num:
                break


        # summarize epoch records
        ep_records["end_time"] = get_datetime_str()
        m_loss_to_use = np.mean(ep_records["loss_to_use"])
        m_crossentropy = np.mean(ep_records["crossentropy"])
        c_matrix = ep_records['confusion_matrix']

        iou,m_iou = IOU_cal(c_matrix)
        fg_iou,fg_m_iou = IOU_cal(c_matrix[1:, 1:])
        m_accuracy = (c_matrix.diag().sum() / c_matrix.sum()).item()
        
        visited_counter = np.array(visited_counter)
        print(np.mean(visited_counter, axis=1))

        log_str_1 = f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
            f'''Start at {ep_records["start_time"]}    ''' \
            f'''End at {ep_records["end_time"]}'''
        log_str_2 = f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
            f"Loss: {m_loss_to_use:.4f}    " \
            f"CrossEntropy Loss: {m_crossentropy:.4f}    " \
            f"mIOU: {m_iou:.4f}    " \
            f"Accuracy: {m_accuracy:.4f}    " \
            f"IOU for each category: {iou.data}"\
            f"fg_m_IOU: {fg_m_iou:.4f}    " \
            f"fg_IOU for each category: {fg_iou.data}"

            # f"m_city_jaccard: {m_city_jaccard:.4f}    " \
            # f"m_city_accuracy: {m_city_accuracy:.4f}   " \
            # f"m_remote_jaccard: {m_remote_jaccard:.4f}    " \
            # f"m_remote_accuracy: {m_remote_accuracy:.4f}   "
        self.logger.info(log_str_1)
        self.logger.info(log_str_2)

        save_str = f'''Epoch {epoch + 1:3d} - {ep_records["start_time"]} - {ep_records["end_time"]} - ''' \
            f"Loss: {m_loss_to_use:.4f}    " \
            f"CrossEntropy Loss: {m_crossentropy:.4f}    " \
            f"mIOU: {m_iou:.4f}    " \
            f"Accuracy: {m_accuracy:.4f}\n"
            # f"Remote Jaccard: {m_remote_jaccard:.4f}    " \
            # f"City Jaccard: {m_city_jaccard:.4f}    " \
            # f"Remote Accuracy: {m_remote_accuracy:.4f}    " \
            # f"City Accuracy: {m_city_accuracy:.4f}\n"
        save_epoch_summary(self.args, save_str)

        # save epoch summary to tensorboard
        self.save_tensorboard_scalars(m_loss_to_use, m_crossentropy, m_iou, m_accuracy, task=f"{task}-epoch", step=epoch + 1)
        # self.save_tensorboard_scalars(m_loss_to_use, m_crossentropy, m_jaccard, m_accuracy,
        #                               m_remote_jaccard, m_city_jaccard, m_remote_accuracy, m_city_accuracy,
        #                               f"{task}-epoch",
        #                               epoch + 1)

        return m_iou

    def save_tensorboard_scalars(self, loss_to_use, crossentropy_loss, miou, accuracy, task, step):
        self.writer.add_scalar(f"{task}/loss", loss_to_use, step)
        self.writer.add_scalar(f"{task}/crossentropy_loss", crossentropy_loss, step)
        self.writer.add_scalar(f"{task}/miou", miou, step)
        self.writer.add_scalar(f"{task}/accuracy", accuracy, step)

        # self.writer.add_scalar(f"{task}/3remote_jaccard", remote_jaccard, step)
        # self.writer.add_scalar(f"{task}/3city_jaccard", city_jaccard, step)
        # self.writer.add_scalar(f"{task}/3remote_accuracy", remote_accuracy, step)
        # self.writer.add_scalar(f"{task}/3city_accuracy", city_accuracy, step)

    def save_tensorboard_onehot_images(self,y,name,epoch):
        B,C,H,W = y.shape
        if y.shape[1] == 5:
            color_map = np.array([
                        [255,255,255],
                        [1, 39, 255],
                        [51, 160, 44],
                        [126, 237, 14],
                        [255, 1, 1]
                    ]) / 255.0
        else:
            color_map = []
            N = 255*255*255/20
            for i in range(0, y.shape[1]):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        y = F.one_hot(torch.argmax(y,dim = 1), len(color_map)).cpu().float()
        color_map = torch.from_numpy(color_map).float()
        im = np.array(torch.einsum('bhwc,cp->bhwp',y,color_map))
        im = im.reshape(B*H,W,-1)
        self.writer.add_image(name, im, epoch, dataformats='HWC')


    def save_tensorboard_images(self, x, y, y_pred, epoch, task='validation'):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255 / 20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        elif self.args.isaid:
            color_map = []
            N = 255*255*255 / 16
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        # )  # loss_functions.edge_inter_union()(y_hr_pred_onehot, y_hr_onehot)
        if self.args.cityscapes or self.args.isaid:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes or self.args.isaid:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error), axis=1)  # B*4*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*4*240*3
        img_concat = img_concat.reshape((-1, 4*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_scale(self, x, y, y_pred, preds, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
            # (self.args.epochs) // 10, 1)) == 0:  # task == "validation" and
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        preds = [i.softmax(dim=1).topk(k=2, dim=1)[0] for i in preds]
        preds = [i[:, 0:1, :, :].repeat(1, 3, 1, 1) for i in preds]
        # preds = [(i[:, 0:1, :, :] - i[:, 1:2, :, :]).repeat(1,3,1,1) for i in preds]
        pred2, pred3, pred4, pred5 = preds
        pred2 = F.interpolate(pred2, scale_factor=2, mode='nearest')
        pred3 = F.interpolate(pred3, scale_factor=4, mode='nearest')
        pred4 = F.interpolate(pred4, scale_factor=8, mode='nearest')
        pred5 = F.interpolate(pred5, scale_factor=16, mode='nearest')
        c = [pred2, pred3, pred4, pred5]

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        # c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[3], RF[2], RF[1], RF[0]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        # self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        # self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_arf(self, x, y, y_pred, c, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        c = [i.unsqueeze(dim=0) for i in c]
        c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[0], RF[1], RF[2], RF[3]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_weight(self, x, y, y_pred, w, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        w = w.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).unsqueeze(dim=1)
        w = w.detach().cpu().numpy()  # bs, 1, H, W

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, w), axis=1)  # B*5*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*5*W*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 240xB * 240x5 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_edge(self, x, y, y_pred, edge, area, epoch):
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred[:, 1:, :, :].argmax(dim=1).detach().cpu()
        edge = F.sigmoid(edge.detach()).cpu()
        area = area[:, 1:, :, :].argmax(dim=1).detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        label_color_map = np.array([color_map[i - 1] for i in y])
        pred_color_map = np.array([color_map[i] for i in y_pred])
        edge = edge.squeeze().numpy()[:, :, :, np.newaxis].repeat(3, axis=-1)
        area_color_map = np.array([color_map[i] for i in area])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        error_mask[:, :, :, 0][y_pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)
        edge = np.expand_dims(edge, axis=1)
        area_color_map = np.expand_dims(area_color_map, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, edge, area_color_map), axis=1)  # B*6*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*6*240*3
        img_concat = img_concat.reshape((-1, 6*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_layer(self, xs, epoch, task='validation'):
        '''
        xs: list
        '''
        xs = [i[:, :3] for i in xs]
        x = torch.stack(xs, dim=1).cpu().numpy()
        B, L, C, H, W = x.shape
        x = (x-np.min(x, axis=(3, 4), keepdims=True))
        x = (x-np.max(x, axis=(3, 4), keepdims=True))

        x = x.transpose(0, 2, 3, 1, 4).reshape(B*C*H, L*W, 1)
        self.writer.add_image(f"{task}-layer visualize", x, epoch + 1, dataformats='HWC')

    def save_tensorboard_onlyatt(self, att, epoch, task='validation'):
        # visualize attention histogram
        '''
        att: B*mh*ks2*H*W
        '''
        B, mh, _, H, W = att.shape
        att = att[:, :1].detach().cpu().numpy()  # select the first attention column
        att_max = np.argmax(att, axis=2)
        ks = int(att.shape[2]**0.5)
        att = att.reshape(B, ks, ks, H, W)
        att = att.transpose(0, 3, 1, 4, 2)
        att = att.reshape(B*H*ks, W*ks, 1)

        self.writer.add_histogram(f"{task}-attention-value", att, epoch + 1)
        self.writer.add_histogram(f"{task}-attention-argmax", att_max + 1, epoch + 1, bins='auto')
        self.writer.add_image(f"{task}-attention-kernel", att, epoch + 1, dataformats='HWC')

    def save_tensorboard_attention(self, x, y, x_before, x_after, att, epoch, task='validation'):

        def normalize(t):
            t = t - t.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0]
            t = t / t.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0]

            return t

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).unsqueeze(dim=1).cpu().numpy()  # B*1*240*240*3
        label_color_map = np.array([color_map[i - 1] for i in y.unsqueeze(dim=1).cpu()])  # B*1*240*240*3
        img_label = np.concatenate((images, label_color_map), axis=2)  # B*1*480*240*3

        x_before = normalize(x_before).unsqueeze(dim=-1)  # B*4*240*240*1
        x_after = normalize(x_after).unsqueeze(dim=-1)  # B*4*240*240*1
        x_before_after = torch.cat((x_before, x_after), dim=2).repeat(1, 1, 1, 1, 3).cpu().numpy()  # B*4*480*240*3

        img_concat = np.concatenate((img_label, x_before_after), axis=1)  # B*5*480*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*480*5*240*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 480xB * 240x5 * 3

        self.writer.add_image(f"{task}-attention", img_concat, epoch + 1, dataformats='HWC')

        # visualize attention histogram
        self.save_tensorboard_onlyatt(att, epoch=epoch, task=task)

    def save_tensorboard_cm(self, cm, epoch):
        cm_s = np.sum(cm, axis=1, keepdims=True) + 1e-6
        cm = cm / cm_s
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.label_names)
        disp.plot(cmap=plt.cm.Blues, values_format=".3g")
        fig = disp.figure_
        fig.set_dpi(150)
        fig.set_size_inches(4, 4)
        self.writer.add_figure(f"validation-confusion-matrix", fig, epoch + 1)

    def save_model(self, epoch, name, model, optimizer):
        # args.output_path is a Path object
        path = (self.args.output_path /
                Path(f"model_state_dict_epoch_{epoch + 1}.tar"))
        d = {
            'epoch': epoch,
            'name': name,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }
        if hasattr(self, "anchors"):
            d['anchors'] = self.anchors
        torch.save(d, path.__str__())
        self.logger.info(f"Model checkpoint saved at {str(path)}")

    def final_evaluate(self, model, dataloader, device):
        del model

        model = self.get_model()
        model = Sync_DataParallel(model)

        path = (self.args.output_path /
                Path(f"model_state_dict_epoch_0.tar"))
        checkpoint = torch.load(path)

        model.load_state_dict(checkpoint["model_state_dict"])
        model.to(self.device)

        self.logger.info(f"{'=' * 50}")
        self.logger.info(f"              Final      Evaluation")
        self.logger.info(f"{'=' * 50}")

        batch_num = len(dataloader)
        with torch.no_grad():
            model.eval()
            self.one_epoch(model, None, dataloader, self.device, -1, batch_num, "final")
