from typing import Any, Dict, Optional

import torch
import torch.nn as nn

from model.vision.basic_modules import _get_clones, init_weights
from model.vision.transformers import TransformerEncoderLayer
from model.vision.lpss_self_diff import DatasetAwareEarlyLPSS
from pipeline.registry import registry


@registry.register_vision_model("unified_encoder_v2")
class UnifiedSpatialCrossEncoderV2(nn.Module):
    """
    Unified encoder for 3D visual-language understanding with self-diff Early LPSS.
    """

    def __init__(
        self,
        hidden_size: int = 768,
        num_attention_heads: int = 12,
        num_layers: int = 4,
        dim_loc: int = 6,
        
        use_self_diff_lpss: bool = False,
        self_diff_lpss_num_subspaces: int = 6,
        self_diff_lpss_txt_len: int = 50,
        self_diff_lpss_obj_len: int = 120,
        self_diff_lpss_temperature: float = 0.5,
        self_diff_lpss_bias_scale_init: float = 12.0,
        self_diff_lpss_dropout: float = 0.1,
        self_diff_lpss_use_gumbel: bool = True,
        self_diff_lpss_top_k: int = 2,
        self_diff_lpss_min_weight: float = 0.02,
        self_diff_lpss_grad_scale: float = 10.0,
        self_diff_lpss_orthogonal_weight: float = 0.1,
        self_diff_lpss_entropy_weight: float = 0.2,
        self_diff_lpss_diversity_weight: float = 0.05,
        self_diff_lpss_use_enhanced_routing: bool = True,
        self_diff_lpss_enabled_subspaces: list = None,
        self_diff_lpss_use_cross_structure: bool = True,
        self_diff_lpss_use_txt_structure: bool = True,
        self_diff_lpss_use_obj_structure: bool = True,
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.use_self_diff_lpss = use_self_diff_lpss

        unified_encoder_layer = TransformerEncoderLayer(hidden_size, num_attention_heads)
        self.unified_encoder = _get_clones(unified_encoder_layer, num_layers)

        loc_layer = nn.Sequential(
            nn.Linear(dim_loc, hidden_size),
            nn.LayerNorm(hidden_size),
        )
        self.loc_layers = _get_clones(loc_layer, 1)

        self.token_type_embeddings = nn.Embedding(2, hidden_size)

        if use_self_diff_lpss:
            self.self_diff_lpss = DatasetAwareEarlyLPSS(
                num_subspaces=self_diff_lpss_num_subspaces,
                txt_len=self_diff_lpss_txt_len,
                obj_len=self_diff_lpss_obj_len,
                num_heads=num_attention_heads,
                routing_dim=hidden_size,
                hidden_dim=256,
                dropout=self_diff_lpss_dropout,
                temperature=self_diff_lpss_temperature,
                bias_scale_init=self_diff_lpss_bias_scale_init,
                use_gumbel=self_diff_lpss_use_gumbel,
                top_k=self_diff_lpss_top_k,
                min_weight=self_diff_lpss_min_weight,
                grad_scale=self_diff_lpss_grad_scale,
                orthogonal_weight=self_diff_lpss_orthogonal_weight,
                entropy_weight=self_diff_lpss_entropy_weight,
                diversity_weight=self_diff_lpss_diversity_weight,
                use_enhanced_routing=self_diff_lpss_use_enhanced_routing,
                enabled_subspaces=self_diff_lpss_enabled_subspaces,
                use_cross_structure=self_diff_lpss_use_cross_structure,
                use_txt_structure=self_diff_lpss_use_txt_structure,
                use_obj_structure=self_diff_lpss_use_obj_structure,
            )
            self._last_self_diff_lpss_info = None
        else:
            self.self_diff_lpss = None
            self._last_self_diff_lpss_info = None

        self._last_early_lpss_info = None

        self.apply(init_weights)

    def forward(
        self,
        txt_embeds,
        txt_masks,
        obj_embeds,
        obj_locs,
        obj_masks,
    ):
        txt_len = txt_embeds.shape[1]
        obj_len = obj_embeds.shape[1]

        early_lpss_attn_mask = None
        if self.use_self_diff_lpss and self.self_diff_lpss is not None:
            early_lpss_attn_mask, self._last_self_diff_lpss_info = self.self_diff_lpss(
                txt_embeds, txt_masks, return_info=True
            )
            self._last_early_lpss_info = self._last_self_diff_lpss_info

        for i, unified_layer in enumerate(self.unified_encoder):
            query_pos = self.loc_layers[0](obj_locs)
            pc_token_type_ids = torch.ones((obj_embeds.shape[0:2])).long().cuda()
            pc_type_embeds = self.token_type_embeddings(pc_token_type_ids)
            obj_embeds = obj_embeds + query_pos + pc_type_embeds

            lang_token_type_ids = torch.zeros((txt_embeds.shape[0:2])).long().cuda()
            lang_type_embeds = self.token_type_embeddings(lang_token_type_ids)
            txt_embeds = txt_embeds + lang_type_embeds

            joint_embeds = torch.cat((txt_embeds, obj_embeds), dim=1)
            joint_masks = torch.cat((txt_masks, obj_masks), dim=1)

            layer_attn_mask = early_lpss_attn_mask if i == 0 else None
            joint_embeds, _ = unified_layer(
                joint_embeds,
                tgt_mask=layer_attn_mask,
                tgt_key_padding_mask=joint_masks.logical_not(),
            )

            txt_embeds, obj_embeds = torch.split(joint_embeds, [txt_len, obj_len], dim=1)

        return txt_embeds, obj_embeds

    def get_early_lpss_diagnostics(self) -> Optional[Dict[str, Any]]:
        if self.self_diff_lpss is not None:
            metrics = self.self_diff_lpss.get_monitor_metrics()
            return metrics.to_dict()
        return None

    def get_last_early_lpss_info(self) -> Optional[Dict[str, Any]]:
        return getattr(self, "_last_early_lpss_info", None)

    def get_self_diff_lpss_auxiliary_loss(self) -> Optional[torch.Tensor]:
        if self.self_diff_lpss is not None:
            return self.self_diff_lpss.auxiliary_loss()
        return None

    def get_self_diff_lpss_diagnostics(self) -> Optional[Dict[str, Any]]:
        if self.self_diff_lpss is not None:
            metrics = self.self_diff_lpss.get_monitor_metrics()
            return metrics.to_dict()
        return None

    def get_last_self_diff_lpss_info(self) -> Optional[Dict[str, Any]]:
        return getattr(self, "_last_self_diff_lpss_info", None)
