import matplotlib.pyplot as plt
from numpy.core.numeric import cross
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.functional import align_tensors
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch.nn.functional as F
import torchvision.transforms.functional as TF


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn

import torch
import random
import numpy as np
import gc
from pathlib import Path
import loss_functions
import models
from dataset import get_datasets
from dataset_hdf5 import get_hdf5_datasets
from dataset_city import get_city_datasets
from dataset_isaid import get_isaid_datasets
from dataset_cityscapes import Cityscapes
from utils import get_logger, save_epoch_summary, get_datetime_str_simplified, get_datetime_str,\
     MyTransform, Sync_DataParallel,BalancedDataParallel,execute_replication_callbacks
from utils_torch import PolynomialLRDecay,confusion_matrix_gpu,IOU_cal
from torchvision.transforms.transforms import CenterCrop
from models.sync_batchnorm.replicate import patch_replication_callback
import sys
sys.path.append("src/models/DynamicRouting")



class Train:
    def __init__(self, args):
        self.args = args
        self.args.nprocs = torch.cuda.device_count()
        self.writer = SummaryWriter(
            log_dir=args.log_dir + args.name)
        self.logger = get_logger(f"DDP train LOGING: {self.args.local_rank}", args) ###get_logger(__name__, args)
        # while self.logger.handlers:
        #     self.logger.handlers.pop()
        self.label_names = ["Water", "Forest", "Field", "Others"]
        self.labels = [i for i in range(0, args.hr_nclasses)]
        self._fix_random_seed()

    def _fix_random_seed(self,):
        random.seed(0)
        torch.manual_seed(0)
        cudnn.benchmark = True
        cudnn.deterministic = True


    #### Only in process:0 (main process), the logger or others save operation is triggered.
    def _if(self,sign=True):
        return sign and self.args.local_rank ==0 # or self.args.local_rank ==-1)


    def get_model(self):
        model = None
        if self.args.net_model == 'unet':
            model = models.UNet(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        else:
            model = getattr(models, self.args.net_model)(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)

        return model
    

    def logger_info(self,*string):
        if self._if():
            for i in string:self.logger.info(i) #print(i) #

    #reduce the information from children progress.
    def reduce_mean(self,tensor):
        nprocs=self.args.nprocs
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= nprocs
        return rt

    #reduce the information from children progress.
    def reduce_sum(self,tensor):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        return rt


    def run(self):
        # prepare datasets and dataloaders for training and validation
        # disable shuffle in validation dataloader
        if self.args.cityscapes:
            ds_path = '/scratch/forest/datasets/cityscapes'
            self.training_ds = Cityscapes(root=ds_path, split='train', target_type='semantic', transforms=MyTransform(size=768, random_crop=True))
            self.validation_ds = Cityscapes(root=ds_path, split='val', target_type='semantic', transforms=MyTransform(size=0, random_crop=False))
            print('Train and Validation data load successful!!!')
        elif self.args.city:
            self.training_ds, self.validation_ds = get_city_datasets(self.args)
        elif self.args.isaid:
            self.training_ds, self.validation_ds = get_isaid_datasets(self.args)
        else:
            self.training_ds, self.validation_ds = get_datasets(self.args)

        ##########################
        #Distributed Training 
        #Step1: initlize
        dist.init_process_group(backend='nccl')
        torch.cuda.set_device(self.args.local_rank)

        ##########################
        ##Step2: Distributed Sampler and Dataloader
        train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_ds,shuffle = True)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(self.validation_ds,shuffle = False)

        training_dl = torch.utils.data.DataLoader(self.training_ds, batch_size=self.args.batch_size,sampler=train_sampler,
                                                  num_workers=self.args.workers, pin_memory=True,)
        if self.args.cityscapes:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=1,sampler=valid_sampler,
                                                        num_workers=self.args.workers, pin_memory=True,)
        else:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=self.args.batch_size, shuffle=False,sampler=valid_sampler,
                                                        num_workers=self.args.workers, pin_memory=True,)

        ##########################
        ##Step3: Model load and cover Sync BN
        model = self.get_model()
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()
        
        ##########################
        ##Step4: Distributed Model
        ##local rank means: the children process is in which GPU-*
        model = DistributedDataParallel(model,device_ids=[self.args.local_rank])
        if self.args.resume_model:
            checkpoint = torch.load(self.args.resume_model)
            model.load_state_dict(checkpoint["model_state_dict"])


        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=self.args.learning_rate,  # 1e-2
                                    momentum=0.9, weight_decay=0.0005)

        self.lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=self.args.epochs*min(len(training_dl), self.args.training_steps_per_epoch),
                                              end_learning_rate=1e-4, power=0.9)

        # train the model for given epochs
        epoch = 0 if not self.args.resume else (checkpoint['epoch'] + 1)
        best_miou = 0
        while epoch < self.args.epochs:
            self.train_one_epoch(model, optimizer, training_dl, epoch)  # if use lr_scheduler, please set in here

            #### valid all to get a stable and precise/accurate score
            if self.args.valid_all:
                if epoch%5 == 0:
                    miou = self.evaluate(model, validation_dl, epoch,batch_num = max(len(validation_dl), self.args.validation_steps_per_epoch))
            else:
                miou = self.evaluate(model, validation_dl, epoch,batch_num = min(len(validation_dl), self.args.validation_steps_per_epoch))

            if self._if(miou > best_miou):
                best_miou = miou
                self.save_model(-1, self.args.name, model, optimizer)  # best model saved as model_state_dict_epoch_0.tar

            epoch += 1

        self.writer.flush()
        self.writer.close()

        try:
            self.final_evaluate(model, validation_dl)
        except Exception as e:
            print("Final Evaludaiton Failed.")
            print('Exception: ', e)

    def train_one_epoch(self, model, optimizer, dataloader, epoch):
        batch_num = min(len(dataloader), self.args.training_steps_per_epoch)
        if batch_num == 0:
            return

        model.train()
        self.one_epoch(model, optimizer, dataloader,
                       epoch, batch_num, "training")

        if self._if((epoch + 1) % self.args.model_save_checkpoint == 0):
            self.save_model(epoch, self.args.name, model, optimizer)

    def evaluate(self, model, dataloader, epoch,batch_num):
        if batch_num == 0:
            return 0
        with torch.no_grad():
            model.eval()
            jac = self.one_epoch(model, None, dataloader,
                                 epoch, batch_num, "validation")
        return jac

    def final_evaluate(self, model, dataloader):
        #### Get model and cover syncBN
        model = self.get_model()
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()

        ## push data paralled and load model state dict.
        model = DistributedDataParallel(model,device_ids=[self.args.local_rank])
        checkpoint = torch.load((self.args.output_path /
                Path(f"model_state_dict_epoch_0.tar")))
        model.load_state_dict(checkpoint["model_state_dict"])
        model.cuda()

        batch_num = len(dataloader)
        with torch.no_grad():
            model.eval()
            self.one_epoch(model, None, dataloader, -1, batch_num, "final")


    def one_epoch(self, model, optimizer, dataloader, epoch, batch_num, task):
        assert task == "training" or task == "validation" or task == "final"
        self.logger_info(f"\n{'=' * 50}\n               Epoch {epoch + 1:3d}      {task}\n{'=' * 50}\n ")

        ep_records = {
            "loss_to_use": [], "crossentropy": [], "miou": [], "accuracy": [],
            "start_time": get_datetime_str(),
            "confusion_matrix": torch.zeros((self.args.hr_nclasses, self.args.hr_nclasses)).cuda().double(),
        }

        visited_counter = [[] for _ in range(len(model.module.ts))]

        for batch_idx, (x, y_hr) in enumerate(iter(dataloader)):
            if not self.args.cityscapes:
                x = x.cuda(non_blocking=True)  # B, 4, H, W

            if self.args.cityscapes:
                mean = [123.675, 116.28, 103.53]
                std = [58.395, 57.12, 57.375]
                x_n = TF.normalize(x * 255.0, mean, std, inplace=False)
                x_n = x_n.cuda(non_blocking=True) 
                outputs = model(x_n)
            else:
                outputs = model(x)
            B, C, H, W = x.shape


            y_hr = y_hr.cuda(non_blocking=True).squeeze(dim=1).long().clamp(max=self.args.hr_nclasses - 1, min=0)  # B, H, W
            y_hr_onehot = torch.nn.functional.one_hot(y_hr, self.args.hr_nclasses).permute(0, 3, 1, 2)  # B, 5, H, W

            #Loss calculate, for model trarning.
            if self.args.net_model.startswith('DeepLabv3p_boost') or self.args.net_model.startswith('UNet_boost'):  # ,'UNet_boost_feacat']:
                if self.args.name_prefix.endswith('0_'):
                    o = outputs
                    y_hr_pred = o[-1]
                    y_hr_pred = F.interpolate(y_hr_pred, size=x.size()[2:], mode='bilinear', align_corners=True)

                    loss = loss_functions.kl_loss(y_hr_pred, y_hr_onehot)
                elif self.args.name_prefix.endswith('1_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = torch.ones(B,1,H,W).cuda()
                    result = torch.zeros(B, o[-1],H,W).cuda()
                    loss = 0
                    end_idx = len(model.module.ts)-1
                    
                    for idx, scale_f in enumerate([16,8,4,2,1][:end_idx+1]):
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=1, stride=1, padding=0).bool()

                        cur = F.upsample(o[idx], mode='bilinear', align_corners=True, scale_factor=scale_f)
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        if self._if(batch_idx % 25 == 0):
                            self.logger_info(torch.sum(not_visited)/torch.numel(not_visited))


                        uncertain = torch.max(cur,dim = 1,keepdim = True)[0]
                        ts = model.module.ts
                        if batch_idx%10==0: #
                            sign = torch.argmax(cur,dim=1,keepdim = True)==torch.argmax(y_hr_onehot,dim=1,keepdim=True)
                            correct_dis = uncertain[sign*expand]
                            error_dis = uncertain[(~sign)*expand]
                            # loss += 10*F.l1_loss(correct_dis, torch.ones_like(correct_dis))+ 10*F.l1_loss(error_dis, torch.zeros_like(error_dis)) #*1.0/cur.shape[1]
                                #F.l1_loss(error_dis, torch.ones_like(error_dis)*1.0/cur.shape[1])
                            if (task == 'training'):
                                if len(correct_dis)>10:
                                    quantile_point=torch.quantile(correct_dis,q= 0.3)
                                else:
                                    quantile_point = model.module.ts[idx]
                                reduce_q_p = self.reduce_mean(quantile_point)
                                model.module.ts[idx] = 0.9*model.module.ts[idx]+0.1*reduce_q_p.detach().cpu().numpy()
                                if idx ==0:
                                    self.logger_info(f'The threshold is: {model.module.ts.cpu().numpy()}')
                            if self._if(epoch>10 and len(correct_dis)>10 and len(error_dis)>10):
                                self.writer.add_histogram(f"{task}{idx}-correct_dis", correct_dis, epoch + 1, bins='auto')
                                self.writer.add_histogram(f"{task}{idx}-error_dis", error_dis, epoch + 1, bins='auto')

                            # loss += 100*torch.mean(uncertain[(~sign)*expand*( uncertain >= ts[idx] )]- 1.0/cur.shape[1])

                        if idx == end_idx:
                            cur_visited = 1.
                            loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                        elif  idx < end_idx:
                            cur_visited = uncertain >= (ts[idx].to(uncertain.device))  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()
                            # loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                            
                            loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand) #+0.5*loss_functions.kl_loss_woSOFT(cur, target,sign= cur_visited*expand)
                        # cache for next step feature
                        result += (not_visited*cur_visited)*cur.detach()
                        if self._if(batch_idx==0 and epoch%10==0):
                            self.save_tensorboard_onehot_images(result,f'{task}:Idx {idx}',epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:break
                    y_hr_pred = result
                elif self.args.name_prefix.endswith('2_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = torch.ones(B,1,H,W).cuda()
                    result = torch.zeros(B, o[-1].shape[1], H,W).cuda()
                    loss = 0
                    end_idx = len(model.module.ts)-1

                        
                    for idx, scale_f in enumerate([16,8,4,2,1][:end_idx+1]):
                        loss_rate = 1
                        expand = not_visited.float()
                        # expand = F.max_pool2d(expand.float(), kernel_size=3, stride=1, padding=1).bool()
                        expand = F.avg_pool2d(expand,kernel_size = scale_f,stride = scale_f).bool()
                        cur_target = F.avg_pool2d(target.float(),kernel_size = scale_f,stride = scale_f)
                        cur = o[idx]
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        if self._if(batch_idx % 25 == 0):
                            self.logger_info(torch.sum(not_visited)/torch.numel(not_visited))

                        uncertain = torch.max(cur,dim = 1,keepdim = True)[0]
                        ts = model.module.ts
                        if batch_idx%10==0: #
                            sign = torch.argmax(cur,dim=1,keepdim = True)==torch.argmax((cur_target>0.99).float(),dim=1,keepdim=True)
                            correct_dis = uncertain[sign*expand]
                            error_dis = uncertain[(~sign)*expand]

                            if (task == 'training') and idx < end_idx:
                                if len(correct_dis)>10:
                                    quantile_point=torch.quantile(correct_dis,q= 0.3)
                                else:
                                    quantile_point = model.module.ts[idx]

                                quantile_point=torch.quantile(correct_dis,q= 0.3)
                                reduce_q_p = self.reduce_mean(quantile_point)
                                model.module.ts[idx] = 0.9*model.module.ts[idx]+0.1*reduce_q_p.detach().cpu().numpy()
                                if idx ==0:
                                    self.logger_info(f'The threshold is: {model.module.ts.cpu().numpy()}')
                            if self._if(epoch>10 and len(correct_dis)>10 and len(error_dis)>10):
                                self.writer.add_histogram(f"{task}{idx}-correct_dis", correct_dis, epoch + 1, bins='auto')
                                self.writer.add_histogram(f"{task}{idx}-error_dis", error_dis, epoch + 1, bins='auto')

                        if idx == end_idx:
                            cur_visited = 1
                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)
                        elif  idx < end_idx:
                            cur_visited = uncertain >= (ts[idx].to(uncertain.device))  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()
                            
                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)

                            cur_visited = F.upsample(cur_visited, mode='nearest',scale_factor=scale_f)
                        # cache for next step feature                            
                        cur = F.upsample(o[idx], mode='nearest',scale_factor=scale_f)
                        result += ((not_visited*cur_visited).cuda())*(cur.detach().cuda())

                        if self._if(batch_idx==0 and epoch%10==0):
                            self.save_tensorboard_onehot_images(result,f'{task}:Idx {idx}',epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:break
                    y_hr_pred = result
                
            else:
                y_hr_pred = outputs
                loss = loss_functions.crossentropy_loss(hard_mining=False)(y_hr_pred, y_hr)
            
            #SYNC multiple process!!!!
            torch.distributed.barrier()

            if task == "training":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
            

            ##############################################################################################
            ### Evaluate function (in train mode, remove grad calculate to fast the speed)################
            ##############################################################################################
            with torch.no_grad():
                y_hr = y_hr.cuda()
                y_hr_onehot = y_hr_onehot.cuda()

                y_hr_pred = y_hr_pred.cuda()
                y_hr_pred_am = y_hr_pred.argmax(dim=1)
                y_hr_pred_onehot = torch.nn.functional.one_hot(y_hr_pred_am, self.args.hr_nclasses).permute(0, 3, 1, 2)

                cf = confusion_matrix_gpu(y_hr.view(-1), y_hr_pred_am.view(-1), num_classes=self.args.hr_nclasses)
                crossentropy = loss_functions.crossentropy_loss(hard_mining=False)(y_hr_pred, y_hr)
                miou = loss_functions.miou(self.args.hr_nclasses)(y_hr_pred_onehot, y_hr_onehot)
                accuracy = loss_functions.accuracy()(y_hr_pred_onehot, y_hr_onehot)
                
                #### reduce multiple loss from multile process
                reduce_cf = self.reduce_sum(cf)
                ep_records['confusion_matrix'] += reduce_cf
                reduce_crossentropy = self.reduce_mean(crossentropy)
                reduce_miou = self.reduce_mean(miou)
                reduce_accuracy = self.reduce_mean(accuracy)


            if batch_idx%10==0:
                # print batch log in console
                log_str = f"Batch: {(batch_idx + 1):4d} / {batch_num:4d}    " \
                    f"Loss: {loss.item():.3f}    " \
                    f"CrossEntropy Loss: {reduce_crossentropy.item():.3f}    " \
                    f"mIOU: {reduce_miou.item():.3f}    " \
                    f"Accuracy: {reduce_accuracy.item():.3f}" 

                self.logger_info(log_str)

            # store loss records
            # the IOU or ACC or cross need to reduce from all process so as to construct ALL DATA
            ep_records["loss_to_use"].append(loss.item())
            ep_records["crossentropy"].append(reduce_crossentropy.item())
            ep_records["miou"].append(reduce_miou.item())
            ep_records["accuracy"].append(reduce_accuracy.item())

            # save tensorboard images in the first batch of each evaluation epoch
            if self._if(batch_idx == 0 and epoch % (max((self.args.epochs)//20, 1)) == 0):  # task == "validation" and
                self.save_tensorboard_images(x, y_hr, y_hr_pred, epoch, task)
            if batch_idx + 1 >= batch_num:
                break
        


        # summarize epoch records
        ep_records["end_time"] = get_datetime_str()
        m_loss_to_use = np.mean(ep_records["loss_to_use"])
        m_crossentropy = np.mean(ep_records["crossentropy"])

        cf_matrix = ep_records['confusion_matrix'] 
        iou,m_iou = IOU_cal(cf_matrix)
        fg_iou,fg_m_iou = IOU_cal(cf_matrix[1:, 1:])
        m_accuracy = (cf_matrix.diag().sum() / cf_matrix.sum()).item()
        
        visited_counter = np.array(visited_counter)

        log_str_1 = f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
                    f"Start at {ep_records['start_time']}    " \
                    f"End at {ep_records['end_time']} \n"\
                    f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
                    f"Loss: {m_loss_to_use:.4f}    " \
                    f"CrossEntropy Loss: {m_crossentropy:.4f}    \n" \
                    f"Accuracy: {m_accuracy:.4f}    \n" \
                    f"mIOU: {m_iou:.4f}    " \
                    f"IOU for each category: {iou.data}\n"\
                    f"fg_mIOU: {fg_m_iou:.4f}    " \
                    f"fg_IOU for each category: {fg_iou.data} \n 1"\
                    f"visited counter in different stage: {np.mean(visited_counter, axis=1)}"

        self.logger_info(log_str_1)

        save_str =  f'''Epoch {epoch + 1:3d} - {ep_records["start_time"]} - {ep_records["end_time"]} - ''' \
                    f"Loss: {m_loss_to_use:.4f}    " \
                    f"CrossEntropy Loss: {m_crossentropy:.4f}    " \
                    f"mIOU: {m_iou:.4f}    " \
                    f"Accuracy: {m_accuracy:.4f}\n"

        
        if self._if():
            save_epoch_summary(self.args, save_str)
            # save epoch summary to tensorboard
            self.save_tensorboard_scalars(m_loss_to_use, m_crossentropy, m_iou, m_accuracy, task=f"{task}-epoch", step=epoch + 1)

        return m_iou
    
    
    def save_tensorboard_scalars(self, loss_to_use, crossentropy_loss, miou, accuracy, task, step):
        
        self.writer.add_scalar(f"{task}/loss", loss_to_use, step)
        self.writer.add_scalar(f"{task}/crossentropy_loss", crossentropy_loss, step)
        self.writer.add_scalar(f"{task}/miou", miou, step)
        self.writer.add_scalar(f"{task}/accuracy", accuracy, step)

        # self.writer.add_scalar(f"{task}/3remote_jaccard", remote_jaccard, step)
        # self.writer.add_scalar(f"{task}/3city_jaccard", city_jaccard, step)
        # self.writer.add_scalar(f"{task}/3remote_accuracy", remote_accuracy, step)
        # self.writer.add_scalar(f"{task}/3city_accuracy", city_accuracy, step)

    def save_tensorboard_onehot_images(self,y,name,epoch):
        
        B,C,H,W = y.shape
        if y.shape[1] == 5:
            color_map = np.array([
                        [255,255,255],
                        [1, 39, 255],
                        [51, 160, 44],
                        [126, 237, 14],
                        [255, 1, 1]
                    ]) / 255.0
        else:
            color_map = []
            N = 255*255*255/20
            for i in range(0, y.shape[1]):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        y = F.one_hot(torch.argmax(y,dim = 1), len(color_map)).cpu().float()
        color_map = torch.from_numpy(color_map).float()
        im = np.array(torch.einsum('bhwc,cp->bhwp',y,color_map))
        im = im.reshape(B*H,W,-1)
        self.writer.add_image(name, im, epoch, dataformats='HWC')


    def save_tensorboard_images(self, x, y, y_pred, epoch, task='validation'):
        
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255 / 20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        elif self.args.isaid:
            color_map = []
            N = 255*255*255 / 16
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        # )  # loss_functions.edge_inter_union()(y_hr_pred_onehot, y_hr_onehot)
        if self.args.cityscapes or self.args.isaid:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes or self.args.isaid:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error), axis=1)  # B*4*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*4*240*3
        img_concat = img_concat.reshape((-1, 4*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_scale(self, x, y, y_pred, preds, epoch):
        
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
            # (self.args.epochs) // 10, 1)) == 0:  # task == "validation" and
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        preds = [i.softmax(dim=1).topk(k=2, dim=1)[0] for i in preds]
        preds = [i[:, 0:1, :, :].repeat(1, 3, 1, 1) for i in preds]
        # preds = [(i[:, 0:1, :, :] - i[:, 1:2, :, :]).repeat(1,3,1,1) for i in preds]
        pred2, pred3, pred4, pred5 = preds
        pred2 = F.interpolate(pred2, scale_factor=2, mode='nearest')
        pred3 = F.interpolate(pred3, scale_factor=4, mode='nearest')
        pred4 = F.interpolate(pred4, scale_factor=8, mode='nearest')
        pred5 = F.interpolate(pred5, scale_factor=16, mode='nearest')
        c = [pred2, pred3, pred4, pred5]

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        # c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[3], RF[2], RF[1], RF[0]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        # self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        # self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_arf(self, x, y, y_pred, c, epoch):
        
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        c = [i.unsqueeze(dim=0) for i in c]
        c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[0], RF[1], RF[2], RF[3]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_weight(self, x, y, y_pred, w, epoch):
        
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        w = w.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).unsqueeze(dim=1)
        w = w.detach().cpu().numpy()  # bs, 1, H, W

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, w), axis=1)  # B*5*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*5*W*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 240xB * 240x5 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_edge(self, x, y, y_pred, edge, area, epoch):
        
        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred[:, 1:, :, :].argmax(dim=1).detach().cpu()
        edge = F.sigmoid(edge.detach()).cpu()
        area = area[:, 1:, :, :].argmax(dim=1).detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        label_color_map = np.array([color_map[i - 1] for i in y])
        pred_color_map = np.array([color_map[i] for i in y_pred])
        edge = edge.squeeze().numpy()[:, :, :, np.newaxis].repeat(3, axis=-1)
        area_color_map = np.array([color_map[i] for i in area])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        error_mask[:, :, :, 0][y_pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)
        edge = np.expand_dims(edge, axis=1)
        area_color_map = np.expand_dims(area_color_map, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, edge, area_color_map), axis=1)  # B*6*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*6*240*3
        img_concat = img_concat.reshape((-1, 6*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_layer(self, xs, epoch, task='validation'):
        '''
        xs: list
        '''
        
        xs = [i[:, :3] for i in xs]
        x = torch.stack(xs, dim=1).cpu().numpy()
        B, L, C, H, W = x.shape
        x = (x-np.min(x, axis=(3, 4), keepdims=True))
        x = (x-np.max(x, axis=(3, 4), keepdims=True))

        x = x.transpose(0, 2, 3, 1, 4).reshape(B*C*H, L*W, 1)
        self.writer.add_image(f"{task}-layer visualize", x, epoch + 1, dataformats='HWC')

    def save_tensorboard_onlyatt(self, att, epoch, task='validation'):
        # visualize attention histogram
        '''
        att: B*mh*ks2*H*W
        '''
        
        B, mh, _, H, W = att.shape
        att = att[:, :1].detach().cpu().numpy()  # select the first attention column
        att_max = np.argmax(att, axis=2)
        ks = int(att.shape[2]**0.5)
        att = att.reshape(B, ks, ks, H, W)
        att = att.transpose(0, 3, 1, 4, 2)
        att = att.reshape(B*H*ks, W*ks, 1)

        self.writer.add_histogram(f"{task}-attention-value", att, epoch + 1)
        self.writer.add_histogram(f"{task}-attention-argmax", att_max + 1, epoch + 1, bins='auto')
        self.writer.add_image(f"{task}-attention-kernel", att, epoch + 1, dataformats='HWC')

    def save_tensorboard_attention(self, x, y, x_before, x_after, att, epoch, task='validation'):
        

        def normalize(t):
            t = t - t.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0]
            t = t / t.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0]

            return t

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).unsqueeze(dim=1).cpu().numpy()  # B*1*240*240*3
        label_color_map = np.array([color_map[i - 1] for i in y.unsqueeze(dim=1).cpu()])  # B*1*240*240*3
        img_label = np.concatenate((images, label_color_map), axis=2)  # B*1*480*240*3

        x_before = normalize(x_before).unsqueeze(dim=-1)  # B*4*240*240*1
        x_after = normalize(x_after).unsqueeze(dim=-1)  # B*4*240*240*1
        x_before_after = torch.cat((x_before, x_after), dim=2).repeat(1, 1, 1, 1, 3).cpu().numpy()  # B*4*480*240*3

        img_concat = np.concatenate((img_label, x_before_after), axis=1)  # B*5*480*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*480*5*240*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 480xB * 240x5 * 3

        self.writer.add_image(f"{task}-attention", img_concat, epoch + 1, dataformats='HWC')

        # visualize attention histogram
        self.save_tensorboard_onlyatt(att, epoch=epoch, task=task)

    def save_tensorboard_cm(self, cm, epoch):
        

        cm_s = np.sum(cm, axis=1, keepdims=True) + 1e-6
        cm = cm / cm_s
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.label_names)
        disp.plot(cmap=plt.cm.Blues, values_format=".3g")
        fig = disp.figure_
        fig.set_dpi(150)
        fig.set_size_inches(4, 4)
        self.writer.add_figure(f"validation-confusion-matrix", fig, epoch + 1)

    def save_model(self, epoch, name, model, optimizer):
        # args.output_path is a Path object
        

        path = (self.args.output_path /
                Path(f"model_state_dict_epoch_{epoch + 1}.tar"))
        d = {
            'epoch': epoch,
            'name': name,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }
        if hasattr(self, "anchors"):
            d['anchors'] = self.anchors
        torch.save(d, path.__str__())
        self.logger.info(f"Model checkpoint saved at {str(path)}")

    


# import argparse
# import os
# import random
# import shutil
# import time
# import warnings

# import torch
# import torch.nn as nn
# import torch.nn.parallel
# import torch.backends.cudnn as cudnn
# import torch.distributed as dist
# import torch.optim
# import torch.multiprocessing as mp
# import torch.utils.data
# import torch.utils.data.distributed
# import torchvision.transforms as transforms
# import torchvision.datasets as datasets
# import torchvision.models as models

# # from apex import amp
# # from apex.parallel import DistributedDataParallel
# # from apex.parallel import convert_syncbn_model

# cudnn.benchmark = True
# cudnn.deterministic = True
# random.seed(0)
# torch.manual_seed(0)

# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', default=-1, type=int,
#                     help='node rank for distributed training')
# args = parser.parse_args()


# def reduce_mean(tensor, nprocs):
#     rt = tensor.clone()
#     dist.all_reduce(rt, op=dist.ReduceOp.SUM)
#     rt /= nprocs
#     return rt

# class Dataset(torch.utils.data.Dataset):
#     def __init__(self,):
#         super().__init__()
#         self.X = torch.arange(100).reshape(50,2,1).float()
#         self.Y = torch.zeros(50,2)
        
#     def __getitem__(self,idx):
#         return self.X[idx],self.Y[idx]
    
#     def __len__(self,):
#         return len(self.X)
    

# dist.init_process_group(backend='nccl')
# torch.cuda.set_device(args.local_rank)

# train_dataset = Dataset()
# train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle = False)
# train_loader = torch.utils.data.DataLoader(train_dataset,
#                                            batch_size=2,
#                                            shuffle=(train_sampler is None),
#                                            num_workers=2,
#                                            pin_memory=True,
#                                            sampler=train_sampler)




# model = nn.Sequential(nn.Linear(1,50),nn.BatchNorm1d(2),nn.ReLU(),nn.Linear(50,2))
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()

# criterion = nn.L1Loss().cuda()
# # model = DistributedDataParallel(model)
# model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank])
# optimizer = torch.optim.SGD(model.parameters(), 1e-3)


# # model, optimizer = amp.initialize(model, optimizer,opt_level="O0")
# res = 0
# model.train()
# for epoch in range(10):
#     for batch_idx, (images, target) in enumerate(train_loader):
#         print(images)
#         images = images.cuda(non_blocking=True)
#         target = target.cuda(non_blocking=True)

#         output = model(images)
#         loss = criterion(output, target)
#         print(loss.item())
#         torch.distributed.barrier()
        
#         optimizer.zero_grad()
# #         with amp.scale_loss(loss, optimizer) as scaled_loss:
#         loss.backward()
#         optimizer.step()
#         res += 1
# print(res)