import matplotlib.pyplot as plt
from numpy.core.numeric import cross
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.functional import align_tensors
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch.nn.functional as F
import torchvision.transforms.functional as TF


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn

import torch
import random
import numpy as np
import gc
from pathlib import Path
import loss_functions
import models
from dataset import get_datasets
from dataset_hdf5 import get_hdf5_datasets
from dataset_city import get_city_datasets
from dataset_isaid import get_isaid_datasets
from dataset_cityscapes import Cityscapes
from dataset_isprs import Vaihingen_Dataset,PotsDam_Dataset

from utils import get_logger, save_epoch_summary, get_datetime_str_simplified, get_datetime_str,\
     Sync_DataParallel,BalancedDataParallel,execute_replication_callbacks,Logger_Lin
import ext_transforms as et
from utils_torch import PolynomialLRDecay,confusion_matrix_gpu,IOU_cal

from torchvision.transforms.transforms import CenterCrop
from models.sync_batchnorm.replicate import patch_replication_callback
import sys
sys.path.append("src/models/DynamicRouting")


class Train:
    def __init__(self, args):
        self.args = args
        self.args.nprocs = torch.cuda.device_count()
        self.writer = SummaryWriter(
            log_dir=args.log_dir + args.name)
        self.logger = Logger_Lin(f"DDP train: ", args)  # get_logger(__name__, self.args) # ###
        self.label_names = ["Water", "Forest", "Field", "Others"]
        self.labels = [i for i in range(0, args.hr_nclasses)]
        self._fix_random_seed()

    def _fix_random_seed(self,):
        random.seed(0)
        torch.manual_seed(0)
        cudnn.benchmark = True
        cudnn.deterministic = True

    # Only in process:0 (main process), the logger or others save operation is triggered.

    def _if(self, sign=True):
        return sign and self.args.local_rank == 0  # or self.args.local_rank ==-1)

    def get_model(self):
        model = None
        if self.args.net_model == 'unet':
            model = models.UNet_ori(in_ch=self.args.input_nchannels, out_ch=self.args.hr_nclasses)
        else:
            model = getattr(models, self.args.net_model)(in_ch=self.args.input_nchannels, num_classes=self.args.hr_nclasses)

        return model

    def logger_info(self, *string):
        if self._if():
            for i in string:
                self.logger.info(i)  # print(i) #

    # reduce the information from children progress.
    def reduce_mean(self, tensor):
        nprocs = self.args.nprocs
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= nprocs
        return rt

    # reduce the information from children progress.
    def reduce_sum(self, tensor):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        return rt

    def run(self):
        # prepare datasets and dataloaders for training and validation
        # disable shuffle in validation dataloader
        if self.args.cityscapes:
            ds_path = '/scratch/forest/datasets/cityscapes'
            self.training_ds = Cityscapes(root=ds_path, split='train', target_type='semantic', 
                                            transforms=et.ExtCompose([
                                                et.ExtRandomScale((0.75, 2.0)),
                                                # et.ExtRandomGaussianBlur(radius=5),
                                                et.ExtRandomCrop(size=(768,768), padding=255,pad_if_needed=True),
                                                et.ExtRandomHorizontalFlip(),
                                                # et.ExtColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25),
                                                et.ExtToTensor(),
                                             ])
                                        )
            self.validation_ds = Cityscapes(root=ds_path, split='val', target_type='semantic',
                                                transforms=et.ExtCompose([et.ExtToTensor(),])
                                            )
            print('Train and Validation data load successful!!!')
        elif self.args.city:
            self.training_ds, self.validation_ds = get_city_datasets(self.args)
        elif self.args.isaid:
            self.training_ds, self.validation_ds = get_isaid_datasets(self.args)
        elif self.args.vaihingen:
            self.training_ds = Vaihingen_Dataset(img_ids = ['1','3','5','7','11','13','15','17','21','23','26','28','30','32','34','37'],
                patch=(768,768),stride=512,mode='train')
            self.validation_ds = Vaihingen_Dataset(img_ids = ['2','4','6','8','10','12','14','16','20','22','24','27','29','31','33','35','38'],
                patch=(768,768),stride=768,mode='valid')
        elif self.args.potsdam:
            self.training_ds = PotsDam_Dataset(img_ids = ['2_10','2_11','2_12','3_10','3_11','3_12','4_10','4_11','4_12','5_10','5_11','5_12','6_7','6_8','6_9','6_10','6_11','6_12','7_7','7_8','7_9','7_10','7_11','7_12'],
                patch=(896,896),stride=512,mode='train')
            self.validation_ds = PotsDam_Dataset(img_ids = ['2_13','2_14','3_13','3_14','4_13','4_14','4_15','5_13','5_14','5_15','6_13','6_14','6_15','7_13'],
                patch=(896,896),stride=896,mode='valid')
        else:
            self.training_ds, self.validation_ds = get_datasets(self.args)

        ##########################
        # Distributed Training
        #Step1: initlize
        dist.init_process_group(backend='nccl')
        torch.cuda.set_device(self.args.local_rank)

        ##########################
        # Step2: Distributed Sampler and Dataloader
        train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_ds, shuffle=True)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(self.validation_ds, shuffle=False)

        training_dl = torch.utils.data.DataLoader(self.training_ds, batch_size=self.args.batch_size, shuffle=False, sampler=train_sampler,
                                                  num_workers=self.args.workers, pin_memory=True,)
        if self.args.cityscapes:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=1, sampler=valid_sampler,
                                                        num_workers=self.args.workers, pin_memory=True,)
        else:
            validation_dl = torch.utils.data.DataLoader(self.validation_ds, batch_size=self.args.batch_size, shuffle=False, sampler=valid_sampler,
                                                        num_workers=self.args.workers, pin_memory=True,)

        ##########################
        # Step3: Model load and cover Sync BN
        model = self.get_model()
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()

        ##########################
        # Step4: Distributed Model
        # local rank means: the children process is in which GPU-*
        model = DistributedDataParallel(model, device_ids=[self.args.local_rank])
        if self.args.resume_model:
            checkpoint = torch.load(self.args.resume_model, map_location='cpu')
            model.load_state_dict(checkpoint["model_state_dict"])

        if not self.args.slowbackbone:
            params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.SGD(params, lr=self.args.learning_rate,  # 1e-2
                                        momentum=0.9, weight_decay=self.args.weight_decay)  # 0.0005)
        else:
            backbone = [id(i) for i in model.module.resnet.parameters()]
            optimizer = torch.optim.SGD(
                                    [{'params': filter(lambda p: id(p) not in backbone,model.module.parameters()), 'lr': self.args.learning_rate},
                                    {'params': model.module.resnet.parameters(), 'lr': 0.1*self.args.learning_rate}],
                                    momentum=0.9, weight_decay=self.args.weight_decay)


        self.lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=self.args.epochs*min(len(training_dl), self.args.training_steps_per_epoch),
                                              end_learning_rate=1e-4, power=0.9) #None

        # train the model for given epochs
        epoch = 0 if not self.args.resume else (checkpoint['epoch'] + 1)
        best_miou = 0
        while epoch < self.args.epochs:
            #### notice here!!! every epoch need to re-shuffle the data.
            train_sampler.set_epoch(epoch)
            self.train_one_epoch(model, optimizer, training_dl, epoch)  # if use lr_scheduler, please set in here

            # valid all to get a stable and precise/accurate score
            if self.args.valid_all:
                miou = 0
                if epoch%5 == 0 or epoch==self.args.epochs-1:
                    miou = self.evaluate(model, validation_dl, epoch,batch_num = max(len(validation_dl), self.args.validation_steps_per_epoch))

            else:
                miou = self.evaluate(model, validation_dl, epoch, batch_num=min(len(validation_dl), self.args.validation_steps_per_epoch))

            if self._if(miou > best_miou):
                best_miou = miou
                self.save_model(-1, self.args.name, model, optimizer)  # best model saved as model_state_dict_epoch_0.tar


            epoch += 1

        self.writer.flush()
        self.writer.close()

        self.evaluate(model, validation_dl, epoch, batch_num=min(len(validation_dl), self.args.validation_steps_per_epoch))
        torch.distributed.barrier()
        try:
            self.final_evaluate(model, validation_dl)
        except Exception as e:
            print("Final Evaludaiton Failed.")
            print('Exception: ', e)

    def train_one_epoch(self, model, optimizer, dataloader, epoch):
        batch_num = min(len(dataloader), self.args.training_steps_per_epoch)
        if batch_num == 0:
            return

        model.train()
        self.one_epoch(model, optimizer, dataloader,
                       epoch, batch_num, "training")

        if (epoch + 1) % self.args.model_save_checkpoint == 0 or epoch == 0:
            path = (self.args.output_path / Path(f"model_state_dict_epoch_{epoch + 1}.tar"))
            if self._if():
                self.save_model(epoch, self.args.name, model, optimizer,path)
            # # sync model 
            # torch.distributed.barrier()
            # checkpoint = torch.load(path, map_location='cpu')
            # model.load_state_dict(checkpoint["model_state_dict"])
            # model.cuda()


    def evaluate(self, model, dataloader, epoch, batch_num):
        if batch_num == 0:
            return 0
        with torch.no_grad():
            model.eval()
            jac = self.one_epoch(model, None, dataloader,
                                 epoch, batch_num, "validation")
        return jac

    def final_evaluate(self, model, dataloader):
        # ## push data paralled and load model state dict.
        checkpoint = torch.load((self.args.output_path /
                                 Path(f"model_state_dict_epoch_0.tar")), map_location='cpu')
        model.load_state_dict(checkpoint["model_state_dict"])
        model.cuda()

        batch_num = len(dataloader)
        with torch.no_grad():
            model.eval()
            self.one_epoch(model, None, dataloader, -1, batch_num, "final")


    def one_epoch(self, model, optimizer, dataloader, epoch, batch_num, task):
        assert task == "training" or task == "validation" or task == "final"
        self.logger_info(f"\n{'=' * 50}\n               Epoch {epoch + 1:3d}      {task}\n{'=' * 50}\n ")

        ep_records = {
            "loss_to_use": [], "crossentropy": [], "miou": [], "accuracy": [],
            "start_time": get_datetime_str(),
            "confusion_matrix": torch.zeros((self.args.hr_nclasses, self.args.hr_nclasses)).cuda().double(),
        }

        visited_counter = [[] for _ in range(len(model.module.ts))] if 'boost' in self.args.net_model else [[0], [0]]
        FG_counter = [[] for _ in range(len(model.module.ts))] if 'boost' in self.args.net_model else [[0], [0]]

        for batch_idx, (x, y_hr) in enumerate(iter(dataloader)):
            # if self.args.cityscapes or self.args.isaid:
            x_n = TF.normalize(x * 255.0, mean = [123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], inplace=False)
            x_n = x_n.cuda(non_blocking=True) 

            outputs = model(x_n)
            # elif self.args.net_model == 'FarSeg':
            #     outputs = model(x, y_hr)
            # else:
            #     x = x.cuda(non_blocking=True)  # B, 4, H, W
            #     outputs = model(x)
            B, C, H, W = x.shape
            y_hr = y_hr.cuda(non_blocking=True)
            y_ignore = (y_hr==255).unsqueeze(dim=1) #self.args.ignore_idx
            y_hr = y_hr.long().clamp(max=self.args.hr_nclasses - 1, min=0)  # B, H, W
            y_hr_onehot = torch.nn.functional.one_hot(y_hr, self.args.hr_nclasses).permute(0, 3, 1, 2)  # B, 5, H, W
            
            # Loss calculate, for model trarning.
            if self.args.net_model.startswith('DeepLabv3p_boost') or self.args.net_model.startswith('Semantic_FPN'):  # ,'UNet_boost_feacat']:
                if self.args.name_prefix.endswith('0_'):
                    o = outputs[-1]
                    y_hr_pred = o
                    y_hr_pred = F.interpolate(y_hr_pred, size=x.size()[2:], mode='bilinear', align_corners=True)
                    loss = loss_functions.kl_loss(y_hr_pred, y_hr_onehot,~y_ignore)
                elif self.args.name_prefix.endswith('1_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = (~y_ignore)*torch.ones(B, 1, H, W).cuda()
                    result = torch.zeros(B, o[-1].shape[1], H, W).cuda()
                    loss = 0
                    end_idx = len(model.module.ts)-1

                    for idx, scale_f in enumerate([16, 8, 4, 2, 1][:end_idx+1]):
                        loss_rate = self.args.stagerate**idx
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=1, stride=1, padding=0).bool()

                        # cur = F.upsample(o[idx], mode='bilinear', align_corners=True, scale_factor=scale_f)
                        cur = F.interpolate(o[idx],size = (H,W),mode='bilinear', align_corners=True) #scale_factor=scale_f
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        _a = target.permute(0,2,3,1)
                        _b = expand.permute(0,2,3,1).squeeze(dim = -1).bool()
                        if len(_a[_b])<5:
                            FG_counter[idx].append(0*torch.sum(_b))
                        else:
                            FG_counter[idx].append(torch.sum(torch.argmax(_a[_b],dim = 1)!=0) /torch.sum(_b))
                        if self._if((batch_idx+1) % 10 == 0):
                            self.logger_info(f"Proportion of not visited at scale {scale_f}: {torch.sum(not_visited)/torch.numel(not_visited)}")
                            # self.logger_info(f"Proportion of FG at scale {scale_f}: {torch.sum(torch.argmax(_a[_b],dim = 1)!=0) /torch.sum(_b)}")

                        uncertain = torch.max(cur, dim=1, keepdim=True)[0]
                        ts = model.module.ts
                        #########################################################
                        #### update rate =1
                        if (batch_idx+1) % 10 == 0:
                            sign = torch.argmax(cur, dim=1, keepdim=True) == torch.argmax(y_hr_onehot, dim=1, keepdim=True)
                            correct_dis = uncertain[sign*expand]
                            error_dis = uncertain[(~sign)*expand]
                            all_dis = uncertain[expand]
                            # loss += 10*F.l1_loss(correct_dis, torch.ones_like(correct_dis))+ 10*F.l1_loss(error_dis, torch.zeros_like(error_dis)) #*1.0/cur.shape[1]
                            #F.l1_loss(error_dis, torch.ones_like(error_dis)*1.0/cur.shape[1])
                            if (task == 'training'):
                                if len(correct_dis) > 10:
                                    quantile_point = torch.quantile(correct_dis, q=self.args.quantile)
                                else:
                                    quantile_point = model.module.ts[idx]
                                # if len(error_dis) > 10:
                                #     quantile_point = torch.quantile(error_dis, q=self.args.quantile)
                                # else:
                                #     quantile_point = model.module.ts[idx]
                                # if len(all_dis) > 10:
                                #     quantile_point = torch.quantile(all_dis, q=self.args.quantile)
                                # else:
                                #     quantile_point = model.module.ts[idx]
                                reduce_q_p = self.reduce_mean(quantile_point)
                                model.module.ts[idx] = self.args.gamma*model.module.ts[idx]+(1-self.args.gamma)*reduce_q_p.detach().cpu().numpy()
                            
                            if self._if(batch_idx%20==0 and epoch > 10 and len(correct_dis) > 10 and len(error_dis) > 10):
                                self.writer.add_histogram(f"{task}{idx}-correct_dis", correct_dis, epoch + 1, bins='auto')
                                self.writer.add_histogram(f"{task}{idx}-error_dis", error_dis, epoch + 1, bins='auto')
                        if batch_idx==0 and idx == 0:
                                self.logger_info(f'The threshold is: {model.module.ts.cpu().numpy()}')
                            # loss += 100*torch.mean(uncertain[(~sign)*expand*( uncertain >= ts[idx] )]- 1.0/cur.shape[1])

                        if idx == end_idx:
                            cur_visited = 1.
                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                        elif idx < end_idx:
                            cur_visited = uncertain >= (ts[idx].to(uncertain.device))  # max(0,1-4./(scale_f**2))
                            cur_visited = cur_visited.float()
                            # loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)

                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)  # +0.5*loss_functions.kl_loss_woSOFT(cur, target,sign= cur_visited*expand)
                        # cache for next step feature
                        result += (not_visited*cur_visited)*cur.detach()
                        if self._if(batch_idx == 0 and epoch % 10 == 0):
                            self.save_tensorboard_onehot_images(result, f'{task}:Idx {idx}', epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:
                            break
                    y_hr_pred = result
                elif self.args.name_prefix.endswith('topk_'):
                    o = outputs
                    target = y_hr_onehot
                    not_visited = (~y_ignore)*torch.ones(B, 1, H, W).cuda()
                    result = torch.zeros(B, o[-1].shape[1], H, W).cuda()
                    loss = 0
                    end_idx = len(model.module.ts)-1

                    for idx, scale_f in enumerate([16, 8, 4, 2, 1][:end_idx+1]):
                        loss_rate = self.args.stagerate**idx
                        expand = not_visited.float()
                        expand = F.max_pool2d(expand.float(), kernel_size=1, stride=1, padding=0).bool()

                        # cur = F.upsample(o[idx], mode='bilinear', align_corners=True, scale_factor=scale_f)
                        cur = F.interpolate(o[idx],size = (H,W),mode='bilinear', align_corners=True) #scale_factor=scale_f
                        cur = F.softmax(cur, dim=1)
                        visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                        _a = target.permute(0,2,3,1)
                        _b = expand.permute(0,2,3,1).squeeze(dim = -1).bool()
                        if len(_a[_b])<5:
                            FG_counter[idx].append(0*torch.sum(_b))
                        else:
                            FG_counter[idx].append(torch.sum(torch.argmax(_a[_b],dim = 1)!=0) /torch.sum(_b))

                        
                        if self._if((batch_idx+1) % 10 == 0):
                            self.logger_info(f"Proportion of not visited at scale {scale_f}: {torch.sum(not_visited)/torch.numel(not_visited)}")

                        uncertain = torch.max(cur, dim=1, keepdim=True)[0]
                        ts = torch.quantile(uncertain[expand],q=self.args.quantile)
                        
                        if idx == end_idx:
                            cur_visited = 1.
                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)
                        elif idx < end_idx:
                            cur_visited = uncertain >= ts
                            cur_visited = cur_visited.float()
                            # loss += loss_functions.kl_loss_woSOFT(cur, target, sign=expand)

                            loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, target, sign=expand)  # +0.5*loss_functions.kl_loss_woSOFT(cur, target,sign= cur_visited*expand)
                        # cache for next step feature
                        result += (not_visited*cur_visited)*cur.detach()
                        if self._if(batch_idx == 0 and epoch % 10 == 0):
                            self.save_tensorboard_onehot_images(result, f'{task}:Idx {idx}', epoch)
                        not_visited = not_visited*(1-cur_visited)
                        if idx == end_idx:
                            break
                    y_hr_pred = result
                # elif self.args.name_prefix.endswith('2_'):
                #     o = outputs
                #     target = y_hr_onehot
                #     not_visited = (~y_ignore)*torch.ones(B, 1, H, W).cuda()
                #     result = torch.zeros(B, o[-1].shape[1], H, W).cuda()
                #     loss = 0
                #     end_idx = len(model.module.ts)-1

                #     for idx, scale_f in enumerate([16, 8, 4, 2, 1][:end_idx+1]):
                #         loss_rate = 1
                #         expand = not_visited.float()
                #         # expand = F.max_pool2d(expand.float(), kernel_size=3, stride=1, padding=1).bool()
                #         expand = F.avg_pool2d(expand, kernel_size=scale_f, stride=scale_f).bool()
                #         cur_target = F.avg_pool2d(target.float(), kernel_size=scale_f, stride=scale_f)
                #         cur = o[idx]
                #         cur = F.softmax(cur, dim=1)
                #         visited_counter[idx].append(torch.sum(not_visited)/torch.numel(not_visited))
                #         if self._if(batch_idx % 25 == 0):
                #             self.logger_info(torch.sum(not_visited)/torch.numel(not_visited))

                #         uncertain = torch.max(cur, dim=1, keepdim=True)[0]
                #         ts = model.module.ts
                #         if batch_idx % 10 == 0:
                #             sign = torch.argmax(cur, dim=1, keepdim=True) == torch.argmax((cur_target > 0.99).float(), dim=1, keepdim=True)
                #             correct_dis = uncertain[sign*expand]
                #             error_dis = uncertain[(~sign)*expand]

                #             if (task == 'training') and idx < end_idx:
                #                 if len(correct_dis) > 10:
                #                     quantile_point = torch.quantile(correct_dis, q=0.3)
                #                 else:
                #                     quantile_point = model.module.ts[idx]

                #                 reduce_q_p = self.reduce_mean(quantile_point)
                #                 model.module.ts[idx] = 0.9*model.module.ts[idx]+0.1*reduce_q_p.detach().cpu().numpy()
                #                 if idx == 0:
                #                     self.logger_info(f'The threshold is: {model.module.ts.cpu().numpy()}')
                #             if self._if(epoch > 10 and len(correct_dis) > 10 and len(error_dis) > 10):
                #                 self.writer.add_histogram(f"{task}{idx}-correct_dis", correct_dis, epoch + 1, bins='auto')
                #                 self.writer.add_histogram(f"{task}{idx}-error_dis", error_dis, epoch + 1, bins='auto')

                #         if idx == end_idx:
                #             cur_visited = 1
                #             loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)
                #         elif idx < end_idx:
                #             cur_visited = uncertain >= (ts[idx].to(uncertain.device))  # max(0,1-4./(scale_f**2))
                #             cur_visited = cur_visited.float()

                #             loss += loss_rate*loss_functions.kl_loss_woSOFT(cur, cur_target, sign=expand)

                #             cur_visited = F.upsample(cur_visited, mode='nearest', scale_factor=scale_f)
                #         # cache for next step feature
                #         cur = F.upsample(o[idx], mode='nearest', scale_factor=scale_f)
                #         result += ((not_visited*cur_visited).cuda())*(cur.detach().cuda())

                #         if self._if(batch_idx == 0 and epoch % 10 == 0):
                #             self.save_tensorboard_onehot_images(result, f'{task}:Idx {idx}', epoch)
                #         not_visited = not_visited*(1-cur_visited)
                #         if idx == end_idx:
                #             break
                #     y_hr_pred = result

            else:
                y_hr_pred = outputs
                loss = loss_functions.kl_loss(y_hr_pred, y_hr_onehot,~y_ignore) #loss_functions.crossentropy_loss(hard_mining=False)(y_hr_pred, y_hr)

            # SYNC multiple process!!!!
            torch.distributed.barrier()

            if task == "training":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()

            ##############################################################################################
            ### Evaluate function (in train mode, remove grad calculate to fast the speed)################
            ##############################################################################################
            with torch.no_grad():
                y_hr_pred_am = y_hr_pred.argmax(dim=1)
                y_hr_pred_onehot = torch.nn.functional.one_hot(y_hr_pred_am, self.args.hr_nclasses).permute(0, 3, 1, 2)
                
                ### remove ignore index
                cf = confusion_matrix_gpu(y_hr.view(-1)[~y_ignore.view(-1)], y_hr_pred_am.view(-1)[~y_ignore.view(-1)], num_classes=self.args.hr_nclasses)
                # debug: print(y_ignore.sum(),cf.sum()/(H*W))
                ### !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                ### !!!!not accurate evalute metric, just for reference!!!!
                crossentropy = loss_functions.crossentropy_loss(hard_mining=False)(y_hr_pred, y_hr)
                miou = loss_functions.miou(self.args.hr_nclasses)(y_hr_pred_onehot, y_hr_onehot)
                accuracy = loss_functions.accuracy()(y_hr_pred_onehot, y_hr_onehot)

                # reduce multiple loss from multile process
                reduce_cf = self.reduce_sum(cf)
                ep_records['confusion_matrix'] += reduce_cf
                reduce_crossentropy = self.reduce_mean(crossentropy)
                reduce_miou = self.reduce_mean(miou)
                reduce_accuracy = self.reduce_mean(accuracy)

            if (batch_idx + 1) % 10 == 0:
                # print batch log in console
                log_str = f"Batch: {(batch_idx + 1):4d} / {batch_num:4d}    " \
                    f"Loss: {loss.item():.3f}    " \
                    f"CrossEntropy Loss: {reduce_crossentropy.item():.3f}    " \
                    f"mIOU: {reduce_miou.item():.3f}    " \
                    f"Accuracy: {reduce_accuracy.item():.3f}"

                self.logger_info(log_str)

            # store loss records
            # the IOU or ACC or cross need to reduce from all process so as to construct ALL DATA
            ep_records["loss_to_use"].append(loss.item())
            ep_records["crossentropy"].append(reduce_crossentropy.item())
            ep_records["miou"].append(reduce_miou.item())
            ep_records["accuracy"].append(reduce_accuracy.item())

            # save tensorboard images in the first batch of each evaluation epoch
            if self._if(batch_idx == 0 and epoch % (max((self.args.epochs)//20, 1)) == 0):  # task == "validation" and
                self.save_tensorboard_images(x, y_hr, y_hr_pred, epoch, task)
            if batch_idx + 1 >= batch_num:
                break

        # summarize epoch records
        ep_records["end_time"] = get_datetime_str()
        m_loss_to_use = np.mean(ep_records["loss_to_use"])
        m_crossentropy = np.mean(ep_records["crossentropy"])

        cf_matrix = ep_records['confusion_matrix']
        iou, m_iou = IOU_cal(cf_matrix)
        fg_iou, fg_m_iou = IOU_cal(cf_matrix[1:, 1:])
        m_accuracy = (cf_matrix.diag().sum() / cf_matrix.sum()).item()

        visited_counter = np.array(visited_counter)

        log_str_1 = f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
                    f"Start at {ep_records['start_time']}    " \
                    f"End at {ep_records['end_time']} \n"\
                    f"Epoch {epoch + 1:3d} {task.upper()} SUMMARY    " \
                    f"Loss: {m_loss_to_use:.4f}    " \
                    f"CrossEntropy Loss: {m_crossentropy:.4f}    \n" \
                    f"Accuracy: {m_accuracy:.4f}    \n" \
                    f"mIOU: {m_iou:.4f}    " \
                    f"IOU for each category: {iou.data}\n"\
                    f"fg_mIOU: {fg_m_iou:.4f}    " \
                    f"fg_IOU for each category: {fg_iou.data} \n"\
                    f"visited counter in different stage: {np.mean(visited_counter, axis=1)}" \
                    f"FG counter in different stage: {np.mean(FG_counter, axis=1)}"
        
        self.logger_info(log_str_1)

        save_str = f'''Epoch {epoch + 1:3d} - {ep_records["start_time"]} - {ep_records["end_time"]} - ''' \
            f"Loss: {m_loss_to_use:.4f}    " \
            f"CrossEntropy Loss: {m_crossentropy:.4f}    " \
            f"mIOU: {m_iou:.4f}    " \
            f"Accuracy: {m_accuracy:.4f}\n"

        if self._if():
            save_epoch_summary(self.args, save_str)
            # save epoch summary to tensorboard
            self.save_tensorboard_scalars(m_loss_to_use, m_crossentropy, m_iou, m_accuracy, task=f"{task}-epoch", step=epoch + 1)

        return m_iou

    def save_tensorboard_scalars(self, loss_to_use, crossentropy_loss, miou, accuracy, task, step):

        self.writer.add_scalar(f"{task}/loss", loss_to_use, step)
        self.writer.add_scalar(f"{task}/crossentropy_loss", crossentropy_loss, step)
        self.writer.add_scalar(f"{task}/miou", miou, step)
        self.writer.add_scalar(f"{task}/accuracy", accuracy, step)

        # self.writer.add_scalar(f"{task}/3remote_jaccard", remote_jaccard, step)
        # self.writer.add_scalar(f"{task}/3city_jaccard", city_jaccard, step)
        # self.writer.add_scalar(f"{task}/3remote_accuracy", remote_accuracy, step)
        # self.writer.add_scalar(f"{task}/3city_accuracy", city_accuracy, step)

    def save_tensorboard_onehot_images(self, y, name, epoch):

        B, C, H, W = y.shape
        if y.shape[1] == 5:
            color_map = np.array([
                [255, 255, 255],
                [1, 39, 255],
                [51, 160, 44],
                [126, 237, 14],
                [255, 1, 1]
            ]) / 255.0
        else:
            color_map = []
            N = 255*255*255/20
            for i in range(0, y.shape[1]):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        y = F.one_hot(torch.argmax(y, dim=1), len(color_map)).cpu().float()
        color_map = torch.from_numpy(color_map).float()
        im = np.array(torch.einsum('bhwc,cp->bhwp', y, color_map))
        im = im.reshape(B*H, W, -1)
        self.writer.add_image(name, im, epoch, dataformats='HWC')

    def save_tensorboard_images(self, x, y, y_pred, epoch, task='validation'):

        color_map = np.array([
            [255,255,255],
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255 / 20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        elif self.args.isaid:
            color_map = []
            N = 255*255*255 / 16
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0
        elif self.args.potsdam or self.args.vaihingen:
            color_map = np.array([[255, 255, 255],[0, 0, 255],[0, 255, 255],[0, 255, 0],[255, 255, 0],[255,255,0]])/255.0
        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        # )  # loss_functions.edge_inter_union()(y_hr_pred_onehot, y_hr_onehot)
        label_color_map = np.array([color_map[i] for i in y])
        pred = torch.argmax(y_pred[:, :, :, :], dim=1)

        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        error_mask[:, :, :, 0][pred != y] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error), axis=1)  # B*4*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*4*240*3
        img_concat = img_concat.reshape((-1, 4*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_scale(self, x, y, y_pred, preds, epoch):

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
            # (self.args.epochs) // 10, 1)) == 0:  # task == "validation" and
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        preds = [i.softmax(dim=1).topk(k=2, dim=1)[0] for i in preds]
        preds = [i[:, 0:1, :, :].repeat(1, 3, 1, 1) for i in preds]
        # preds = [(i[:, 0:1, :, :] - i[:, 1:2, :, :]).repeat(1,3,1,1) for i in preds]
        pred2, pred3, pred4, pred5 = preds
        pred2 = F.interpolate(pred2, scale_factor=2, mode='nearest')
        pred3 = F.interpolate(pred3, scale_factor=4, mode='nearest')
        pred4 = F.interpolate(pred4, scale_factor=8, mode='nearest')
        pred5 = F.interpolate(pred5, scale_factor=16, mode='nearest')
        c = [pred2, pred3, pred4, pred5]

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        # c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[3], RF[2], RF[1], RF[0]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        # self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        # self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_arf(self, x, y, y_pred, c, epoch):

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        RF = [i.detach().cpu().permute(0, 2, 3, 1).numpy()[:, np.newaxis, :, :] for i in c]
        c = [i.unsqueeze(dim=0) for i in c]
        c = torch.cat((c[0], c[1], c[2], c[3]), dim=0)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, RF[0], RF[1], RF[2], RF[3]), axis=1)  # B*8*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*8*W*3
        img_concat = img_concat.reshape((-1, 8*W, 3))  # 240xB * 240x8 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')
        self.writer.add_histogram(f"validation-RF-value", c, epoch + 1)
        self.writer.add_histogram(f"validation-RF-argmax", c.argmax(dim=2) + 1, epoch + 1, bins='auto')

    def save_tensorboard_images_weight(self, x, y, y_pred, w, epoch):

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        if self.args.cityscapes:
            color_map = []
            N = 255*255*255/20
            for i in range(0, 20):
                n = N * i
                color_map.append([n//(255*255), (n//255) % 255, n % 255])
            color_map = np.array(color_map) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred.detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        if self.args.cityscapes:
            label_color_map = np.array([color_map[i] for i in y])
            pred = torch.argmax(y_pred[:, :, :, :], dim=1)
        else:
            label_color_map = np.array([color_map[i - 1] for i in y])
            pred = torch.argmax(y_pred[:, 1:, :, :], dim=1)
        pred_color_map = np.array([color_map[i] for i in pred])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        if self.args.cityscapes:
            error_mask[:, :, :, 0][pred != y] = 1
        else:
            error_mask[:, :, :, 0][pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)

        w = w.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).unsqueeze(dim=1)
        w = w.detach().cpu().numpy()  # bs, 1, H, W

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, w), axis=1)  # B*5*H*W*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*H*5*W*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 240xB * 240x5 * 3

        self.writer.add_image(f"validation-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_images_edge(self, x, y, y_pred, edge, area, epoch):

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        x = x.detach().cpu()
        y = y.detach().cpu()
        y_pred = y_pred[:, 1:, :, :].argmax(dim=1).detach().cpu()
        edge = F.sigmoid(edge.detach()).cpu()
        area = area[:, 1:, :, :].argmax(dim=1).detach().cpu()

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).numpy()
        label_color_map = np.array([color_map[i - 1] for i in y])
        pred_color_map = np.array([color_map[i] for i in y_pred])
        edge = edge.squeeze().numpy()[:, :, :, np.newaxis].repeat(3, axis=-1)
        area_color_map = np.array([color_map[i] for i in area])

        error_mask = torch.zeros(x[:, 0:3, :, :].permute(0, 2, 3, 1).shape).float()
        error_mask[:, :, :, 0][y_pred != (y-1)] = 1
        error = images * 0.7 + error_mask.numpy() * 0.3

        images = np.expand_dims(images, axis=1)
        label_color_map = np.expand_dims(label_color_map, axis=1)
        pred_color_map = np.expand_dims(pred_color_map, axis=1)
        error = np.expand_dims(error, axis=1)
        edge = np.expand_dims(edge, axis=1)
        area_color_map = np.expand_dims(area_color_map, axis=1)

        img_concat = np.concatenate((images, label_color_map, pred_color_map, error, edge, area_color_map), axis=1)  # B*6*240*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*240*6*240*3
        img_concat = img_concat.reshape((-1, 6*W, 3))  # 240xB * 240x4 * 3

        self.writer.add_image(f"{task}-label-prediction", img_concat, epoch + 1, dataformats='HWC')

    def save_tensorboard_layer(self, xs, epoch, task='validation'):
        '''
        xs: list
        '''

        xs = [i[:, :3] for i in xs]
        x = torch.stack(xs, dim=1).cpu().numpy()
        B, L, C, H, W = x.shape
        x = (x-np.min(x, axis=(3, 4), keepdims=True))
        x = (x-np.max(x, axis=(3, 4), keepdims=True))

        x = x.transpose(0, 2, 3, 1, 4).reshape(B*C*H, L*W, 1)
        self.writer.add_image(f"{task}-layer visualize", x, epoch + 1, dataformats='HWC')

    def save_tensorboard_onlyatt(self, att, epoch, task='validation'):
        # visualize attention histogram
        '''
        att: B*mh*ks2*H*W
        '''

        B, mh, _, H, W = att.shape
        att = att[:, :1].detach().cpu().numpy()  # select the first attention column
        att_max = np.argmax(att, axis=2)
        ks = int(att.shape[2]**0.5)
        att = att.reshape(B, ks, ks, H, W)
        att = att.transpose(0, 3, 1, 4, 2)
        att = att.reshape(B*H*ks, W*ks, 1)

        self.writer.add_histogram(f"{task}-attention-value", att, epoch + 1)
        self.writer.add_histogram(f"{task}-attention-argmax", att_max + 1, epoch + 1, bins='auto')
        self.writer.add_image(f"{task}-attention-kernel", att, epoch + 1, dataformats='HWC')

    def save_tensorboard_attention(self, x, y, x_before, x_after, att, epoch, task='validation'):

        def normalize(t):
            t = t - t.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0]
            t = t / t.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0]

            return t

        color_map = np.array([
            [1, 39, 255],
            [51, 160, 44],
            [126, 237, 14],
            [255, 1, 1]
        ]) / 255.0

        B, C, H, W = x.shape

        images = x[:, 0:3, :, :].permute(0, 2, 3, 1).unsqueeze(dim=1).cpu().numpy()  # B*1*240*240*3
        label_color_map = np.array([color_map[i - 1] for i in y.unsqueeze(dim=1).cpu()])  # B*1*240*240*3
        img_label = np.concatenate((images, label_color_map), axis=2)  # B*1*480*240*3

        x_before = normalize(x_before).unsqueeze(dim=-1)  # B*4*240*240*1
        x_after = normalize(x_after).unsqueeze(dim=-1)  # B*4*240*240*1
        x_before_after = torch.cat((x_before, x_after), dim=2).repeat(1, 1, 1, 1, 3).cpu().numpy()  # B*4*480*240*3

        img_concat = np.concatenate((img_label, x_before_after), axis=1)  # B*5*480*240*3
        img_concat = img_concat.transpose((0, 2, 1, 3, 4))  # B*480*5*240*3
        img_concat = img_concat.reshape((-1, 5*W, 3))  # 480xB * 240x5 * 3

        self.writer.add_image(f"{task}-attention", img_concat, epoch + 1, dataformats='HWC')

        # visualize attention histogram
        self.save_tensorboard_onlyatt(att, epoch=epoch, task=task)

    def save_tensorboard_cm(self, cm, epoch):

        cm_s = np.sum(cm, axis=1, keepdims=True) + 1e-6
        cm = cm / cm_s
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.label_names)
        disp.plot(cmap=plt.cm.Blues, values_format=".3g")
        fig = disp.figure_
        fig.set_dpi(150)
        fig.set_size_inches(4, 4)
        self.writer.add_figure(f"validation-confusion-matrix", fig, epoch + 1)

    def save_model(self, epoch, name, model, optimizer,path = None):
        # args.output_path is a Path object
        if path is None:
            path = (self.args.output_path /
                    Path(f"model_state_dict_epoch_{epoch + 1}.tar"))
        d = {
            'epoch': epoch,
            'name': name,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }
        if hasattr(self, "anchors"):
            d['anchors'] = self.anchors
        torch.save(d, path.__str__())
        self.logger.info(f"Model checkpoint saved at {str(path)}")


# import argparse
# import os
# import random
# import shutil
# import time
# import warnings

# import torch
# import torch.nn as nn
# import torch.nn.parallel
# import torch.backends.cudnn as cudnn
# import torch.distributed as dist
# import torch.optim
# import torch.multiprocessing as mp
# import torch.utils.data
# import torch.utils.data.distributed
# import torchvision.transforms as transforms
# import torchvision.datasets as datasets
# import torchvision.models as models

# # from apex import amp
# # from apex.parallel import DistributedDataParallel
# # from apex.parallel import convert_syncbn_model

# cudnn.benchmark = True
# cudnn.deterministic = True
# random.seed(0)
# torch.manual_seed(0)

# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', default=-1, type=int,
#                     help='node rank for distributed training')
# args = parser.parse_args()


# def reduce_mean(tensor, nprocs):
#     rt = tensor.clone()
#     dist.all_reduce(rt, op=dist.ReduceOp.SUM)
#     rt /= nprocs
#     return rt

# class Dataset(torch.utils.data.Dataset):
#     def __init__(self,):
#         super().__init__()
#         self.X = torch.arange(100).reshape(50,2,1).float()
#         self.Y = torch.zeros(50,2)

#     def __getitem__(self,idx):
#         return self.X[idx],self.Y[idx]

#     def __len__(self,):
#         return len(self.X)


# dist.init_process_group(backend='nccl')
# torch.cuda.set_device(args.local_rank)

# train_dataset = Dataset()
# train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle = False)
# train_loader = torch.utils.data.DataLoader(train_dataset,
#                                            batch_size=2,
#                                            shuffle=(train_sampler is None),
#                                            num_workers=2,
#                                            pin_memory=True,
#                                            sampler=train_sampler)


# model = nn.Sequential(nn.Linear(1,50),nn.BatchNorm1d(2),nn.ReLU(),nn.Linear(50,2))
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()

# criterion = nn.L1Loss().cuda()
# # model = DistributedDataParallel(model)
# model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank])
# optimizer = torch.optim.SGD(model.parameters(), 1e-3)


# # model, optimizer = amp.initialize(model, optimizer,opt_level="O0")
# res = 0
# model.train()
# for epoch in range(10):
#     for batch_idx, (images, target) in enumerate(train_loader):
#         print(images)
#         images = images.cuda(non_blocking=True)
#         target = target.cuda(non_blocking=True)

#         output = model(images)
#         loss = criterion(output, target)
#         print(loss.item())
#         torch.distributed.barrier()

#         optimizer.zero_grad()
# #         with amp.scale_loss(loss, optimizer) as scaled_loss:
#         loss.backward()
#         optimizer.step()
#         res += 1
# print(res)
