"""
SPG: mitigate class imbalance in point cloud semantic segmentation through separate subspace prototypes.
"""
import __init__
# NOTE: Toggle auxiliary Prior branch (True→enable; False→disable)
# Change here to switch between full/reduced modes.
# ---------------------------
USE_PRIOR = True  # ← set to False to train main branch only

import argparse, yaml, os, copy, logging, time, numpy as np
 
from tqdm import tqdm
import torch, torch.nn as nn
from torch import distributed as dist, multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from torch_scatter import scatter
from torch.cuda.amp import autocast, GradScaler  # AMP support
from openpoints.utils import set_random_seed, save_checkpoint, load_checkpoint, resume_checkpoint, \
    setup_logger_dist, cal_model_parm_nums, generate_exp_directory, resume_exp_directory, EasyConfig, dist_utils, find_free_port
from openpoints.utils import AverageMeter, ConfusionMatrix, get_mious
from openpoints.dataset import build_dataloader_from_cfg, get_scene_seg_features, get_class_weights
from openpoints.dataset.data_util import voxelize
from openpoints.transforms import build_transforms_from_cfg
from openpoints.optim import build_optimizer_from_cfg
from openpoints.scheduler import build_scheduler_from_cfg
from openpoints.loss import build_criterion_from_cfg
from openpoints.models import build_model_from_cfg
from openpoints.models.layers import furthest_point_sample
from main.test_utils import write_to_csv
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from utils.consistency_loss import ConsistencyLoss




# --- Utility: Null GradScaler when AMP is disabled ---
class NullScaler:
    def scale(self, loss):
        return loss

    def step(self, optimizer):
        optimizer.step()

    def update(self):
        return None

    def unscale_(self, optimizer):
        return None


def main(gpu, cfg):
    if cfg.distributed:
        if cfg.mp:
            cfg.rank = gpu
        dist.init_process_group(backend=cfg.dist_backend,
                                init_method=cfg.dist_url,
                                world_size=cfg.world_size,
                                rank=cfg.rank)
    # logger
    if not cfg.debug:
        setup_logger_dist(cfg.log_path, cfg.rank, name=cfg.dataset.common.NAME) # all mode create log
    if cfg.rank == 0 and (not cfg.debug): 
        writer = SummaryWriter(log_dir=os.path.join(cfg.run_dir, 'tensorboard')) if cfg.is_training else None # only training create tensorboard
    else:
        writer = None
    set_random_seed(7421, deterministic=cfg.deterministic)
    torch.backends.cudnn.enabled = True
    # logging.info(cfg)

    

    train_loader = build_dataloader_from_cfg(cfg.batch_size,
                                            cfg.dataset,
                                            cfg.dataloader,
                                            datatransforms_cfg=cfg.datatransforms,
                                            split='train',
                                            distributed=cfg.distributed,
                                            )
    logging.info(f"length of training dataset: {len(train_loader.dataset)}")
    cfg.model.beta = 1.0 - 1.0 / train_loader.__len__()
    # support different prior configurations
    if USE_PRIOR and cfg.mode != 'test':
        if cfg.get('model_prior', None) is not None:
            cfg.model_prior.beta = 1.0 - 1.0 / train_loader.__len__()
        if cfg.get('model_prior_fine', None) is not None:
            cfg.model_prior_fine.beta = 1.0 - 1.0 / train_loader.__len__()
        if cfg.get('model_prior_coarse', None) is not None:
            cfg.model_prior_coarse.beta = 1.0 - 1.0 / train_loader.__len__()

        # If coarse prior not specified in YAML, reuse fine prior config
        if cfg.get('model_prior_coarse', None) is None:
            cfg.model_prior_coarse = copy.deepcopy(cfg.model_prior)

    # ---------------- Dynamic prior branch configuration ----------------
    # Determine number of hierarchy levels defined by the backbone (coarse->fine)
    num_levels_cfg = cfg.model.get('hier_levels', None)
    if num_levels_cfg is None:
        # Fallback to length of num_classes_per_level list when hier_levels missing
        num_levels_cfg = len(cfg.model.get('num_classes_per_level', [])) or 1
    cfg.num_levels = num_levels_cfg  # store for later reuse

    # prior_active: list[int] length == num_levels; 1 means build prior branch at that level
    prior_active_default = [0] * cfg.num_levels
    prior_active = cfg.get('prior_active', prior_active_default)
    # If list shorter than required, pad with zeros; if longer, truncate
    if len(prior_active) < cfg.num_levels:
        prior_active = prior_active + [0] * (cfg.num_levels - len(prior_active))
    elif len(prior_active) > cfg.num_levels:
        prior_active = prior_active[:cfg.num_levels]
    cfg.prior_active = prior_active  # save back for global access

    # For backward‐compatibility with the old two-branch hard-coded design, keep the
    # USE_PRIOR flag controlling global enable/disable. If all prior_active entries are 0,
    # we effectively disable prior branches regardless of USE_PRIOR.
    USE_PRIOR_DYNAMIC = USE_PRIOR and any(prior_active)

    # ---------------- Consistency loss (optional) ----------------
    cons_cfg = cfg.get('consistency_loss', None)
    if cons_cfg and cons_cfg.get('enable', False):
        # Build hierarchy matrices
        if 'matrices' in cons_cfg:
            h_mats = [np.array(m, dtype=np.float32) for m in cons_cfg.matrices]
        elif 'matrices_file' in cons_cfg:
            matrices_file_path = cons_cfg.matrices_file
            if matrices_file_path.endswith('.npy'):
                h_mats = np.load(matrices_file_path, allow_pickle=True)
                h_mats = list(h_mats)
            elif matrices_file_path.endswith(('.yaml', '.yml')):
                import yaml
                h_mats = []
                with open(matrices_file_path, 'r') as f:
                    file_list = yaml.safe_load(f).get('file_list', [])
                for f_csv in file_list:
                    if not os.path.isabs(f_csv):
                        # Interpret CSV path relative to the YAML file's directory
                        f_csv = os.path.join("/root/autodl-tmp/PointSegBase-main", f_csv)
                    m = np.loadtxt(f_csv, delimiter=',')
                    h_mats.append(m.astype(np.float32))
            else:
                raise ValueError(f"Unsupported extension for matrices_file: {matrices_file_path}. Use .npy or .yaml")
        else:
            raise ValueError('consistency_loss requires "matrices" or "matrices_file"')
        layer_weights = cons_cfg.get('layer_weights', None)
        cons_mode = cons_cfg.get('mode', 'inter')
        cons_weight = cons_cfg.get('weight', 1.0)
        # robustly cast weight to float
        if isinstance(cons_weight, str):
            try:
                cons_weight = float(cons_weight)
            except ValueError:
                cons_weight = 0.0
        # Build loss fn only when positive weight
        if cons_weight and cons_weight > 0:
            consistency_loss_fn = ConsistencyLoss(h_mats, device=cfg.rank, layer_weights=layer_weights, mode=cons_mode)
        else:
            consistency_loss_fn = None
            cons_weight = 0.0
    else:
        consistency_loss_fn = None
        cons_weight = 0.0

    # ---------------------------------------------------------------------
    # NOTE: create main model from cfg file (same as before)
    # ---------------------------------------------------------------------
    model = build_model_from_cfg(cfg.model).to(cfg.rank)
    model_size = cal_model_parm_nums(model)
    # logging.info(model)
    logging.info('Number of params: %.4f M' % (model_size / 1e6))
    # === Prior branches ===
    if USE_PRIOR_DYNAMIC:
        model_prior_list = []
        for i in range(cfg.num_levels):
            if cfg.prior_active[i]:
                # Duplicate prior configuration, set the number of classes for this level, then build the model
                cfg_prior_i = copy.deepcopy(cfg.model_prior)
                cfg_prior_i.num_classes = cfg.model.num_classes_per_level[i]
                model_prior_list.append(build_model_from_cfg(cfg_prior_i).to(cfg.rank))
            else:
                model_prior_list.append(None)
        model_prior_size_list = [cal_model_parm_nums(p) for p in model_prior_list if p is not None]
        logging.info('Prior models params total : %.4f M' % (sum(model_prior_size_list) / 1e6))
    else:
        model_prior_list = [None for i in range(cfg.num_levels)]

    # save consistency loss fn & weight into cfg so train_one_epoch can access
    cfg.consistency_loss_fn = consistency_loss_fn
    cfg.consistency_weight = cons_weight

    if cfg.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        for i in range(cfg.num_levels):
            if model_prior_list[i] is not None:
                model_prior_list[i] = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_prior_list[i])
        logging.info('Using Synchronized BatchNorm ...')
    if cfg.distributed:
        torch.cuda.set_device(gpu)
        model = nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[cfg.rank], output_device=cfg.rank, broadcast_buffers=True
        )
        for i in range(cfg.num_levels):
            if model_prior_list[i] is not None:
                model_prior_list[i] = nn.parallel.DistributedDataParallel(
                    model_prior_list[i].cuda(), device_ids=[cfg.rank], output_device=cfg.rank, broadcast_buffers=True
                )
        logging.info('Using Distributed Data parallel ...')

    # optimizer & scheduler
    # --- fix: YAML may contain numbers as strings; ensure weight_decay is float ---
    if isinstance(cfg.optimizer.get('weight_decay', 0.0), str):
        try:
            cfg.optimizer['weight_decay'] = float(cfg.optimizer['weight_decay'])
        except ValueError:
            logging.warning(f"Failed to cast weight_decay={cfg.optimizer['weight_decay']} to float; using default 0.0")
            cfg.optimizer['weight_decay'] = 0.0

    optimizer = build_optimizer_from_cfg(model, lr=cfg.lr, **cfg.optimizer)
    scheduler = build_scheduler_from_cfg(cfg, optimizer)
    # AMP: initialize gradient scaler for mixed precision training
    amp_enabled = bool(cfg.get('amp', True))
    scaler = GradScaler(enabled=amp_enabled) if amp_enabled else NullScaler()
    if USE_PRIOR_DYNAMIC:
        optimizer_prior_list = []
        scheduler_prior_list = []
        for i in range(cfg.num_levels):
            if cfg.prior_active[i]:
                opt_i = build_optimizer_from_cfg(model_prior_list[i], lr=cfg.lr, **cfg.optimizer)
                optimizer_prior_list.append(opt_i)
                scheduler_prior_list.append(build_scheduler_from_cfg(cfg, opt_i))
            else:
                optimizer_prior_list.append(None)
                scheduler_prior_list.append(None)
    else:
        optimizer_prior_list = scheduler_prior_list = [None] * cfg.num_levels

    # build dataset
    val_loader = build_dataloader_from_cfg(cfg.get('val_batch_size', cfg.batch_size),
                                           cfg.dataset,
                                           cfg.dataloader,
                                           datatransforms_cfg=cfg.datatransforms,
                                           split='test',
                                           distributed=cfg.distributed
                                           )
    logging.info(f"length of validation dataset: {len(val_loader.dataset)}")
    num_classes = val_loader.dataset.num_classes if hasattr(val_loader.dataset, 'num_classes') else None
    if num_classes is not None:
        assert cfg.num_classes == num_classes
    logging.info(f"number of classes of the dataset: {num_classes}")
    cfg.classes = val_loader.dataset.classes if hasattr(val_loader.dataset, 'classes') else np.arange(num_classes)
    cfg.cmap = np.array(val_loader.dataset.cmap) if hasattr(val_loader.dataset, 'cmap') else None

    if ("semantic3d" in cfg.dataset.common.NAME.lower()) or ("toronto3d" in cfg.dataset.common.NAME.lower()):
        test_loader = build_dataloader_from_cfg(cfg.get('val_batch_size', cfg.batch_size),
                                           cfg.dataset,
                                           cfg.dataloader,
                                           datatransforms_cfg=cfg.datatransforms,
                                           split='test',
                                           distributed=cfg.distributed
                                           )
        logging.info(f"length of test dataset: {len(test_loader.dataset)}")
    elif "scannetv2" in cfg.dataset.common.NAME.lower():
        cfg.test_batch_size = 1
        cfg.dataset.common.collate_fn = False # NOTE: train_loader's collate_fn need to be True
        cfg.dataloader.num_workers = 2
        test_loader = build_dataloader_from_cfg(cfg.test_batch_size,
                                           cfg.dataset,
                                           cfg.dataloader,
                                           datatransforms_cfg=cfg.datatransforms,
                                           split='test',
                                           distributed=False,
                                           )
        logging.info(f"length of test dataset: {len(test_loader.dataset)}")
    
    # optionally resume from a checkpoint
    if cfg.pretrained_path is not None:
        if cfg.mode == 'resume':
            resume_checkpoint(cfg, model, optimizer, scheduler, pretrained_path=cfg.pretrained_path)
            # ----------------------------------------------------------------------
            # Resume PRIOR branch checkpoints (fine & coarse)
            # ----------------------------------------------------------------------
            # Behavior:
            # 1) If YAML explicitly specifies `prior_model_path_<i>`, load that checkpoint for the level.
            # 2) Otherwise fall back to backward-compatible `prior_model_path`.
            # ----------------------------------------------------------------------
            if USE_PRIOR_DYNAMIC:
                for i in range(cfg.num_levels):
                    if cfg.prior_active[i] and model_prior_list[i] is not None:
                        prior_ckpt_path = cfg.get(f'prior_model_path_{i}', cfg.get('prior_model_path', None))
                        resume_checkpoint(cfg, model_prior_list[i], optimizer_prior_list[i], scheduler_prior_list[i],
                                              pretrained_path=prior_ckpt_path)
            val_miou, val_macc_raw, val_oa, val_ious = validate(model, val_loader, cfg)
            logging.info(f'\nresume val miou is {val_miou}\n ')
            # resume the SummaryWriter instance from the crashed epoch
            if cfg.rank == 0:
                writer = SummaryWriter(log_dir=os.path.join(cfg.run_dir, 'tensorboard'), purge_step=cfg.start_epoch)
        elif cfg.mode == 'test':
            if cfg.rank == 0:
                # Deprecated S3DIS-specific dual test replaced by generic evaluation.
                best_epoch, _ = load_checkpoint(model, pretrained_path=cfg.pretrained_path)
                from main.test_utils import evaluate_loader_hierarchical
                generic_loader = build_dataloader_from_cfg(cfg.get('val_batch_size', cfg.batch_size),
                                                          cfg.dataset,
                                                          cfg.dataloader,
                                                          datatransforms_cfg=cfg.datatransforms,
                                                          split='test',
                                                          distributed=False)
                metrics_levels = evaluate_loader_hierarchical(model, generic_loader, cfg)
                with np.printoptions(precision=2, suppress=True):
                    for lv, m in enumerate(metrics_levels):
                        logging.info(f'[TEST] Level {lv}: mIoU {m["miou"]:.2f} OA {m["oa"]:.2f}')
                return metrics_levels[-1]['miou']
            else:
                return True
        else:
            logging.info(f'Finetuning from {cfg.pretrained_path}')
            model.load_model_from_ckpt(cfg.pretrained_path, only_encoder=cfg.only_encoder)
    else:
        logging.info('Training from scratch')
    
    # Legacy class-weighting code removed; prefer dataset-provided weights or cfg.criterion settings.
    criterion = build_criterion_from_cfg(cfg.criterion)
    # Use standard Cross-Entropy criterion for all levels (BalancedSoftmax removed)
    criterion_main_ce = criterion

    if USE_PRIOR_DYNAMIC:
        criterion_supcon = build_criterion_from_cfg(cfg.criterion_SupCon)
    else:
        criterion_supcon = None  # placeholder when prior disabled

    # L1 feature alignment is used only when prior is enabled; otherwise skip
    criterion_l1 = nn.SmoothL1Loss(reduction='mean') if USE_PRIOR_DYNAMIC else None

    # ===> start training
    val_miou, val_macc, val_oa, val_ious = 0., 0., 0., []
    best_val, macc_when_best, oa_when_best, ious_when_best, best_epoch = 0., 0., 0., [], 0
    for epoch in range(cfg.start_epoch, cfg.epochs + 1):
        
        if cfg.distributed:
            train_loader.sampler.set_epoch(epoch)
        if hasattr(train_loader.dataset, 'epoch'):  # some dataset sets the dataset length as a fixed steps.
            train_loader.dataset.epoch = epoch - 1
        hierarchical = True  # default
        loss_prior_l1, loss_prior_supcon, loss_main_l1, loss_main_ce, train_miou, train_macc = \
            train_one_epoch(model, train_loader, criterion_main_ce, optimizer, scheduler, epoch, cfg,\
                model_prior_list, optimizer_prior_list, scheduler_prior_list,
                criterion_supcon, criterion_l1, USE_PRIOR_DYNAMIC, scaler)

        is_best = False
        if epoch % cfg.val_freq == 0:
            val_miou, val_macc_raw, val_oa, val_ious = validate(model, val_loader, cfg)
            # `validate` returns either scalar mAcc or list (per-level mIoU list). Ensure we log scalar.
            if isinstance(val_macc_raw, (list, tuple)):
                val_macc = val_macc_raw[-1]
            else:
                val_macc = val_macc_raw
            if val_miou > best_val:
                is_best = True
                best_val = val_miou
                macc_when_best = val_macc
                oa_when_best = val_oa
                ious_when_best = val_ious
                best_epoch = epoch
                with np.printoptions(precision=2, suppress=True):
                    logging.info(f'Find a better ckpt @E{epoch}, val_miou {val_miou:.2f} val_macc {macc_when_best:.2f}, val_oa {oa_when_best:.2f}')
        # with np.printoptions(precision=2, suppress=True):
        #     logging.info(f'mious: {val_ious}')
        lr = optimizer.param_groups[0]['lr']
        logging.info(f'Epoch {epoch} LR {lr:.6f} '
                     f'train_miou {train_miou:.2f}, val_miou {val_miou:.2f}, best val miou {best_val:.2f}')
        if writer is not None:
            writer.add_scalar('val/best_val', best_val, epoch)
            writer.add_scalar('val/val_miou', val_miou, epoch)
            writer.add_scalar('val/macc_when_best', macc_when_best, epoch)
            writer.add_scalar('val/oa_when_best', oa_when_best, epoch)
            writer.add_scalar('val/val_macc', val_macc, epoch)
            writer.add_scalar('val/val_oa', val_oa, epoch)
            writer.add_scalar('train/loss_prior', loss_prior_l1, epoch)
            writer.add_scalar('train/loss_prior_supcon', loss_prior_supcon, epoch)
            writer.add_scalar('train/loss_main_l1', loss_main_l1, epoch)
            # writer.add_scalar('train/loss_main_struct', loss_main_struct, epoch)
            writer.add_scalar('train/loss_main_ce', loss_main_ce, epoch)
            writer.add_scalar('train/train_miou', train_miou, epoch)
            writer.add_scalar('train/train_macc', train_macc, epoch)
            writer.add_scalar('lr', lr, epoch)

        if cfg.sched_on_epoch:
            scheduler.step(epoch)
            if USE_PRIOR_DYNAMIC:
                for i in range(cfg.num_levels):
                    if scheduler_prior_list[i] is not None:
                        scheduler_prior_list[i].step(epoch)
        if cfg.rank == 0:
            save_checkpoint(cfg, model, epoch, optimizer, scheduler,
                            additioanl_dict={'best_val': best_val},
                            is_best=is_best, post_fix='main_model_ckpt')
            if USE_PRIOR_DYNAMIC:
                for i in range(cfg.num_levels):
                    if cfg.prior_active[i] and model_prior_list[i] is not None:
                        save_checkpoint(cfg, model_prior_list[i], epoch, optimizer_prior_list[i], scheduler_prior_list[i],
                                additioanl_dict={'best_val': best_val},
                                is_best=is_best, post_fix=f'prior_{i}_ckpt', is_logging=False)
            is_best = False

    # validate
    with np.printoptions(precision=2, suppress=True):
        logging.info(
            f'Best ckpt @E{best_epoch},  val_oa {oa_when_best:.2f}, val_macc {macc_when_best:.2f}, val_miou {best_val:.2f}, '
            f'\nEach cls IoU: {ious_when_best}')
    
    # test
    if cfg.rank == 0:  # only main process does the final test/eval
        # Build a generic test loader if split exists
        try:
            test_loader_generic = build_dataloader_from_cfg(cfg.get('val_batch_size', cfg.batch_size),
                                                           cfg.dataset,
                                                           cfg.dataloader,
                                                           datatransforms_cfg=cfg.datatransforms,
                                                           split='test',
                                                           distributed=False)
        except Exception:
            test_loader_generic = None

        load_checkpoint(model, pretrained_path=os.path.join(cfg.ckpt_dir, f'{cfg.run_name}_main_model_ckpt_best.pth'))

        from main.test_utils import evaluate_loader_hierarchical

        if test_loader_generic is not None and len(test_loader_generic) > 0:
            all_level_metrics = evaluate_loader_hierarchical(model, test_loader_generic, cfg)
        else:
            logging.warning("Test split not provided; falling back to validation loader for evaluation.")
            all_level_metrics = evaluate_loader_hierarchical(model, val_loader, cfg)

            # Log results for each level
            logging.info("------ Hierarchical Test Results ------")
            for i, metrics in enumerate(all_level_metrics):
                logging.info(f"--- Level {i} ---")
                logging.info(f"  mIoU: {metrics['miou']:.2f}, mAcc: {metrics['macc']:.2f}, OA: {metrics['oa']:.2f}")
                with np.printoptions(precision=2, suppress=True):
                    logging.info(f"  IoUs: {metrics['ious']}")

            # For backward compatibility and reporting, use the finest level's metrics
            final_metrics = all_level_metrics[-1]
            test_miou, test_macc, test_oa, test_ious = final_metrics['miou'], final_metrics['macc'], final_metrics['oa'], final_metrics['ious']

            if writer is not None:
                writer.add_scalar('test_miou', test_miou, epoch)
                writer.add_scalar('test_macc', test_macc, epoch)
                writer.add_scalar('test_oa', test_oa, epoch)

            # Write the finest level's results to CSV
            write_to_csv(test_oa, test_macc, test_miou, test_ious, best_epoch, cfg, write_header=True)
            logging.info(f'Saved results for the finest level in {cfg.csv_path}')

    return True


def gen_sample(numbers, mu, std=0.05):
    """mu: (class_num, prior_feas_dim)"""
    mu = mu.unsqueeze(0).repeat(numbers, 1, 1)
    eps = torch.randn_like(mu)
    return eps * std + mu

def create_affinity_matrix(xyz: torch.Tensor=None, feats: torch.Tensor=None, offset: torch.Tensor=None, target: torch.Tensor=None, sample: bool=False, size: int=2000):
    """
    xyz: (b*n, 3)->[b, npoints, 3](List len==b)
    features: (b*n, c)->[b, npoints, c](List len==b)
    return: affinity_matrix and according target(optional)
    """
    affinity_matrix = []
    select_target = []
    if sample:
        offset_list = []
        for i in range(offset.size(0)):
            if i>0: offset_list.append(offset[i].item() - offset[i-1].item()) 
            else: offset_list.append(offset[0])
        xyz_list = list(torch.split(xyz, offset_list, dim=0))
        feats_list = torch.split(feats, offset_list, dim=0)
        target_list = torch.split(target, offset_list, dim=0)
        
        for i in range(offset.size(0)):
            xyz_list[i] = xyz_list[i].unsqueeze(0) # (1, npoints, 3)
            idx = furthest_point_sample(xyz_list[i], size).long().squeeze(0)
            select_feats = feats_list[i][idx, :]
            # select_feats = torch.div(select_feats, torch.norm(select_feats, dim=1, keepdim=True)) 
            affinity_matrix.append(torch.einsum("...nc,...mc->...nm", [select_feats, select_feats]).unsqueeze(0))
            select_target.append(target_list[i][idx])
        affinity_matrix = torch.cat(affinity_matrix, dim=0)
        select_target = torch.cat(select_target, dim=0)
        return affinity_matrix, select_target
    else:
        feats = torch.split(feats, size, dim=0)
        feats = [feats[i].unsqueeze(0) for i in range(len(feats))]
        feats = torch.cat(feats, dim=0)
        # feats = torch.div(feats, torch.norm(feats, dim=2, keepdim=True)) 
        affinity_matrix = torch.einsum("...nc,...mc->...nm", [feats, feats])
        return affinity_matrix

def train_one_epoch(model, train_loader, criterion_main_ce, optimizer, scheduler, epoch, cfg, \
                     model_prior_list, optimizer_prior_list, scheduler_prior_list,
                     criterion_supcon, criterion_l1, USE_PRIOR_DYNAMIC, scaler):
    # AMP flag (default True). YAML may store as str; normalize to bool
    amp_cfg_val = cfg.get('amp', True)
    if isinstance(amp_cfg_val, str):
        amp_enabled = amp_cfg_val.lower() in ['1', 'true', 'yes', 'on']
    else:
        amp_enabled = bool(amp_cfg_val)
    consistency_loss_fn = cfg.get('consistency_loss_fn', None)
    cons_weight = cfg.get('consistency_weight', 0.0)
    if isinstance(cons_weight, str):
        try:
            cons_weight = float(cons_weight)
        except ValueError:
            cons_weight = 0.0
    loss_prior_l1_meter = AverageMeter()
    loss_prior_supcon_meter = AverageMeter()
    loss_main_l1_meter = AverageMeter()
    loss_main_ce_meter = AverageMeter()
    cm = ConfusionMatrix(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index)
    # set models to training mode
    model.train()
    if USE_PRIOR_DYNAMIC:
        for i in range(cfg.num_levels):
            if model_prior_list[i] is not None:
                model_prior_list[i].train()
    pbar = tqdm(enumerate(train_loader), total=train_loader.__len__())
    num_iter = 0
    for idx, data_ in pbar:
        hierarchical = False  # default to non-hierarchical; will be set True when dataset provides coarse+fine labels
        # if idx>3: break #debug
        if isinstance(data_, (list, tuple)) and len(data_) == 3:
            data_prior, _, data = data_ # coarse data is not used
        elif isinstance(data_, (list, tuple)) and len(data_) == 2:
            data_prior, data = data_
        elif isinstance(data_, dict):
            data = data_
            data_prior = copy.deepcopy(data)
        else:
            raise ValueError("Expect dataloader to return 2 or 3 items, got {}".format(len(data_) if isinstance(data_, (list, tuple)) else type(data_)))
        data['mask'], data_prior['mask'] = None, None
        # some datasets need to ignore 'unlabeled' class
        if ('semantic3d' in cfg.dataset.common.NAME.lower()) or ('toronto3d' in cfg.dataset.common.NAME.lower()):
            data['mask'] = ~(data['y']==0) 
            data_prior['mask'] = ~(data_prior['y']==0)
        elif 'scannetv2' in cfg.dataset.common.NAME.lower():
            data['mask'] = ~(data['y']==255) 
            data_prior['mask'] = ~(data_prior['y']==255)
        keys = data.keys() if callable(data.keys) else data.keys
        for key in keys:
            if data[key] is None:
                continue
            elif not isinstance(data[key], list):
                data_prior[key] = data_prior[key].cuda(non_blocking=True)
                data[key] = data[key].cuda(non_blocking=True)
            elif torch.is_tensor(data[key][0]):
                for i in range(len(data[key])):
                    data_prior[key][i] = data_prior[key][i].cuda(non_blocking=True)
                    data[key][i] = data[key][i].cuda(non_blocking=True)
        if idx == 0:
            logging.info(f"data['pos'] shape: {data['pos'].shape}")
            logging.info(f"data['x'] shape: {data['x'].shape}")
        num_iter += 1

        if USE_PRIOR_DYNAMIC:
            # --- PRIOR branch targets
            if data_prior['y'].dim() > 1 and data_prior['y'].shape[-1] > 1:  # hierarchical (>=2 levels)
                target_prior = data_prior['y'][:, -1]    # use last column as finest level
            else:                                       # ordinary single-level
                target_prior = data_prior['y'].squeeze(-1)
            
            target_prior = target_prior.reshape(-1)

        # --- MAIN branch targets
        if data['y'].shape[-1] > 1:  # hierarchical with >=2 levels
            hierarchical = True
            target = data['y'][..., -1]
        else:
            target = data['y'].squeeze(-1)

        # Flatten targets to shape (B*N,) for loss calculation and indexing.
        target = target.reshape(-1)

        if cfg.model.get('dim_modified', True):
            # data['x'] from dataloader is (B, N, C_in), e.g. (4, 40000, 4)
            # get_scene_seg_features expects (B, N, C_in) and returns (B, C_out, N)
            data['x'] = get_scene_seg_features(cfg.model.in_channels, data['pos'], data['x'])
            data_prior['x'] = get_scene_seg_features(6, data_prior['pos'], data_prior['x'])

            # Ensure labels have a batch dimension consistent with pos/x:
            # If labels are 1-D (N,), unsqueeze to (1, N)
            def _ensure_batch_label(d):
                if torch.is_tensor(d['y']) and d['y'].dim() == 1:
                    d['y'] = d['y'].unsqueeze(0)

            _ensure_batch_label(data)
            _ensure_batch_label(data_prior)

            # When dataloader outputs flattened scenes (N,3), ensure xyz is
            # expanded to (1,N,3) so that downstream Set Abstraction modules
            # receive 3-D tensors.
            if data['pos'].dim() == 2:
                data['pos'] = data['pos'].unsqueeze(0)
            if data_prior['pos'].dim() == 2:
                data_prior['pos'] = data_prior['pos'].unsqueeze(0)

        total_loss = torch.tensor(0., device=cfg.rank)
        
        # ---- forward through prior branches ---------------------------------
        prior_feas_list = [None] * cfg.num_levels
        prior_proto_list = [None] * cfg.num_levels
        for i in range(cfg.num_levels):
            if cfg.prior_active[i] and model_prior_list[i] is not None:
                ctx = autocast(enabled=amp_enabled)
                with ctx:
                    prior_feas_i, prior_proto_i = model_prior_list[i](
                        data_prior, is_train=True, mask=data_prior.get('mask', None), ignore_index=cfg.ignore_index
                    )
                prior_feas_list[i] = prior_feas_i
                prior_proto_list[i] = prior_proto_i

        ctx_main = autocast(enabled=amp_enabled)
        with ctx_main:
            outputs = model(data, is_train=True, mask=data.get('mask', None), ignore_index=cfg.ignore_index)

        logits_list = outputs[0]
        main_feas_list = outputs[1] if len(outputs) > 1 else [None]*len(logits_list)
        main_proto_list = outputs[2] if len(outputs) > 2 else [None]*len(logits_list)
        hierarchical = len(logits_list) > 1

        num_levels = len(logits_list)

        # ------------------------------------------------------------------
        # Diagnostic checks for NaN / Inf in logits & feature tensors
        # ------------------------------------------------------------------
        for lvl, lg in enumerate(logits_list):
            if torch.isnan(lg).any():
                logging.warning(f"NaN detected in logits level {lvl}")
            if torch.isinf(lg).any():
                logging.warning(f"Inf detected in logits level {lvl}")
        for lvl, ft in enumerate(main_feas_list):
            if ft is not None:
                if torch.isnan(ft).any():
                    logging.warning(f"NaN detected in features level {lvl}")
                if torch.isinf(ft).any():
                    logging.warning(f"Inf detected in features level {lvl}")

        # ------------------
        # Prior loss
        # ------------------
        loss_prior_supcon = torch.tensor(0., device=logits_list[0].device)
        loss_prior_l1 = torch.tensor(0., device=logits_list[0].device)
        if USE_PRIOR_DYNAMIC:
            for i in range(num_levels):
                if cfg.prior_active[i] and prior_feas_list[i] is not None:
                    pf = prior_feas_list[i]
                    loss_prior_supcon += cfg.loss_prior_supcon * criterion_supcon(pf[:, :-1], pf[:, -1])
                    # Use main branch prototypes to supervise prior branch features (cast to FP32, apply masks and numeric safeguards)
                    if main_proto_list[i] is not None:
                        buf = main_proto_list[i]
                        if buf.device != pf.device:
                            buf = buf.to(pf.device)
                        prior_labels = pf[:, -1].long()
                        # Valid label range: [0, num_classes_i)
                        num_classes_i = getattr(cfg.model, 'num_classes_per_level', [cfg.num_classes])[i] if hasattr(cfg.model, 'num_classes_per_level') else cfg.num_classes
                        valid_mask = (prior_labels >= 0) & (prior_labels < num_classes_i)
                        if valid_mask.any():
                            feats_prior = pf[:, :-1][valid_mask]
                            labels_valid = prior_labels[valid_mask]
                            prior_target = buf[labels_valid, :]
                            # FP32 + numeric safety handling
                            feats_prior = torch.nan_to_num(feats_prior.float(), nan=0.0, posinf=1e4, neginf=-1e4)
                            prior_target = torch.nan_to_num(prior_target.float(), nan=0.0, posinf=1e4, neginf=-1e4)
                            with autocast(enabled=False):
                                loss_prior_l1 += cfg.loss_prior * criterion_l1(feats_prior, prior_target)


        # ---------------- Main CE loss over all levels -------------------
        loss_main_ce = torch.tensor(0., device=logits_list[-1].device)
        hier_w = getattr(cfg, 'hier_coarse_weight', 1.0)

        for lvl, logits_lvl in enumerate(logits_list):
            target_lvl = data['y'][..., lvl] if hierarchical and data['y'].shape[-1] > 1 else target
            target_lvl_flat = target_lvl.reshape(-1)
            logits_lvl_fp32 = logits_lvl.float()  # cast to fp32 for stable loss calculation
            # Select criterion for this level (support list or single)
            crit_lvl = criterion_main_ce[lvl] if isinstance(criterion_main_ce, (list, tuple)) else criterion_main_ce
            # Apply dataset mask if present, then compute CE
            if data['mask'] is not None:
                mask_flat = data['mask'].reshape(-1)
                target_lvl_flat = target_lvl_flat[mask_flat]
                logits_lvl_fp32 = logits_lvl_fp32[mask_flat]
            ce = crit_lvl(logits_lvl_fp32, target_lvl_flat)

            w = hier_w if lvl < num_levels-1 else 1.0
            loss_main_ce += w * ce

        loss_main_l1 = torch.tensor(0., device=logits_list[-1].device)

        # ---------------- Consistency loss ----------------
        loss_consistency = torch.tensor(0., device=logits_list[-1].device)
        if (consistency_loss_fn is not None) and (cons_weight and cons_weight > 0):
            loss_consistency = cons_weight * consistency_loss_fn(logits_list)

        # ---------------- Main L1 loss over all levels -------------------
        if USE_PRIOR_DYNAMIC and criterion_l1 is not None:
            for i in range(num_levels):
                if cfg.prior_active[i]:
                    # Select correct labels for current level (coarse→fine)
                    if hierarchical and data['y'].shape[-1] > 1:
                        tgt_lvl = data['y'][..., i]
                    else:
                        tgt_lvl = target
                    tgt_lvl = tgt_lvl.reshape(-1).long()

                    # Use prior branch prototypes to supervise main branch features (cast to FP32, apply masks and numeric safeguards)
                    if prior_proto_list[i] is not None and (main_feas_list[i] is not None):
                        num_classes_i = getattr(cfg.model, 'num_classes_per_level', [cfg.num_classes])[i] if hasattr(cfg.model, 'num_classes_per_level') else cfg.num_classes
                        valid_mask = (tgt_lvl >= 0) & (tgt_lvl < num_classes_i)
                        # Combine with dataset mask (ignore unlabeled classes)
                        mask_flat = data.get('mask', None)
                        if mask_flat is not None:
                            mask_flat = mask_flat.reshape(-1)
                            if mask_flat.shape[0] == valid_mask.shape[0]:
                                valid_mask = valid_mask & mask_flat
                        if valid_mask.any():
                            feats_main = main_feas_list[i][valid_mask, :]
                            labels_valid = tgt_lvl[valid_mask]
                            target_main_feas = prior_proto_list[i][labels_valid, :]
                            # FP32 + numeric safety handling
                            feats_main = torch.nan_to_num(feats_main.float(), nan=0.0, posinf=1e4, neginf=-1e4)
                            target_main_feas = torch.nan_to_num(target_main_feas.float(), nan=0.0, posinf=1e4, neginf=-1e4)
                            with autocast(enabled=False):
                                loss_main_l1 += cfg.loss_main_l1 * criterion_l1(feats_main, target_main_feas)

        # === NaN / Inf monitoring ===
        for name_, tensor_ in {
            'loss_prior_supcon': loss_prior_supcon,
            'loss_prior_l1': loss_prior_l1,
            'loss_main_l1': loss_main_l1,  # placeholder, will be overwritten later
            'loss_main_ce': loss_main_ce   # placeholder
            ,'loss_consistency': loss_consistency
        }.items():
            if torch.isnan(tensor_).any() or torch.isinf(tensor_).any():
                logging.warning(f"NaN/Inf detected in {name_} at iter {idx} epoch {epoch}")

        # ------------------------------------------------------------------
        # Total loss
        # ------------------------------------------------------------------
        loss_aux = loss_main_l1 + loss_prior_supcon + loss_prior_l1 + loss_consistency
        loss_aux = loss_aux * 10
        loss = loss_main_ce + loss_aux
        scaler.scale(loss).backward()

        # compute prior_prototype, loss_main, loss_prior and loss average on different gpu
        if cfg.distributed:
            dist.all_reduce(loss_prior_l1)
            loss_prior_l1 /= cfg.world_size
            dist.all_reduce(loss_prior_supcon)
            loss_prior_supcon /= cfg.world_size
            dist.all_reduce(loss_main_l1)
            loss_main_l1 /= cfg.world_size
            dist.all_reduce(loss_main_ce)
            loss_main_ce /= cfg.world_size
            if (consistency_loss_fn is not None) and (cons_weight and cons_weight > 0):
                dist.all_reduce(loss_consistency)
                loss_consistency /= cfg.world_size

        # optimize
        if num_iter == cfg.step_per_update:
            if cfg.get('grad_norm_clip') is not None and cfg.grad_norm_clip > 0.:
                # Unscale the gradients before clipping
                logging.debug(f"grad_norm_clip: {cfg.grad_norm_clip}")
                if amp_enabled:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_norm_clip, norm_type=2)
                if USE_PRIOR_DYNAMIC:
                    for i in range(cfg.num_levels):
                        if optimizer_prior_list[i] is not None:
                            if amp_enabled:
                                scaler.unscale_(optimizer_prior_list[i])
                            if model_prior_list[i] is not None:
                                torch.nn.utils.clip_grad_norm_(model_prior_list[i].parameters(), cfg.grad_norm_clip, norm_type=2)
            num_iter = 0
            # Optimizer step (scaled when AMP enabled)
            scaler.step(optimizer)
            optimizer.zero_grad()
            if USE_PRIOR_DYNAMIC:
                for i in range(cfg.num_levels):
                    if optimizer_prior_list[i] is not None:
                        scaler.step(optimizer_prior_list[i]); optimizer_prior_list[i].zero_grad()
            # update scaler for next iteration
            scaler.update()
            if not cfg.sched_on_epoch:
                scheduler.step(epoch)
                if USE_PRIOR_DYNAMIC:
                    for i in range(cfg.num_levels):
                        if scheduler_prior_list[i] is not None:
                            scheduler_prior_list[i].step(epoch)

        # update confusion matrix
        # only use the finest level for training metrics
        predits = logits_list[-1].argmax(dim=1)
        target_lvl = data['y'][..., -1] if hierarchical and data['y'].shape[-1] > 1 else target
        target_lvl = target_lvl.reshape(-1)
        
        mask = data.get('mask', None)
        if mask is not None:
            predits = predits[mask]
            target_lvl = target_lvl[mask]
        
        if ('Semantic3D' in cfg.dataset.common.NAME) or ('Toronto3D' in cfg.dataset.common.NAME):
            predits = predits + 1 # restore mapping in training: from [0,7] to [1,8] 
        
        cm.update(predits, target_lvl)
        
        loss_prior_l1_meter.update(loss_prior_l1.item())
        loss_prior_supcon_meter.update(loss_prior_supcon.item())
        loss_main_l1_meter.update(loss_main_l1.item())
        loss_main_ce_meter.update(loss_main_ce.item())
        loss_log = loss_main_l1_meter.val + loss_main_ce_meter.val + loss_prior_l1_meter.val + loss_prior_supcon_meter.val

        if idx % cfg.print_freq:
            pbar.set_description(f"Train Epoch [{epoch}/{cfg.epochs}] "
                                f"Loss {loss_log:.3f} Acc {cm.overall_accuray:.2f}")
    miou, macc, _, _, _ = cm.all_metrics()
    return loss_prior_l1_meter.avg, loss_prior_supcon_meter.avg, loss_main_l1_meter.avg, loss_main_ce_meter.avg, miou, macc


@torch.no_grad()
def validate(model, val_loader, cfg):
    torch.cuda.empty_cache()
    # set model and model_prior to eval mode
    model.eval()
    cm_list = []  # will be built dynamically based on logits levels
    pbar = tqdm(enumerate(val_loader), total=val_loader.__len__())
    for idx, data in pbar:
        # if idx>2: break #debug
        # some datasets need to ignore 'unlabeled' class
        if ('semantic3d' in cfg.dataset.common.NAME.lower()) or ('toronto3d' in cfg.dataset.common.NAME.lower()):
            data['mask'] = ~(data['y']==0) 
        elif 'scannetv2' in cfg.dataset.common.NAME.lower():
            data['mask'] = ~(data['y']==255) 
        keys = data.keys() if callable(data.keys) else data.keys
        for key in keys:
            if not isinstance(data[key], list):
                data[key] = data[key].cuda(non_blocking=True)
            elif torch.is_tensor(data[key][0]):
                for i in range(len(data[key])):
                    data[key][i] = data[key][i].cuda(non_blocking=True)
        target = data['y'].squeeze(-1)
        if cfg.model.get('dim_modified', True):
            data['x'] = get_scene_seg_features(cfg.model.in_channels, data['pos'], data['x'])
            if data['pos'].dim() == 2:
                data['pos'] = data['pos'].unsqueeze(0)

        # new list-based interface
        outputs = model(data)

        # new list-based interface
        if isinstance(outputs, list):
            logits_list = outputs
        elif isinstance(outputs, tuple) and isinstance(outputs[0], list):
            logits_list = outputs[0]
        else:
            # fallback old
            if torch.is_tensor(outputs):
                logits_list = [outputs]
            elif len(outputs)==3:
                logits_list = [outputs[0]]
            elif len(outputs)==6:
                logits_list = [outputs[1], outputs[0]]  # coarse, fine
            else:
                raise ValueError("Unexpected model outputs")

        # build cm_list if first iter
        if not cm_list:
            for lg in logits_list:
                cm_list.append(ConfusionMatrix(num_classes=lg.size(1), ignore_index=None))

        # iterate levels
        y_all = data['y']
        if y_all.dim()==1:
            y_all = y_all.unsqueeze(-1)

        for lvl, lg in enumerate(logits_list):
            tgt_lvl = y_all[..., lvl] if y_all.size(-1)>lvl else y_all[..., -1]
            if lg.dim()==2 and tgt_lvl.dim()>1:
                tgt_lvl = torch.cat([t.squeeze() for t in tgt_lvl.split(1,0)])
            pred_lvl = lg.argmax(dim=1)
            cm_list[lvl].update(pred_lvl, tgt_lvl.reshape(-1))

    miou_list = []
    oa_list = []
    ious_list = []  # store per-class IoU for each level

    for cm_ in cm_list:
        tp, union, count = cm_.tp, cm_.union, cm_.count
        if cfg.distributed:
            dist.all_reduce(tp); dist.all_reduce(union); dist.all_reduce(count)
        mi, ma, oa, ious, _ = get_mious(tp, union, count)
        miou_list.append(mi)
        oa_list.append(oa)
        ious_list.append(ious)

    logging.info(f'Validation mIoU per level: {miou_list}')
    for lvl, ious_lvl in enumerate(ious_list):
        logging.info(f'Level {lvl} IoU: {ious_lvl}')
    return miou_list[-1], miou_list, oa_list[-1], ious_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Scene segmentation training/testing')
    parser.add_argument('--debug', type=bool, default=False, help='setting debug mode to control not create tensorboard')
    # parser.add_argument('--cfg', type=str, required=True, help='config file')
    parser.add_argument('--cfg', type=str, help='config file')
    parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')
    args, opts = parser.parse_known_args()
    # args.debug = True # set debug mode
    # Keep the value provided by command line; do not force profile mode on.

    cfg = EasyConfig()
    cfg.load(args.cfg, recursive=True)
    cfg.update(vars(args))    # overwrite the default arguments in yml  
    
    # NOTE: 'test'/'val' or 'resume' mode
    # cfg.mode = 'test'
    # cfg.mode = 'val'
    # cfg.mode = 'resume'
    # cfg.pretrained_path = ""
    # cfg.prior_model_path = ""
    cfg.seed = np.random.randint(1, 10000)
    # init distributed env first, since logger depends on the dist info.
    cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)
    cfg.sync_bn = cfg.world_size > 1

    # init log dir
    cfg.task_name = args.cfg.split('.')[-2].split('/')[-2]  # task/dataset name, \eg s3dis, modelnet40_cls
    cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1]   # cfg_basename, \eg pointnext-xl 
    tags = [
        cfg.task_name,  # task name (the folder of name under ./cfgs
        cfg.mode,
        cfg.cfg_basename,  # cfg file name
        f'ngpus{cfg.world_size}',
        f'seed{cfg.seed}',
    ]
    cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)

    cfg.is_training = cfg.mode in ['train', 'training', 'finetune', 'finetuning']
    if cfg.mode == 'train':
        generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))
    else:  # resume from the existing ckpt and reuse the folder.
        resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)
    os.environ["JOB_LOG_DIR"] = cfg.log_dir
    cfg_path = os.path.join(cfg.run_dir, "cfg.yaml")
    if not cfg.debug:
        with open(cfg_path, 'w') as f:
            yaml.dump(cfg, f, indent=2)
            os.system('cp %s %s' % (args.cfg, cfg.run_dir))
    cfg.cfg_path = cfg_path

    # multi processing.
    if cfg.mp:
        port = find_free_port()
        cfg.dist_url = f"tcp://localhost:{port}"
        logging.info('using mp spawn for distributed training')
        mp.spawn(main, nprocs=cfg.world_size, args=(cfg,))
    else:
        main(0, cfg)
