import torch
import torch.nn as nn
from mmdet3d.models.detectors.two_stage_fsd import FSD
from mmcv import Config
from mmdet3d.models import build_model
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from collections import OrderedDict
import torch.distributed as dist

class FSDEarly(nn.Module):
    def __init__(self, args):
        super(FSDEarly, self).__init__()
        self.detector_cfg = Config.fromfile(args['detector_cfg'])
        self.detector = build_model(self.detector_cfg.model,
                    train_cfg=self.detector_cfg.get('train_cfg'),
                    test_cfg=self.detector_cfg.get('test_cfg'))
    
    # copy from mmdet3d.models.detector.base.py, 176-th row
    def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary information.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
                which may be a weighted sum of all losses, log_vars contains \
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                if 'loss' in _key)

        # If the loss_vars has different length, GPUs will wait infinitely
        if dist.is_available() and dist.is_initialized():
            log_var_length = torch.tensor(len(log_vars), device=loss.device)
            dist.all_reduce(log_var_length)
            message = (f'rank {dist.get_rank()}' +
                    f' len(log_vars): {len(log_vars)}' + ' keys: ' +
                    ','.join(log_vars.keys()))
            assert log_var_length == len(log_vars) * dist.get_world_size(), \
                'loss log variables are different across GPUs!\n' + message

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars

    def forward(self, data_dict, return_loss=True, **kwargs):
        points = data_dict['processed_lidar']['points']
        bs = len(points)
        
        gt_labels_3d = [torch.zeros(int(data_dict['object_bbx_mask'][i].sum().item()), dtype=torch.long).to(
            data_dict['object_bbx_mask'][i].device) for i in range(bs)]
        gt_bboxes_3d = [LiDARInstance3DBoxes(data_dict['object_bbx_center'][i]
            [data_dict['object_bbx_mask'][i] > 0, :], origin=(0.5, 0.5, 0.5))
            for i in range(bs)]

        # construct the 'img_metas' variable
        # type: list, element: dict
        # length == batch size
        img_metas = {'box_type_3d': LiDARInstance3DBoxes}
        if data_dict.get('metas', None) is not None:
            for k,v in data_dict['metas'].items():
                img_metas[k] = v
        img_meta_list = [img_metas for _ in range(bs)]
        
        if return_loss:
            losses = self.detector.forward_train(
                points, 
                img_meta_list,
                gt_bboxes_3d, 
                gt_labels_3d
            )
             # dict_keys(['loss_sem_seg', 'loss_vote', 'recall_Car', 'num_clusters', 
            # 'num_fg_points', 'loss_cls.task0', 'loss_center.task0', 
            # 'loss_size.task0', 'loss_rot.task0', 'loss_rcnn_cls', 'num_pos_rois', 
            # 'num_neg_rois', 'loss_rcnn_bbox', 'loss_rcnn_corner'])
            
            # copy from mmdet3d.models.detector.base.py, 249-th row
            loss, log_vars = self._parse_losses(losses)
            outputs = dict(
                loss=loss, log_vars=log_vars, num_samples=len(img_meta_list))
            return outputs
        else:
            return self.detector.simple_test(points, img_meta_list, rescale=True)
            
       