"""Main script for language modulation."""

import os

import torch
import torch.distributed as dist

from main_utils import parse_option, BaseTrainTester
from scannet.model_util_scannet import ReferitDatasetConfig
from scannet.model_util_scannet import ScannetDatasetConfig
from src.joint_det_dataset import Joint3DDataset
from src.grounding_evaluator import GroundingEvaluator, GroundingEvaluatorGTBoxes
from models import GroupFreeModulator, DeformableGroupFreeModulator, GroupFreeGTModulator
from models import APCalculator, parse_predictions, parse_groundtruths


import ipdb
st = ipdb.set_trace


class TrainTester(BaseTrainTester):
    """Train/test a language grounder."""

    def __init__(self, args):
        """Initialize."""
        super().__init__(args)

    @staticmethod
    def get_datasets(args):
        """Initialize datasets."""
        dataset_dict = {}  # dict to use multiple datasets
        for dset in args.dataset:
            dataset_dict[dset] = 1
        if args.joint_det:
            dataset_dict['scannet'] = 10
        print('Loading datasets:', sorted(list(dataset_dict.keys())))
        train_dataset = Joint3DDataset(
            dataset_dict=dataset_dict,
            test_dataset=args.test_dataset,
            split='train' if not args.debug else 'val',
            num_points=args.num_point,
            use_color=args.use_color, use_height=args.use_height,
            overfit=args.debug,
            detect_intermediate=args.detect_intermediate,
            filter_relations=args.filter_relations,
            use_oriented_boxes=args.use_oriented_boxes,
            augment=not args.no_augment,
            rotate_pc=args.rotate_pc,
            use_multiview=args.use_multiview,
            train_viewpoint_module=args.train_viewpoint_module,
            butd=args.butd,
            butd_gt=args.butd_gt,
            butd_cls=args.butd_cls
        )
        test_dataset = Joint3DDataset(
            dataset_dict=dataset_dict,
            test_dataset=args.test_dataset,
            split='val' if not args.eval_train else 'train',
            num_points=args.num_point,
            use_color=args.use_color, use_height=args.use_height,
            overfit=args.debug,
            detect_intermediate=args.detect_intermediate,
            filter_relations=args.filter_relations,
            use_oriented_boxes=args.use_oriented_boxes,
            augment=not args.no_augment,
            rotate_pc=args.rotate_pc,
            use_multiview=args.use_multiview,
            train_viewpoint_module=args.train_viewpoint_module,
            butd=args.butd,
            butd_gt=args.butd_gt,
            butd_cls=args.butd_cls,
            run_on_target_phrases=args.run_on_target_phrases
        )
        '''
        train_dataset100 = Joint3DDataset(
            dataset_dict=dataset_dict,
            test_dataset=args.test_dataset,
            split='train100' if not args.debug else 'val',
            num_points=args.num_point,
            use_color=args.use_color, use_height=args.use_height,
            overfit=args.debug,
            detect_intermediate=args.detect_intermediate,
            filter_relations=args.filter_relations,
            use_oriented_boxes=args.use_oriented_boxes,
            augment=not args.no_augment,
            rotate_pc=args.rotate_pc,
            use_multiview=args.use_multiview,
            train_viewpoint_module=args.train_viewpoint_module,
            butd=args.butd,
            butd_gt=args.butd_gt
        )
        '''
        train_dataset100 = None
        if args.dataset == ['scannet']:
            dataset_config = ScannetDatasetConfig()
        else:
            dataset_config = ReferitDatasetConfig()
        return train_dataset, test_dataset, train_dataset100, dataset_config

    @staticmethod
    def get_model(args, dataset_config):
        """Initialize the model."""
        num_input_channel = int(args.use_color) * 3
        if args.use_height:
            num_input_channel += 1
        if args.use_multiview:
            num_input_channel += 128
        if args.use_soft_token_loss:
            num_class = 256
        else:
            num_class = dataset_config.num_class
        if args.deformable:
            print("Loading Deformable Attention Model")
            model = DeformableGroupFreeModulator(
                num_class,
                num_heading_bin=dataset_config.num_heading_bin,
                num_size_cluster=dataset_config.num_size_cluster,
                mean_size_arr=dataset_config.mean_size_arr,
                input_feature_dim=num_input_channel,
                width=args.width,
                bn_momentum=args.bn_momentum,
                sync_bn=args.syncbn,
                num_proposal=args.num_target,
                sampling=args.sampling,
                dropout=args.transformer_dropout,
                activation=args.transformer_activation,
                nhead=args.nhead,
                num_decoder_layers=args.num_decoder_layers,
                dim_feedforward=args.dim_feedforward,
                self_position_embedding=args.self_position_embedding,
                cross_position_embedding=args.cross_position_embedding,
                size_cls_agnostic=args.size_cls_agnostic,
                num_encoder_layers=args.num_encoder_layers,
            )
        elif args.use_gt_grounder:
            model = GroupFreeGTModulator(
                num_class,
                num_heading_bin=dataset_config.num_heading_bin,
                num_size_cluster=dataset_config.num_size_cluster,
                mean_size_arr=dataset_config.mean_size_arr,
                input_feature_dim=num_input_channel,
                width=args.width,
                bn_momentum=args.bn_momentum,
                sync_bn=args.syncbn,
                num_proposal=args.num_target,
                sampling=args.sampling,
                dropout=args.transformer_dropout,
                activation=args.transformer_activation,
                nhead=args.nhead,
                num_decoder_layers=args.num_decoder_layers,
                dim_feedforward=args.dim_feedforward,
                self_position_embedding=args.self_position_embedding,
                size_cls_agnostic=args.size_cls_agnostic,
                contrastive_align_loss=args.use_contrastive_align,
                contrastive_hungarian=args.contrastive_hungarian,
                sa_lang=not args.no_sa_lang,
                sa_vis=not args.no_sa_vis,
                use_gt_box=args.use_gt_box,
                use_gt_class=args.use_gt_class,
                num_obj_classes=485,
                gt_with_bbox_loss=args.gt_with_bbox_loss,
                gt_with_bbox_sampling=args.gt_with_bbox_sampling,
                freeze_text_encoder=args.freeze_text_encoder,
                use_logits=args.use_logits,
                use_oriented_boxes=args.use_oriented_boxes,
                use_detected_boxes=args.use_detected_boxes
            )
        else:
            model = GroupFreeModulator(
                num_class,
                num_heading_bin=dataset_config.num_heading_bin,
                num_size_cluster=dataset_config.num_size_cluster,
                mean_size_arr=dataset_config.mean_size_arr,
                input_feature_dim=num_input_channel,
                width=args.width,
                bn_momentum=args.bn_momentum,
                sync_bn=args.syncbn,
                num_proposal=args.num_target,
                sampling=args.sampling,
                dropout=args.transformer_dropout,
                activation=args.transformer_activation,
                nhead=args.nhead,
                num_decoder_layers=args.num_decoder_layers,
                dim_feedforward=args.dim_feedforward,
                self_position_embedding=args.self_position_embedding,
                size_cls_agnostic=args.size_cls_agnostic,
                contrastive_align_loss=args.use_contrastive_align,
                contrastive_hungarian=args.contrastive_hungarian,
                sa_lang=not args.no_sa_lang,
                sa_vis=not args.no_sa_vis,
                use_gt_box=args.use_gt_box,
                use_gt_class=args.use_gt_class,
                num_obj_classes=485,
                gt_with_bbox_loss=args.gt_with_bbox_loss,
                gt_with_bbox_sampling=args.gt_with_bbox_sampling,
                train_viewpoint_module=args.train_viewpoint_module,
                train_viewpoint_prototype=args.train_viewpoint_prototype,
                teacher_forcing=args.teacher_forcing,
                butd=args.butd or args.butd_gt or args.butd_cls,
                use_class_for_butd=not args.box_only_butd
            )
        return model

    @staticmethod
    def _get_inputs(args, batch_data):
        return {
            'point_clouds': batch_data['point_clouds'].float(),
            'text': batch_data['utterances'],
            'all_bboxes': (
                batch_data['all_bboxes'] if not args.use_detected_boxes
                else batch_data['all_detected_boxes']
            ),
            'all_bbox_label_mask': (
                batch_data['all_bbox_label_mask']
                if not args.use_detected_boxes
                else batch_data['all_detected_bbox_label_mask']
            ),
            'all_classes': (
                batch_data['sem_cls_label']
                if not args.use_detected_boxes
                else batch_data['all_detected_class_ids']
            ),
            'positive_map': batch_data['positive_map'],
            'target_name': batch_data['target_name'],
            'center_label': batch_data['center_label'],
            'size_gts': batch_data['size_gts'],
            'points_to_boxes': batch_data['points_to_boxes'],
            "all_detected_boxes": batch_data['all_detected_boxes'],
            "all_detected_bbox_label_mask": batch_data['all_detected_bbox_label_mask'],
            "all_detected_class_ids": batch_data['all_detected_class_ids']
        }

    @torch.no_grad()
    def evaluate_one_epoch(self, epoch, test_loader, dataset_config,
                           model, criterion, set_criterion, viewpoint_criterion, args):
        """
        Eval grounding after a single epoch.

        Some of the args:
            dataset_config: a class like ReferitDatasetConfig
            model: a nn.Module that returns end_points (dict)
            criterion: a function that returns (loss, end_points)
        """
        stat_dict = {}
        model.eval()  # set model to eval mode (for bn and dp)
        if set_criterion is not None:
            set_criterion.eval()

        if viewpoint_criterion is not None:
            viewpoint_criterion.eval()

        if args.num_decoder_layers > 0:
            prefixes = ['last_', 'proposal_']
            prefixes = ['last_']
            if not args.use_gt_grounder:
                prefixes.append('proposal_')
        else:
            prefixes = ['proposal_']  # only proposal
        prefixes += [f'{i}head_' for i in range(args.num_decoder_layers - 1)]

        if args.ce_variant:
            # evaluator = None
            evaluator = GroundingEvaluatorGTBoxes(
                only_root=not args.eval_anchors,
                topks=[1, 5, 10], prefixes=prefixes, gt_variant=True,
                apply_classifiers=args.apply_classifiers
            )
            acc = {prefix: 0 for prefix in prefixes}
            nsamples = 0
        elif args.use_gt_box and not args.use_detected_boxes:
            print('GT EVAL')
            evaluator = GroundingEvaluatorGTBoxes(
                only_root=not args.eval_anchors,
                topks=[1, 5, 10], prefixes=prefixes,
                apply_classifiers=args.apply_classifiers
            )
        else:
            evaluator = GroundingEvaluator(
                only_root=not args.eval_anchors, thresholds=[0.25, 0.5],
                topks=[1, 5, 10], prefixes=prefixes,
                use_detected_boxes=args.use_detected_boxes,
                evaluate_viewpoint=args.train_viewpoint_module,
                train_viewpoint_prototype=args.train_viewpoint_prototype,
                run_on_target_phrases=args.run_on_target_phrases
            )

        # Main eval branch
        for batch_idx, batch_data in enumerate(test_loader):
            stat_dict, end_points = self._main_eval_branch(
                batch_idx, batch_data, test_loader, model, stat_dict,
                criterion, set_criterion, viewpoint_criterion, args, dataset_config
            )
            if evaluator is not None:
                for prefix in prefixes:
                    evaluator.evaluate(end_points, prefix)
            else:
                for prefix in prefixes:
                    if prefix == 'proposal_':
                        continue
                    acc[prefix] += (
                        end_points[f'{prefix}sem_cls_scores'].argmax(1)
                        == end_points['target_id']
                    ).sum().item()
                nsamples += len(end_points['target_id'])
            if args.visualize:
                self._viz_dets(batch_idx, end_points)

        if dist.get_rank() == 0:
            niter = epoch
            self._tb_logs(stat_dict, niter, batch_idx + 1, 'Val')

        if dist.get_rank() == 0:
            if evaluator is not None:
                evaluator.print_stats()
            else:
                for prefix in prefixes:
                    print(prefix, acc[prefix] / nsamples)
        return None

    @torch.no_grad()
    def evaluate_one_epoch_old(self, epoch, test_loader, dataset_config,
                               model, criterion, set_criterion, viewpoint_criterion, args):
        """
        Eval grounding after a single epoch.

        Some of the args:
            dataset_config: a class like ReferitDatasetConfig
            model: a nn.Module that returns end_points (dict)
            criterion: a function that returns (loss, end_points)
        """
        # Used for AP calculation
        CONFIG_DICT = {
            'remove_empty_box': False, 'use_3d_nms': True,
            'nms_iou': 0.25, 'use_old_type_nms': False, 'cls_nms': True,
            'per_class_proposal': True, 'conf_thresh': 0.0,
            'dataset_config': dataset_config,
            'hungarian_loss': args.use_hungarian_loss
        }
        stat_dict = {}
        model.eval()  # set model to eval mode (for bn and dp)
        if set_criterion is not None:
            set_criterion.eval()

        if viewpoint_criterion is not None:
            viewpoint_criterion.eval()

        if args.num_decoder_layers > 0:
            prefixes = ['last_', 'proposal_']
            prefixes += [
                f'{i}head_' for i in range(args.num_decoder_layers - 1)
            ]
        else:
            prefixes = ['proposal_']  # only proposal
        ap_calculator_list = [
            APCalculator(iou_thresh, dataset_config.class2type)
            for iou_thresh in args.ap_iou_thresholds
        ]

        batch_pred_map_cls_dict = {k: [] for k in prefixes}
        batch_gt_map_cls_dict = {k: [] for k in prefixes}

        # Main eval branch
        for batch_idx, batch_data in enumerate(test_loader):
            stat_dict, end_points = self._main_eval_branch(
                batch_idx, batch_data, test_loader, model, stat_dict,
                criterion, set_criterion, args, dataset_config
            )

            # Parse predictions
            for prefix in prefixes:
                batch_pred_map_cls = parse_predictions(
                    end_points, CONFIG_DICT, prefix,
                    size_cls_agnostic=args.size_cls_agnostic)
                batch_gt_map_cls = parse_groundtruths(
                    end_points, CONFIG_DICT,
                    size_cls_agnostic=args.size_cls_agnostic)
                batch_pred_map_cls_dict[prefix].append(batch_pred_map_cls)
                batch_gt_map_cls_dict[prefix].append(batch_gt_map_cls)
                # '''

        if dist.get_rank() == 0:
            niter = epoch
            self._tb_logs(stat_dict, niter, batch_idx + 1, 'Val')

        for prefix in prefixes:
            for (batch_pred_map_cls, batch_gt_map_cls) in zip(
                batch_pred_map_cls_dict[prefix], batch_gt_map_cls_dict[prefix]
            ):
                for ap_calculator in ap_calculator_list:
                    ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls)
            # Evaluate average precision
            for i, ap_calculator in enumerate(ap_calculator_list):
                metrics_list = ap_calculator.compute_accuracy()
                self.logger.info(
                    '=====================>%s Accuracy @ %s: %s @IoU: %s'
                    % (
                        prefix,
                        ', '.join([str(key) for key in metrics_list.keys()]),
                        ', '.join([str(val) for val in metrics_list.values()]),
                        args.ap_iou_thresholds[i]
                    )
                )
                ap_calculator.reset()

        return None

    @torch.no_grad()
    def export_viewpoints(self, epoch, loader, dataset_config,
                          model, criterion, set_criterion,
                          viewpoint_criterion, args, split='val'):
        
        demo_dir = './dataset/language_grounding/viewpoints_data_{split}'
        for batch_idx, batch_data in enumerate(loader):
            # Move to GPU
            batch_data = self._to_gpu(batch_data)
            inputs = self._get_inputs(args, batch_data)

            # Forward pass
            end_points = model(inputs)

            # Compute loss
            for key in batch_data:
                assert (key not in end_points)
                end_points[key] = batch_data[key]

            _, end_points = self._compute_loss(
                end_points, criterion, set_criterion, viewpoint_criterion, args, dataset_config
            )

            if f'last_pred_viewpoint' in end_points:
                pred_viewpoint = end_points[f'last_pred_viewpoint'].detach().cpu().numpy()

            data_dict = {
                "scan_id": end_points['scan_ids'],
                "pred_viewpoint": pred_viewpoint,
            }
            if not os.path.exists(demo_dir):
                os.mkdir(demo_dir)


if __name__ == '__main__':
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    opt = parse_option()
    torch.cuda.set_device(opt.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    train_tester = TrainTester(opt)
    ckpt_path = train_tester.main(opt)
