import mmcv
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from mmcv.runner import BaseModule, force_fp32, auto_fp16
from mmdet.core import multi_apply
from mmdet.models import DETECTORS
from mmdet.models.builder import build_backbone, build_detector
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet3d.datasets.pipelines import Compose
from projects.mmdet3d_plugin.datasets.pipelines import TimeCompose
from projects.mmdet3d_plugin.utils import CUDATimer
from projects.mmdet3d_plugin.models.utils.detr3d_transformer import DeformableFeatureAggregationCuda

from typing import Optional


@DETECTORS.register_module()
class PipelineDetector(BaseModule):

    def __init__(self,
                 detector,
                 img_pipeline=None,
                 pts_pipeline=None,
                 use_time_compose=False,
                 dataset=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 **kwargs):
        super(PipelineDetector, self).__init__(**kwargs)
        self.detector = build_detector(detector)
        if use_time_compose:
            self.img_pipeline = TimeCompose(img_pipeline) if img_pipeline is not None else None
            self.pts_pipeline = TimeCompose(pts_pipeline) if pts_pipeline is not None else None
        else:
            self.img_pipeline = Compose(img_pipeline) if img_pipeline is not None else None
            self.pts_pipeline = Compose(pts_pipeline) if pts_pipeline is not None else None
        self.dataset = dataset

    def load_img(self):
        raise NotImplementedError()

    def load_pts(self):
        raise NotImplementedError()

    def forward(self):
        raise NotImplementedError()

@DETECTORS.register_module()
class CmtPipelineDetector(PipelineDetector):

    def load_img(self, sample_idx=None, data_info=None, used_cameras=None):
        assert (sample_idx is None) ^ (data_info is None), 'One and only one of `sample_idx` and `data_info` can be non-None.'

        if data_info is None:
            data_info = self.dataset.custom_get_data_info(sample_idx, used_cameras=used_cameras, use_lidar=False)
            self.dataset.pre_pipeline(data_info)
        outputs, time_dict = self.img_pipeline(data_info)
        time_data_loading, time_data_processing = postprocess_time_dict(time_dict)
        outputs.update({'time_data_loading': time_data_loading, 'time_data_processing': time_data_processing})
        return outputs

    def load_pts(self, sample_idx=None, data_info=None):
        assert (sample_idx is None) ^ (data_info is None), 'One and only one of `sample_idx` and `data_info` can be non-None.'

        if data_info is None:
            data_info = self.dataset.custom_get_data_info(sample_idx, used_cameras=[], use_lidar=True)
            self.dataset.pre_pipeline(data_info)
        outputs, time_dict = self.pts_pipeline(data_info)
        time_data_loading, time_data_processing = postprocess_time_dict(time_dict)
        outputs.update({'time_data_loading': time_data_loading, 'time_data_processing': time_data_processing})
        return outputs

    def forward(self):
        raise NotImplementedError()

@DETECTORS.register_module()
class RepDetr3DPipelineDetector(PipelineDetector):

    def __init__(self, detector, img_pipeline=None, pts_pipeline=None, use_time_compose=False, dataset=None, train_cfg=None, test_cfg=None, pretrained=None, **kwargs):
        super().__init__(detector, img_pipeline, pts_pipeline, use_time_compose, dataset, train_cfg, test_cfg, pretrained, **kwargs)
        self._deform_feature_aggregation_modules = []
        for module in self.detector.modules():
            if isinstance(module, DeformableFeatureAggregationCuda):
                self._deform_feature_aggregation_modules.append(module)

    def load_img(self, sample_idx, used_cameras=None):
        data_info = self.dataset.custom_get_data_info(sample_idx, used_cameras=used_cameras)
        self.dataset.pre_pipeline(data_info)
        outputs, time_dict = self.img_pipeline(data_info)
        time_data_loading, time_data_processing = postprocess_time_dict(time_dict)
        outputs.update({'time_data_loading': time_data_loading, 'time_data_processing': time_data_processing})
        return outputs

    def load_pts(self):
        raise NotImplementedError()

    def set_num_cams_valid(self, num_cams_valid: Optional[int] = None):
        if num_cams_valid is not None:
            assert num_cams_valid >= 1

        for module in self._deform_feature_aggregation_modules:
            module.num_cams_valid = num_cams_valid

    def forward(self):
        raise NotImplementedError()


@DETECTORS.register_module()
class VisionTransformerPipelineDetector(PipelineDetector):

    def load_img(self, sample_idx):
        imgs, token, time_dict = self.dataset[sample_idx]
        outputs = {
            'imgs': imgs,
            'token': token,
        }
        outputs.update(time_dict)
        return outputs

    def load_pts(self):
        raise NotImplementedError()

    def forward(self, imgs):
        return self.detector(imgs)


def postprocess_time_dict(time_dict):
    time_loading = 0.
    time_processing = 0.
    for name, times in time_dict.items():
        if 'Load' in name:
            time_loading += times['elapsed']
        else:
            time_processing += times['elapsed']
    return time_loading, time_processing
