from mmengine.registry import MODELS
from mmengine.model import BaseModule
from mmcv.cnn import Scale
from functools import partial
import torch.nn as nn, torch
import torch.nn.functional as F
from .utils import linear_relu_ln, GaussianPrediction, cartesian, reverse_cartesian
from .utils import safe_sigmoid

# from .vis_utils import vis_gaussian_in_img


@MODELS.register_module()
class SparseGaussianRefinementModuleV2(BaseModule):
    def __init__(
        self,
        embed_dims=256,
        pc_range=None,
        scale_range=None,
        unit_xy=None,
        semantics=False,
        semantic_dim=None,
        include_opa=True,
        include_ele=True,
        semantics_activation="softmax",
        xy_activation="sigmoid",
        scale_activation="sigmoid",
        **kwargs,
    ):
        super().__init__()
        self.embed_dims = embed_dims

        if semantics:
            assert semantic_dim is not None
        else:
            semantic_dim = 0

        self.output_dim = 8 + int(include_opa) + int(include_ele) + semantic_dim
        self.semantic_start = 8 + int(include_opa)
        self.semantic_dim = semantic_dim
        self.include_opa = include_opa
        self.include_ele = include_ele
        self.semantics_activation = semantics_activation
        self.xy_act = xy_activation
        self.scale_act = scale_activation

        self.pc_range = pc_range
        self.scale_range = scale_range
        self.register_buffer("unit_xy", torch.tensor(unit_xy, dtype=torch.float), False)
        self.get_xy = partial(
            cartesian, pc_range=pc_range, use_sigmoid=(xy_activation == "sigmoid")
        )
        self.reverse_xy = partial(
            reverse_cartesian,
            pc_range=pc_range,
            use_sigmoid=(xy_activation == "sigmoid"),
        )

        self.layers = nn.Sequential(
            *linear_relu_ln(embed_dims, 2, 2),
            nn.Linear(self.embed_dims, self.output_dim),
            Scale([1.0] * self.output_dim),
        )

    def forward(
        self,
        instance_feature: torch.Tensor,
        implicit_features: torch.Tensor,
        anchor: torch.Tensor,
        anchor_embed: torch.Tensor,
        **kwargs,
    ):
        output = self.layers(instance_feature + anchor_embed)

        #### for xy
        delta_xy = (2 * safe_sigmoid(output[..., :2]) - 1.0) * self.unit_xy[None, None]
        original_xy = self.get_xy(anchor[..., :2])
        anchor_xy = original_xy + delta_xy
        anchor_xy = self.reverse_xy(anchor_xy)

        #### for scale
        anchor_scale = output[..., 2:4]

        #### for rotation
        anchor_rotation = output[..., 4:6]
        # anchor_rotation = 2 * safe_sigmoid(anchor_rotation) - 1.0

        #### for opacity
        anchor_opa = output[..., 6 : (6 + int(self.include_opa))]

        #### for semantic
        anchor_sem = output[
            ..., self.semantic_start : (self.semantic_start + self.semantic_dim)
        ]

        output = torch.cat(
            [anchor_xy, anchor_scale, anchor_rotation, anchor_opa, anchor_sem], dim=-1
        )

        xy = self.get_xy(anchor_xy)

        if self.scale_act == "sigmoid":
            scale = safe_sigmoid(anchor_scale)
        scale = (
            self.scale_range[0] + (self.scale_range[1] - self.scale_range[0]) * scale
        )

        if self.semantics_activation == "softmax":
            semantics = anchor_sem.softmax(dim=-1)
        elif self.semantics_activation == "softplus":
            semantics = F.softplus(anchor_sem)
        else:
            semantics = anchor_sem

        gaussian = GaussianPrediction(
            means=xy,
            scales=scale,
            rotations=2 * safe_sigmoid(anchor_rotation) - 1,
            opacities=safe_sigmoid(anchor_opa),
            semantics=semantics,
            original_means=original_xy,
            delta_means=delta_xy,
            features=instance_feature,
            im_features=implicit_features,
        )
        return output, gaussian  # , semantics
