import torch, torch.nn as nn
from mmdet.registry import MODELS
from mmengine.model import BaseModule
from torch.cuda.amp import autocast
from .utils import safe_inverse_sigmoid, safe_sigmoid


@MODELS.register_module()
class GaussianInit(BaseModule):
    def __init__(
        self,
        num_anchor,
        embed_dims,
        anchor_grad=True,
        feat_grad=True,
        semantics=False,
        semantic_dim=None,
        include_opa=True,
        xy_activation="sigmoid",
        scale_activation="sigmoid",
        pc_range=[-32, -32, 0.2, 32, 32, 20.2],
        random_sampling=True,
        random_samples_front=448,
        random_samples_back=128,
        # implicit_smaples=64,
        **kwargs,
    ):
        super().__init__()
        self.embed_dims = embed_dims
        self.xy_act = xy_activation
        self.scale_act = scale_activation
        self.include_opa = include_opa
        self.semantics = semantics
        self.semantic_dim = semantic_dim

        self.num_candidate_anchor = num_anchor  # * 4
        self.random_samples_front = random_samples_front
        self.random_samples_back = random_samples_back
        self.random_samples = random_samples_front + random_samples_back
        if self.random_samples > 0:
            self.random_anchors = self.init_random_anchors()

        scale = torch.ones(num_anchor, 2, dtype=torch.float) * 0.5
        if scale_activation == "sigmoid":
            scale = safe_inverse_sigmoid(scale)

        rots = torch.zeros(num_anchor, 2, dtype=torch.float)
        rots[:, 0] = 0.5
        rots[:, 1] = 1
        rots = safe_inverse_sigmoid(rots)

        if include_opa:
            opacity = safe_inverse_sigmoid(
                0.5 * torch.ones((num_anchor, 1), dtype=torch.float)
            )
        else:
            opacity = torch.ones((num_anchor, 0), dtype=torch.float)

        if semantics:
            assert semantic_dim is not None
        else:
            semantic_dim = 0
        semantic = torch.randn(num_anchor, semantic_dim, dtype=torch.float)
        anchor = torch.cat([scale, rots, opacity, semantic], dim=-1)

        self.num_anchor = num_anchor
        self.anchor = nn.Parameter(
            torch.tensor(anchor, dtype=torch.float32),
            requires_grad=anchor_grad,
        )
        self.instance_feature = nn.Parameter(
            torch.zeros([num_anchor + self.random_samples, self.embed_dims]),
            requires_grad=feat_grad,
        )
        self.implicit_feature = nn.Parameter(
            torch.zeros([random_samples_back, self.embed_dims]),
            requires_grad=feat_grad,
        )

        self.pc_range = pc_range
        self.random_sampling = random_sampling

    def init_random_anchors(self):
        num_anchor = self.random_samples

        # xy = torch.rand(num_anchor, 2, dtype=torch.float)
        x_front = (
            torch.rand(self.random_samples_front, 1, dtype=torch.float) * 0.5 + 0.5
        )
        y_front = torch.rand(self.random_samples_front, 1, dtype=torch.float)
        xy_front = torch.cat([x_front, y_front], dim=-1)
        x_back = torch.rand(self.random_samples_back, 1, dtype=torch.float) * 0.5
        y_back = torch.rand(self.random_samples_back, 1, dtype=torch.float) * 0.5 + 0.25
        xy_back = torch.cat([x_back, y_back], dim=-1)
        xy = torch.cat([xy_front, xy_back], dim=0)

        if self.xy_act == "sigmoid":
            xy = safe_inverse_sigmoid(xy)

        scale = torch.ones(num_anchor, 2, dtype=torch.float) * 0.5
        if self.scale_act == "sigmoid":
            scale = safe_inverse_sigmoid(scale)

        rots = torch.zeros(num_anchor, 2, dtype=torch.float)
        rots[:, 0] = 0.5
        rots[:, 1] = 1
        rots = safe_inverse_sigmoid(rots)

        if self.include_opa:
            opacity = safe_inverse_sigmoid(
                0.5 * torch.ones((num_anchor, 1), dtype=torch.float)
            )
        else:
            opacity = torch.ones((num_anchor, 0), dtype=torch.float)

        if self.semantics:
            semantic_dim = self.semantic_dim
            assert semantic_dim is not None
        else:
            semantic_dim = 0
        semantic = torch.randn(num_anchor, semantic_dim, dtype=torch.float)
        anchor = torch.cat([xy, scale, rots, opacity, semantic], dim=-1)
        anchor = nn.Parameter(anchor, False)
        return anchor

    def init_weights(self):
        if self.instance_feature.requires_grad:
            torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)
        if self.implicit_feature.requires_grad:
            torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)

    def forward(self, bev_features, **kwargs):
        """
        A 2D network is utilized to predict the initial occ probability (the position of anchor points are then obtained).
        Anchor features are randomly initalized.
        """
        B, C, H, W = bev_features.shape

        magnitude = bev_features.abs().sum(dim=1, keepdim=True)

        magnitude_flat = magnitude.view(B, -1)  # (B, H*W)
        topk_values, topk_indices = torch.topk(
            magnitude_flat, self.num_candidate_anchor, dim=1
        )  # (B, N)

        topk_y_coord = topk_indices // W  # (B, N)
        topk_x_coord = topk_indices % W  # (B, N)

        topk_y = (topk_y_coord.float() + 0.5) / H * (
            self.pc_range[4] - self.pc_range[1]
        ) + self.pc_range[1]

        topk_x = (topk_x_coord.float() + 0.5) / W * (
            self.pc_range[3] - self.pc_range[0]
        ) + self.pc_range[0]

        candidate_xy = torch.stack([topk_x, topk_y], dim=-1)

        anchor_xy = candidate_xy

        anchor_xy[..., 0] = (anchor_xy[..., 0] - self.pc_range[0]) / (
            self.pc_range[3] - self.pc_range[0]
        )
        anchor_xy[..., 1] = (anchor_xy[..., 1] - self.pc_range[1]) / (
            self.pc_range[4] - self.pc_range[1]
        )

        if self.xy_act == "sigmoid":
            xy = safe_inverse_sigmoid(anchor_xy)
        anchor = torch.cat([xy, torch.tile(self.anchor[None], (B, 1, 1))], dim=-1)

        if self.random_samples > 0:
            random_anchors = torch.tile(self.random_anchors[None], (B, 1, 1))
            anchor = torch.cat([anchor, random_anchors], dim=1)

        instance_feature = torch.tile(self.instance_feature[None], (B, 1, 1))
        implicit_feature = torch.tile(self.implicit_feature[None], (B, 1, 1))
        # implicit_feature = None

        return (anchor, instance_feature, implicit_feature)
