"""Shared utilities for all main scripts."""

import argparse
import json
import os
import time

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter

from models.loss_helper import compute_hungarian_loss, get_ce_loss
from models.losses import HungarianMatcher, SetCriterion, ViewpointCriterion
from models import get_loss
from utils import get_scheduler, setup_logger, pc_util, vis_utils
import wandb

import ipdb
st = ipdb.set_trace


def parse_option():
    """Parse cmd arguments."""
    parser = argparse.ArgumentParser()
    # Model
    parser.add_argument('--width', default=1, type=int,
                        help='backbone width')
    parser.add_argument('--num_target', type=int, default=256,
                        help='Proposal number')
    parser.add_argument('--sampling', default='kps', type=str,
                        help='Query points sampling method (kps, fps)')

    # Transformer
    parser.add_argument('--nhead', default=8, type=int,
                        help='number of attention heads')
    parser.add_argument('--num_encoder_layers', default=3, type=int)
    parser.add_argument('--num_decoder_layers', default=6, type=int)
    parser.add_argument('--dim_feedforward', default=2048, type=int)
    parser.add_argument('--transformer_dropout', default=0.1, type=float)
    parser.add_argument('--transformer_activation', default='relu', type=str)
    parser.add_argument('--self_position_embedding', default='loc_learned',
                        type=str, help='(none, xyz_learned, loc_learned)')
    parser.add_argument('--cross_position_embedding', default='xyz_learned',
                        type=str, help='(none, xyz_learned)')

    # Loss
    parser.add_argument('--query_points_generator_loss_coef', default=0.8,
                        type=float)
    parser.add_argument('--obj_loss_coef', default=0.1, type=float,
                        help='Loss weight for objectness loss')
    parser.add_argument('--box_loss_coef', default=1, type=float,
                        help='Loss weight for box loss')
    parser.add_argument('--sem_cls_loss_coef', default=0.1, type=float,
                        help='Loss weight for classification loss')
    parser.add_argument('--center_loss_type', default='smoothl1', type=str,
                        help='(smoothl1, l1)')
    parser.add_argument('--center_delta', default=0.04, type=float,
                        help='delta for smoothl1 loss in center loss')
    parser.add_argument('--size_loss_type', default='smoothl1', type=str,
                        help='(smoothl1, l1)')
    parser.add_argument('--size_delta', default=0.111111111111, type=float,
                        help='delta for smoothl1 loss in size loss')
    parser.add_argument('--heading_loss_type', default='smoothl1', type=str,
                        help='(smoothl1, l1)')
    parser.add_argument('--heading_delta', default=1.0, type=float,
                        help='delta for smoothl1 loss in heading loss')
    parser.add_argument('--query_points_obj_topk', default=4, type=int)
    parser.add_argument('--size_cls_agnostic', action='store_true',
                        help='Use class-agnostic size prediction.')
    parser.add_argument('--regularizer_coef', default=0.05, type=float,
                        help='Loss weight for offset regularizer')
    parser.add_argument('--use_deform_aux_losses', action='store_true')

    # Data
    parser.add_argument('--batch_size', type=int, default=8,
                        help='Batch Size during training')
    parser.add_argument('--dataset', type=str, default=['sr3d'],
                        nargs='+', help='list of datasets to train on')
    parser.add_argument('--test_dataset', default='sr3d')
    parser.add_argument('--num_point', type=int, default=50000,
                        help='Point Number')
    parser.add_argument(
        '--data_root',
        default='./dataset/language_grounding/scans/')
    parser.add_argument('--use_height', action='store_true',
                        help='Use height signal in input.')
    parser.add_argument('--use_color', action='store_true',
                        help='Use RGB color in input.')
    parser.add_argument('--use_sunrgbd_v2', action='store_true',
                        help='Use V2 box labels for SUN RGB-D dataset')
    parser.add_argument('--num_workers', type=int, default=4)

    # Training
    parser.add_argument('--start_epoch', type=int, default=1)
    parser.add_argument('--max_epoch', type=int, default=400)
    parser.add_argument('--optimizer', type=str, default='adamW')
    parser.add_argument('--momentum', type=float, default=0.9)  # if applicable
    parser.add_argument('--weight_decay', type=float, default=0.0005)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--lr_backbone", default=1e-5, type=float)
    parser.add_argument("--text_encoder_lr", default=5e-5, type=float)
    parser.add_argument('--lr-scheduler', type=str, default='step',
                        choices=["step", "cosine"])
    parser.add_argument('--warmup-epoch', type=int, default=-1)
    parser.add_argument('--warmup-multiplier', type=int, default=100)
    parser.add_argument('--lr_decay_epochs', type=int, default=[280, 340],
                        nargs='+', help='when to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='for step scheduler. decay rate for lr')
    parser.add_argument('--clip_norm', default=0.1, type=float,
                        help='gradient clipping max norm')
    parser.add_argument('--bn_momentum', type=float, default=0.1)
    parser.add_argument('--syncbn', action='store_true')

    # io
    parser.add_argument('--checkpoint_path', default=None,
                        help='Model checkpoint path')
    parser.add_argument('--log_dir', default='log',
                        help='Dump dir to save model checkpoint')
    parser.add_argument('--print_freq', type=int, default=10)  # batch-wise
    parser.add_argument('--save_freq', type=int, default=10)  # epoch-wise
    parser.add_argument('--val_freq', type=int, default=5)  # epoch-wise

    # others
    parser.add_argument("--local_rank", type=int,
                        help='local rank for DistributedDataParallel')
    parser.add_argument('--ap_iou_thresholds', type=float, default=[0.25, 0.5],
                        nargs='+', help='A list of AP IoU thresholds')
    parser.add_argument("--rng_seed", type=int, default=0, help='manual seed')
    parser.add_argument("--debug", action='store_true',
                        help="try to overfit few samples")
    parser.add_argument("--deformable", action='store_true',
                        help="run deformable modulation model")
    parser.add_argument('--use_hungarian_loss', action='store_true',
                        help='Use only the hungarian matching loss')
    parser.add_argument('--eval', default=False, action='store_true')
    parser.add_argument('--use_contrastive_align', action='store_true')
    parser.add_argument('--use_soft_token_loss', action='store_true')
    parser.add_argument('--detect_intermediate', action='store_true')
    parser.add_argument('--contrastive_hungarian', action='store_true')
    parser.add_argument('--eval_anchors', action='store_true')
    parser.add_argument('--object_det_language', action='store_true')
    parser.add_argument('--joint_det', action='store_true')
    parser.add_argument('--visualize', action='store_true')
    parser.add_argument('--eval_train', action='store_true')
    parser.add_argument('--use_gt_box', action='store_true')
    parser.add_argument('--use_gt_class', action='store_true')
    parser.add_argument('--no_sa_lang', default=False, action='store_true')
    parser.add_argument('--no_sa_vis', default=False, action='store_true')
    parser.add_argument('--gt_with_bbox_loss', action='store_true')
    parser.add_argument('--gt_with_bbox_sampling', action='store_true')
    parser.add_argument('--new_contrastive', action='store_true')
    parser.add_argument('--use_gt_grounder', action='store_true')
    parser.add_argument('--freeze_text_encoder', action='store_true')
    parser.add_argument('--filter_relations', action='store_true')
    parser.add_argument('--apply_classifiers', action='store_true')
    parser.add_argument('--ce_variant', action='store_true')
    parser.add_argument('--agnostic', action='store_true')
    parser.add_argument('--use_detected_boxes', action='store_true')
    parser.add_argument('--use_logits', action='store_true')  # exist? yes
    parser.add_argument('--use_oriented_boxes', action='store_true')  # exist? yes
    parser.add_argument('--no_augment', action='store_true')  # exist? yes
    parser.add_argument('--rotate_pc', action='store_true')
    parser.add_argument('--use_multiview', action='store_true')
    parser.add_argument('--train_viewpoint_module', action='store_true')
    parser.add_argument('--train_viewpoint_prototype', action='store_true')
    parser.add_argument('--teacher_forcing', action='store_true')
    parser.add_argument('--butd', action='store_true')
    parser.add_argument('--butd_gt', action='store_true')
    parser.add_argument('--butd_cls', action='store_true')
    parser.add_argument('--box_only_butd', action='store_true')
    parser.add_argument('--run_on_target_phrases', action='store_true')

    args, _ = parser.parse_known_args()

    args.high_res = args.num_point > 50000
    args.eval = args.eval or args.eval_train
    args.size_cls_agnostic = args.size_cls_agnostic or args.agnostic

    return args


def load_checkpoint(args, model, optimizer, scheduler):
    """Load from checkpoint."""
    print("=> loading checkpoint '{}'".format(args.checkpoint_path))

    checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
    try:
        args.start_epoch = int(checkpoint['epoch']) + 1
    except Exception:
        args.start_epoch = 0
    model.load_state_dict(checkpoint['model'], strict=False)
    if not args.eval:
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    print("=> loaded successfully '{}' (epoch {})".format(
        args.checkpoint_path, checkpoint['epoch']
    ))

    del checkpoint
    torch.cuda.empty_cache()


def save_checkpoint(args, epoch, model, optimizer, scheduler, save_cur=False):
    """Save checkpoint if requested."""
    if save_cur or epoch % args.save_freq == 0:
        state = {
            'config': args,
            'save_path': '',
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch
        }
        spath = os.path.join(args.log_dir, f'ckpt_epoch_{epoch}.pth')
        state['save_path'] = spath
        torch.save(state, spath)
        print("Saved in {}".format(spath))
    else:
        print("not saving checkpoint")


class BaseTrainTester:
    """Basic train/test class to be inherited."""

    def __init__(self, args):
        """Initialize."""
        name = args.log_dir.split('/')[-1]
        # Create log dir
        log_dir = ''
        while log_dir == '' or os.path.exists(log_dir):
            log_dir = os.path.join(
                args.log_dir,
                f'{args.dataset}_{int(time.time())}',
                f'{np.random.randint(100000000)}'
            )
        args.log_dir = log_dir
        os.makedirs(args.log_dir, exist_ok=True)

        # Create logger
        self.logger = setup_logger(
            output=args.log_dir, distributed_rank=dist.get_rank(),
            name=name
        )
        if args.visualize:
            wandb.init(project="3d_grounder", name="debug")
        # Save config file and initialize tb writer
        if dist.get_rank() == 0:
            path = os.path.join(args.log_dir, "config.json")
            with open(path, 'w') as f:
                json.dump(vars(args), f, indent=2)
            self.logger.info("Full config saved to {}".format(path))
            self.logger.info(str(vars(args)))
            self.writer = SummaryWriter(os.path.join(args.log_dir, 'tb'))

    @staticmethod
    def get_datasets(args):
        """Initialize datasets."""
        dataset_config = None
        train_dataset = None
        test_dataset = None
        return train_dataset, test_dataset, dataset_config

    def get_loaders(self, args):
        """Initialize data loaders."""
        def my_worker_init_fn(worker_id):
            np.random.seed(np.random.get_state()[1][0] + worker_id)
        # Datasets
        train_dataset, test_dataset, train_dataset100, dataset_config = \
            self.get_datasets(args)
        # Samplers and loaders
        train_sampler = DistributedSampler(train_dataset)
        train_loader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            worker_init_fn=my_worker_init_fn,
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True
        )
        test_sampler = DistributedSampler(test_dataset, shuffle=False)
        test_loader = DataLoader(
            test_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            worker_init_fn=my_worker_init_fn,
            pin_memory=True,
            sampler=test_sampler,
            drop_last=False
        )
        if train_dataset100 is not None:
            train_sampler100 = DistributedSampler(train_dataset100, shuffle=False)
            train_loader100 = DataLoader(
                train_dataset100,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.num_workers,
                worker_init_fn=my_worker_init_fn,
                pin_memory=True,
                sampler=train_sampler100,
                drop_last=False
            )
        else:
            train_loader100 = None
        return train_loader, test_loader, train_loader100, dataset_config

    @staticmethod
    def get_model(args, dataset_config):
        """Initialize the model."""
        return None

    @staticmethod
    def get_criterion(args):
        """Get loss criterion for training."""
        set_criterion = None
        viewpoint_criterion = ViewpointCriterion(
                weight=0,
                train_viewpoint_prototype=args.train_viewpoint_prototype)
        if args.use_hungarian_loss:
            matcher = HungarianMatcher(
                1, 5, 2,
                args.use_soft_token_loss,
                use_detected_boxes=args.use_detected_boxes
            )
            weight_dict = {
                'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2,
                'loss_contrastive_align': 0
            }
            losses = []
            if not args.use_gt_box or args.gt_with_bbox_loss:
                losses.append('boxes')
            if not args.contrastive_hungarian:
                losses.append('labels')
            if args.use_contrastive_align:
                weight_dict['loss_contrastive_align'] = 1
                losses.append('contrastive_align')
            # if args.new_contrastive:
            #    weight_dict['loss_ce'] = 1e-14
            set_criterion = SetCriterion(
                num_classes=1,  matcher=matcher, weight_dict=weight_dict,
                eos_coef=0.1, losses=losses, temperature=0.07,
                soft_token=args.use_soft_token_loss,
                contrastive_hungarian=args.contrastive_hungarian,
                use_gt_box=args.use_gt_box,
                new_contrastive=args.new_contrastive,
                detect_intermediate=args.detect_intermediate
            ).cuda()
            criterion = compute_hungarian_loss
        elif args.ce_variant:
            criterion = get_ce_loss
        else:
            criterion = get_loss

        if args.train_viewpoint_module:
            viewpoint_criterion.weight = 25

        return criterion, set_criterion, viewpoint_criterion

    @staticmethod
    def get_optimizer(args, model):
        """Initialize optimizer."""
        param_dicts = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if "backbone_net" not in n and "text_encoder" not in n
                    and p.requires_grad
                ]
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if "backbone_net" in n and p.requires_grad
                ],
                "lr": args.lr_backbone
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if "text_encoder" in n and p.requires_grad
                ],
                "lr": args.text_encoder_lr
            }
        ]
        optimizer = optim.AdamW(param_dicts,
                                lr=args.lr,
                                weight_decay=args.weight_decay)
        return optimizer

    def main(self, args):
        """Run main training/testing pipeline."""
        # Get loaders
        train_loader, test_loader, train_loader100, dataset_config = \
            self.get_loaders(args)
        n_data = len(train_loader.dataset)
        self.logger.info(f"length of training dataset: {n_data}")
        n_data = len(test_loader.dataset)
        self.logger.info(f"length of testing dataset: {n_data}")

        # Get model
        model = self.get_model(args, dataset_config)

        # Get criterion
        criterion, set_criterion, viewpoint_criterion = self.get_criterion(args)

        # Get optimizer
        optimizer = self.get_optimizer(args, model)

        # Get scheduler
        scheduler = get_scheduler(optimizer, len(train_loader), args)

        # Move model to devices
        if torch.cuda.is_available():
            model = model.cuda()
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            broadcast_buffers=False, find_unused_parameters=True
        )

        # Check for a checkpoint
        if args.checkpoint_path:
            assert os.path.isfile(args.checkpoint_path)
            load_checkpoint(args, model, optimizer, scheduler)

        # Just eval and end execution
        if args.eval:
            if train_loader100 is not None:
                print("Training evaluation...................")
                self.evaluate_one_epoch(
                    args.start_epoch, train_loader100, dataset_config,
                    model, criterion, set_criterion, viewpoint_criterion, args
                )
            print("Testing evaluation.....................")
            self.evaluate_one_epoch(
                args.start_epoch, test_loader, dataset_config,
                model, criterion, set_criterion, viewpoint_criterion, args
            )
            return

        # Training loop
        for epoch in range(args.start_epoch, args.max_epoch + 1):
            train_loader.sampler.set_epoch(epoch)
            tic = time.time()
            self.train_one_epoch(
                epoch, train_loader, dataset_config, model,
                criterion, set_criterion, viewpoint_criterion,
                optimizer, scheduler, args
            )
            self.logger.info(
                'epoch {}, total time {:.2f}, '
                'lr_base {:.5f}, lr_decoder {:.5f}'.format(
                    epoch, (time.time() - tic),
                    optimizer.param_groups[0]['lr'],
                    optimizer.param_groups[1]['lr']
                )
            )
            if epoch % args.val_freq == 0:
                if dist.get_rank() == 0:  # save model
                    save_checkpoint(args, epoch, model, optimizer, scheduler)
                if not args.debug and train_loader100 is not None:
                    print("Train 100 evaluation.......")
                    self.evaluate_one_epoch(
                        epoch, train_loader100, dataset_config,
                        model, criterion, set_criterion,
                        viewpoint_criterion, args
                    )
                print("Test evaluation.......")
                self.evaluate_one_epoch(
                    epoch, test_loader, dataset_config,
                    model, criterion, set_criterion,
                    viewpoint_criterion, args
                )

        # Training is over, evaluate
        save_checkpoint(args, 'last', model, optimizer, scheduler, True)
        saved_path = os.path.join(args.log_dir, 'ckpt_epoch_last.pth')
        self.logger.info("Saved in {}".format(saved_path))
        self.evaluate_one_epoch(
            args.max_epoch, test_loader, dataset_config,
            model, criterion, set_criterion, viewpoint_criterion, args
        )
        return saved_path

    @staticmethod
    def _to_gpu(data_dict):
        if torch.cuda.is_available():
            for key in data_dict:
                if isinstance(data_dict[key], torch.Tensor):
                    data_dict[key] = data_dict[key].cuda(non_blocking=True)
        return data_dict

    @staticmethod
    def _get_inputs(args, batch_data):
        return {
            'point_clouds': batch_data['point_clouds'].float(),
            'text': batch_data['utterances']
        }

    @staticmethod
    def _compute_loss(end_points, criterion, set_criterion, viewpoint_criterion, args,
                      dataset_config):
        if set_criterion is not None:
            loss, end_points = criterion(
                end_points, args.num_decoder_layers,
                set_criterion,
                args.query_points_generator_loss_coef,
                query_points_obj_topk=args.query_points_obj_topk,
                contrastive_hungarian=args.contrastive_hungarian,
                use_gt_box=args.use_gt_box,
                gt_with_bbox_loss=args.gt_with_bbox_loss,
                gt_grounder=args.use_gt_grounder,
                use_detected_boxes=args.use_detected_boxes,
                train_viewpoint_module=args.train_viewpoint_module,
                viewpoint_criterion=viewpoint_criterion,
            )
        elif args.ce_variant:
            loss, end_points = criterion(end_points, args.num_decoder_layers)
        else:
            loss, end_points = criterion(
                end_points, dataset_config,
                args.num_decoder_layers,
                args.query_points_generator_loss_coef,
                obj_loss_coef=args.obj_loss_coef,
                box_loss_coef=args.box_loss_coef,
                sem_cls_loss_coef=args.sem_cls_loss_coef,
                query_points_obj_topk=args.query_points_obj_topk,
                center_loss_type=args.center_loss_type,
                center_delta=args.center_delta,
                size_loss_type=args.size_loss_type,
                size_delta=args.size_delta,
                heading_loss_type=args.heading_loss_type,
                heading_delta=args.heading_delta,
                size_cls_agnostic=args.size_cls_agnostic,
                num_encoder_layers=args.num_encoder_layers,
                do_deformable=args.deformable,
                offset_regularizer_coef=args.regularizer_coef
            )
        return loss, end_points

    def _tb_logs(self, stat_dict, niter, len_accum, split='Train'):
        self.writer.add_scalar(
            f'{split}/Loss', stat_dict['loss'] / len_accum, niter)
        for k in stat_dict.keys():
            res = stat_dict[k] / len_accum
            if '_loss' in k:
                self.writer.add_scalar(f"{split}/More_Losses/{k}", res, niter)
            elif 'ratio' in k:
                self.writer.add_scalar(f"{split}/Ratios/{k}", res, niter)
            elif 'acc' in k:
                self.writer.add_scalar(f'{split}/Acc', res, niter)

    @staticmethod
    def _accumulate_stats(stat_dict, end_points):
        for key in end_points:
            if 'loss' in key or 'acc' in key or 'ratio' in key:
                if key not in stat_dict:
                    stat_dict[key] = 0
                if isinstance(end_points[key], (float, int)):
                    stat_dict[key] += end_points[key]
                else:
                    stat_dict[key] += end_points[key].item()
        return stat_dict

    def train_one_epoch(self, epoch, train_loader, dataset_config, model,
                        criterion, set_criterion, viewpoint_criterion,
                        optimizer, scheduler, args):
        """
        Run a single epoch.

        Some of the args:
            dataset_config: a class like ReferitDatasetConfig
            model: a nn.Module that returns end_points (dict)
            criterion: a function that returns (loss, end_points)
        """
        stat_dict = {}  # collect statistics
        model.train()  # set model to training mode
        if set_criterion is not None:
            set_criterion.train()

        if viewpoint_criterion is not None:
            viewpoint_criterion.train()

        # Loop over batches
        for batch_idx, batch_data in enumerate(train_loader):
            # Move to GPU
            batch_data = self._to_gpu(batch_data)
            inputs = self._get_inputs(args, batch_data)
            if "train" not in inputs:
                inputs.update({"train": True})
            else:
                inputs['train'] = True

            # Forward pass
            end_points = model(inputs)

            # Compute loss and gradients, update parameters.
            for key in batch_data:
                assert (key not in end_points)
                end_points[key] = batch_data[key]
            loss, end_points = self._compute_loss(
                end_points, criterion, set_criterion, viewpoint_criterion,
                args, dataset_config
            )
            optimizer.zero_grad()
            loss.backward()
            if args.clip_norm > 0:
                grad_total_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_norm
                )
                stat_dict['grad_norm'] = grad_total_norm
            optimizer.step()
            scheduler.step()

            # Accumulate statistics and print out
            stat_dict = self._accumulate_stats(stat_dict, end_points)

            if (batch_idx + 1) % args.print_freq == 0:
                # Terminal logs
                self.logger.info(
                    f'Train: [{epoch}][{batch_idx + 1}/{len(train_loader)}]  '
                )
                self.logger.info(''.join([
                    f'{key} {stat_dict[key] / args.print_freq:.4f} \t'
                    for key in sorted(stat_dict.keys())
                    if 'loss' in key and 'proposal_' not in key
                    and 'last_' not in key and 'head_' not in key
                ]))
                # Tensorboard logs
                if dist.get_rank() == 0:
                    niter = epoch*len(train_loader) + batch_idx
                    self._tb_logs(stat_dict, niter, args.print_freq)

                for key in sorted(stat_dict.keys()):
                    stat_dict[key] = 0

    @torch.no_grad()
    def _main_eval_branch(self, batch_idx, batch_data, test_loader, model,
                          stat_dict,
                          criterion, set_criterion, viewpoint_criterion, args, dataset_config):
        # Move to GPU
        batch_data = self._to_gpu(batch_data)
        inputs = self._get_inputs(args, batch_data)
        if "train" not in inputs:
            inputs.update({"train": False})
        else:
            inputs["train"] = False

        # Forward pass
        end_points = model(inputs)

        # Compute loss
        for key in batch_data:
            assert (key not in end_points)
            end_points[key] = batch_data[key]
        _, end_points = self._compute_loss(
            end_points, criterion, set_criterion, viewpoint_criterion,
            args, dataset_config
        )
        for key in end_points:
            if 'pred_size' in key:
                end_points[key] = torch.clamp(end_points[key], min=1e-6)

        # Accumulate statistics and print out
        stat_dict = self._accumulate_stats(stat_dict, end_points)
        if (batch_idx + 1) % args.print_freq == 0:
            self.logger.info(f'Eval: [{batch_idx + 1}/{len(test_loader)}]  ')
            self.logger.info(''.join([
                f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t'
                for key in sorted(stat_dict.keys())
                if 'loss' in key and 'proposal_' not in key
                and 'last_' not in key and 'head_' not in key
            ]))
        return stat_dict, end_points

    @torch.no_grad()
    def evaluate_one_epoch(self, epoch, test_loader, dataset_config,
                           model, criterion, set_criterion,
                           viewpoint_criterion, args):
        """
        Eval grounding after a single epoch.

        Some of the args:
            dataset_config: a class like ReferitDatasetConfig
            model: a nn.Module that returns end_points (dict)
            criterion: a function that returns (loss, end_points)
        """
        return None

    @staticmethod
    def _viz_dets(step, end_points):
        b = 0
        scene_id = end_points['scan_ids'][b]
        mesh = vis_utils.align_mesh(scene_id)

        wandb.log(
            {
                "scene_mesh": wandb.Object3D(mesh)
            }
        )

        # Parse annotations
        labelled_pc = np.concatenate(  # colored point cloud
            (
                end_points['point_clouds'][b, :, :3].detach().cpu().numpy(),
                end_points['og_color'][b].detach().cpu().numpy()*256
            ),
            axis=-1
        )
        all_boxes = end_points['all_bboxes'][b].detach().cpu().numpy()
        all_boxes_points = pc_util.box2points(all_boxes)
        # Target
        target_box = all_boxes_points[end_points['target_id'][b].item()]
        # Distractors
        distractors_boxes = all_boxes_points[[
            i.item() for i in end_points['distractor_ids'][b] if i != -1
        ]]
        # Anchors
        anchors_boxes = all_boxes_points[[
            i.item() for i in end_points['anchor_ids'][b] if i != -1
        ]]

        # Visualise all boxes
        gt_boxes = (
            [end_points['target_id'][b].item()]
            + end_points['distractor_ids'][b].cpu().numpy().tolist()
            + end_points['anchor_ids'][b].cpu().numpy().tolist()
        )
        wandb.log({
            "ground_truth_point_scene": wandb.Object3D({
                "type": "lidar/beta",
                "points": labelled_pc,
                "boxes": np.array(
                    [  # target
                        {
                            "corners": target_box.tolist(),
                            "label": "target",
                            "color": [0, 255, 0]
                        }
                    ]
                    + [  # anchors
                        {
                            "corners": c.tolist(),
                            "label": "anchor",
                            "color": [0, 0, 255]
                        }
                        for c in anchors_boxes
                    ]
                    + [  # distractors
                        {
                            "corners": c.tolist(),
                            "label": "distractor",
                            "color": [0, 255, 255]
                        }
                        for c in distractors_boxes
                    ]
                    + [  # other
                        {
                            "corners": c.tolist(),
                            "label": "other",
                            "color": [255, 0, 0]
                        }
                        for i, c in enumerate(all_boxes_points)
                        if i not in gt_boxes
                    ]
                )
            }),
            "utterance": wandb.Html(end_points['utterances'][b])
        }, step=step)

        # Plot target, anchors, distractors, top-1 prediction
        if 'last_center' in end_points:
            pred_boxes_points = pc_util.box2points(
                np.concatenate((
                    end_points['last_center'][b].detach().cpu().numpy(),
                    end_points['last_pred_size'][b].detach().cpu().numpy()
                ), axis=-1)  # Q, 6
            )
        else:
            pred_boxes_points = pc_util.box2points(
                end_points['last_pred_boxes_bbf'][b].detach().cpu().numpy()
            )
        # pred_anchor_boxes = pc_util.box2points(
        #     end_points['anchor_estimates'][b].detach().cpu().numpy()
        # )
        wandb.log(
            {
                "gt_and_top5": wandb.Object3D({
                    "type": "lidar/beta",
                    "points": labelled_pc,
                    "boxes": np.array(
                        [  # target
                            {
                                "corners": target_box.tolist(),
                                "label": "gt",
                                "color": [0, 255, 0]
                            }
                        ]
                        + [  # anchors
                            {
                                "corners": c.tolist(),
                                "label": "anchor",
                                "color": [0, 0, 255]
                            }
                            for i, c in enumerate(anchors_boxes)
                        ]
                        + [  # predicted
                            {
                                "corners": c.tolist(),
                                "label": "top%d_bbf" % i,
                                "color": [255, 0, 0]
                            }
                            for i, c in enumerate(pred_boxes_points[:5])
                        ]
                        + [  # distractors
                            {
                                "corners": c.tolist(),
                                "label": "distractor",
                                "color": [0, 255, 255]
                            }
                            for c in distractors_boxes
                        ]
                    )
                }),
                "sentence": wandb.Html(end_points['utterances'][b])
            },
            step=step
        )

        # All predicted boxes (irrespective of confidence)
        wandb.log(
            {
                "all_predicted_boxes": wandb.Object3D(
                    {
                        "type": "lidar/beta",
                        "points": labelled_pc,
                        "boxes": np.array(
                            [
                                {
                                    "corners": c.tolist(),
                                    "label": "q#" + str(idx),
                                    "color": [255, 0, 0]
                                }
                                for idx, c in enumerate(pred_boxes_points)
                            ]
                            # +
                            # [
                            #     {
                            #         "corners": c.tolist(),
                            #         "label": "a#" + str(idx),
                            #         "color": [0, 0, 255]
                            #     }
                            #     for idx, c in enumerate(pred_anchor_boxes)
                            # ]
                        )
                    }
                )
            },
            step=step
        )

        # Predictions using soft-token assignment to the root
        if 'last_pred_boxes_bbs' in end_points:
            top10_bbs_points = pc_util.box2points(
                end_points['last_pred_boxes_bbs'][b].detach().cpu().numpy()
            )
            wandb.log(
                {
                    "top10_boxes_BBS": wandb.Object3D(
                        {
                            "type": "lidar/beta",
                            "points": labelled_pc,
                            "boxes": np.array([
                                {
                                    "corners": c.tolist(),
                                    "label": "#" + str(idx),
                                    "color": [255, 0, 0]
                                }
                                for idx, c in enumerate(top10_bbs_points)
                            ])
                        }
                    )
                },
                step=step
            )

        # Predictions using contrastive matching of the root
        if 'last_pred_boxes_bbf' in end_points:
            top10_bbf_points = pc_util.box2points(
                end_points['last_pred_boxes_bbf'][b].detach().cpu().numpy()
            )
            wandb.log(
                {
                    "top10_boxes_BBF": wandb.Object3D(
                        {
                            "type": "lidar/beta",
                            "points": labelled_pc,
                            "boxes": np.array([
                                {
                                    "corners": c.tolist(),
                                    "label": "#" + str(idx),
                                    "color": [255, 0, 0]
                                }
                                for idx, c in enumerate(top10_bbf_points)
                            ])
                        }
                    )
                },
                step=step
            )

        # Plot where query points lie
        if 'query_points_xyz' in end_points:
            att_pc = np.copy(labelled_pc)
            att_pc[:, 3:] = 127
            q_color = np.zeros((len(end_points['query_points_xyz'][b]), 3))
            q_color[:, 0] = 255
            q_points = np.concatenate([
                end_points['query_points_xyz'][b].detach().cpu().numpy(),
                q_color
            ], 1)
            wandb.log(
                {
                    "query_in_red": wandb.Object3D({
                        "type": "lidar/beta",
                        "points": np.concatenate([att_pc, q_points])
                    })
                },
                commit=False
            )

        # # Plot where language attends
        # if 'lv_attention0' in end_points:
        #     att_pc = np.copy(labelled_pc)
        #     att_pc[:, 3:] = 127
        #     for i in range(3):
        #         att = end_points['lv_attention%d' % i][b]
        #         att = att * end_points['positive_map'][b][0][:len(att)][:, None]
        #         att = att.sum(0)  # (num_points,)
        #         pc_ = np.concatenate([
        #             end_points['seed_xyz'][b][att.topk(64)[1]].detach().cpu().numpy(),
        #             np.zeros((64, 3))
        #         ], 1)
        #         pc_[:, i + 3] = 255
        #         if i == 2:
        #             pc_[:, 4] = 255
        #         att_pc = np.concatenate([att_pc, pc_])
        #     wandb.log(
        #         {
        #             "attention_per_layer": wandb.Object3D({
        #                 "type": "lidar/beta",
        #                 "points": att_pc,
        #                 "boxes": np.array(
        #                     [  # target
        #                         {
        #                             "corners": target_box.tolist(),
        #                             "label": "gt: " + names[0],
        #                             "color": [0, 255, 0]
        #                         }
        #                     ]
        #                     + [  # anchors
        #                         {
        #                             "corners": c.tolist(),
        #                             "label": "anchor: " + names[1 + num_distr + i],
        #                             "color": [0, 0, 255]
        #                         }
        #                         for i, c in enumerate(anchors_boxes)
        #                     ]
        #                     + [  # predicted
        #                         {
        #                             "corners": c.tolist(),
        #                             "label": "top1_bbs",
        #                             "color": [255, 0, 0]
        #                         }
        #                         for c in pred_boxes_points[0:1]
        #                     ]
        #                     + [  # distractors
        #                         {
        #                             "corners": c.tolist(),
        #                             "label": "distractor",
        #                             "color": [0, 255, 255]
        #                         }
        #                         for c in distractors_boxes
        #                     ]
        #                 )
        #             })
        #         },
        #         commit=False
        #     )
