"""
(Distributed) training script for scene segmentation
This file currently supports training and testing on S3DIS
If more than 1 GPU is provided, will launch multi processing distributed training by default
if you only wana use 1 GPU, set `CUDA_VISIBLE_DEVICES` accordingly
"""
import __init__
import argparse, yaml, os, logging, numpy as np, csv, wandb, glob
from tqdm import tqdm
import torch, torch.nn as nn
from torch import distributed as dist, multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from torch_scatter import scatter
from openpoints_student.utils import set_random_seed, save_checkpoint, load_checkpoint, resume_checkpoint, setup_logger_dist, \
    cal_model_parm_nums, Wandb, generate_exp_directory, resume_exp_directory, EasyConfig, dist_utils, find_free_port, load_checkpoint_inv
from openpoints_student.utils import AverageMeter, ConfusionMatrix, get_mious
from openpoints_student.dataset import build_dataloader_from_cfg, get_features_by_keys, get_class_weights
from openpoints_student.dataset.data_util import voxelize
from openpoints_student.dataset.semantic_kitti.semantickitti import load_label_kitti, load_pc_kitti, remap_lut_read, remap_lut_write, get_semantickitti_file_list
from openpoints_student.transforms import build_transforms_from_cfg
from openpoints_student.optim import build_optimizer_from_cfg
from openpoints_student.scheduler import build_scheduler_from_cfg
from openpoints_student.loss import build_criterion_from_cfg
from openpoints_student.models_student import build_model_student_from_cfg
from openpoints.models import build_model_from_cfg
from robust_overfiting import Robust_Overfitting
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def accuracy(outputs, targets, topk=(1, )):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = targets.size(0)

        _, pred = outputs.topk(maxk, 1, True, True) #torch.Size([2, 5, 24000])
        pred = pred.t()
        correct = pred.eq(targets.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
            res.append(correct_k.mul_(1 / batch_size))
    return res

def get_weight_dict_from_teacher(teacher_model):
    teacher_dict = {}
    index = 0
    for k,v in teacher_model.state_dict().items():
        if k.find('conv')!=-1:
           teacher_dict[index] = v
           index+=1
    return teacher_dict
class search_loss(nn.Module):
    def __init__(self, alpha, beta, network_loss_func, gamma):
        super(search_loss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.beta = beta
        self.network_loss_func = network_loss_func

    def forward(self, output, teacher_output, cost, target, rkd_loss):
        #rkd_loss.requires_grad=True
        #prob1 = F.softmax(teacher_output, dim=-1)
        #prob2 = F.softmax(output, dim=-1)
        #kl = ((prob1 * torch.log(1e-6 + prob1/(prob2+1e-6))).sum(1)).mean()
        kl = 0.0
        if rkd_loss!=None:
           ce_loss = self.network_loss_func(output, target)
           #network_loss = ce_loss + self.gamma*rkd_loss
           network_loss = ce_loss + self.gamma*rkd_loss + kl 
        else:
           network_loss = self.network_loss_func(output, target)
        latency_loss = torch.log(torch.max(cost) ** self.beta)
        loss = self.alpha * network_loss * latency_loss
        if rkd_loss!=None:
           return loss, ce_loss, rkd_loss, kl, latency_loss
        else:
           return loss, network_loss, latency_loss

def write_to_csv(oa, macc, miou, ious, best_epoch, cfg_s, write_header=True, area=5):
    ious_table = [f'{item:.2f}' for item in ious]
    header = ['method', 'Area', 'OA', 'mACC', 'mIoU'] + cfg_s.classes + ['best_epoch', 'log_path', 'wandb link']
    data = [cfg_s.cfg_basename, str(area), f'{oa:.2f}', f'{macc:.2f}',
            f'{miou:.2f}'] + ious_table + [str(best_epoch), cfg_s.run_dir,
                                           wandb.run.get_url() if cfg_s.wandb.use_wandb else '-']
    with open(cfg_s.csv_path, 'a', encoding='UTF8', newline='') as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(header)
        writer.writerow(data)
        f.close()


def generate_data_list(cfg_s):
    if 's3dis' in cfg_s.dataset.common.NAME.lower():
        raw_root = os.path.join(cfg_s.dataset.common.data_root, 'raw')
        data_list = sorted(os.listdir(raw_root))
        data_list = [os.path.join(raw_root, item) for item in data_list if
                     'Area_{}'.format(cfg_s.dataset.common.test_area) in item]
    elif 'scannet' in cfg_s.dataset.common.NAME.lower():
        data_list = glob.glob(os.path.join(cfg_s.dataset.common.data_root, cfg_s.dataset.test.split, "*.pth"))
    elif 'semantickitti' in cfg_s.dataset.common.NAME.lower():
        if cfg_s.dataset.test.split == 'val':
            split_no = 1
        else:
            split_no = 2
        data_list = get_semantickitti_file_list(os.path.join(cfg_s.dataset.common.data_root, 'sequences'),
                                                str(cfg_s.dataset.test.test_id + 11))[split_no]
    else:
        raise Exception('dataset not supported yet'.format(args.data_name))
    return data_list


def load_data(data_path, cfg_s):
    label, feat = None, None
    if 's3dis' in cfg_s.dataset.common.NAME.lower():
        data = np.load(data_path)  # xyzrgbl, N*7
        coord, feat, label = data[:, :3], data[:, 3:6], data[:, 6]
        feat = np.clip(feat / 255., 0, 1).astype(np.float32)
    elif 'scannet' in cfg_s.dataset.common.NAME.lower():
        data = torch.load(data_path)  # xyzrgbl, N*7
        coord, feat = data[0], data[1]
        if cfg_s.dataset.test.split != 'test':
           label = data[2]
        else:
            label = None
        feat = np.clip((feat + 1) / 2., 0, 1).astype(np.float32)
    elif 'semantickitti' in cfg_s.dataset.common.NAME.lower():
        coord = load_pc_kitti(data_path[0])
        if cfg_s.dataset.test.split != 'test':
            label = load_label_kitti(data_path[1], remap_lut_read)
    coord -= coord.min(0)

    idx_points = []
    voxel_idx, reverse_idx_part,reverse_idx_sort = None, None, None
    voxel_size = cfg_s.dataset.common.get('voxel_size', None)

    if voxel_size is not None:
        # idx_sort: original point indicies sorted by voxel NO.
        # voxel_idx: Voxel NO. for the sorted points
        idx_sort, voxel_idx, count = voxelize(coord, voxel_size, mode=1)
        if cfg_s.get('test_mode', 'multi_voxel') == 'nearest_neighbor':
            idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count
            idx_part = idx_sort[idx_select]
            npoints_subcloud = voxel_idx.max()+1
            idx_shuffle = np.random.permutation(npoints_subcloud)
            idx_part = idx_part[idx_shuffle] # idx_part: randomly sampled points of a voxel
            reverse_idx_part = np.argsort(idx_shuffle, axis=0) # revevers idx_part to sorted
            idx_points.append(idx_part)
            reverse_idx_sort = np.argsort(idx_sort, axis=0)
        else:
            for i in range(count.max()):
                idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
                idx_part = idx_sort[idx_select]
                np.random.shuffle(idx_part)
                idx_points.append(idx_part)
    else:
        idx_points.append(np.arange(label.shape[0]))
    return coord, feat, label, idx_points, voxel_idx, reverse_idx_part, reverse_idx_sort


def main(gpu, cfg, cfg_s):
    if cfg_s.distributed:
        if cfg_s.mp:
            cfg_s.rank = gpu
        dist.init_process_group(backend=cfg_s.dist_backend,
                                init_method=cfg_s.dist_url,
                                world_size=cfg_s.world_size,
                                rank=cfg_s.rank)
        dist.barrier()

    # logger
    setup_logger_dist(cfg_s.log_path, cfg_s.rank, name=cfg_s.dataset.common.NAME)
    if cfg_s.rank == 0:
        Wandb.launch(cfg_s, cfg_s.wandb.use_wandb)
        writer = SummaryWriter(log_dir=cfg_s.run_dir) if cfg_s.is_training else None
    else:
        writer = None
    set_random_seed(cfg_s.seed + cfg_s.rank, deterministic=cfg_s.deterministic)
    torch.backends.cudnn.enabled = True
    logging.info(cfg_s)

    if cfg.model.get('in_channels', None) is None:
        cfg.model.in_channels = cfg.model.encoder_args.in_channels
    if cfg_s.model.get('in_channels', None) is None:
        cfg_s.model.in_channels = cfg_s.model.encoder_args.in_channels
    model = build_model_student_from_cfg(cfg_s.model).to(cfg_s.rank)
    # model_student = build_model_student_from_cfg(cfg_s.model).to(cfg_s.rank)
    teacher_model = build_model_from_cfg(cfg.model).to(cfg.rank)
    model_size = cal_model_parm_nums(model)
    logging.info(model)
    logging.info('Number of params: %.4f M' % (model_size / 1e6))
    train_loss_func = nn.CrossEntropyLoss(reduction='mean')

    if cfg_s.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        logging.info('Using Synchronized BatchNorm ...')
    if cfg.sync_bn:
        teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model)
        logging.info('teacher_model Using Synchronized BatchNorm ...')
    if cfg_s.distributed:
        torch.cuda.set_device(gpu)
        model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[cfg_s.rank], output_device=cfg_s.rank)
        logging.info('Using Distributed Data parallel ...')

    # optimizer & scheduler
    optimizer = build_optimizer_from_cfg(model, lr=cfg_s.lr, **cfg_s.optimizer)
    scheduler = build_scheduler_from_cfg(cfg_s, optimizer)

    # build dataset
    val_loader = build_dataloader_from_cfg(cfg_s.get('val_batch_size', cfg_s.batch_size),
                                           cfg_s.dataset,
                                           cfg_s.dataloader,
                                           datatransforms_cfg=cfg_s.datatransforms,
                                           split='val',
                                           distributed=cfg_s.distributed
                                           )
    logging.info(f"length of validation dataset: {len(val_loader.dataset)}")
    num_classes = val_loader.dataset.num_classes if hasattr(val_loader.dataset, 'num_classes') else None
    if num_classes is not None:
        assert cfg_s.num_classes == num_classes
    logging.info(f"number of classes of the dataset: {num_classes}")
    cfg_s.classes = val_loader.dataset.classes if hasattr(val_loader.dataset, 'classes') else np.arange(num_classes)
    cfg_s.cmap = np.array(val_loader.dataset.cmap) if hasattr(val_loader.dataset, 'cmap') else None
    validate_fn = validate if 'sphere' not in cfg_s.dataset.common.NAME.lower() else validate_sphere

    # optionally resume from a checkpoint
    path = os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))))
    model_path = os.path.join(path,cfg.load_path)
    load_checkpoint(teacher_model, pretrained_path=model_path)
    # teacher_model = Robust_Overfitting(model_student,'cuda', model_path)
    # teacher_dict = get_weight_dict_from_teacher(teacher_model)
    # teacher_model = build_model_from_cfg(cfg.model).to(cfg.rank)
    # model_student = model_student(teacher_dict)
    model_module = model.module if hasattr(model, 'module') else model
    if cfg_s.pretrained_path is not None:
        if cfg_s.mode == 'resume':
            resume_checkpoint(cfg_s, model, optimizer, scheduler, pretrained_path=cfg_s.pretrained_path)
        else:
            if cfg_s.mode == 'val':
                best_epoch, best_val = load_checkpoint(model, pretrained_path=cfg_s.pretrained_path)
                val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg_s, num_votes=1, epoch=epoch)
                with np.printoptions(precision=2, suppress=True):
                    logging.info(
                        f'Best ckpt @E{best_epoch},  val_oa , val_macc, val_miou: {val_oa:.2f} {val_macc:.2f} {val_miou:.2f}, '
                        f'\niou per cls is: {val_ious}')
                return val_miou
            elif cfg_s.mode == 'test':
                best_epoch, best_val = load_checkpoint(model, pretrained_path=cfg_s.pretrained_path)
                data_list = generate_data_list(cfg_s)
                logging.info(f"length of test dataset: {len(data_list)}")
                test_miou, test_macc, test_oa, test_ious, test_accs, _ = test(model, data_list, cfg_s)

                if test_miou is not None:
                    with np.printoptions(precision=2, suppress=True):
                        logging.info(
                            f'Best ckpt @E{best_epoch},  test_oa , test_macc, test_miou: {test_oa:.2f} {test_macc:.2f} {test_miou:.2f}, '
                            f'\niou per cls is: {test_ious}')
                    cfg_s.csv_path = os.path.join(cfg_s.run_dir, cfg_s.run_name + '_test.csv')
                    write_to_csv(test_oa, test_macc, test_miou, test_ious, best_epoch, cfg_s)
                return test_miou

            elif 'encoder' in cfg_s.mode:
                if 'inv' in cfg_s.mode:
                    logging.info(f'Finetuning from {cfg_s.pretrained_path}')
                    load_checkpoint_inv(model.encoder, cfg_s.pretrained_path)
                else:
                    logging.info(f'Finetuning from {cfg_s.pretrained_path}')
                    load_checkpoint(model_module.encoder, cfg_s.pretrained_path, cfg_s.get('pretrained_module', None))

            else:
                logging.info(f'Finetuning from {cfg_s.pretrained_path}')
                load_checkpoint(model, cfg_s.pretrained_path, cfg_s.get('pretrained_module', None))
    else:
        logging.info('Training from scratch')

    if 'freeze_blocks' in cfg_s.mode:
        for p in model_module.encoder.blocks.parameters():
            p.requires_grad = False

    train_loader = build_dataloader_from_cfg(cfg_s.batch_size,
                                             cfg_s.dataset,
                                             cfg_s.dataloader,
                                             datatransforms_cfg=cfg_s.datatransforms,
                                             split='train',
                                             distributed=cfg_s.distributed,
                                             )
    logging.info(f"length of training dataset: {len(train_loader.dataset)}")

    cfg_s.criterion_args.weight = None
    if cfg_s.get('cls_weighed_loss', False):
        if hasattr(train_loader.dataset, 'num_per_class'):
            cfg_s.criterion_args.weight = get_class_weights(train_loader.dataset.num_per_class, normalize=True)
        else:
            logging.info('`num_per_class` attribute is not founded in dataset')
    criterion = build_criterion_from_cfg(cfg_s.criterion_args).cuda()

    # ===> start training
    if cfg_s.use_amp:
        scaler = torch.cuda.amp.GradScaler()
    else:
        scaler = None

    val_miou, val_macc, val_oa, val_ious, val_accs = 0., 0., 0., [], []
    best_val, macc_when_best, oa_when_best, ious_when_best, best_epoch = 0., 0., 0., [], 0
    total_iter = 0
    temperature = 5.0
    for epoch in range(cfg_s.start_epoch, cfg_s.epochs + 1):
        if cfg_s.distributed:
            train_loader.sampler.set_epoch(epoch)
        if hasattr(train_loader.dataset, 'epoch'):  # some dataset sets the dataset length as a fixed steps.
            train_loader.dataset.epoch = epoch - 1
        train_loss, train_miou, train_macc, train_oa, _, _, total_iter,cost_meter,network_loss_meter,rkd_loss_meter,kl_loss_meter,acc1_meter,acc5_meter = \
            train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, total_iter, cfg_s,teacher_model,train_loss_func,temperature,writer)
        temperature = temperature * np.exp(-0.045)
        is_best = False
        if epoch % cfg_s.val_freq == 0:
            val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg_s,teacher_model, temperature,epoch=epoch, total_iter=total_iter)
            if val_miou > best_val:
                is_best = True
                best_val = val_miou
                macc_when_best = val_macc
                oa_when_best = val_oa
                ious_when_best = val_ious
                best_epoch = epoch
                with np.printoptions(precision=2, suppress=True):
                    logging.info(
                        f'Find a better ckpt @E{epoch}, val_miou {val_miou:.2f} val_macc {macc_when_best:.2f}, val_oa {oa_when_best:.2f}'
                        f'\nmious: {val_ious}')

        lr = optimizer.param_groups[0]['lr']
        logging.info(f'Epoch {epoch} LR {lr:.6f} '
                     f'train_miou {train_miou:.2f}, val_miou {val_miou:.2f}, best val miou {best_val:.2f}, network_loss {network_loss_meter:.4f}, acc@1 {acc1_meter:.4f}, acc@5 {acc5_meter}')
        if writer is not None:
            writer.add_scalar('best_val', best_val, epoch)
            writer.add_scalar('val_miou', val_miou, epoch)
            writer.add_scalar('macc_when_best', macc_when_best, epoch)
            writer.add_scalar('oa_when_best', oa_when_best, epoch)
            writer.add_scalar('val_macc', val_macc, epoch)
            writer.add_scalar('val_oa', val_oa, epoch)
            writer.add_scalar('train_loss', train_loss, epoch)
            writer.add_scalar('train_miou', train_miou, epoch)
            writer.add_scalar('train_macc', train_macc, epoch)
            writer.add_scalar('lr', lr, epoch)
            writer.add_scalar('cost',cost_meter,epoch)
            writer.add_scalar('network_loss',network_loss_meter,epoch)
            writer.add_scalar('intermediate-kl',rkd_loss_meter,epoch)
            writer.add_scalar('kl',kl_loss_meter,epoch)
            writer.add_scalar('acc@1',acc1_meter,epoch)
            writer.add_scalar('acc@5',acc5_meter,epoch)

        if cfg_s.sched_on_epoch:
            scheduler.step(epoch)
        if cfg_s.rank == 0:
            save_checkpoint(cfg_s, model, epoch, optimizer, scheduler,
                            additioanl_dict={'best_val': best_val},
                            is_best=is_best
                            )
            is_best = False
    # do not save file to wandb to save wandb space
    # if writer is not None:
    #     Wandb.add_file(os.path.join(cfg_s.ckpt_dir, f'{cfg_s.run_name}_ckpt_best.pth'))
    # Wandb.add_file(os.path.join(cfg_s.ckpt_dir, f'{cfg_s.logname}_ckpt_latest.pth'))

    # validate
    with np.printoptions(precision=2, suppress=True):
        logging.info(
            f'Best ckpt @E{best_epoch},  val_oa {oa_when_best:.2f}, val_macc {macc_when_best:.2f}, val_miou {best_val:.2f}, '
            f'\niou per cls is: {ious_when_best}')

    if cfg_s.world_size < 2:  # do not support multi gpu testing
        # test
        load_checkpoint(model, pretrained_path=os.path.join(cfg_s.ckpt_dir, f'{cfg_s.run_name}_ckpt_best.pth'))
        cfg_s.csv_path = os.path.join(cfg_s.run_dir, cfg_s.run_name + f'.csv')
        if 'sphere' in cfg_s.dataset.common.NAME.lower():
            # TODO: 
            test_miou, test_macc, test_oa, test_ious, test_accs = validate_sphere(model, val_loader, cfg_s, epoch=epoch)
        else:
            data_list = generate_data_list(cfg_s)
            test_miou, test_macc, test_oa, test_ious, test_accs, _ = test(model, data_list, cfg_s)
        with np.printoptions(precision=2, suppress=True):
            logging.info(
                f'Best ckpt @E{best_epoch},  test_oa {test_oa:.2f}, test_macc {test_macc:.2f}, test_miou {test_miou:.2f}, '
                f'\niou per cls is: {test_ious}')
        if writer is not None:
            writer.add_scalar('test_miou', test_miou, epoch)
            writer.add_scalar('test_macc', test_macc, epoch)
            writer.add_scalar('test_oa', test_oa, epoch)
        write_to_csv(test_oa, test_macc, test_miou, test_ious, best_epoch, cfg_s, write_header=True)
        logging.info(f'save results in {cfg_s.csv_path}')
        if cfg_s.use_voting:
            load_checkpoint(model, pretrained_path=os.path.join(cfg_s.ckpt_dir, f'{cfg_s.run_name}_ckpt_best.pth'))
            set_random_seed(cfg_s.seed)
            val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg_s, num_votes=20,
                                                                         data_transform=data_transform, epoch=epoch)
            if writer is not None:
                writer.add_scalar('val_miou20', val_miou, cfg_s.epochs + 50)

            ious_table = [f'{item:.2f}' for item in val_ious]
            data = [cfg_s.cfg_basename, 'True', f'{val_oa:.2f}', f'{val_macc:.2f}', f'{val_miou:.2f}'] + ious_table + [
                str(best_epoch), cfg_s.run_dir]
            with open(cfg_s.csv_path, 'w', encoding='UT8') as f:
                writer = csv.writer(f)
                writer.writerow(data)
    else:
        logging.warning('Testing using multiple GPUs is not allowed for now. Running testing after this training is required.')
    if writer is not None:
        writer.close()
    # dist.destroy_process_group() # comment this line due to https://github.com/guochengqian/PointNeXt/issues/95
    wandb.finish(exit_code=True)


def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, total_iter, cfg_s,teacher_model,train_loss_func,temperature,writer):
    loss_meter = AverageMeter()
    cm = ConfusionMatrix(num_classes=cfg_s.num_classes, ignore_index=cfg_s.ignore_index)
    model.train()  # set model to training mode
    pbar = tqdm(enumerate(train_loader), total=train_loader.__len__())
    num_iter = 0
    network_loss_meter = AverageMeter()
    rkd_loss_meter = AverageMeter()
    cost_meter = AverageMeter()
    kl_loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    effective_channel_meter = AverageMeter()
    # network_losses = []
    # rkd_losses = []
    # kl_losses = []
    # costs = []
    for idx, data in pbar:
        keys = data.keys() if callable(data.keys) else data.keys
        for key in keys:
            data[key] = data[key].cuda(non_blocking=True)
        num_iter += 1
        target = data['y'].squeeze(-1)
        """ debug
        from openpoints.dataset import vis_points
        vis_points(data['pos'].cpu().numpy()[0], labels=data['y'].cpu().numpy()[0])
        vis_points(data['pos'].cpu().numpy()[0], data['x'][0, :3, :].transpose(1, 0))
        end of debug """
        data['x'] = get_features_by_keys(data, cfg_s.feature_keys)
        data['epoch'] = epoch
        total_iter += 1 
        data['iter'] = total_iter 
        # outputs = []
        with torch.no_grad():
                all_output,output_layer= teacher_model(data) #torch.Size([32, 15])
                # all_output= all_output.transpose(1,0) #torch.Size([15, 32])
                teacher_output = all_output #torch.Size([15, 32])
        with torch.cuda.amp.autocast(enabled=cfg_s.use_amp):
            logits,cost_accumulate, kl_accumulate,effective_channel = model(data,output_layer,temperature)
            # outputs.append(logits)
            search_loss_func = search_loss(0.2, 0.5, train_loss_func, 1.0)
            loss, network_loss, rkd_loss, kl_loss, latency_loss = search_loss_func \
                (logits, teacher_output, cost_accumulate, target, kl_accumulate)
            loss, network_loss, rkd_loss, kl_loss, latency_loss = loss.mean(), \
                network_loss.mean(), rkd_loss.mean(), kl_loss, latency_loss.mean()
            # network_losses.append(network_loss)
            # rkd_losses.append(rkd_loss)
            # costs.append(latency_loss)
            # kl_losses.append(kl_loss)
            # loss = criterion(logits, target) if 'mask' not in cfg_s.criterion_args.NAME.lower() \
            #     else criterion(logits, target, data['mask'])
        # outputs = torch.cat(logits)
        if cfg_s.use_amp:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        # optimize
        if num_iter == cfg_s.step_per_update:
            if cfg_s.get('grad_norm_clip') is not None and cfg_s.grad_norm_clip > 0.:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg_s.grad_norm_clip, norm_type=2)
            num_iter = 0

            if cfg_s.use_amp:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()

            optimizer.zero_grad()
            if not cfg_s.sched_on_epoch:
                scheduler.step(epoch)
            # mem = torch.cuda.max_memory_allocated() / 1024. / 1024.
            # print(f"Memory after backward is {mem}")
        acc1, acc5 = accuracy(
                            (logits.view(1, logits.size(1),-1)).squeeze().transpose(1,0),
                            target.view(-1),
                            topk=(1, 5))
        loss = loss.item()
        effective_channel=effective_channel.item()
        effective_channel_meter.update(effective_channel)
        writer.add_scalar('effective_channel', effective_channel_meter.avg, idx)
        rkd_loss = rkd_loss.item()
        # kl_loss = kl_loss.item()
        cost = latency_loss.item()
        network_loss = network_loss.item()
        acc1 = acc1.item()
        acc5 = acc5.item()
        
        # num = logits.view(-1, 13).size(0)
        num = logits.view(-1, 40).size(0)
        # loss_meter.update(loss, num)
        rkd_loss_meter.update(rkd_loss, num)
        kl_loss_meter.update(kl_loss, num)
        cost_meter.update(cost, num)
        network_loss_meter.update(network_loss, num)
        acc1_meter.update(acc1, num)
        acc5_meter.update(acc5, num)
            
        # update confusion matrix
        cm.update(logits.argmax(dim=1), target)
        loss_meter.update(loss)

        if idx % cfg_s.print_freq:
            pbar.set_description(f"Train Epoch [{epoch}/{cfg_s.epochs}] "
                                f"Loss {loss_meter.val:.3f} Acc {cm.overall_accuray:.2f}"
                                f'cost {cost_meter.val:.4f} ({cost_meter.avg:.4f}) '
                                f'network_loss {network_loss_meter.val:.4f} ({network_loss_meter.avg:.4f}) '
                                f'intermediate-kl {rkd_loss_meter.val:.8f} ({rkd_loss_meter.avg:.8f}) '
                                f'rkd_loss {0.2*rkd_loss_meter.val:.4f} ({0.2*rkd_loss_meter.avg:.4f}) '
                                f'kl {kl_loss_meter.val:.4f} ({kl_loss_meter.avg:.4f}) '
                                f'acc@1 {acc1_meter.val:.4f} ({acc1_meter.avg:.4f}) '
                                f'acc@5 {acc5_meter.val:.4f} ({acc5_meter.avg:.4f})'
                                )
        
    miou, macc, oa, ious, accs = cm.all_metrics()
    return loss_meter.avg, miou, macc, oa, ious, accs, total_iter,cost_meter.avg,network_loss_meter.avg,rkd_loss_meter.avg,kl_loss_meter.avg,acc1_meter.avg,acc5_meter.avg


@torch.no_grad()
def validate(model, val_loader, cfg_s, teacher_model, temperature,num_votes=1, data_transform=None, epoch=-1, total_iter=-1):
    model.eval()  # set model to eval mode
    cm = ConfusionMatrix(num_classes=cfg_s.num_classes, ignore_index=cfg_s.ignore_index)
    pbar = tqdm(enumerate(val_loader), total=val_loader.__len__(), desc='Val')
    # network_loss_meter = AverageMeter()
    # rkd_loss_meter = AverageMeter()
    # cost_meter = AverageMeter()
    # kl_loss_meter = AverageMeter()
    # acc1_meter = AverageMeter()
    # acc5_meter = AverageMeter()
    for idx, data in pbar:
        keys = data.keys() if callable(data.keys) else data.keys
        for key in keys:
            data[key] = data[key].cuda(non_blocking=True)
        target = data['y'].squeeze(-1)
        data['x'] = get_features_by_keys(data, cfg_s.feature_keys)
        data['epoch'] = epoch
        data['iter'] = total_iter 
        with torch.no_grad():
                # all_output = teacher_model(data).squeeze(0)
                all_output,output_layer = teacher_model(data)
        logits,cost_accumulate, kl_accumulate,effective_channel = model(data,output_layer,temperature)
        # logits=logits.squeeze(0)
        # search_loss_func = search_loss(0.2, 0.5, train_loss_func, 1.0)
        # loss, network_loss, rkd_loss, kl_loss, latency_loss = search_loss_func \
        #     (logits, all_output, cost_accumulate, target, kl_accumulate)
        # loss, network_loss, rkd_loss, kl_loss, latency_loss = loss.mean(), \
        #     network_loss.mean(), rkd_loss.mean(), kl_loss, latency_loss.mean()
        # acc1, acc5 = accuracy(
        #                     logits.view(-1, 13),
        #                     target.view(-1,1).squeeze(),
        #                     topk=(1, 5))
        # loss = loss.item()
        # rkd_loss = rkd_loss.item()
        # # kl_loss = kl_loss.item()
        # cost = latency_loss.item()
        # network_loss = network_loss.item()
        # acc1 = acc1.item()
        # acc5 = acc5.item()
        
        # num = logits.view(-1, 13).size(0)
        # # loss_meter.update(loss, num)
        # rkd_loss_meter.update(rkd_loss, num)
        # kl_loss_meter.update(kl_loss, num)
        # cost_meter.update(cost, num)
        # network_loss_meter.update(network_loss, num)
        # acc1_meter.update(acc1, num)
        # acc5_meter.update(acc5, num)
        if 'mask' not in cfg_s.criterion_args.NAME or cfg_s.get('use_maks', False):
            cm.update(logits.argmax(dim=1), target)
        else:
            mask = data['mask'].bool()
            cm.update(logits.argmax(dim=1)[mask], target[mask])

        """visualization in debug mode
        from openpoints.dataset.vis3d import vis_points, vis_multi_points
        coord = data['pos'].cpu().numpy()[0]
        pred = logits.argmax(dim=1)[0].cpu().numpy()
        label = target[0].cpu().numpy()
        if cfg.ignore_index is not None:
            if (label == cfg.ignore_index).sum() > 0:
                pred[label == cfg.ignore_index] = cfg.num_classes
                label[label == cfg.ignore_index] = cfg.num_classes
        vis_multi_points([coord, coord], labels=[label, pred])
        """
        # tp, union, count = cm.tp, cm.union, cm.count
        # if cfg.distributed:
        #     dist.all_reduce(tp), dist.all_reduce(union), dist.all_reduce(count)
        # miou, macc, oa, ious, accs = get_mious(tp, union, count)
        # with np.printoptions(precision=2, suppress=True):
        #     logging.info(f'{idx}-th cloud,  test_oa , test_macc, test_miou: {oa:.2f} {macc:.2f} {miou:.2f}, '
        #                 f'\niou per cls is: {ious}')

    tp, union, count = cm.tp, cm.union, cm.count
    if cfg_s.distributed:
        dist.all_reduce(tp), dist.all_reduce(union), dist.all_reduce(count)
    miou, macc, oa, ious, accs = get_mious(tp, union, count)
    return miou, macc, oa, ious, accs


@torch.no_grad()
def validate_sphere(model, val_loader, cfg_s, num_votes=1, data_transform=None, epoch=-1, total_iter=-1):
    """
    validation for sphere sampled input points with mask.
    in this case, between different batches, there are overlapped points.
    thus, one point can be evaluated multiple times.
    In this validate_mask, we will avg the logits.
    """
    model.eval()  # set model to eval mode
    cm = ConfusionMatrix(num_classes=cfg_s.num_classes, ignore_index=cfg_s.ignore_index)
    if cfg_s.get('visualize', False):
        from openpoints.dataset.vis3d import write_obj
        cfg_s.vis_dir = os.path.join(cfg_s.run_dir, 'visualization')
        os.makedirs(cfg_s.vis_dir, exist_ok=True)
        cfg_s.cmap = cfg_s.cmap.astype(np.float32) / 255.

    pbar = tqdm(enumerate(val_loader), total=val_loader.__len__())
    all_logits, idx_points = [], []
    for idx, data in pbar:
        for key in data.keys():
            data[key] = data[key].cuda(non_blocking=True)
        data['x'] = get_features_by_keys(data, cfg_s.feature_keys)
        data['epoch'] = epoch
        data['iter'] = total_iter 
        logits = model(data)
        all_logits.append(logits)
        idx_points.append(data['input_inds'])
    all_logits = torch.cat(all_logits, dim=0).transpose(1, 2).reshape(-1, cfg_s.num_classes)
    idx_points = torch.cat(idx_points, dim=0).flatten()

    if cfg_s.distributed:
        dist.all_reduce(all_logits), dist.all_reduce(idx_points)

    # average overlapped predictions to subsampled points
    all_logits = scatter(all_logits, idx_points, dim=0, reduce='mean')

    # now, project the original points to the subsampled points
    # these two targets would be very similar but not the same
    # val_points_targets = all_targets[val_points_projections]
    # torch.allclose(val_points_labels, val_points_targets)
    all_logits = all_logits.argmax(dim=1)
    val_points_labels = torch.from_numpy(val_loader.dataset.clouds_points_labels[0]).squeeze(-1).to(all_logits.device)
    val_points_projections = torch.from_numpy(val_loader.dataset.projections[0]).to(all_logits.device).long()
    val_points_preds = all_logits[val_points_projections]

    del all_logits, idx_points
    torch.cuda.empty_cache()

    cm.update(val_points_preds, val_points_labels)
    miou, macc, oa, ious, accs = cm.all_metrics()

    if cfg_s.get('visualize', False):
        dataset_name = cfg_s.dataset.common.NAME.lower()
        coord = val_loader.dataset.clouds_points[0]
        colors = val_loader.dataset.clouds_points_colors[0].astype(np.float32)
        gt = val_points_labels.cpu().numpy().squeeze()
        pred = val_points_preds.cpu().numpy().squeeze()
        gt = cfg_s.cmap[gt, :]
        pred = cfg_s.cmap[pred, :]
        # output pred labels
        # save per room
        rooms = val_loader.dataset.clouds_rooms[0]

        for idx in tqdm(range(len(rooms)-1), desc='save visualization'):
            start_idx, end_idx = rooms[idx], rooms[idx+1]
            write_obj(coord[start_idx:end_idx], colors[start_idx:end_idx],
                        os.path.join(cfg_s.vis_dir, f'input-{dataset_name}-{idx}.obj'))
            # output ground truth labels
            write_obj(coord[start_idx:end_idx], gt[start_idx:end_idx],
                        os.path.join(cfg_s.vis_dir, f'gt-{dataset_name}-{idx}.obj'))
            # output pred labels
            write_obj(coord[start_idx:end_idx], pred[start_idx:end_idx],
                        os.path.join(cfg_s.vis_dir, f'{cfg_s.cfg_basename}-{dataset_name}-{idx}.obj'))
    return miou, macc, oa, ious, accs


# TODO: multi gpu support. Warp to a dataloader.
@torch.no_grad()
def test(model, data_list, cfg_s, num_votes=1):
    """using a part of original point cloud as input to save memory.
    Args:
        model (_type_): _description_
        test_loader (_type_): _description_
        cfg_s (_type_): _description_
        num_votes (int, optional): _description_. Defaults to 1.
    Returns:
        _type_: _description_
    """
    model.eval()  # set model to eval mode
    all_cm = ConfusionMatrix(num_classes=cfg_s.num_classes, ignore_index=cfg_s.ignore_index)
    set_random_seed(0)
    cfg_s.visualize = cfg_s.get('visualize', True)
    if cfg_s.visualize:
        from openpoints.dataset.vis3d import write_obj
        cfg_s.vis_dir = os.path.join(cfg_s.run_dir, 'visualization')
        os.makedirs(cfg_s.vis_dir, exist_ok=True)
        cfg_s.cmap = cfg_s.cmap.astype(np.float32) / 255.

    # data
    trans_split = 'val' if cfg_s.datatransforms.get('test', None) is None else 'test'
    pipe_transform = build_transforms_from_cfg(trans_split, cfg_s.datatransforms)

    dataset_name = cfg_s.dataset.common.NAME.lower()
    len_data = len(data_list)

    cfg_s.save_path = cfg_s.get('save_path', f'results/{cfg_s.task_name}/{cfg_s.dataset.test.split}/{cfg_s.cfg_basename}')
    if 'semantickitti' in cfg_s.dataset.common.NAME.lower():
        cfg_s.save_path = os.path.join(cfg_s.save_path, str(cfg_s.dataset.test.test_id + 11), 'predictions')
    os.makedirs(cfg_s.save_path, exist_ok=True)

    gravity_dim = cfg_s.datatransforms.kwargs.gravity_dim
    nearest_neighbor = cfg_s.get('test_mode', 'multi_voxel') == 'nearest_neighbor'
    for cloud_idx, data_path in enumerate(data_list):
        logging.info(f'Test [{cloud_idx}]/[{len_data}] cloud')
        cm = ConfusionMatrix(num_classes=cfg_s.num_classes, ignore_index=cfg_s.ignore_index)
        all_logits = []
        coord, feat, label, idx_points, voxel_idx, reverse_idx_part, reverse_idx  = load_data(data_path, cfg_s)
        if label is not None:
            label = torch.from_numpy(label.astype(np.int).squeeze()).cuda(non_blocking=True)

        len_part = len(idx_points)
        nearest_neighbor = len_part == 1
        pbar = tqdm(range(len(idx_points)))
        for idx_subcloud in pbar:
            pbar.set_description(f"Test on {cloud_idx}-th cloud [{idx_subcloud}]/[{len_part}]]")
            if not (nearest_neighbor and idx_subcloud>0):
                idx_part = idx_points[idx_subcloud]
                coord_part = coord[idx_part]
                coord_part -= coord_part.min(0)

                feat_part =  feat[idx_part] if feat is not None else None
                data = {'pos': coord_part}
                if feat_part is not None:
                    data['x'] = feat_part
                if pipe_transform is not None:
                    data = pipe_transform(data)
                if 'heights' in cfg_s.feature_keys and 'heights' not in data.keys():
                    if 'semantickitti' in cfg_s.dataset.common.NAME.lower():
                        data['heights'] = torch.from_numpy((coord_part[:, gravity_dim:gravity_dim + 1] - coord_part[:, gravity_dim:gravity_dim + 1].min()).astype(np.float32)).unsqueeze(0)
                    else:
                        data['heights'] = torch.from_numpy(coord_part[:, gravity_dim:gravity_dim + 1].astype(np.float32)).unsqueeze(0)
                if not cfg_s.dataset.common.get('variable', False):
                    if 'x' in data.keys():
                        data['x'] = data['x'].unsqueeze(0)
                    data['pos'] = data['pos'].unsqueeze(0)
                else:
                    data['o'] = torch.IntTensor([len(coord)])
                    data['batch'] = torch.LongTensor([0] * len(coord))

                for key in data.keys():
                    data[key] = data[key].cuda(non_blocking=True)
                data['x'] = get_features_by_keys(data, cfg_s.feature_keys)
                logits = model(data)
                """visualization in debug mode. !!! visulization is not correct, should remove ignored idx.
                from openpoints.dataset.vis3d import vis_points, vis_multi_points
                vis_multi_points([coord, coord_part], labels=[label.cpu().numpy(), logits.argmax(dim=1).squeeze().cpu().numpy()])
                """

            all_logits.append(logits)
        all_logits = torch.cat(all_logits, dim=0)
        if not cfg_s.dataset.common.get('variable', False):
            all_logits = all_logits.transpose(1, 2).reshape(-1, cfg_s.num_classes)

        if not nearest_neighbor:
            # average merge overlapped multi voxels logits to original point set
            idx_points = torch.from_numpy(np.hstack(idx_points)).cuda(non_blocking=True)
            all_logits = scatter(all_logits, idx_points, dim=0, reduce='mean')
        else:
            # interpolate logits by nearest neighbor
            all_logits = all_logits[reverse_idx_part][voxel_idx][reverse_idx]
        pred = all_logits.argmax(dim=1)
        if label is not None:
            cm.update(pred, label)
        """visualization in debug mode
        from openpoints.dataset.vis3d import vis_points, vis_multi_points
        vis_multi_points([coord, coord], labels=[label.cpu().numpy(), all_logits.argmax(dim=1).squeeze().cpu().numpy()])
        """
        if cfg_s.visualize:
            gt = label.cpu().numpy().squeeze() if label is not None else None
            pred = pred.cpu().numpy().squeeze()
            gt = cfg_s.cmap[gt, :] if gt is not None else None
            pred = cfg_s.cmap[pred, :]
            # output pred labels
            if 's3dis' in dataset_name:
                file_name = f'{dataset_name}-Area{cfg_s.dataset.common.test_area}-{cloud_idx}'
            else:
                file_name = f'{dataset_name}-{cloud_idx}'

            write_obj(coord, feat,
                      os.path.join(cfg_s.vis_dir, f'input-{file_name}.obj'))
            # output ground truth labels
            if gt is not None:
                write_obj(coord, gt,
                        os.path.join(cfg_s.vis_dir, f'gt-{file_name}.obj'))
            # output pred labels
            write_obj(coord, pred,
                      os.path.join(cfg_s.vis_dir, f'{cfg_s.cfg_basename}-{file_name}.obj'))

        if cfg_s.get('save_pred', False):
            if 'semantickitti' in cfg_s.dataset.common.NAME.lower():
                pred = pred + 1
                pred = pred.cpu().numpy().squeeze()
                pred = pred.astype(np.uint32)
                upper_half = pred >> 16  # get upper half for instances
                lower_half = pred & 0xFFFF  # get lower half for semantics (lower_half.shape) (100k+, )
                lower_half = remap_lut_write[lower_half]  # do the remapping of semantics
                pred = (upper_half << 16) + lower_half  # reconstruct full label
                pred = pred.astype(np.uint32)
                frame_id = data_path[0].split('/')[-1][:-4]
                store_path = os.path.join(cfg_s.save_path, frame_id + '.label')
                pred.tofile(store_path)
            elif 'scannet' in cfg_s.dataset.common.NAME.lower():
                pred = pred.cpu().numpy().squeeze()
                label_int_mapping={0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 14, 13: 16, 14: 24, 15: 28, 16: 33, 17: 34, 18: 36, 19: 39}
                pred=np.vectorize(label_int_mapping.get)(pred)
                save_file_name=data_path.split('/')[-1].split('_')
                save_file_name=save_file_name[0]+'_'+save_file_name[1]+'.txt'
                save_file_name=os.path.join(cfg_s.save_path,save_file_name)
                np.savetxt(save_file_name, pred, fmt="%d")

        if label is not None:
            tp, union, count = cm.tp, cm.union, cm.count
            miou, macc, oa, ious, accs = get_mious(tp, union, count)
            with np.printoptions(precision=2, suppress=True):
                logging.info(
                    f'[{cloud_idx}]/[{len_data}] cloud,  test_oa , test_macc, test_miou: {oa:.2f} {macc:.2f} {miou:.2f}, '
                    f'\niou per cls is: {ious}')
            all_cm.value += cm.value

    if 'scannet' in cfg_s.dataset.common.NAME.lower():
        logging.info(f" Please select and zip all the files (DON'T INCLUDE THE FOLDER) in {cfg_s.save_path} and submit it to"
                     f" Scannet Benchmark https://kaldir.vc.in.tum.de/scannet_benchmark/. ")

    if label is not None:
        tp, union, count = all_cm.tp, all_cm.union, all_cm.count
        if cfg_s.distributed:
            dist.all_reduce(tp), dist.all_reduce(union), dist.all_reduce(count)
        miou, macc, oa, ious, accs = get_mious(tp, union, count)
        return miou, macc, oa, ious, accs, all_cm
    else:
        return None, None, None, None, None, None


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Scene segmentation training/testing')
    parser.add_argument('--cfg', type=str, required=True, help='config file')
    parser.add_argument('--cfg_s', type=str, required=True, help='config file')
    # parser.add_argument('--cfg', type=str, default="cfgs/s3dis/pointnext-xl.yaml", help='config file')
    parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')
    args, opts = parser.parse_known_args()
    cfg = EasyConfig()
    cfg_s = EasyConfig()
    cfg.load(args.cfg, recursive=True)
    cfg_s.load(args.cfg_s, recursive=True)
    cfg.update(opts)  # overwrite the default arguments in yml
    cfg_s.update(opts)

    if cfg.seed is None:
        cfg.seed = np.random.randint(1, 10000)
    if cfg_s.seed is None:
        cfg_s.seed = np.random.randint(1, 10000)

    # init distributed env first, since logger depends on the dist info.
    cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)
    cfg.sync_bn = cfg.world_size > 1
    cfg_s.rank, cfg_s.world_size, cfg_s.distributed, cfg_s.mp = dist_utils.get_dist_info(cfg_s)
    cfg_s.sync_bn = cfg_s.world_size > 1

    # init log dir
    cfg.task_name = args.cfg.split('.')[-2].split('/')[-2]  # task/dataset name, \eg s3dis, modelnet40_cls
    cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1]  # cfg_basename, \eg pointnext-xl
    cfg_s.task_name = args.cfg_s.split('.')[-2].split('/')[-2]  # task/dataset name, \eg s3dis, modelnet40_cls
    cfg_s.cfg_basename = args.cfg_s.split('.')[-2].split('/')[-1]  # cfg_basename, \eg pointnext-xl
    tags = [
        cfg.task_name,  # task name (the folder of name under ./cfgs
        cfg.mode,
        cfg.cfg_basename,  # cfg file name
        f'ngpus{cfg.world_size}',
    ]
    tags_s = [
        cfg_s.task_name,  # task name (the folder of name under ./cfgs
        cfg_s.mode,
        cfg_s.cfg_basename,  # cfg_s file name
        f'ngpus{cfg_s.world_size}',
    ]
    opt_list = [] # for checking experiment configs from logging file
    opt_list_s = []
    for i, opt in enumerate(opts):
        if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:
            opt_list.append(opt)
    cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)
    cfg.opts = '-'.join(opt_list)
    cfg_s.root_dir = os.path.join(cfg_s.root_dir, cfg_s.task_name)
    cfg_s.opts = '-'.join(opt_list_s)

    cfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']
    cfg_s.is_training = cfg_s.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']
    if cfg.mode in ['resume', 'val', 'test']:
        resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)
        cfg.wandb.tags = [cfg.mode]
    else:
        generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))
        cfg.wandb.tags = tags
    if cfg_s.mode in ['resume', 'val', 'test']:
        resume_exp_directory(cfg_s, pretrained_path=cfg_s.pretrained_path)
        cfg_s.wandb.tags = [cfg_s.mode]
    else:
        generate_exp_directory(cfg_s, tags_s, additional_id=os.environ.get('MASTER_PORT', None))
        cfg_s.wandb.tags_s = tags_s

    os.environ["JOB_LOG_DIR"] = cfg.log_dir
    os.environ["JOB_LOG_DIR_S"] = cfg_s.log_dir
    cfg_path = os.path.join(cfg.run_dir, "cfg.yaml")
    cfg_s_path = os.path.join(cfg_s.run_dir, "cfg.yaml")
    with open(cfg_path, 'w') as f:
        yaml.dump(cfg, f, indent=2)
        os.system('cp %s %s' % (args.cfg, cfg.run_dir))
    cfg.cfg_path = cfg_path
    with open(cfg_s_path, 'w') as f:
        yaml.dump(cfg_s, f, indent=2)
        os.system('cp %s %s' % (args.cfg_s, cfg_s.run_dir))
    cfg_s.cfg_path = cfg_s_path

    # wandb config
    cfg.wandb.name = cfg.run_name
    cfg_s.wandb.name = cfg_s.run_name

    # multi processing.
    if cfg_s.mp:
        port = find_free_port()
        cfg_s.dist_url = f"tcp://localhost:{port}"
        print('using mp spawn for distributed training')
        mp.spawn(main, nprocs=cfg_s.world_size, args=(cfg_s,))
    else:
        main(0, cfg, cfg_s)
