from inspect import signature

import torch

from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from mmdet.models import (
    DETECTORS,
    BaseDetector,
    build_backbone,
    build_head,
    build_neck,
)
from .grid_mask import GridMask
import time
try:
    from ..ops import feature_maps_format
    DAF_VALID = True
except:
    DAF_VALID = False

__all__ = ["ATDRIVE"]


@DETECTORS.register_module()
class ATDRIVE(BaseDetector):
    def __init__(
        self,
        img_backbone,
        head,
        img_neck=None,
        init_cfg=None,
        train_cfg=None,
        test_cfg=None,
        pretrained=None,
        use_grid_mask=True,
        use_deformable_func=False,
        depth_branch=None,
    ):
        super(ATDRIVE, self).__init__(init_cfg=init_cfg)
        if pretrained is not None:
            backbone.pretrained = pretrained
        self.img_backbone = build_backbone(img_backbone)
        if img_neck is not None:
            # img_neck.pretrained = pretrained
            self.img_neck = build_neck(img_neck)
            # # torch.save(self.img_neck.state_dict(), 'img_neck.pth')
            # import pdb; pdb.set_trace()
            # self.img_neck.load_state_dict(torch.load('/mnt/private-user-data/ed/ATDRIVEv7/img_neck.pth'))
            # torch.save(self.img_neck, 'img_neck_model.pth')
            
            # self.img_neck = torch.load('/mnt/private-user-data/ed/ATDRIVEv7/img_neck_model.pth')
        self.head = build_head(head)
        self.use_grid_mask = use_grid_mask
        if use_deformable_func:
            assert DAF_VALID, "deformable_aggregation needs to be set up."
        self.use_deformable_func = use_deformable_func
        if depth_branch is not None:
            self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
        else:
            self.depth_branch = None
        if use_grid_mask:
            self.grid_mask = GridMask(
                True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
            ) 
        # vis
        # from mmdet3d_plugin.models.vis_dataset import Visualizer
        from .vis_dataset import Visualizer
        self.visualizer = Visualizer()

    @auto_fp16(apply_to=("img",), out_fp32=True)
    def extract_feat(self, img, return_depth=False, metas=None):
        # import pdb; pdb.set_trace()
        bs = img.shape[0]
        if img.dim() == 5:  # multi-view
            num_cams = img.shape[1]
            img = img.flatten(end_dim=1)
        else:
            num_cams = 1
        if self.use_grid_mask:
            img = self.grid_mask(img)
        if "metas" in signature(self.img_backbone.forward).parameters:
            feature_maps = self.img_backbone(img, num_cams, metas=metas)
        else:
            feature_maps = self.img_backbone(img)
        
        # from thop import profile
        # Flops, params = profile(self.img_neck, inputs=(feature_maps,)) # macs
        # print('Flops2: % .4fG'%(Flops / 1000000000))# 计算量
        # print('params参数量: % .4fM'% (params / 1000000))

        if self.img_neck is not None:
            # import pdb; pdb.set_trace()
            # self.img_neck = torch.load('/mnt/private-user-data/ed/ATDRIVEv7/img_neck_model.pth').half().to('cuda')
            feature_maps = list(self.img_neck(feature_maps))
        

        for i, feat in enumerate(feature_maps):
            feature_maps[i] = torch.reshape(
                feat, (bs, num_cams) + feat.shape[1:]
            )
        if return_depth and self.depth_branch is not None:
            depths = self.depth_branch(feature_maps, metas.get("focal"))
            # Flops, params = profile(self.depth_branch, inputs=(feature_maps, metas.get("focal"))) # macs
            # print('Flops3: % .4fG'%(Flops / 1000000000))# 计算量
            # print('params参数量: % .4fM'% (params / 1000000))
        else:
            depths = None
        if self.use_deformable_func:
            feature_maps = feature_maps_format(feature_maps)
        if return_depth:
            return feature_maps, depths
        
        # Flops, params = profile(self.img_backbone, inputs=(img,)) # macs
        # print('Flops1: % .4fG'%(Flops / 1000000000))# 计算量
        # print('params参数量: % .4fM'% (params / 1000000))

        

        return feature_maps

    @force_fp32(apply_to=("img",))
    def forward(self, img, **data):
        # import pdb; pdb.set_trace()
        if self.training:
            return self.forward_train(img, **data)
        else:
            return self.forward_test(img, **data)

    def forward_train(self, img, **data):
        # import pdb; pdb.set_trace()
        # self.vis_data(img, data)
        # for i in range(len(data["timestamp"].cpu().numpy())):
        #     print("=====", i,"==", torch.cuda.current_device(),"==", str(data["timestamp"].cpu().numpy()[i]), img.shape)

        feature_maps, depths = self.extract_feat(img, True, data)
        model_outs = self.head(feature_maps, data)
        output = self.head.loss(model_outs, data)
        if depths is not None and "gt_depth" in data:
            output["loss_dense_depth"] = self.depth_branch.loss(
                depths, data["gt_depth"]
            )
            
        # output['det_output'] = model_outs[0]
        # output['map_output']= model_outs[1]
        # output['motion_output']= model_outs[2]
        # output['planning_output']= model_outs[3]
        # from thop import profile
        # Flops, params = profile(self.extract_feat, inputs=(img,)) # macs
        # print('Flops: % .4fG'%(Flops / 1000000000))# 计算量
        # print('params参数量: % .4fM'% (params / 1000000))
        return output, model_outs

    def forward_test(self, img, **data):
        if isinstance(img, list):
            return self.aug_test(img, **data)
        else:
            return self.simple_test(img, **data)

    def simple_test(self, img, **data):
        t0 = time.time()
        feature_maps = self.extract_feat(img)

        model_outs = self.head(feature_maps, data)
        results = self.head.post_process(model_outs, data)

        # from thop import profile
        # Flops, params = profile(self.head, inputs=(feature_maps, data)) # macs
        # print('Flops: % .4fG'%(Flops / 1000000000))# 计算量
        # print('params参数量: % .4fM'% (params / 1000000))
        output = [{"img_bbox": result} for result in results]
        return output

    def aug_test(self, img, **data):
        # fake test time augmentation
        for key in data.keys():
            if isinstance(data[key], list):
                data[key] = data[key][0]
        return self.simple_test(img[0], **data)
    
    def train_step(self, data, optimizer):
        """The iteration step during training.

        This method defines an iteration step during training, except for the
        back propagation and optimizer updating, which are done in an optimizer
        hook. Note that in some complicated cases or models, the whole process
        including back propagation and optimizer updating is also defined in
        this method, such as GAN.

        Args:
            data (dict): The output of dataloader.
            optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
                runner is passed to ``train_step()``. This argument is unused
                and reserved.

        Returns:
            dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
                ``num_samples``.

                - ``loss`` is a tensor for back propagation, which can be a
                  weighted sum of multiple losses.
                - ``log_vars`` contains all the variables to be sent to the
                  logger.
                - ``num_samples`` indicates the batch size (when the model is
                  DDP, it means the batch size on each GPU), which is used for
                  averaging the logs.
        """
        losses, model_outs = self(**data)
        
        loss, log_vars = self._parse_losses(losses)

        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))

        return outputs, model_outs
    
    def val_step(self, data, optimizer=None):
        """The iteration step during validation.

        This method shares the same signature as :func:`train_step`, but used
        during val epochs. Note that the evaluation after training epochs is
        not implemented with this method, but an evaluation hook.
        """
        losses, model_outs = self(**data)
        loss, log_vars = self._parse_losses(losses)

        log_vars_ = dict()
        for loss_name, loss_value in log_vars.items():
            k = loss_name + '_val'
            log_vars_[k] = loss_value

        outputs = dict(
            loss=loss, log_vars=log_vars_, num_samples=len(data['img_metas']))

        return outputs

    def vis_data(self, img, data):
        self.visualizer.add_vis(img, data)
