from mmengine.registry import MODELS
from mmengine.model import BaseModule
from mmengine import build_from_cfg
from mmengine.model import xavier_init, constant_init
import torch, torch.nn as nn
import numpy as np
from typing import List, Optional
from .utils import safe_sigmoid
from .utils import get_rotation_matrix
from .utils import linear_relu_ln
from .ops import DeformableAggregationFunction as DAF


@MODELS.register_module()
class SparseGaussianKeyPointsGenerator_LiDAR(BaseModule):
    def __init__(
        self,
        embed_dims=256,
        num_learnable_pts=0,
        learnable_fixed_scale=1,
        fix_scale=None,
        pc_range=None,
        scale_range=None,
        xy_activation="sigmoid",
        scale_activation="sigmoid",
        rot_activation="sigmoid",
        **kwargs,
    ):
        super(SparseGaussianKeyPointsGenerator_LiDAR, self).__init__()
        self.embed_dims = embed_dims
        self.num_learnable_pts = num_learnable_pts
        self.learnable_fixed_scale = learnable_fixed_scale
        if fix_scale is None:
            fix_scale = ((0.0, 0.0),)
        self.fix_scale = np.array(fix_scale)
        self.num_pts = len(self.fix_scale) + num_learnable_pts
        if num_learnable_pts > 0:
            self.learnable_fc = nn.Linear(self.embed_dims, num_learnable_pts * 2)

        self.pc_range = pc_range
        self.scale_range = scale_range
        self.xy_act = xy_activation
        self.scale_act = scale_activation
        self.rot_act = rot_activation

    def init_weight(self):
        if self.num_learnable_pts > 0:
            xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)

    def forward(
        self,
        anchor,
        instance_feature=None,
    ):
        bs, num_anchor = anchor.shape[:2]
        fix_scale = anchor.new_tensor(self.fix_scale)
        # fix_scale = torch.from_numpy(self.fix_scale).type_as(anchor)

        scale = fix_scale[None, None].tile([bs, num_anchor, 1, 1])
        if self.num_learnable_pts > 0 and instance_feature is not None:
            learnable_scale = (
                safe_sigmoid(
                    self.learnable_fc(instance_feature).reshape(
                        bs, num_anchor, self.num_learnable_pts, 2
                    )
                )
                - 0.5
            )
            scale = torch.cat(
                [scale, learnable_scale * self.learnable_fixed_scale], dim=-2
            )

        gs_scales = anchor[..., None, 2:4]
        if self.scale_act == "sigmoid":
            gs_scales = safe_sigmoid(gs_scales)
        gs_scales = (
            self.scale_range[0]
            + (self.scale_range[1] - self.scale_range[0]) * gs_scales
        )

        key_points = scale * gs_scales

        gs_rots = anchor[..., 4:6]
        if self.rot_act == "sigmoid":
            gs_rots = 2 * safe_sigmoid(gs_rots) - 1
        rotation_mat = get_rotation_matrix(gs_rots)
        # rots = anchor[..., 4:8]
        # rotation_mat = get_rotation_matrix(rots).transpose(-1, -2)

        key_points = torch.cat(
            [key_points, torch.ones_like(key_points[..., :1])], dim=-1
        )
        key_points = torch.matmul(
            rotation_mat[:, :, None], key_points[..., None]
        ).squeeze(-1)[..., :2]

        xy = anchor[..., :2]
        if self.xy_act == "sigmoid":
            xy = safe_sigmoid(xy)

        xx = xy[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
        yy = xy[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]

        xy = torch.stack([xx, yy], dim=-1)

        key_points_bev = key_points + xy.unsqueeze(2)

        return key_points_bev


@MODELS.register_module()
class DeformableFeatureAggregation_LiDAR(BaseModule):
    def __init__(
        self,
        embed_dims: int = 256,
        num_groups: int = 8,
        num_levels: int = 4,
        num_cams: int = 6,
        proj_drop: float = 0.0,
        attn_drop: float = 0.0,
        kps_generator: dict = None,
        use_deformable_func=False,
        use_camera_embed=False,
        residual_mode="add",
    ):
        super(DeformableFeatureAggregation_LiDAR, self).__init__()
        if embed_dims % num_groups != 0:
            raise ValueError(
                f"embed_dims must be divisible by num_groups, "
                f"but got {embed_dims} and {num_groups}"
            )
        self.group_dims = int(embed_dims / num_groups)
        self.embed_dims = embed_dims
        self.num_levels = num_levels
        self.num_groups = num_groups
        self.num_cams = num_cams
        self.use_deformable_func = use_deformable_func and DAF is not None
        assert self.use_deformable_func
        self.attn_drop = attn_drop
        self.residual_mode = residual_mode
        self.proj_drop = nn.Dropout(proj_drop)
        kps_generator["embed_dims"] = embed_dims
        self.kps_generator = build_from_cfg(kps_generator, MODELS)
        self.num_pts = self.kps_generator.num_pts
        self.pc_range = self.kps_generator.pc_range

        self.output_proj = nn.Linear(embed_dims, embed_dims)

        if use_camera_embed:
            self.camera_encoder = nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, 12))
            self.weights_fc = nn.Linear(
                embed_dims, num_groups * num_levels * self.num_pts
            )
        else:
            self.camera_encoder = None
            self.weights_fc = nn.Linear(
                embed_dims, num_groups * num_cams * num_levels * self.num_pts
            )

    def init_weight(self):
        constant_init(self.weights_fc, val=0.0, bias=0.0)
        xavier_init(self.output_proj, distribution="uniform", bias=0.0)

    def forward(
        self,
        instance_feature: torch.Tensor,
        anchor: torch.Tensor,
        anchor_embed: torch.Tensor,
        feature_maps: List[torch.Tensor],
        **kwargs: dict,
    ):
        bs, num_anchor = instance_feature.shape[:2]
        # Generate the key points for each gaussian cluster according to the anchor and instance feature
        key_points_bev = self.kps_generator(anchor, instance_feature)
        temp_key_points_list = feature_queue = temp_anchor_embeds = []
        if self.use_deformable_func:
            feature_maps = DAF.feature_maps_format(feature_maps)

        for (
            temp_feature_maps,
            temp_key_points,
            temp_anchor_embed,
        ) in zip(
            feature_queue[::-1] + [feature_maps],
            temp_key_points_list[::-1] + [key_points_bev],
            temp_anchor_embeds[::-1] + [anchor_embed],
        ):
            weights, weight_mask = self._get_weights(
                instance_feature, temp_anchor_embed
            )
            if self.use_deformable_func:
                weights = (
                    weights.permute(0, 1, 4, 2, 3, 5)
                    .contiguous()
                    .reshape(
                        bs,
                        num_anchor,
                        self.num_pts,
                        self.num_cams,
                        self.num_levels,
                        self.num_groups,
                    )
                )
                weight_mask = (
                    weight_mask.permute(0, 1, 4, 2, 3, 5)
                    .contiguous()
                    .reshape(
                        bs,
                        num_anchor,
                        self.num_pts,
                        self.num_cams,
                        self.num_levels,
                        self.num_groups,
                    )
                )
                points_2d, mask = self.project_points(temp_key_points, self.pc_range)
                points_2d = points_2d.permute(0, 2, 3, 1, 4).reshape(
                    bs, num_anchor * self.num_pts, self.num_cams, 2
                )
                mask = mask.permute(0, 2, 3, 1)
                mask = mask[..., None, None] & weight_mask
                all_miss = mask.sum(dim=[2, 3, 4], keepdim=True) == 0
                all_miss = all_miss.expand(
                    -1, -1, self.num_pts, self.num_cams, self.num_levels, -1
                )
                weights[~mask] = -torch.inf
                weights[all_miss] = 0.0
                weights = (
                    weights.flatten(2, 4)
                    .softmax(dim=-2)
                    .reshape(
                        bs,
                        num_anchor * self.num_pts,
                        self.num_cams,
                        self.num_levels,
                        self.num_groups,
                    )
                )
                # weights_clone = weights.detach().clone()
                # weights_clone[~all_miss.flatten(1, 2)] = 0.
                # weights = weights - weights_clone
                weights = weights * (1 - all_miss.flatten(1, 2).float())

                temp_features_next = DAF.apply(
                    *temp_feature_maps, points_2d, weights
                ).reshape(bs, num_anchor, self.num_pts, self.embed_dims)
            else:
                temp_features_next = self.feature_sampling(
                    temp_feature_maps,
                    temp_key_points,
                    self.pc_range,
                )
                temp_features_next = self.multi_view_level_fusion(
                    temp_features_next, weights
                )

            features = temp_features_next

        features = features.sum(dim=2)  # fuse multi-point features
        output = self.proj_drop(self.output_proj(features))
        if self.residual_mode == "add":
            output = output + instance_feature
        elif self.residual_mode == "cat":
            output = torch.cat([output, instance_feature], dim=-1)
        return output

    def _get_weights(self, instance_feature, anchor_embed, metas=None):
        bs, num_anchor = instance_feature.shape[:2]
        feature = instance_feature + anchor_embed
        weights = (
            self.weights_fc(feature).reshape(bs, num_anchor, -1, self.num_groups)
            # .softmax(dim=-2)
            .reshape(
                bs,
                num_anchor,
                self.num_cams,
                self.num_levels,
                self.num_pts,
                self.num_groups,
            )
        )
        if self.training and self.attn_drop > 0:
            # mask = torch.rand(
            #     bs, num_anchor, self.num_cams, 1, self.num_pts, 1
            # )
            # mask = mask.to(device=weights.device, dtype=weights.dtype)
            # weights = ((mask > self.attn_drop) * weights) / (
            #     1 - self.attn_drop
            # )
            mask = torch.rand_like(weights)
            mask = mask > self.attn_drop
        else:
            mask = torch.ones_like(weights) > 0
        return weights, mask

    @staticmethod
    def project_points(key_points, pc_range):
        key_points[..., 0] = (key_points[..., 0] - pc_range[0]) / (
            pc_range[3] - pc_range[0]
        )
        key_points[..., 1] = (key_points[..., 1] - pc_range[1]) / (
            pc_range[4] - pc_range[1]
        )

        mask = (
            (key_points[..., 0] >= 0)
            & (key_points[..., 1] >= 0)
            & (key_points[..., 0] < 1)
            & (key_points[..., 1] < 1)
        )
        return key_points.unsqueeze(1), mask.unsqueeze(1)

    @staticmethod
    def feature_sampling(
        feature_maps: List[torch.Tensor],
        key_points: torch.Tensor,
        pc_range: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        num_levels = len(feature_maps)
        num_cams = feature_maps[0].shape[1]
        bs, num_anchor, num_pts = key_points.shape[:3]

        points_2d, _ = DeformableFeatureAggregation_LiDAR.project_points(
            key_points, pc_range
        )
        points_2d = points_2d * 2 - 1
        points_2d = points_2d.flatten(end_dim=1)

        features = []
        for fm in feature_maps:
            features.append(
                torch.nn.functional.grid_sample(fm.flatten(end_dim=1), points_2d)
            )
        features = torch.stack(features, dim=1)
        features = features.reshape(
            bs, num_cams, num_levels, -1, num_anchor, num_pts
        ).permute(
            0, 4, 1, 2, 5, 3
        )  # bs, num_anchor, num_cams, num_levels, num_pts, embed_dims

        return features

    def multi_view_level_fusion(
        self,
        features: torch.Tensor,
        weights: torch.Tensor,
    ):
        bs, num_anchor = weights.shape[:2]
        features = weights[..., None] * features.reshape(
            features.shape[:-1] + (self.num_groups, self.group_dims)
        )
        features = features.sum(dim=2).sum(dim=2)
        features = features.reshape(bs, num_anchor, self.num_pts, self.embed_dims)
        return features
