# Common
import os
import wandb
import numpy as np
import time
# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import MinkowskiEngine as ME

# my module
import os.path as osp

from dataset.get_dataloader import get_TV_dl
from network.lr_adjust import adjust_learning_rate, adjust_learning_rate_D
from utils import common as com

from network.domain_mix import laserMix
from validate_train import validater

import datetime

# from utils.mean_std_proto import MultiFeaturePrototypeEMA

import math

source_label, target_label = 0, 1

def my_worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

class stage1_ours_Trainer:
    def __init__(self,
                 cfg,
                 net_G, ema_G, net_D, 
                 G_optim, D_optim, 
                 logger, tf_writer, device):

        self.start_iter = 0
        self.ml_info = {'bt_tgt_spIoU': 0}
        self.cfg = cfg
        self.logger = logger
        self.tf_writer = tf_writer
        self.device = device

        self.net_G = net_G
        self.ema_G = ema_G
        self.net_D = net_D
        self.net_D_original_state_dict = net_D.state_dict()
        self.reset_discriminator_iter = self.cfg.MODEL_D.RESET_DISCRIMINATOR_ITER if hasattr(self.cfg.MODEL_D, 'RESET_DISCRIMINATOR_ITER') else 2000


        self.G_optim = G_optim
        self.D_optim = D_optim

        """ ETA Calculation"""
        self.iter_times = []
        self.max_time_buffer = 50
              
        """ Define Loss Function """
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)  # seg loss
        self.intensity_criterion = nn.MSELoss()  # intensity loss
        if self.cfg.MODEL_D.GAN_MODE == 'ls_gan':# gan loss
            self.criterionGAN = nn.MSELoss(reduction='none')  
        elif self.cfg.MODEL_D.GAN_MODE == 'vanilla_gan':
            self.criterionGAN = nn.BCEWithLogitsLoss(reduction='none') 

        """  get_dataset & dataloader """
        self.init_dataloader()
        self.t_val_iter = self.cfg.TRAIN.T_VAL_ITER
        self.s_val_iter = self.cfg.TRAIN.S_VAL_ITER

        """ Other training parameters"""
        self.c_iter = 0  # Current Iter
        self.round = 0 # current round
        self.best_IoU_iter = 0
        self.best_IoU_after_saveIter = 0
        
        if self.cfg.MEAN_TEACHER.use_mt:
            self.create_ema_model(self.ema_G, self.net_G)
        
        self.use_entropy_reduction = self.cfg.MODEL_G.use_entropy_reduction if hasattr(self.cfg.MODEL_G, 'use_entropy_reduction') else False
        self.use_inter_proto = self.cfg.MODEL_G.use_inter_proto if hasattr(self.cfg.MODEL_G, 'use_inter_proto') else False
        self.use_intra_proto = self.cfg.MODEL_G.use_intra_proto if hasattr(self.cfg.MODEL_G, 'use_intra_proto') else False
        self.mean_std_const_layers = self.cfg.MODEL_G.mean_std_const_layers if hasattr(self.cfg.MODEL_G, 'mean_std_const_layers') else []
        self.mean_std_const_layers = [str(layer) for layer in self.mean_std_const_layers.split(",")] if self.mean_std_const_layers != [] else []
        self.voxel_size = self.cfg.HYPERPARAMETER.VOXEL_SIZE
        self.max_range = self.cfg.HYPERPARAMETER.MAX_RANGE
        self.bin_size = self.cfg.HYPERPARAMETER.BIN_SIZE
        self.num_bins = int(math.ceil(self.max_range / self.bin_size))

        self.entropy_regularization = self.cfg.MODEL_G.entropy_regularization if hasattr(self.cfg.MODEL_G, 'entropy_regularization') else 0
        
        # Intensity learning parameters
        self.use_intensity_learning = self.cfg.MODEL_G.use_intensity_learning if hasattr(self.cfg.MODEL_G, 'use_intensity_learning') else False
        self.intensity_lambda = self.cfg.MODEL_G.intensity_lambda if hasattr(self.cfg.MODEL_G, 'intensity_lambda') else 1.0
        self.intensity_pseudo_lambda = self.cfg.MODEL_G.intensity_pseudo_lambda if hasattr(self.cfg.MODEL_G, 'intensity_pseudo_lambda') else 0.5

        # ==============================================================
        #  (NEW) adversarial class-wise statistics (no memory-bank)
        #        * only final encoder feature ('out_b4p16') is used
        # ==============================================================
        self.num_classes     = self.cfg.MODEL_G.NUM_CLASSES
        self.last_layer_key  = 'out_b4p16'                        # final feature map name
        self.feat_dim        = 256                                # dim of this layer
        self.lambda_adv_cls  = getattr(self.cfg.TGT_LOSS, 'LAMBDA_ADV_CLS',
                                       self.cfg.TGT_LOSS.LAMBDA_ADV)            

        # ==============================================================
        #  (NEW) Hierarchical class definitions for hierarchical adversarial loss
        # ==============================================================
        # self.use_classwise_adv = getattr(self.cfg.TGT_LOSS, 'USE_CLASSWISE_ADV', False)
        self.use_hierarchical_adv = getattr(self.cfg.TGT_LOSS, 'USE_HIERARCHICAL_ADV', False)
        if self.use_hierarchical_adv:
            self.init_hierarchical_classes()
            # Initialize hierarchical statistics variables
            self.src_hier_mean_cat = None
            self.src_hier_std_cat = None
            self.t2s_hier_mean_cat = None
            self.t2s_hier_std_cat = None
            self.tgt_hier_mean_cat = None
            self.tgt_hier_std_cat = None
        
        self.label_interpolator = ME.MinkowskiInterpolation()

    def init_hierarchical_classes(self):
        # ----------------------------------------------------------
        # 1) hierarchy setting
        # ----------------------------------------------------------
        HIERARCHY_CONFIG = {
            "SemanticPOSS": {
                "Qwen3-235B-A22B-2507": dict( # best value
                    things_vehicle           = [],
                    things_two_wheelers      = [],
                    things_person            = [1, 13],          # rider, person
                    things_traffic_elements  = [5, 6],           # traffic-sign, pole
                    things_objects           = [],               # trashcan, cone_stone 등
                    stuff_pavement           = [],               # ground
                    stuff_natural            = [3, 4],           # trunk, plant
                    stuff_structures         = [8, 10],          # building, fence
                ),
                "DeepSeekR1": dict( # best value
                    things_vehicle           = [],
                    things_two_wheelers      = [],
                    things_person            = [1, 2, 11, 13],   # rider, car, bike, person # dynamic obejcts
                    things_traffic_elements  = [5, 6, 7, 9],     # traffic-sign, pole
                    things_objects           = [],               # trashcan, cone_stone 등
                    stuff_pavement           = [],               # ground
                    stuff_natural            = [3, 4],           # trunk, plant
                    stuff_structures         = [8, 10],          # building, fence
                ),
                4: dict( # best value
                    things_vehicle           = [],
                    things_two_wheelers      = [],
                    things_person            = [1, 13],          # rider, person
                    things_traffic_elements  = [5, 6],           # traffic-sign, pole
                    things_objects           = [],               # trashcan, cone_stone 등
                    stuff_pavement           = [],               # ground
                    stuff_natural            = [3, 4],           # trunk, plant
                    stuff_structures         = [8, 10],          # building, fence
                ),
                3: dict(
                    things_vehicle           = [],
                    things_person            = [1, 13],
                    things_traffic_elements  = [5, 6],
                    things_objects           = [7, 9],           # trashcan, cone_stone
                    stuff_pavement           = [],
                    stuff_natural            = [3, 4],
                    stuff_structures         = [8, 10],
                ),
                2: dict(
                    things_vehicle           = [2],              # car
                    things_person            = [1, 11, 13],      # rider, bike, person
                    things_traffic_elements  = [5, 6],           # traffic-sign, pole
                    stuff_pavement           = [12],             # ground
                    stuff_natural            = [3, 4],
                    stuff_structures         = [8, 10],
                ),
                1: dict(  # 기본값
                    things_vehicle           = [2, 11],          # car, bike
                    things_person            = [1, 13],          # rider, person
                    things_traffic_elements  = [5, 6, 7, 9],     # traffic-sign, pole, # trashcan, cone_stone
                    stuff_pavement           = [12],             # ground
                    stuff_natural            = [3, 4],           # trunk, plant
                    stuff_structures         = [8, 10],          # building, fence
                ),
            },

            "SemanticKITTI": {
                "things_stuff_only": dict(
                    things_vehicle           = [1,2,3,4,5,6,7,8,18,19],
                    things_person            = [],
                    things_traffic_elements  = [],
                    stuff_pavement           = [9,10,11,12,13,14,15,16,17],
                    stuff_natural            = [],
                    stuff_structures         = [],
                ),
                "Qwen3-235B-A22B-2507": dict(
                    things_vehicle           = [1, 2, 3, 4, 5],
                    things_person            = [6, 7, 8],
                    things_traffic_elements  = [],
                    stuff_pavement           = [9,10,11,12],
                    stuff_natural            = [15, 16, 17],
                    stuff_structures         = [13, 14, 18, 19],
                ),
                "DeepSeekR1": dict(
                    things_vehicle           = [1, 2, 3, 4, 5, 6, 7, 8], ## dynamic objects 
                    things_person            = [],
                    things_traffic_elements  = [],
                    stuff_pavement           = [9,10,11,12],
                    stuff_natural            = [15, 16, 17],
                    stuff_structures         = [13, 14, 18, 19],
                ),
                3: dict(
                    things_vehicle           = [1, 4, 5],        # car, truck, other-vehicle
                    things_two_wheelers      = [2, 3],           # bicycle, motorcycle
                    things_person            = [6, 7, 8],        # person, bicyclist, motorcyclist
                    things_traffic_elements  = [18, 19],         # pole, traffic-sign
                    stuff_pavement           = [9,10,11,12],  # road, parking, sidewalk..., terrain
                    stuff_natural            = [15,16,17],
                    stuff_structures         = [13, 14],         # building, fence
                ),
                2: dict(
                    things_vehicle           = [1, 4, 5],        # car, truck, other-vehicle
                    things_two_wheelers      = [2, 3],           # bicycle, motorcycle
                    things_person            = [6, 7, 8],        # person, bicyclist, motorcyclist
                    things_traffic_elements  = [18, 19],         # pole, traffic-sign
                    stuff_pavement           = [9,10,11,12,17],  # road, parking, sidewalk..., terrain
                    stuff_natural            = [15, 16],         # vegetation, trunk
                    stuff_structures         = [13, 14],         # building, fence
                ),
                1: dict(  # 기본값, best value
                    things_vehicle           = [1,2,3,4,5],
                    things_person            = [6,7,8],
                    things_traffic_elements  = [18,19],
                    stuff_pavement           = [9,10,11,12],
                    stuff_natural            = [15,16,17],
                    stuff_structures         = [13,14],
                ),
            },
        }
        dataset_type  = getattr(self.cfg.DATASET_TARGET, 'TYPE', 'SemanticKITTI')
        version       = getattr(self.cfg.TGT_LOSS, 'HIERARCHY_VERSION', 1)

        try:
            hier = HIERARCHY_CONFIG[dataset_type][version]
        except KeyError:
            raise ValueError(f"[Hierarchy] 정의되지 않은 조합: "
                            f"{dataset_type} (version={version})")

        for name, cls_list in hier.items():
            setattr(self, f"{name}_classes", cls_list)

        # things / stuff grouping
        self.things_classes = (
            hier.get("things_vehicle", []) +
            hier.get("things_two_wheelers", []) +
            hier.get("things_person", []) +
            hier.get("things_traffic_elements", []) +
            hier.get("things_objects", [])
        )
        self.stuff_classes = (
            hier.get("stuff_pavement", []) +
            hier.get("stuff_natural", []) +
            hier.get("stuff_structures", [])
        )

        print(f"[Hierarchy] {dataset_type}-v{version} loaded:")
        print("  things:", self.things_classes)
        print("  stuff :", self.stuff_classes)

    def reset_discriminator(self):
        self.net_D.load_state_dict(self.net_D_original_state_dict)
        # self.net_D.train()

    def train(self):
        ## for validation before training
        _ = self.src_valer.rolling_predict(self.net_G, self.ema_G, self.c_iter, domain='src', sanity_check=True)
        _ = self.tgt_valer.rolling_predict(self.net_G, self.ema_G, self.c_iter, domain='tgt', sanity_check=True)

        for epoch in range(self.cfg.TRAIN.MAX_EPOCHS):
            self.train_one_epoch()
    
    def train_one_epoch(self):
        src_iter = iter(self.src_TraDL)
        for tgt_BData in self.tgt_train_loader:          
            self.wb_dict = {}
            self.c_iter += 1
            start_t = time.time()
            self.set_lr()
            self.set_zero_grad()

            # send data to GPU
            try:
                src_BData = next(src_iter)
            except StopIteration:
                src_iter = iter(self.src_TraDL)
                src_BData = next(src_iter)
            self.src_BData = self.send_data2GPU(src_BData)
            self.tgt_BData = self.send_data2GPU(tgt_BData)

            # 1. use teacher model to generate pseudo label
            with torch.no_grad():  # old-model generate pseudo-label
                tgt_G_in = ME.SparseTensor(self.tgt_BData['feats_mink'], self.tgt_BData['coords_mink'])
                self.tgt_o_logits = self.ema_G(tgt_G_in, is_train=False)
                if self.use_intensity_learning:
                    if self.cfg.DATASET_SOURCE.use_aug_for_laserMix:
                        # src_labels = self.src_BData["aug_labels_mink"].cuda()
                        src_G_in = ME.SparseTensor(coordinates=self.src_BData["aug_coords_mink"].int(),
                                                    features=self.src_BData["aug_feats_mink"])
                    else:
                        # src_labels = self.src_BData["labels_mink"].cuda()
                        src_G_in = ME.SparseTensor(coordinates=self.src_BData["coords_mink"].int(),
                                                    features=self.src_BData["feats_mink"])
                        # src_G_in = ME.SparseTensor(self.src_BData['feats_mink'], self.src_BData['coords_mink'])
                    ema_output = self.ema_G(src_G_in, is_train=False, intensity_out=True)
                    
                    # Handle different return types from ema_G
                    if isinstance(ema_output, tuple) and len(ema_output) >= 2:
                        # If ema_G returns a tuple with intensity prediction
                        _, self.src_intensity_pseudo = ema_output
                    else:
                        # If ema_G only returns segmentation output (no intensity head)
                        print("Warning: ema_G does not support intensity output, skipping intensity pseudo labeling for source")
                        self.src_intensity_pseudo = None
            # 2. filter pseudo label with confidence 
            target_confidence_th = self.cfg.PSEUDO_LABEL.threshold
            target_pseudo = self.tgt_o_logits.F
            target_pseudo = F.softmax(target_pseudo, dim=-1)
            target_conf, target_pseudo = target_pseudo.max(dim=-1)
            filtered_target_pseudo = torch.zeros_like(target_pseudo)
            valid_idx = target_conf > target_confidence_th
            filtered_target_pseudo[valid_idx] = target_pseudo[valid_idx]
            target_pseudo = filtered_target_pseudo.long()
            self.tgt_BData['pseudo_label'] = target_pseudo
            if self.use_intensity_learning and self.src_intensity_pseudo is not None:
                # Use intensity pseudo label for source data
                src_intensity_pseudo_labels = self.src_intensity_pseudo.F.squeeze().detach()
                if self.cfg.DATASET_SOURCE.use_aug_for_laserMix: 
                    self.src_BData["aug_sp_remis"] = src_intensity_pseudo_labels
                else:
                    self.src_BData["sp_remis"] = src_intensity_pseudo_labels
            # mask data
            self.masked_batch = laserMix(self.cfg, self.src_BData, self.tgt_BData)

            # update G
            src_loss = self.train_source()
            tgt_loss = self.train_target()
            all_loss = src_loss + tgt_loss
            if self.use_entropy_reduction:
                entropy_loss = self.entropy_reduction()
                all_loss = all_loss + entropy_loss
            
            # Add all adversarial losses
            adv_loss = self.train_t2s_adv()
            all_loss = all_loss + adv_loss
            all_loss.backward()
            self.G_optim.step()

            # update D
            d_loss = self.train_net_D()
            d_loss.backward()
            self.D_optim.step()
            
            if self.cfg.MEAN_TEACHER.use_mt and \
                self.cfg.MEAN_TEACHER.alpha_ema > 0 and \
                    self.c_iter % self.cfg.MEAN_TEACHER.update_every == 0:
                self.update_ema_variables(self.ema_G, self.net_G)

            ############################################################################
            iteration_time = time.time() - start_t
            
            self.iter_times.append(iteration_time)
            if len(self.iter_times) > self.max_time_buffer:
                self.iter_times.pop(0)

            avg_iter_time = sum(self.iter_times) / len(self.iter_times)

            total_iter = self.cfg.TRAIN.MAX_ITERS
            iters_left = total_iter - self.c_iter

            eta_seconds = int(avg_iter_time * iters_left)

            eta_str = str(datetime.timedelta(seconds=eta_seconds))
            ############################################################################

            if self.c_iter % self.cfg.TRAIN.LOG_PERIOD == 0:
                print_str = ('iter:{0:6d}, '
                            'seg_Ls:{1:.4f}, '
                            'pse_seg_loss:{2:.4f}, '
                            'itr:{3:.3f}, '
                            'ETA:{4}, '
                            'Exp:{5}').format(self.c_iter,
                                             self.wb_dict['netG/seg_Loss'],
                                             self.wb_dict['netG/pse_seg_loss'],
                                             iteration_time,
                                             eta_str,
                                             self.cfg.TRAIN.EXP_NAME)
                
                if self.use_intensity_learning:
                    print_str += ', intensity_loss:{:.4f}'.format(self.wb_dict['netG/intensity_loss'])
                    if 'netG/t2s_intensity_loss' in self.wb_dict:
                        print_str += ', t2s_intensity_loss:{:.4f}'.format(self.wb_dict['netG/t2s_intensity_loss'])
                    if 'netG/tgt_intensity_loss' in self.wb_dict:
                        print_str += ', tgt_intensity_loss:{:.4f}'.format(self.wb_dict['netG/tgt_intensity_loss'])
                
                print(print_str)
                self.save_log()  # save logs

            if self.c_iter % self.t_val_iter == 0: # Traget domain val.
                self.valid_and_save()

            if self.c_iter % self.s_val_iter == 0: # Source domain val.
                _  = self.src_valer.rolling_predict(self.net_G, self.ema_G, self.c_iter, domain='src')

            if self.c_iter % self.reset_discriminator_iter == 0:
                self.reset_discriminator()

            if self.c_iter % 100 == 0:
                torch.cuda.empty_cache()

            if self.c_iter == self.cfg.TRAIN.MAX_ITERS:
                if self.c_iter % self.t_val_iter != 0:
                    self.valid_and_save()
                print("Finish training, this is max iter: {}".format(self.c_iter))
                quit()

        torch.cuda.empty_cache()

    def train_source(self):# ===========train G ================
        if self.cfg.DATASET_SOURCE.use_aug_for_laserMix:
            src_labels = self.src_BData["aug_labels_mink"].cuda()
            src_G_in = ME.SparseTensor(coordinates=self.src_BData["aug_coords_mink"].int(),
                                          features=self.src_BData["aug_feats_mink"])
        else:
            src_labels = self.src_BData["labels_mink"].cuda()
            src_G_in = ME.SparseTensor(coordinates=self.src_BData["coords_mink"].int(),
                                          features=self.src_BData["feats_mink"])
                
        # Train with Source. compute source seg loss        
        if self.use_intensity_learning:
            self.src_logits, _, self.src_encodedFt_dict, self.src_intensity_pred = self.net_G(src_G_in, feature_out=True, intensity_out=True)
            
            # Use intensity pseudo label for source data
            if hasattr(self, 'src_intensity_pseudo') and self.src_intensity_pseudo is not None:
                src_intensity_pseudo_labels = self.src_intensity_pseudo.F.squeeze().detach()
                src_intensity_loss = self.intensity_criterion(self.src_intensity_pred.F.squeeze(), src_intensity_pseudo_labels.float())
                self.wb_dict['netG/intensity_loss'] = src_intensity_loss.item()
            else:
                print("No intensity pseudo label for source data")
                # Fallback to original ground truth if pseudo label is not available
                src_intensity_gt = self.src_BData["aug_sp_remis"] if self.cfg.DATASET_SOURCE.use_aug_for_laserMix else self.src_BData["sp_remis"]
                src_intensity_loss = self.intensity_criterion(self.src_intensity_pred.F.squeeze(), src_intensity_gt.float())
                self.wb_dict['netG/intensity_loss'] = src_intensity_loss.item()
        else:
            self.src_logits, _, self.src_encodedFt_dict = self.net_G(src_G_in, feature_out=True)

        self.src_encodedFt_dict = {k: v for k, v in self.src_encodedFt_dict.items() if k in self.mean_std_const_layers}
        all_src_loss = 0.
        
        # loss 1. main classifier CE loss
        src_seg_loss = self.criterion(self.src_logits.F, src_labels)
        all_src_loss = all_src_loss + src_seg_loss
        

        if self.use_intensity_learning:
            all_src_loss = all_src_loss + self.intensity_lambda * src_intensity_loss


        self.wb_dict['netG/seg_Loss'] = src_seg_loss.mean()
        self.wb_dict['netG/all_src_loss'] = all_src_loss.mean()
        
        # entropy regularization
        if self.entropy_regularization > 0 : 
            src_out_entropy = F.softmax(self.src_logits.F, dim=1)
            src_out_entropy = -1.0 * src_out_entropy * F.log_softmax(self.src_logits.F, dim=1)
            src_out_entropy = src_out_entropy.sum(dim=1)
            src_out_entropy = src_out_entropy.mean() * self.entropy_regularization
            all_src_loss = all_src_loss - src_out_entropy
            self.wb_dict['netG/src_entropy'] = src_out_entropy

        return all_src_loss

    def compute_classwise_matrix(self, encoded_tensor, coords, labels):
        # First, get the highest resolution layer's class distribution
        highest_res_layer = 'out_p1'  # This has the most classes (19)
        highest_res_feat = encoded_tensor[highest_res_layer]
        highest_res_coords = highest_res_feat.C.to(dtype=torch.float32)  # Ensure float32
        
        # Create labels tensor for highest resolution
        highest_res_labels = torch.zeros(
            (highest_res_feat.F.shape[0], 1),
            dtype=torch.float32,  # Explicitly set to float32
            device=self.device
        )
        
        # Create sparse tensor for original labels with explicit dtype
        orig_labels_st = ME.SparseTensor(
            features=labels.unsqueeze(1).to(dtype=torch.float32),  # Ensure float32
            coordinates=coords.to(dtype=torch.float32),  # Ensure float32
            coordinate_manager=ME.CoordinateManager(D=highest_res_feat.C.shape[1] - 1),
            device=self.device
        )
        
        # Ensure all coordinates are float32
        highest_res_coords = highest_res_coords.to(dtype=torch.float32)
        
        # Interpolate labels to highest resolution coordinates
        interpolated_labels = self.label_interpolator(orig_labels_st, highest_res_coords)
        highest_res_labels[:, 0] = interpolated_labels[:, 0].to(dtype=torch.float32)
        
        # Create sparse tensor for highest resolution labels
        highest_res_labels_st = ME.SparseTensor(
            features=highest_res_labels,
            coordinates=highest_res_coords,
            coordinate_manager=highest_res_feat.coordinate_manager,
            device=self.device
        )
        
        # Get unique classes in highest resolution, excluding unlabeled (0)
        unique_classes = torch.unique(highest_res_labels_st.F[:, 0].long())
        unique_classes = unique_classes[unique_classes != 0]  # Exclude unlabeled class
        num_classes = len(unique_classes)
        
        # Initialize dictionaries for interpolated features
        self.layer_mean_dict = {}
        self.layer_std_dict = {}
        
        # Process each layer
        for layer_key, feat_tensor in encoded_tensor.items():
            # Ensure feature tensor coordinates are float32
            feat_coords = feat_tensor.C.to(dtype=torch.float32)
            
            # Create sparse tensor with explicit dtype
            feat_st = ME.SparseTensor(
                features=feat_tensor.F.to(dtype=torch.float32),  # Ensure float32
                coordinates=feat_coords,
                coordinate_manager=feat_tensor.coordinate_manager,
                device=self.device
            )
            
            # Interpolate features to highest resolution
            interpolated_feats = self.label_interpolator(feat_st, highest_res_coords)
            
            # Compute statistics using interpolated features and highest resolution labels
            mean, std, counts = self.compute_classwise_stats_interpolated(
                interpolated_feats, 
                highest_res_labels_st.F[:, 0].long(),
                unique_classes
            )
            
            valid = counts > 0
            if not valid.any():
                continue
                
            # Get valid class statistics
            mean_ = mean[valid]  # [K, D]
            std_ = std[valid]    # [K, D]
            
            self.layer_mean_dict[layer_key] = mean_
            self.layer_std_dict[layer_key] = std_
        
        # If we have statistics from all layers, concatenate them
        if len(self.layer_mean_dict) == len(self.mean_std_const_layers):
            try:
                mean_cat = torch.cat([self.layer_mean_dict[layer_key] 
                                    for layer_key in self.mean_std_const_layers], dim=1)
                std_cat = torch.cat([self.layer_std_dict[layer_key] 
                                   for layer_key in self.mean_std_const_layers], dim=1)
            except KeyError as e:
                raise KeyError(f"Error concatenating stats: {e}, Available layers: {list(self.layer_mean_dict.keys())}")
            # Clear the dictionaries for next iteration
            self.layer_mean_dict.clear()
            self.layer_std_dict.clear()
            return mean_cat, std_cat
        return None, None

    def compute_classwise_stats_interpolated(self, interpolated_feats, labels, unique_classes):
        """Return per-class mean, std, count for interpolated features, excluding unlabeled class."""
        feats = interpolated_feats  # [N, C]
        feat_dim = feats.shape[1]
        num_classes = len(unique_classes)
        
        means = torch.zeros(num_classes, feat_dim, device=feats.device)
        stds = torch.zeros_like(means)
        counts = torch.zeros(num_classes, device=feats.device, dtype=torch.long)
        
        # Create a mapping from original class indices to new indices (excluding unlabeled)
        class_mapping = {c.item(): i for i, c in enumerate(unique_classes)}
        
        for c in unique_classes:
            mask = labels == c
            if mask.any():
                sel = feats[mask]
                idx = class_mapping[c.item()]
                counts[idx] = sel.size(0)
                means[idx] = sel.mean(0)
                stds[idx] = sel.std(0, unbiased=False)
        
        return means, stds, counts

    def compute_hierarchical_matrix(self, encoded_tensor, coords, labels):
        # First, get the highest resolution layer's class distribution
        highest_res_layer = 'out_p1'  # This has the most classes (19)
        highest_res_feat = encoded_tensor[highest_res_layer]
        highest_res_coords = highest_res_feat.C.to(dtype=torch.float32)  # Ensure float32
        
        # Create labels tensor for highest resolution
        highest_res_labels = torch.zeros(
            (highest_res_feat.F.shape[0], 1),
            dtype=torch.float32,  # Explicitly set to float32
            device=self.device
        )
        
        # Create sparse tensor for original labels with explicit dtype
        orig_labels_st = ME.SparseTensor(
            features=labels.unsqueeze(1).to(dtype=torch.float32),  # Ensure float32
            coordinates=coords.to(dtype=torch.float32),  # Ensure float32
            coordinate_manager=ME.CoordinateManager(D=highest_res_feat.C.shape[1] - 1),
            device=self.device
        )
        
        # Ensure all coordinates are float32
        highest_res_coords = highest_res_coords.to(dtype=torch.float32)
        
        # Interpolate labels to highest resolution coordinates
        interpolated_labels = self.label_interpolator(orig_labels_st, highest_res_coords)
        highest_res_labels[:, 0] = interpolated_labels[:, 0].to(dtype=torch.float32)
        
        # Create sparse tensor for highest resolution labels
        highest_res_labels_st = ME.SparseTensor(
            features=highest_res_labels,
            coordinates=highest_res_coords,
            coordinate_manager=highest_res_feat.coordinate_manager,
            device=self.device
        )
        
        # Initialize dictionaries for interpolated features
        self.hier_layer_mean_dict = {}
        self.hier_layer_std_dict = {}
        
        # Process each layer
        for layer_key, feat_tensor in encoded_tensor.items():
            # Ensure feature tensor coordinates are float32
            feat_coords = feat_tensor.C.to(dtype=torch.float32)
            
            # Create sparse tensor with explicit dtype
            feat_st = ME.SparseTensor(
                features=feat_tensor.F.to(dtype=torch.float32),  # Ensure float32
                coordinates=feat_coords,
                coordinate_manager=feat_tensor.coordinate_manager,
                device=self.device
            )
            
            # Interpolate features to highest resolution
            interpolated_feats = self.label_interpolator(feat_st, highest_res_coords)
            
            # Compute hierarchical statistics using interpolated features and highest resolution labels
            mean, std, counts = self.compute_hierarchical_stats_interpolated(
                interpolated_feats, 
                highest_res_labels_st.F[:, 0].long()
            )
            
            valid = counts > 0
            if not valid.any():
                continue
            
            # Get valid hierarchical group statistics
            mean_ = mean[valid]  # [K, D]
            std_ = std[valid]    # [K, D]
            
            self.hier_layer_mean_dict[layer_key] = mean_
            self.hier_layer_std_dict[layer_key] = std_
        
        # If we have statistics from all layers, concatenate them
        if len(self.hier_layer_mean_dict) == len(self.mean_std_const_layers):
            try:
                mean_cat = torch.cat([self.hier_layer_mean_dict[layer_key] 
                                    for layer_key in self.mean_std_const_layers], dim=1)
                std_cat = torch.cat([self.hier_layer_std_dict[layer_key] 
                                   for layer_key in self.mean_std_const_layers], dim=1)
            except KeyError as e:
                raise KeyError(f"Error concatenating hierarchical stats: {e}, Available layers: {list(self.hier_layer_mean_dict.keys())}")
            # Clear the dictionaries for next iteration
            self.hier_layer_mean_dict.clear()
            self.hier_layer_std_dict.clear()
            return mean_cat, std_cat
        return None, None

    def compute_hierarchical_stats_interpolated(self, interpolated_feats, labels):
        """Return hierarchical mean, std, count for interpolated features."""
        feats = interpolated_feats  # [N, C]
        feat_dim = feats.shape[1]
        
        # Define hierarchical groups
        # ─────────────────────────────────────────────
        # NOTE: We deliberately **exclude** the two
        # root-level groups ('things', 'stuff') so that
        # they never participate in adversarial learning.
        # ─────────────────────────────────────────────
        hierarchical_groups = {
            'vehicle'          : self.things_vehicle_classes,
            'person'           : self.things_person_classes,
            'traffic_elements' : self.things_traffic_elements_classes,
            'pavement'         : self.stuff_pavement_classes,
            'natural'          : self.stuff_natural_classes,
            'structures'       : self.stuff_structures_classes
        }
        if not getattr(self.cfg.TGT_LOSS, 'DROP_ROOT_GROUPS', False):
            hierarchical_groups['things'] = self.things_classes
            hierarchical_groups['stuff'] = self.stuff_classes
        
        num_groups = len(hierarchical_groups)
        means = torch.zeros(num_groups, feat_dim, device=feats.device)
        stds = torch.zeros_like(means)
        counts = torch.zeros(num_groups, device=feats.device, dtype=torch.long)
        
        for i, (group_name, class_list) in enumerate(hierarchical_groups.items()):
            # Create mask for all classes in this hierarchical group
            group_mask = torch.zeros_like(labels, dtype=torch.bool)
            for class_id in class_list:
                group_mask |= (labels == class_id)
            
            if group_mask.any():
                sel = feats[group_mask]
                counts[i] = sel.size(0)
                means[i] = sel.mean(0)
                stds[i] = sel.std(0, unbiased=False)
        
        return means, stds, counts

    def train_target(self):
        """
        original_coords = quantized_coords.float() * voxel_size
        """
        all_tgt_loss = 0.
        t2s_stensor = ME.SparseTensor(coordinates=self.masked_batch["masked_source_pts"].int(),
                                      features=self.masked_batch["masked_source_features"])
        # self.t2s_out = self.net_G(t2s_stensor)
        # self.t2s_out, _, self.t2s_encodedFt_dict = self.net_G(t2s_stensor, feature_out=True)
        if self.use_intensity_learning:
            self.t2s_out, _, self.t2s_encodedFt_dict, self.t2s_intensity_pred = self.net_G(t2s_stensor, feature_out=True, intensity_out=True)
            
            # Use real intensity values for t2s data (no pseudo labeling)
            if "masked_source_intensities" in self.masked_batch:
                # Use real intensity values from the masked batch
                t2s_intensity_gt = self.masked_batch["masked_source_intensities"].float()
                t2s_intensity_loss = self.intensity_criterion(self.t2s_intensity_pred.F.squeeze(), t2s_intensity_gt)
                all_tgt_loss = all_tgt_loss + self.intensity_lambda * t2s_intensity_loss
                self.wb_dict['netG/t2s_intensity_loss'] = t2s_intensity_loss.item()
        else:
            self.t2s_out, _, self.t2s_encodedFt_dict = self.net_G(t2s_stensor, feature_out=True)

        self.t2s_encodedFt_dict = {k: v for k, v in self.t2s_encodedFt_dict.items() if k in self.mean_std_const_layers}
        t2s_labels = self.masked_batch["masked_source_labels"].cuda()
        t2s_loss = self.criterion(self.t2s_out.F, t2s_labels.long())
        all_tgt_loss = all_tgt_loss + t2s_loss
        self.wb_dict['netG/pse_seg_loss'] = t2s_loss.mean()

        # Train pure target domain with intensity learning
        if self.use_intensity_learning:
            tgt_G_in = ME.SparseTensor(self.tgt_BData['feats_mink'], self.tgt_BData['coords_mink'])
            self.tgt_out, _, self.tgt_encodedFt_dict, self.tgt_intensity_pred = self.net_G(tgt_G_in, feature_out=True, intensity_out=True)
            
            # Use real intensity values for target data
            if "sp_remis" in self.tgt_BData:
                tgt_intensity_gt = self.tgt_BData["sp_remis"].float()
                tgt_intensity_loss = self.intensity_criterion(self.tgt_intensity_pred.F.squeeze(), tgt_intensity_gt)
                all_tgt_loss = all_tgt_loss + self.intensity_lambda * tgt_intensity_loss
                self.wb_dict['netG/tgt_intensity_loss'] = tgt_intensity_loss.item()

        # kl loss
        if self.cfg.TGT_LOSS.lambda_sac > 0:
            with torch.no_grad():  # old-model generate pseudo-label
                tea_tgt_G_in = ME.SparseTensor(self.tgt_BData['feats_mink'], self.tgt_BData['coords_mink'])
                tea_raw_tgt_logit = self.ema_G(tea_tgt_G_in)
                del_tgt_out = tea_raw_tgt_logit.F[self.tgt_BData['aug_del_mask']] ## drop nothing for no DGT
                raw2aug_tgt_out = del_tgt_out[self.tgt_BData['aug_unique_map']] # DGT 처리한 branch랑 순서 맞추기
            tgt_G_in = ME.SparseTensor(features=self.tgt_BData['aug_feats_mink'], coordinates=self.tgt_BData['aug_coords_mink'].int())
            stu_aug_tgt_logit = self.net_G(tgt_G_in)
            
            sac_loss = F.kl_div(F.log_softmax(stu_aug_tgt_logit.F, dim=1),
                                F.softmax(raw2aug_tgt_out.detach(), dim=1))
            sac_loss = sac_loss * self.cfg.TGT_LOSS.lambda_sac 
            all_tgt_loss = all_tgt_loss + sac_loss
            self.wb_dict['netG/sac_loss'] = sac_loss.mean()

        return all_tgt_loss

    def entropy_reduction(self):
        t2s_out_entropy = F.softmax(self.t2s_out.F, dim=1)
        t2s_out_entropy = -1.0 * t2s_out_entropy * F.log_softmax(self.t2s_out.F, dim=1)
        t2s_out_entropy = t2s_out_entropy.sum(dim=1)
        t2s_out_entropy = t2s_out_entropy.mean()
        return t2s_out_entropy

    def compute_feature_stats_vectorized(self, encoded_ft_dict, voxel_size, max_range, bin_size, num_bins, device):
        computed_mean_mat_dict = {}
        computed_std_mat_dict = {}

        for key, tensor in encoded_ft_dict.items():
            coords = tensor.C
            feats = tensor.F
            # spatial coords
            spatial_coords = coords[:, 1:].float() * voxel_size
            distances = torch.norm(spatial_coords, dim=1)
            valid_mask = distances < max_range
            if valid_mask.sum() == 0:
                continue

            valid_distances = distances[valid_mask]
            valid_feats = feats[valid_mask]
            # nan to 0 inf to 
            valid_feats = torch.nan_to_num(valid_feats, nan=0.0, posinf=1e6, neginf=-1e6)
            feature_dim = valid_feats.shape[1]

            bin_indices = (valid_distances / bin_size).floor().long()
            bin_indices = torch.clamp(bin_indices, max=num_bins-1)

            sums = torch.zeros((num_bins, feature_dim), device=device)
            counts = torch.zeros((num_bins, 1), device=device)
            sums = sums.index_add(0, bin_indices, valid_feats)
            counts = counts.index_add(0, bin_indices, torch.ones(valid_feats.size(0), 1, device=device))
            means = sums / counts.clamp_min(1.0)

            means_per_elem = means[bin_indices]  # [N_valid, feature_dim]
            diff = valid_feats - means_per_elem
            sq_diff = diff ** 2

            sq_diff_sum = torch.zeros((num_bins, feature_dim), device=device)
            sq_diff_sum = sq_diff_sum.index_add(0, bin_indices, sq_diff)
            variances = sq_diff_sum / counts.clamp_min(1.0)
            stds = torch.sqrt(torch.clamp(variances, min=0.0) + 1e-6)

            zero_mask = (counts.squeeze(1) == 0).unsqueeze(1)
            means = torch.where(zero_mask, torch.zeros_like(means), means)
            stds = torch.where(zero_mask, torch.zeros_like(stds), stds)

            computed_mean_mat_dict[key] = means
            computed_std_mat_dict[key] = stds

        mean_cat = torch.cat([computed_mean_mat_dict[k] for k in computed_mean_mat_dict.keys()], dim=1)
        std_cat = torch.cat([computed_std_mat_dict[k] for k in computed_std_mat_dict.keys()], dim=1)
        return mean_cat, std_cat

        
    def train_t2s_adv(self):
        device = torch.device('cuda')
        total_loss = torch.tensor([0.], device=device)
        
        # 1. Original distance-based adversarial loss
        # tgt feature 
        tgt_G_in = ME.SparseTensor(self.tgt_BData['feats_mink'], self.tgt_BData['coords_mink'])
        _, _, tgt_encodedFt_dict = self.net_G(tgt_G_in, feature_out=True)
        tgt_encodedFt_dict = {k: v for k, v in tgt_encodedFt_dict.items() if k in self.mean_std_const_layers}

        self.tgt_computed_mean_mat_cat, self.tgt_computed_std_mat_cat = self.compute_feature_stats_vectorized(
            tgt_encodedFt_dict, self.voxel_size, self.max_range, self.bin_size, self.num_bins, device)
        
        # t2s feature 
        self.t2s_computed_mean_mat_cat, self.t2s_computed_std_mat_cat = self.compute_feature_stats_vectorized(
            self.t2s_encodedFt_dict, self.voxel_size, self.max_range, self.bin_size, self.num_bins, device)
        
        # src feature 
        self.src_computed_mean_mat_cat, self.src_computed_std_mat_cat = self.compute_feature_stats_vectorized(
            self.src_encodedFt_dict, self.voxel_size, self.max_range, self.bin_size, self.num_bins, device)
        
        D_logit_out = self.net_D(self.tgt_computed_mean_mat_cat)
        adv_lab = torch.zeros_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        
        D_logit_out2 = self.net_D(self.tgt_computed_std_mat_cat)
        adv_lab2 = torch.zeros_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        
        adv_loss_ = (adv_loss + adv_loss2) * self.cfg.TGT_LOSS.LAMBDA_ADV
        self.wb_dict['netG/tgt_adv_Loss'] = adv_loss_.item()
        total_loss = total_loss + adv_loss_
        
        ## T2S adversarial loss 
        D_logit_out = self.net_D(self.t2s_computed_mean_mat_cat)
        adv_lab = torch.zeros_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        
        D_logit_out2 = self.net_D(self.t2s_computed_std_mat_cat)
        adv_lab2 = torch.zeros_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        
        adv_loss_ = (adv_loss + adv_loss2) * self.cfg.TGT_LOSS.LAMBDA_ADV
        self.wb_dict['netG/t2s_adv_Loss'] = adv_loss_.item()
        total_loss = total_loss + adv_loss_

        # 2. Classwise adversarial loss
        # Get classwise statistics for source, target, and t2s
        src_mean_cat, src_std_cat = self.compute_classwise_matrix(
            self.src_encodedFt_dict, 
            self.src_BData["coords_mink"].int() if not self.cfg.DATASET_SOURCE.use_aug_for_laserMix else self.src_BData["aug_coords_mink"].int(),
            self.src_BData["labels_mink"].cuda() if not self.cfg.DATASET_SOURCE.use_aug_for_laserMix else self.src_BData["aug_labels_mink"].cuda()
        )
        
        tgt_mean_cat, tgt_std_cat = self.compute_classwise_matrix(
            tgt_encodedFt_dict,
            self.tgt_BData['coords_mink'].int(),
            self.tgt_BData['pseudo_label']
        )
        
        t2s_mean_cat, t2s_std_cat = self.compute_classwise_matrix(
            self.t2s_encodedFt_dict,
            self.masked_batch["masked_source_pts"].int(),
            self.masked_batch["masked_source_labels"].cuda()
        )
        
        # Store for discriminator training
        self.src_classwise_mean = src_mean_cat
        self.src_classwise_std = src_std_cat
        self.tgt_classwise_mean = tgt_mean_cat
        self.tgt_classwise_std = tgt_std_cat
        self.t2s_classwise_mean = t2s_mean_cat
        self.t2s_classwise_std = t2s_std_cat
        
        # Target classwise adversarial loss
        D_logit_out = self.net_D(tgt_mean_cat)
        adv_lab = torch.zeros_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        
        D_logit_out2 = self.net_D(tgt_std_cat)
        adv_lab2 = torch.zeros_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        
        adv_loss_ = (adv_loss + adv_loss2) * self.lambda_adv_cls
        self.wb_dict['netG/tgt_classwise_adv_Loss'] = adv_loss_.item()
        total_loss = total_loss + adv_loss_
        
        # T2S classwise adversarial loss
        D_logit_out = self.net_D(t2s_mean_cat)
        adv_lab = torch.zeros_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        
        D_logit_out2 = self.net_D(t2s_std_cat)
        adv_lab2 = torch.zeros_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        
        adv_loss_ = (adv_loss + adv_loss2) * self.lambda_adv_cls
        self.wb_dict['netG/t2s_classwise_adv_Loss'] = adv_loss_.item()
        total_loss = total_loss + adv_loss_

        if self.use_hierarchical_adv:
            # 3. Hierarchical adversarial loss
            # Get hierarchical statistics for source, target, and t2s
            src_mean_cat, src_std_cat = self.compute_hierarchical_matrix(
                self.src_encodedFt_dict,
                self.src_BData["coords_mink"].int() if not self.cfg.DATASET_SOURCE.use_aug_for_laserMix else self.src_BData["aug_coords_mink"].int(),
                self.src_BData["labels_mink"].cuda() if not self.cfg.DATASET_SOURCE.use_aug_for_laserMix else self.src_BData["aug_labels_mink"].cuda()
            )
            
            tgt_mean_cat, tgt_std_cat = self.compute_hierarchical_matrix(
                tgt_encodedFt_dict,
                self.tgt_BData['coords_mink'].int(),
                self.tgt_BData['pseudo_label']
            )
            
            t2s_mean_cat, t2s_std_cat = self.compute_hierarchical_matrix(
                self.t2s_encodedFt_dict,
                self.masked_batch["masked_source_pts"].int(),
                self.masked_batch["masked_source_labels"].cuda()
            )
            
            # Store for discriminator training
            self.src_hier_mean = src_mean_cat
            self.src_hier_std = src_std_cat
            self.tgt_hier_mean = tgt_mean_cat
            self.tgt_hier_std = tgt_std_cat
            self.t2s_hier_mean = t2s_mean_cat
            self.t2s_hier_std = t2s_std_cat
            
            # Target hierarchical adversarial loss
            D_logit_out = self.net_D(tgt_mean_cat)
            adv_lab = torch.zeros_like(D_logit_out)
            adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
            
            D_logit_out2 = self.net_D(tgt_std_cat)
            adv_lab2 = torch.zeros_like(D_logit_out2)
            adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
            
            adv_loss_ = (adv_loss + adv_loss2) * self.lambda_adv_cls
            self.wb_dict['netG/tgt_hier_adv_Loss'] = adv_loss_.item()
            total_loss = total_loss + adv_loss_
            
            # T2S hierarchical adversarial loss
            D_logit_out = self.net_D(t2s_mean_cat)
            adv_lab = torch.zeros_like(D_logit_out)
            adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
            
            D_logit_out2 = self.net_D(t2s_std_cat)
            adv_lab2 = torch.zeros_like(D_logit_out2)
            adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
            
            adv_loss_ = (adv_loss + adv_loss2) * self.lambda_adv_cls
            self.wb_dict['netG/t2s_hier_adv_Loss'] = adv_loss_.item()
            total_loss = total_loss + adv_loss_
        
        return total_loss

    def train_net_D(self):     # ===========train D================
        for param in self.net_D.parameters():  # Bring back Grads in D
            param.requires_grad = True
        self.D_optim.zero_grad()
        
        total_loss = torch.tensor([0.], device=self.device)
        
        # 1. Original distance-based adversarial loss
        ## Train with Source
        D_logit_out = self.net_D(self.src_computed_mean_mat_cat.detach())
        adv_lab = torch.zeros_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        D_logit_out2 = self.net_D(self.src_computed_std_mat_cat.detach())
        adv_lab2 = torch.zeros_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        adv_loss_src = (adv_loss + adv_loss2) * self.cfg.TGT_LOSS.LAMBDA_ADV

        ## Train with Target2Source
        D_logit_out = self.net_D(self.t2s_computed_mean_mat_cat.detach())
        adv_lab = torch.ones_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        D_logit_out2 = self.net_D(self.t2s_computed_std_mat_cat.detach())
        adv_lab2 = torch.ones_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        adv_loss_t2s = (adv_loss + adv_loss2) * self.cfg.TGT_LOSS.LAMBDA_ADV

        ## Train with Target
        D_logit_out = self.net_D(self.tgt_computed_mean_mat_cat.detach())
        adv_lab = torch.ones_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        D_logit_out2 = self.net_D(self.tgt_computed_std_mat_cat.detach())
        adv_lab2 = torch.ones_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        adv_loss_tgt = (adv_loss + adv_loss2) * self.cfg.TGT_LOSS.LAMBDA_ADV

        total_loss = total_loss + adv_loss_src + adv_loss_t2s + adv_loss_tgt
        self.wb_dict['netD/adv_loss_src'] = adv_loss_src.item()
        self.wb_dict['netD/adv_loss_t2s'] = adv_loss_t2s.item()
        self.wb_dict['netD/adv_loss_tgt'] = adv_loss_tgt.item()

        # 2. Classwise adversarial loss
        ## Train with Source
        D_logit_out = self.net_D(self.src_classwise_mean.detach())
        adv_lab = torch.zeros_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        D_logit_out2 = self.net_D(self.src_classwise_std.detach())
        adv_lab2 = torch.zeros_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        adv_loss_src = (adv_loss + adv_loss2) * self.lambda_adv_cls

        ## Train with Target2Source
        D_logit_out = self.net_D(self.t2s_classwise_mean.detach())
        adv_lab = torch.ones_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        D_logit_out2 = self.net_D(self.t2s_classwise_std.detach())
        adv_lab2 = torch.ones_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        adv_loss_t2s = (adv_loss + adv_loss2) * self.lambda_adv_cls

        ## Train with Target
        D_logit_out = self.net_D(self.tgt_classwise_mean.detach())
        adv_lab = torch.ones_like(D_logit_out)
        adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
        D_logit_out2 = self.net_D(self.tgt_classwise_std.detach())
        adv_lab2 = torch.ones_like(D_logit_out2)
        adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
        adv_loss_tgt = (adv_loss + adv_loss2) * self.lambda_adv_cls

        total_loss = total_loss + adv_loss_src + adv_loss_t2s + adv_loss_tgt
        self.wb_dict['netD/classwise_adv_loss_src'] = adv_loss_src.item()
        self.wb_dict['netD/classwise_adv_loss_t2s'] = adv_loss_t2s.item()
        self.wb_dict['netD/classwise_adv_loss_tgt'] = adv_loss_tgt.item()

        if self.use_hierarchical_adv:
            # 3. Hierarchical adversarial loss
            ## Train with Source
            D_logit_out = self.net_D(self.src_hier_mean.detach())
            adv_lab = torch.zeros_like(D_logit_out)
            adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
            D_logit_out2 = self.net_D(self.src_hier_std.detach())
            adv_lab2 = torch.zeros_like(D_logit_out2)
            adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
            adv_loss_src = (adv_loss + adv_loss2) * self.lambda_adv_cls

            ## Train with Target2Source
            D_logit_out = self.net_D(self.t2s_hier_mean.detach())
            adv_lab = torch.ones_like(D_logit_out)
            adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
            D_logit_out2 = self.net_D(self.t2s_hier_std.detach())
            adv_lab2 = torch.ones_like(D_logit_out2)
            adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
            adv_loss_t2s = (adv_loss + adv_loss2) * self.lambda_adv_cls

            ## Train with Target
            D_logit_out = self.net_D(self.tgt_hier_mean.detach())
            adv_lab = torch.ones_like(D_logit_out)
            adv_loss = self.criterionGAN(D_logit_out, adv_lab).mean()
            D_logit_out2 = self.net_D(self.tgt_hier_std.detach())
            adv_lab2 = torch.ones_like(D_logit_out2)
            adv_loss2 = self.criterionGAN(D_logit_out2, adv_lab2).mean()
            adv_loss_tgt = (adv_loss + adv_loss2) * self.lambda_adv_cls

            total_loss = total_loss + adv_loss_src + adv_loss_t2s + adv_loss_tgt
            self.wb_dict['netD/hier_adv_loss_src'] = adv_loss_src.item()
            self.wb_dict['netD/hier_adv_loss_t2s'] = adv_loss_t2s.item()
            self.wb_dict['netD/hier_adv_loss_tgt'] = adv_loss_tgt.item()

        return total_loss

    def valid_and_save(self):
        cp_fn = os.path.join(self.cfg.TRAIN.MODEL_DIR, 'cp_current.tar')
        self.fast_save_CP(cp_fn)

        if self.cfg.TGT_LOSS.CAL_out:
            proto_path = os.path.join(self.cfg.TRAIN.MODEL_DIR, 'cp_out_iter_{}.tar'.format(self.c_iter))
            self.out_class_center.save(proto_path)

        # If you want save model checkpoint, set cfg.TRAIN.SAVE_MORE_ITER = True
        if self.c_iter > self.cfg.TRAIN.SAVE_ITER and self.cfg.TRAIN.SAVE_MORE_ITER:
            cp_fn = os.path.join(self.cfg.TRAIN.MODEL_DIR, 'cp_{}_iter.tar'.format(self.c_iter))
            self.fast_save_CP(cp_fn)

        tgt_sp_iou = self.tgt_valer.rolling_predict(self.net_G, self.ema_G, self.c_iter, domain='tgt')

        if (tgt_sp_iou > self.best_IoU_after_saveIter and self.c_iter > self.cfg.TRAIN.SAVE_ITER) or \
                tgt_sp_iou > self.ml_info['bt_tgt_spIoU']:
            s_name = 'target_Sp'

            if (tgt_sp_iou > self.best_IoU_after_saveIter and self.c_iter > self.cfg.TRAIN.SAVE_ITER):
                self.best_IoU_after_saveIter = tgt_sp_iou
                s_name = 'target_Sp_After'

            self.best_IoU_iter = self.c_iter
            self.ml_info['bt_tgt_spIoU'] = tgt_sp_iou
            wandb.run.summary["bt_tgt_spIoU"] = tgt_sp_iou

            com.save_best_check(self.net_G, None, 
                                self.G_optim, None, None,
                                self.c_iter, self.logger,
                                self.cfg.TRAIN.MODEL_DIR, name=s_name,
                                iou=tgt_sp_iou)

        torch.cuda.empty_cache()

    def save_log(self):
        self.wb_dict['lr/lr_G'] = self.G_optim.state_dict()['param_groups'][0]['lr']

        for k, v in self.wb_dict.items():
            self.tf_writer.add_scalar(k, v, self.c_iter)
            wandb.log({k: v}, step=self.c_iter)

    def set_zero_grad(self):
        self.net_G.train()  # set model to training mode
        self.net_D.train()
       
        self.G_optim.zero_grad()

        self.ema_G.eval()
       
        for param in self.net_D.parameters():
            param.requires_grad = False
     
    def set_lr(self):
        current_lr_G = adjust_learning_rate(self.cfg.OPTIMIZER.LEARNING_RATE_G,
                                            self.c_iter, self.cfg.TRAIN.MAX_ITERS,
                                            self.cfg.TRAIN.PREHEAT_STEPS)
        current_lr_D = adjust_learning_rate_D(self.cfg.OPTIMIZER.LEARNING_RATE_D,
                                            self.c_iter, self.cfg.TRAIN.MAX_ITERS,
                                            self.cfg.TRAIN.PREHEAT_STEPS)
        for index in range(len(self.G_optim.param_groups)):
            self.G_optim.param_groups[index]['lr'] = current_lr_G
        for index in range(len(self.D_optim.param_groups)):
            self.D_optim.param_groups[index]['lr'] = current_lr_D
     
    def update_ema_variables(self, ema_net, net):
        alpha_teacher = min(1 - 1 / (self.c_iter + 1), self.cfg.MEAN_TEACHER.alpha_ema)
        self.cur_alpha_teacher = alpha_teacher
        for ema_param, param in zip(ema_net.parameters(), net.parameters()):
            ema_param.data.mul_(alpha_teacher).add_(param.data, alpha=1 - alpha_teacher)
        for t, s in zip(ema_net.buffers(), net.buffers()):
            if not t.dtype == torch.int64:
                t.data.mul_(alpha_teacher).add_(s.data, alpha=1 - alpha_teacher)

    def create_ema_model(self, ema, net):
        print('create_ema_model G to current iter {}'.format(self.c_iter))
        for param_q, param_k in zip(net.parameters(), ema.parameters()):
            param_k.data = param_q.data.clone()
        for buffer_q, buffer_k in zip(net.buffers(), ema.buffers()):
            buffer_k.data = buffer_q.data.clone()
        ema.eval()
        for param in ema.parameters():
            param.requires_grad_(False)
        for param in ema.parameters():
            param.detach_()

    @staticmethod
    def send_data2GPU(batch_data):
        for key in batch_data:  # send data to gpu
            batch_data[key] = batch_data[key].cuda(non_blocking=True)
        return batch_data

    def fast_save_CP(self, checkpoint_file):
        com.save_checkpoint(checkpoint_file,
                            self.net_G, None, 
                            self.G_optim, None, 
                            None,
                            self.c_iter)
    
    def init_dataloader(self):
        # init source dataloader
        if self.cfg.DATASET_SOURCE.TYPE == "SynLiDAR":
            from dataset.SynLiDAR_trainSet import SynLiDAR_Dataset
            src_tra_dset = SynLiDAR_Dataset(self.cfg, 'training')
            src_val_dset = SynLiDAR_Dataset(self.cfg, 'validation')
        elif self.cfg.DATASET_SOURCE.TYPE == "SynLiDAR_sample1":
            from dataset.SynLiDAR_trainSet_sample1 import SynLiDAR_Dataset_Sample1
            src_tra_dset = SynLiDAR_Dataset_Sample1(self.cfg, 'training')
            src_val_dset = SynLiDAR_Dataset_Sample1(self.cfg, 'validation')
        elif self.cfg.DATASET_SOURCE.TYPE == "SynLiDAR_sample2":
            from dataset.SynLiDAR_trainSet_sample2 import SynLiDAR_Dataset_Sample2
            src_tra_dset = SynLiDAR_Dataset_Sample2(self.cfg, 'training')
            src_val_dset = SynLiDAR_Dataset_Sample2(self.cfg, 'validation')
        elif self.cfg.DATASET_SOURCE.TYPE == "SynLiDAR_sample3":
            from dataset.SynLiDAR_trainSet_sample3 import SynLiDAR_Dataset_Sample3
            src_tra_dset = SynLiDAR_Dataset_Sample3(self.cfg, 'training')
            src_val_dset = SynLiDAR_Dataset_Sample3(self.cfg, 'validation')
        
        # self.src_TraDL, self.src_ValDL = get_TV_dl(self.cfg, src_tra_dset, src_val_dset)
        self.src_TraDL, self.src_ValDL, self.src_sampler = get_TV_dl(self.cfg, src_tra_dset, src_val_dset)
        
        if self.cfg.DATASET_TARGET.TYPE == "SemanticKITTI":
            from dataset.semkitti_trainSet import SemanticKITTI
            t_tra_dset = SemanticKITTI(self.cfg, 'training')
            t_val_dset = SemanticKITTI(self.cfg, 'validation')
        elif self.cfg.DATASET_TARGET.TYPE == "SemanticPOSS":
            from dataset.SemanticPoss_trainSet import semPoss_Dataset
            t_tra_dset = semPoss_Dataset(self.cfg, 'training')
            t_val_dset = semPoss_Dataset(self.cfg, 'validation')  
     
        # self.tgt_train_loader, _ = get_TV_dl(self.cfg, t_tra_dset, t_val_dset, domain='target')
        self.tgt_train_loader, _, self.tgt_sampler = get_TV_dl(self.cfg, t_tra_dset, t_val_dset, domain='target')
        
        # init validater
        self.src_valer = validater(self.cfg, self.cfg.DATASET_SOURCE.TYPE, 'source', self.criterion, self.tf_writer, self.logger)
        self.tgt_valer = validater(self.cfg, self.cfg.DATASET_TARGET.TYPE, 'target', self.criterion, self.tf_writer, self.logger)
        
        