import matplotlib.pyplot as plt
from numpy.core.numeric import cross
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.functional import align_tensors
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch.nn.functional as F
import torchvision.transforms.functional as TF


import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn

import torch
import random
import numpy as np
import gc
from pathlib import Path
import loss_functions
import models
from dataset import get_datasets
from dataset_hdf5 import get_hdf5_datasets
from dataset_city import get_city_datasets
from dataset_isaid import get_isaid_datasets
from dataset_cityscapes import Cityscapes
from utils import get_logger, save_epoch_summary, get_datetime_str_simplified, get_datetime_str,\
     MyTransform, Sync_DataParallel,BalancedDataParallel,execute_replication_callbacks
from utils_torch import PolynomialLRDecay,confusion_matrix_gpu,IOU_cal
from torchvision.transforms.transforms import CenterCrop
from models.sync_batchnorm.replicate import patch_replication_callback
import sys
class A():
    name = None
    output_path = None
    local_rank = 0
    pass
args = A()



class Train:
    def __init__(self, args):
        self.args = args
        self.args.nprocs = torch.cuda.device_count()
        self.logger = get_logger(f"DDP train LOGING: {self.args.local_rank}", args,save_dir='.') ###get_logger(__name__, args)
        self.logger_info('111111111111')

        # while self.logger.handlers:
        #     self.logger.handlers.pop()
        self.label_names = ["Water", "Forest", "Field", "Others"]
        self._fix_random_seed()
        self.logger_info('222222222222')

    def _fix_random_seed(self,):
        random.seed(0)
        torch.manual_seed(0)
        cudnn.benchmark = True
        cudnn.deterministic = True
        #### Only in process:0 (main process), the logger or others save operation is triggered.
        
    def _if(self,sign=True):
        return sign and self.args.local_rank ==0 # or self.args.local_rank ==-1)



    def logger_info(self,*string):
        if self._if():
            for i in string:self.logger.info(i) #print(i) #
        

t = Train(args)