from inspect import signature

import torch

from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from mmdet.models import (
    DETECTORS,
    BaseDetector,
    build_backbone,
    build_head,
    build_neck,
)
from .grid_mask import GridMask
import time
try:
    from ..ops import feature_maps_format
    DAF_VALID = True
except:
    DAF_VALID = False

__all__ = ["ADDI"]


@DETECTORS.register_module()
class ADDI(BaseDetector):
    def __init__(
        self,
        img_backbone,
        head,
        img_neck=None,
        init_cfg=None,
        train_cfg=None,
        test_cfg=None,
        pretrained=None,
        use_grid_mask=True,
        use_deformable_func=False,
        depth_branch=None,
    ):
        super(ADDI, self).__init__(init_cfg=init_cfg)
        if pretrained is not None:
            backbone.pretrained = pretrained
        self.img_backbone = build_backbone(img_backbone)
        if img_neck is not None:
            # img_neck.pretrained = pretrained
            self.img_neck = build_neck(img_neck)
            # # torch.save(self.img_neck.state_dict(), 'img_neck.pth')
            # import pdb; pdb.set_trace()
            # self.img_neck.load_state_dict(torch.load('/mnt/private-user-data/ed/ADDIv7/img_neck.pth'))
            # torch.save(self.img_neck, 'img_neck_model.pth')
            
            # self.img_neck = torch.load('/mnt/private-user-data/ed/ADDIv7/img_neck_model.pth')
        self.head = build_head(head)
        self.use_grid_mask = use_grid_mask
        if use_deformable_func:
            assert DAF_VALID, "deformable_aggregation needs to be set up."
        self.use_deformable_func = use_deformable_func
        if depth_branch is not None:
            self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
        else:
            self.depth_branch = None
        if use_grid_mask:
            self.grid_mask = GridMask(
                True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
            ) 

    @auto_fp16(apply_to=("img",), out_fp32=True)
    def extract_feat(self, img, return_depth=False, metas=None):
        # import pdb; pdb.set_trace()
        bs = img.shape[0]
        if img.dim() == 5:  # multi-view
            num_cams = img.shape[1]
            img = img.flatten(end_dim=1)
        else:
            num_cams = 1
        if self.use_grid_mask:
            img = self.grid_mask(img)
        if "metas" in signature(self.img_backbone.forward).parameters:
            feature_maps = self.img_backbone(img, num_cams, metas=metas)
        else:
            feature_maps = self.img_backbone(img)
        
        # from thop import profile
        # Flops, params = profile(self.img_neck, inputs=(feature_maps,)) # macs
        # print('Flops2: % .4fG'%(Flops / 1000000000))# 
        # print('params: % .4fM'% (params / 1000000))

        if self.img_neck is not None:
            # import pdb; pdb.set_trace()
            # self.img_neck = torch.load('/mnt/private-user-data/ed/ADDIv7/img_neck_model.pth').half().to('cuda')
            feature_maps = list(self.img_neck(feature_maps))
        

        for i, feat in enumerate(feature_maps):
            feature_maps[i] = torch.reshape(
                feat, (bs, num_cams) + feat.shape[1:]
            )
        if return_depth and self.depth_branch is not None:
            depths = self.depth_branch(feature_maps, metas.get("focal"))
           
        else:
            depths = None
        if self.use_deformable_func:
            feature_maps = feature_maps_format(feature_maps)
        if return_depth:
            return feature_maps, depths
               

        return feature_maps

    @force_fp32(apply_to=("img",))
    def forward(self, img, **data):
        # import pdb; pdb.set_trace()
        if self.training:
            return self.forward_train(img, **data)
        else:
            return self.forward_test(img, **data)

    def forward_train(self, img, **data):
        # import pdb; pdb.set_trace()
        feature_maps, depths = self.extract_feat(img, True, data)
        model_outs = self.head(feature_maps, data)
        output = self.head.loss(model_outs, data)
        if depths is not None and "gt_depth" in data:
            output["loss_dense_depth"] = self.depth_branch.loss(
                depths, data["gt_depth"]
            )
        # from thop import profile
        # Flops, params = profile(self.extract_feat, inputs=(img,)) # macs
        # print('Flops: % .4fG'%(Flops / 1000000000))# 
        # print('params: % .4fM'% (params / 1000000))
        return output

    def forward_test(self, img, **data):
        if isinstance(img, list):
            return self.aug_test(img, **data)
        else:
            return self.simple_test(img, **data)

    def simple_test(self, img, **data):
        t0 = time.time()
        feature_maps = self.extract_feat(img)

        model_outs = self.head(feature_maps, data)
        results = self.head.post_process(model_outs, data)

        # from thop import profile
        # Flops, params = profile(self.head, inputs=(feature_maps, data)) # macs
        # print('Flops: % .4fG'%(Flops / 1000000000))# 
        # print('params: % .4fM'% (params / 1000000))
        output = [{"img_bbox": result} for result in results]
        return output

    def aug_test(self, img, **data):
        # fake test time augmentation
        for key in data.keys():
            if isinstance(data[key], list):
                data[key] = data[key][0]
        return self.simple_test(img[0], **data)
