"""Main script for object detection with our loaders."""

import os

import numpy as np
import torch
import torch.distributed as dist

from main_utils import parse_option, BaseTrainTester
from models import APCalculator, parse_predictions, parse_groundtruths
# from models.lang_detector import GroupFreeLangDetector
from models.detector import GroupFreeDetector
from scannet.model_util_scannet import ScannetDatasetConfig
from src.scannet_dataset import ScanNetDataset


class TrainTester(BaseTrainTester):
    """Train/test a language grounder."""

    def __init__(self, args):
        """Initialize."""
        super().__init__(args)

    @staticmethod
    def get_datasets(args):
        """Initialize datasets."""
        dataset_config = ScannetDatasetConfig(18, agnostic=args.agnostic)
        train_dataset = ScanNetDataset(
            'train' if not args.debug else 'val',
            num_points=args.num_point,
            use_color=args.use_color, use_height=args.use_height,
            overfit=args.debug, agnostic=args.agnostic,
            use_multiview=args.use_multiview
        )
        test_dataset = ScanNetDataset(
            'val' if not args.eval_train else 'train',
            num_points=args.num_point,
            use_color=args.use_color, use_height=args.use_height,
            overfit=args.debug, agnostic=args.agnostic,
            use_multiview=args.use_multiview
        )
        return train_dataset, test_dataset, None, dataset_config

    @staticmethod
    def get_model(args, dataset_config):
        """Initialize the model."""
        num_input_channel = int(args.use_color) * 3
        if args.use_height:
            num_input_channel += num_input_channel
        elif args.use_multiview:
            num_input_channel += 128
        model = GroupFreeDetector(
            num_class=dataset_config.num_class,
            num_heading_bin=dataset_config.num_heading_bin,
            num_size_cluster=dataset_config.num_size_cluster,
            mean_size_arr=dataset_config.mean_size_arr,
            input_feature_dim=num_input_channel,
            width=args.width,
            bn_momentum=args.bn_momentum,
            sync_bn=args.syncbn,
            num_proposal=args.num_target,
            sampling=args.sampling,
            dropout=args.transformer_dropout,
            activation=args.transformer_activation,
            nhead=args.nhead,
            num_decoder_layers=args.num_decoder_layers,
            dim_feedforward=args.dim_feedforward,
            self_position_embedding=args.self_position_embedding,
            size_cls_agnostic=args.size_cls_agnostic,
            agnostic=args.agnostic
        )
        return model

    @staticmethod
    def _get_inputs(batch_data):
        return {
            'point_clouds': batch_data['point_clouds'].float(),
            'text': batch_data['utterances']
        }

    @torch.no_grad()
    def evaluate_one_epoch(self, epoch, test_loader, dataset_config,
                           model, criterion, set_criterion, args):
        """
        Eval grounding after a single epoch.

        Some of the args:
            dataset_config: a class like ReferitDatasetConfig
            model: a nn.Module that returns end_points (dict)
            criterion: a function that returns (loss, end_points)
        """
        # Used for AP calculation
        CONFIG_DICT = {
            'remove_empty_box': False, 'use_3d_nms': True,
            'nms_iou': 0.25, 'use_old_type_nms': False, 'cls_nms': True,
            'per_class_proposal': True, 'conf_thresh': 0.0,
            'dataset_config': dataset_config,
            'hungarian_loss': args.use_hungarian_loss
        }
        stat_dict = {}
        model.eval()  # set model to eval mode (for bn and dp)
        if set_criterion is not None:
            set_criterion.eval()

        if args.num_decoder_layers > 0:
            prefixes = ['last_', 'proposal_']
            prefixes += [f'{i}head_' for i in range(args.num_decoder_layers - 1)]
        else:
            prefixes = ['proposal_']  # only proposal
        prefixes = ['last_']
        ap_calculator_list = [
            APCalculator(iou_thresh, dataset_config.class2type)
            for iou_thresh in args.ap_iou_thresholds
        ]
        mAPs = [
            [iou_thresh, {k: 0 for k in prefixes}]
            for iou_thresh in args.ap_iou_thresholds
        ]

        batch_pred_map_cls_dict = {k: [] for k in prefixes}
        batch_gt_map_cls_dict = {k: [] for k in prefixes}

        # Main eval branch
        wordidx = np.array([ 0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  7,  7,  8,  9, 10, 11, 12, 13, 13, 14, 15, 16, 16, 17, 17])
        tokenidx = np.array([ 1,  2,  3,  5,  7,  9, 11, 13, 15, 17, 18, 19, 21, 23, 25, 27, 29, 31, 32, 34, 36, 38, 39, 41, 42])
        for batch_idx, batch_data in enumerate(test_loader):
            stat_dict, end_points = self._main_eval_branch(
                batch_idx, batch_data, test_loader, model, stat_dict,
                criterion, set_criterion, args, dataset_config
            )
            sem_cls = torch.zeros_like(end_points['last_sem_cls_scores'])[..., :18]
            for w, t in zip(wordidx, tokenidx):
                sem_cls[..., w] += end_points['last_sem_cls_scores'][..., t]
            end_points['last_sem_cls_scores'] = sem_cls

            # Parse predictions
            # for prefix in prefixes:
            prefix = 'last_'
            batch_pred_map_cls = parse_predictions(
                end_points, CONFIG_DICT, prefix,
                size_cls_agnostic=args.size_cls_agnostic)
            batch_gt_map_cls = parse_groundtruths(
                end_points, CONFIG_DICT,
                size_cls_agnostic=args.size_cls_agnostic)
            batch_pred_map_cls_dict[prefix].append(batch_pred_map_cls)
            batch_gt_map_cls_dict[prefix].append(batch_gt_map_cls)

        if dist.get_rank() == 0:
            niter = epoch
            self._tb_logs(stat_dict, niter, batch_idx + 1, 'Val')

        mAP = 0.0
        # for prefix in prefixes:
        prefix = 'last_'
        for (batch_pred_map_cls, batch_gt_map_cls) in zip(batch_pred_map_cls_dict[prefix],
                                                            batch_gt_map_cls_dict[prefix]):
            for ap_calculator in ap_calculator_list:
                ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls)
        # Evaluate average precision
        for i, ap_calculator in enumerate(ap_calculator_list):
            metrics_dict = ap_calculator.compute_metrics()
            self.logger.info(f'=====================>{prefix} IOU THRESH: {args.ap_iou_thresholds[i]}<=====================')
            for key in metrics_dict:
                self.logger.info(f'{key} {metrics_dict[key]}')
            if prefix == 'last_' and ap_calculator.ap_iou_thresh > 0.3:
                mAP = metrics_dict['mAP']
            mAPs[i][1][prefix] = metrics_dict['mAP']
            ap_calculator.reset()

        for mAP in mAPs:
            self.logger.info(f'IoU[{mAP[0]}]:\t' + ''.join([f'{key}: {mAP[1][key]:.4f} \t' for key in sorted(mAP[1].keys())]))

        return None


if __name__ == '__main__':
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    opt = parse_option()
    torch.cuda.set_device(opt.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    train_tester = TrainTester(opt)
    ckpt_path = train_tester.main(opt)
