from typing import List, Optional
import torch, torch.nn as nn

from mmdet.registry import MODELS
from mmengine import build_from_cfg
from mmengine.model import BaseModule
from .utils import GaussianPrediction


@MODELS.register_module()
class GaussianEncoder(BaseModule):
    def __init__(
        self,
        anchor_encoder: dict,
        norm_layer: dict,
        ffn: dict,
        deformable_model_pts: dict,
        deformable_model_img: dict,
        implicit_fusion: dict,
        refine_layer: dict,
        mid_refine_layer: dict = None,
        spconv_layer: dict = None,
        num_decoder: int = 6,
        operation_order: Optional[List[str]] = None,
        init_cfg=None,
        **kwargs,
    ):
        super().__init__(init_cfg)
        self.num_decoder = num_decoder

        if operation_order is None:
            operation_order = [
                "spconv",
                "norm",
                "deformable_img",
                "deformable_pts",
                "norm",
                "ffn",
                "norm",
                "refine",
            ] * num_decoder
        self.operation_order = operation_order

        # =========== build modules ===========
        def build(cfg, registry):
            if cfg is None:
                return None
            return build_from_cfg(cfg, registry)

        self.anchor_encoder = build(anchor_encoder, MODELS)
        self.op_config_map = {
            "norm": [norm_layer, MODELS],
            "ffn": [ffn, MODELS],
            "deformable_img": [deformable_model_img, MODELS],
            "deformable_pts": [deformable_model_pts, MODELS],
            "implicit_fusion": [implicit_fusion, MODELS],
            "refine": [refine_layer, MODELS],
            "mid_refine": [mid_refine_layer, MODELS],
            "spconv": [spconv_layer, MODELS],
        }
        self.layers = nn.ModuleList(
            [
                build(*self.op_config_map.get(op, [None, None]))
                for op in self.operation_order
            ]
        )

    def init_weights(self):
        for i, op in enumerate(self.operation_order):
            if self.layers[i] is None:
                continue
            elif op != "refine":
                for p in self.layers[i].parameters():
                    if p.dim() > 1:
                        nn.init.xavier_uniform_(p)
        for m in self.modules():
            if hasattr(m, "init_weight"):
                m.init_weight()

    def forward(
        self,
        representation,
        rep_features,
        implicit_features,
        ms_img_feats=None,
        ms_pts_feats=None,
        metas=None,
        **kwargs,
    ):
        img_feature_maps = ms_img_feats
        pts_feature_maps = ms_pts_feats
        if isinstance(img_feature_maps, torch.Tensor):
            img_feature_maps = [img_feature_maps]
        if isinstance(pts_feature_maps, torch.Tensor):
            pts_feature_maps = [pts_feature_maps]
        instance_feature = rep_features
        anchor = representation
        anchor_embed = self.anchor_encoder(anchor)
        img_features = img_feature_maps[-1].permute(0, 1, 3, 4, 2).flatten(1, 3)
        pts_features = pts_feature_maps[-1].permute(0, 1, 3, 4, 2).flatten(1, 3)

        prediction = []
        for i, op in enumerate(self.operation_order):
            if op == "spconv":
                instance_feature = self.layers[i](instance_feature, anchor)
            elif op == "norm" or op == "ffn":
                instance_feature = self.layers[i](instance_feature)
            elif op == "identity":
                identity = instance_feature
            elif op == "add":
                instance_feature = instance_feature + identity
            elif op == "deformable_img":
                instance_feature = self.layers[i](
                    instance_feature,
                    anchor,
                    anchor_embed,
                    img_feature_maps,
                    metas,
                )
            elif op == "deformable_pts":
                instance_feature = self.layers[i](
                    instance_feature,
                    anchor,
                    anchor_embed,
                    pts_feature_maps,
                )
            elif op == "implicit_fusion":
                implicit_features = self.layers[i](
                    instance_feature,
                    implicit_features,
                    anchor,
                    img_features,
                    pts_features,
                )
            elif "refine" in op:
                anchor, gaussian = self.layers[i](
                    instance_feature,
                    implicit_features,
                    anchor,
                    anchor_embed,
                )

                prediction.append({"gaussian": gaussian})

                if i != len(self.operation_order) - 1:
                    anchor_embed = self.anchor_encoder(anchor)

            else:
                raise NotImplementedError(f"{op} is not supported.")

        return {"representation": prediction}

    # def forward(
    #     self,
    #     representation,
    #     rep_features,
    #     ms_img_feats=None,
    #     ms_pts_feats=None,
    #     metas=None,
    #     **kwargs,
    # ):
    #     img_feature_maps = ms_img_feats
    #     pts_feature_maps = ms_pts_feats
    #     if isinstance(img_feature_maps, torch.Tensor):
    #         img_feature_maps = [img_feature_maps]
    #     if isinstance(pts_feature_maps, torch.Tensor):
    #         pts_feature_maps = [pts_feature_maps]
    #     instance_feature = rep_features
    #     anchor = representation

    #     latency_list = []
    #     starter = torch.cuda.Event(enable_timing=True)
    #     ender = torch.cuda.Event(enable_timing=True)
    #     anchor_embed = self.anchor_encoder(anchor)

    #     prediction = []
    #     for i, op in enumerate(self.operation_order):
    #         if op == "spconv":
    #             torch.cuda.synchronize()
    #             starter.record()
    #             instance_feature = self.layers[i](instance_feature, anchor)
    #             ender.record()
    #             torch.cuda.synchronize()
    #             elapsed_time_ms = starter.elapsed_time(ender)
    #             latency_list.append({"spconv": elapsed_time_ms})

    #         elif op == "norm" or op == "ffn":
    #             torch.cuda.synchronize()
    #             starter.record()
    #             instance_feature = self.layers[i](instance_feature)
    #             ender.record()
    #             torch.cuda.synchronize()
    #             elapsed_time_ms = starter.elapsed_time(ender)
    #             latency_list.append({op: elapsed_time_ms})
    #         elif op == "identity":
    #             identity = instance_feature
    #         elif op == "add":
    #             instance_feature = instance_feature + identity
    #         elif op == "deformable_img":
    #             torch.cuda.synchronize()
    #             starter.record()
    #             instance_feature = self.layers[i](
    #                 instance_feature,
    #                 anchor,
    #                 anchor_embed,
    #                 img_feature_maps,
    #                 metas,
    #             )
    #             ender.record()
    #             torch.cuda.synchronize()
    #             elapsed_time_ms = starter.elapsed_time(ender)
    #             latency_list.append({"img": elapsed_time_ms})
    #         elif op == "deformable_pts":
    #             torch.cuda.synchronize()
    #             starter.record()
    #             instance_feature = self.layers[i](
    #                 instance_feature,
    #                 anchor,
    #                 anchor_embed,
    #                 pts_feature_maps,
    #             )
    #             ender.record()
    #             torch.cuda.synchronize()
    #             elapsed_time_ms = starter.elapsed_time(ender)
    #             latency_list.append({"pts": elapsed_time_ms})
    #         elif "refine" in op:
    #             torch.cuda.synchronize()
    #             starter.record()
    #             anchor, gaussian = self.layers[i](
    #                 instance_feature,
    #                 anchor,
    #                 anchor_embed,
    #             )
    #             ender.record()
    #             torch.cuda.synchronize()
    #             elapsed_time_ms = starter.elapsed_time(ender)
    #             latency_list.append({"refine": elapsed_time_ms})

    #             prediction.append({"gaussian": gaussian})
    #             if i != len(self.operation_order) - 1:
    #                 anchor_embed = self.anchor_encoder(anchor)
    #         else:
    #             raise NotImplementedError(f"{op} is not supported.")

    #     return {"representation": prediction}
