import torch
import torch.nn as nn
from mmdet3d.models.detectors.two_stage_fsd import FSD
from mmcv import Config
from mmdet3d.models import build_model
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from collections import OrderedDict
import torch.distributed as dist
from opencood.models.fuse_modules.cluster_fusion import ClusterFusion
from opencood.visualization.cppc_vis import VisUtil

class FSDIntermediate(nn.Module):
    def __init__(self, args):
        super(FSDIntermediate, self).__init__()
        self.detector_cfg = Config.fromfile(args['detector_cfg'])
        self.detector = build_model(self.detector_cfg.model,
                    train_cfg=self.detector_cfg.get('train_cfg'),
                    test_cfg=self.detector_cfg.get('test_cfg'))
        
        self.detector.fusion_module = ClusterFusion(args.get('fusion_cfg', {}))
        
        self.use_single_label = args.get('use_single_label', False)
        self.proj_first = args.get('proj_first', False)
    
    # copy from mmdet3d.models.detector.base.py, 176-th row
    def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary information.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
                which may be a weighted sum of all losses, log_vars contains \
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                if 'loss' in _key)

        # If the loss_vars has different length, GPUs will wait infinitely
        if dist.is_available() and dist.is_initialized():
            log_var_length = torch.tensor(len(log_vars), device=loss.device)
            dist.all_reduce(log_var_length)
            message = (f'rank {dist.get_rank()}' +
                    f' len(log_vars): {len(log_vars)}' + ' keys: ' +
                    ','.join(log_vars.keys()))
            assert log_var_length == len(log_vars) * dist.get_world_size(), \
                'loss log variables are different across GPUs!\n' + message

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars
    
    def get_scene_level_gt(self, data_dict):
        record_len = data_dict['record_len']
        bs = len(record_len)
        
        gt_labels_3d = [torch.zeros(int(data_dict['object_bbx_mask'][i].sum().item()), dtype=torch.long).to(
            data_dict['object_bbx_mask'][i].device) for i in range(bs)]
        gt_bboxes_3d = [LiDARInstance3DBoxes(data_dict['object_bbx_center'][i]
            [data_dict['object_bbx_mask'][i] > 0, :], origin=(0.5, 0.5, 0.5))
            for i in range(bs)]
        
        return gt_bboxes_3d, gt_labels_3d

    def get_car_level_gt(self, data_dict):
        """get car level gt box and label(diretly load or project from scene level)
        """
        points = data_dict['processed_lidar']['points']
        record_len = data_dict['record_len']
        bs = len(record_len)
        if 'object_bbx_center_single_v' in data_dict.keys() and 'object_bbx_center_single_i' in data_dict.keys() \
            and (self.use_single_label or not self.proj_first):
            # have origin car-level gt box(load from ego json) when not project first (set for DAIR-V2X)
            gt_bboxes_3d_cars = []
            gt_labels_3d_cars = []
            for i in range(bs):
                mask_v = data_dict['object_bbx_mask_single_v'][i].bool()
                mask_i = data_dict['object_bbx_mask_single_i'][i].bool()
                gt_bboxes_3d_cars.append(LiDARInstance3DBoxes(data_dict['object_bbx_center_single_v'][i][mask_v, :],
                                                              origin=(0.5, 0.5, 0.5)))
                gt_bboxes_3d_cars.append(LiDARInstance3DBoxes(data_dict['object_bbx_center_single_i'][i][mask_i, :],
                                                              origin=(0.5, 0.5, 0.5)))
                gt_labels_3d_cars.append(torch.zeros(int(mask_v.sum().item()), dtype=torch.long).to(
                                                              data_dict['object_bbx_mask'][i].device))
                gt_labels_3d_cars.append(torch.zeros(int(mask_i.sum().item()), dtype=torch.long).to(
                                                              data_dict['object_bbx_mask'][i].device))
        elif 'car_object_bbx_center' in data_dict.keys():
            # have origin car-level gt box when not project first (set for V2X-Set)
            gt_bboxes_3d_cars =  [LiDARInstance3DBoxes(data_dict['car_object_bbx_center'][i], origin=(0.5, 0.5, 0.5))
                for i in range(len(data_dict['car_object_bbx_center']))]
            gt_labels_3d_cars = [torch.zeros(data_dict['car_object_bbx_center'][i].shape[0], dtype=torch.long).to(
                gt_bboxes_3d_cars[0].device) for i in range(len(gt_bboxes_3d_cars))] 
        elif 'object_bbx_idx' in data_dict.keys():
            # remapping scene-level box to car-level box by object_bbx_idx when project (set for V2X-Set)
            object_bbx_idx = [
                data_dict['object_bbx_idx'][i][:, :record_len[i]].transpose(0, 1) 
                for i in range(bs)
            ]
            # [car_num*bs, max_object_num], boolean
            gt_bboxes_3d_cars = [LiDARInstance3DBoxes(data_dict['object_bbx_center'][i] 
                [object_bbx_idx[i][j] > 0, :], origin=(0.5, 0.5, 0.5)) for i in range(bs) for j in range(record_len[i])]
            object_bbx_idx = torch.cat(object_bbx_idx, dim=0)
            gt_labels_3d_cars = [torch.zeros(int(object_bbx_idx[i].sum().item()), dtype=torch.long).to(
                object_bbx_idx.device) for i in range(object_bbx_idx.shape[0])]
        else:
            gt_bboxes_3d_cars = gt_labels_3d_cars = None

        if 'object_bbx_center_single_v' in data_dict.keys() and len(points) > 2:
            return gt_bboxes_3d_cars, gt_labels_3d_cars
            
        assert len(gt_labels_3d_cars) == len(points)
        assert len(gt_bboxes_3d_cars) == len(points)

        return gt_bboxes_3d_cars, gt_labels_3d_cars
    
    def forward(self, data_dict, return_loss=True, **kwargs):
        points = data_dict['processed_lidar']['points'] # list, length == batch size * car_num
        record_len = data_dict['record_len'] # list, length == batch size
        bs = len(record_len)
        
        gt_bboxes_3d, gt_labels_3d = self.get_scene_level_gt(data_dict)
        gt_bboxes_3d_cars, gt_labels_3d_cars = self.get_car_level_gt(data_dict)

        # construct the 'img_metas' variable
        # type: list, element: dict
        # length == batch size
        img_metas = {'box_type_3d': LiDARInstance3DBoxes, 'record_len': record_len}
        if data_dict.get('metas', None) is not None:
            for k,v in data_dict['metas'].items():
                img_metas[k] = v
        img_meta_list = [img_metas.copy() for _ in range(bs)]
        if 'frame_ids' in data_dict.keys():
            # only used in DAIR-V2X dataset
            for img_meta, frame_id in zip(img_meta_list, data_dict['frame_ids']):
                img_meta.update({'frame_id': frame_id})
        
        # extend image_metas from scene_level to car_level
        img_meta_list = [img_meta_list[i].copy() for i in range(bs) for _ in range(record_len[i])]
        # get agent project to ego matrix
        for i in range(bs):
            for j in range(record_len[i]):
                # default in each scene, idx 0 is ego car
                try:
                    time_delay = data_dict['time_delay'][i][j]
                except:
                    time_delay = 0
                img_meta_list[record_len[:i].sum().item() + j].update({
                        'proj2ego_matrix': data_dict['pairwise_t_matrix'][i, j, 0],
                        'lidar_pose': data_dict['lidar_pose'][i][j],
                        'lidar_pose_clean': data_dict['lidar_pose_clean'][i][j],
                        'time_delay': time_delay,
                    })
        # set VisUtil
        lidar_range = self.detector.segmentor.point_cloud_range if self.detector.__class__.__name__ != 'VoteSegmentor' else self.point_cloud_range
        VisUtil.set_args(
            frame_id=img_meta_list[0]['frame_id'] if 'frame_ids' in data_dict.keys() else '',
            idx=img_meta_list[0]['batch_idx'],
            lidar_range=lidar_range,
        )
        if return_loss:
            if self.detector.__class__.__name__ == 'VoteSegmentor':
                # pretrain segmentor without collaboration
                losses = self.detector.forward_train(
                    points, img_meta_list,
                    gt_bboxes_3d_cars, gt_labels_3d_cars,
                )
            else:
                # end2end train the pipline
                losses = self.detector.forward_train(
                    points, img_meta_list,
                    gt_bboxes_3d, gt_labels_3d,
                    gt_bboxes_3d_cars, gt_labels_3d_cars,
                )
            
            loss, log_vars = self._parse_losses(losses)
            outputs = dict(
                loss=loss, log_vars=log_vars, num_samples=len(img_meta_list))
            return outputs
        else:
            return self.detector.simple_test(points, img_meta_list, rescale=True, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d)
            
       