# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple, Type

import torch
from torch import nn
import torch.nn.functional as F  # ## SPATIAL MOD ##: Import F for functional API

from sam2.modeling.sam2_utils import LayerNorm2d, MLP

# ## SPATIAL MOD ##: A new helper module for the core fusion logic.
# This could be a more complex Transformer layer, but a single MHA is a great starting point.


class SpatialFusionAttention(nn.Module):
    """
    Performs cross-attention from an observer slice to a window of target slices,
    injecting relative positional embeddings (RPE) into the attention mechanism.
    """

    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(
        self,
        observer_feats: torch.Tensor,
        target_feats: torch.Tensor,
        rpe_bias: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            observer_feats (torch.Tensor): The query. Shape (B, 1, D).
            target_feats (torch.Tensor): The key and value. Shape (B, N, D).
            rpe_bias (torch.Tensor): The relative position bias. Shape (B, NumHeads, 1, N).

        Returns:
            torch.Tensor: The context-enriched observer feature. Shape (B, 1, D).
        """
        # The MultiheadAttention layer can accept an `attn_mask` which is added to the
        # QK^T matrix. We use this to inject our RPE bias.
        # The expected shape for the mask is (B * NumHeads, 1, N), so we reshape.
        B, N, _ = target_feats.shape
        num_heads = self.attention.num_heads
        rpe_bias_reshaped = rpe_bias.reshape(B * num_heads, 1, N)

        # The MHA layer returns the attention output and optionally the attention weights.
        # We only need the output here.
        context_enriched_feature, _ = self.attention(
            query=observer_feats,
            key=target_feats,
            value=target_feats,
            attn_mask=rpe_bias_reshaped,
        )
        return context_enriched_feature


class FusionMaskDecoder(nn.Module):  # SPATIAL MOD ##: Renamed from MaskDecoder to SpatialMaskDecoder
    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        # ## SPATIAL MOD ##: New parameters for our spatial fusion logic
        spatial_fusion_num_heads: int = 8,
        max_rel_pos: int = 16,  # Max relative distance (e.g., for a window of 33 slices)
        # -- End of new parameters
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
        use_high_res_features: bool = False,
        iou_prediction_use_sigmoid=False,
        dynamic_multimask_via_stability=False,
        dynamic_multimask_stability_delta=0.05,
        dynamic_multimask_stability_thresh=0.98,
        pred_obj_scores: bool = False,
        pred_obj_scores_mlp: bool = False,
        use_multimask_token_for_obj_ptr: bool = False,
    ) -> None:
        """
        ## SPATIAL MOD ##: Updated docstring
        Predicts masks given an image and prompt embeddings. This version is modified
        to accept a spatial context window, enriching the features of the "observer"
        slice before prediction. It also adds an auxiliary task to predict relative
        slice positions.
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        # ## SPATIAL MOD ##: Initialize new modules for spatial reasoning
        self.max_rel_pos = max_rel_pos
        self.num_rel_pos = 2 * self.max_rel_pos + 1  # Total unique positions (-k, ..., 0, ..., +k)

        # 1. RPE Embedding Layer
        # The embedding dim is set to `spatial_fusion_num_heads` because RPE bias is often
        # applied per-head. The embedding will be projected to create the bias matrix.
        self.relative_pos_embedding = nn.Embedding(self.num_rel_pos, spatial_fusion_num_heads)

        # 2. Spatial Fusion Module
        # This module will perform the core observer-target attention.
        self.spatial_fusion = SpatialFusionAttention(
            embed_dim=transformer_dim,
            num_heads=spatial_fusion_num_heads
        )

        # # 3. Position Prediction Head and Loss
        # # This MLP predicts the relative position of each target slice.
        # # It takes the enriched feature and outputs logits for each possible relative position.
        # self.position_prediction_head = MLP(transformer_dim, 256, self.num_rel_pos, 3)
        # self.position_loss_fn = nn.CrossEntropyLoss()
        # # -- End of new module initialization

        self.num_multimask_outputs = num_multimask_outputs

        self.iou_token = nn.Embedding(1, transformer_dim)
        self.num_mask_tokens = num_multimask_outputs + 1
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

        self.pred_obj_scores = pred_obj_scores
        if self.pred_obj_scores:
            self.obj_score_token = nn.Embedding(1, transformer_dim)
        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr

        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(
                transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
            ),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(
                transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
            ),
            activation(),
        )
        self.use_high_res_features = use_high_res_features
        if use_high_res_features:
            self.conv_s0 = nn.Conv2d(
                transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
            )
            self.conv_s1 = nn.Conv2d(
                transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
            )

        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )

        self.iou_prediction_head = MLP(
            transformer_dim,
            iou_head_hidden_dim,
            self.num_mask_tokens,
            iou_head_depth,
            sigmoid_output=iou_prediction_use_sigmoid,
        )
        if self.pred_obj_scores:
            self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
            if pred_obj_scores_mlp:
                self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)

        self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
        self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
        self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh

    # ## SPATIAL MOD ##: The `forward` signature is completely changed.
    # It now accepts the `spatial_out` dict you created.
    def forward(
        self,
        image_embeddings: torch.Tensor,
        spatial_out: dict,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
        repeat_image: bool,
        # high_res_features is now handled inside, using observer_fpn
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        ## SPATIAL MOD ##: New forward pass.
        1. Fuses spatial context into observer features using attention with RPE.
        2. Calculates the auxiliary position prediction loss.
        3. Passes the enriched features to the original mask prediction logic.
        """
        # Step 1: Spatial Fusion and Position Prediction
        # --------------------------------------------------------------------------
        context_enriched_features, pos_loss = self._spatial_fusion_forward(spatial_out)

        # Step 2: Prepare inputs for the original SAM decoder logic
        # --------------------------------------------------------------------------
        # The `context_enriched_features` now serves as the `image_embeddings`
        # for the rest of the prediction process.
        image_embeddings = context_enriched_features

        # High-resolution features must be selected for the observer slice
        if self.use_high_res_features:
            # `observer_fpn` is a list of tensors of shape (B*C, N, FPN_D, FPN_H, FPN_W)
            # We only need the FPN features for the observer, which is at index 0 of the N dim
            # because of how `observer_features_expanded` was created.
            high_res_features = [fpn_level[:, 0, ...] for fpn_level in spatial_out['observer_fpn']]
        else:
            high_res_features = None

        # Step 3: Run the original mask prediction logic
        # --------------------------------------------------------------------------
        masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
            repeat_image=repeat_image,
            high_res_features=high_res_features,
        )

        # Step 4: Post-process masks and tokens (original logic, unchanged)
        # --------------------------------------------------------------------------
        if multimask_output:
            masks = masks[:, 1:, :, :]
            iou_pred = iou_pred[:, 1:]
        elif self.dynamic_multimask_via_stability and not self.training:
            masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
        else:
            masks = masks[:, 0:1, :, :]
            iou_pred = iou_pred[:, 0:1]

        if multimask_output and self.use_multimask_token_for_obj_ptr:
            sam_tokens_out = mask_tokens_out[:, 1:]
        else:
            sam_tokens_out = mask_tokens_out[:, 0:1]

        # ## SPATIAL MOD ##: Return the new position loss alongside original outputs
        return masks, iou_pred, sam_tokens_out, object_score_logits, pos_loss

    # ## SPATIAL MOD ##: This is the new method containing our core innovation.
    def _spatial_fusion_forward(self, spatial_out: dict) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Performs spatial fusion attention and computes the position prediction loss.
        """
        # Unpack features from the dictionary you created
        observer_features = spatial_out['observer_features']  # (B*C, N, D, Hf, Wf)
        target_features = spatial_out['target_features']     # (B*C, N, D, Hf, Wf)
        relative_positions = spatial_out['relative_positions']  # (B*C, N)

        B_eff, N, D, Hf, Wf = observer_features.shape

        # The observer feature is the same across the N dimension, so we take the first one.
        # This will be our Query.
        # Reshape for attention: (B*C, D, Hf, Wf) -> (B*C, D, Hf*Wf) -> (B*C, Hf*Wf, D)
        query = observer_features[:, 0, ...].flatten(2).permute(0, 2, 1)

        # The target features are our Keys and Values.
        # Reshape for attention: (B*C, N, D, Hf, Wf) -> (B*C, N, D, Hf*Wf) -> (B*C, N, Hf*Wf, D)
        # We need to average pool over the spatial dimensions (Hf, Wf) to get a single
        # vector per slice. This is a common strategy.
        # Let's use adaptive average pooling to be robust to feature map size.
        pool = nn.AdaptiveAvgPool2d(1)
        query_pooled = pool(observer_features[:, 0, ...]).flatten(1).unsqueeze(1)  # (B*C, 1, D)
        target_pooled = pool(target_features.view(-1, D, Hf, Wf)).view(B_eff, N, D)  # (B*C, N, D)

        # --- RPE Calculation ---
        # 1. Clip relative positions to the supported range [-k, k]
        rel_pos_clipped = torch.clamp(relative_positions, -self.max_rel_pos, self.max_rel_pos)
        # 2. Convert to embedding indices [0, 2k]
        rel_pos_indices = rel_pos_clipped + self.max_rel_pos
        # 3. Look up embeddings
        rpe_embeddings = self.relative_pos_embedding(rel_pos_indices)  # (B*C, N, NumHeads)
        # 4. Reshape for MHA attn_mask: (B*C, N, H) -> (B*C, H, 1, N) for broadcasting
        rpe_bias = rpe_embeddings.permute(0, 2, 1).unsqueeze(2)

        # --- Spatial Fusion Attention ---
        # The output `context_feature` is the observer's feature enriched by its neighbors.
        # Shape: (B*C, 1, D)
        context_feature_pooled, _ = self.spatial_fusion.attention(
            query=query_pooled,
            key=target_pooled,
            value=target_pooled,
            attn_mask=rpe_bias.reshape(B_eff * self.spatial_fusion.num_heads, 1, N),
        )

        # We need to combine this context back with the original high-resolution observer feature.
        # A simple and effective way is to add it back (like a residual connection).
        # Reshape context feature and add to the original, unpooled observer feature map.
        # Note: This enriched feature will then be passed to the main SAM transformer.
        context_enrichment = context_feature_pooled.squeeze(1).unsqueeze(-1).unsqueeze(-1)  # (B*C, D, 1, 1)
        final_enriched_feature = observer_features[:, 0, ...] + context_enrichment  # (B*C, D, Hf, Wf)

        # --- Auxiliary Position Prediction Loss ---
        # We use the attention-infused target features (`target_pooled`) to predict positions.
        # In a real transformer, this would be the output of the Keys/Values from the decoder.
        # Here we use the pooled target features as a proxy.
        predicted_pos_logits = self.position_prediction_head(target_pooled)  # (B*C, N, num_classes)

        # Flatten for loss calculation
        # Logits: (B*C * N, num_classes), Target: (B*C * N)
        loss = self.position_loss_fn(
            predicted_pos_logits.view(-1, self.num_rel_pos),
            rel_pos_indices.view(-1),
        )

        return final_enriched_feature, loss

    # All other methods (`predict_masks`, `_get_stability_scores`, etc.) remain unchanged.
    # They are inherited or can be copied directly from the original `MaskDecoder`.
    # I am omitting them here for brevity, but they must be included in your final class definition.
    # ... (predict_masks, _get_stability_scores, _dynamic_multimask_via_stability) ...
    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        repeat_image: bool,
        high_res_features: Optional[List[torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        s = 0
        if self.pred_obj_scores:
            output_tokens = torch.cat(
                [
                    self.obj_score_token.weight,
                    self.iou_token.weight,
                    self.mask_tokens.weight,
                ],
                dim=0,
            )
            s = 1
        else:
            output_tokens = torch.cat(
                [self.iou_token.weight, self.mask_tokens.weight], dim=0
            )
        output_tokens = output_tokens.unsqueeze(0).expand(
            sparse_prompt_embeddings.size(0), -1, -1
        )
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        if repeat_image:
            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        else:
            assert image_embeddings.shape[0] == tokens.shape[0]
            src = image_embeddings
        src = src + dense_prompt_embeddings
        assert (
            image_pe.size(0) == 1
        ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        hs, src = self.transformer(src, pos_src, tokens)
        iou_token_out = hs[:, s, :]
        mask_tokens_out = hs[:, s + 1: (s + 1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)
        if not self.use_high_res_features:
            upscaled_embedding = self.output_upscaling(src)
        else:
            dc1, ln1, act1, dc2, act2 = self.output_upscaling
            feat_s0, feat_s1 = high_res_features
            upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
            upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)

        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(
                self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
            )
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)
        if self.pred_obj_scores:
            assert s == 1
            object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
        else:
            object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)

        return masks, iou_pred, mask_tokens_out, object_score_logits

    def _get_stability_scores(self, mask_logits):
        """(Omitted for brevity - same as original)"""
        mask_logits = mask_logits.flatten(-2)
        stability_delta = self.dynamic_multimask_stability_delta
        area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
        area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
        stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
        return stability_scores

    def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
        """(Omitted for brevity - same as original)"""
        multimask_logits = all_mask_logits[:, 1:, :, :]
        multimask_iou_scores = all_iou_scores[:, 1:]
        best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
        batch_inds = torch.arange(
            multimask_iou_scores.size(0), device=all_iou_scores.device
        )
        best_multimask_logits = multimask_logits[batch_inds, best_scores_inds].unsqueeze(1)
        best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds].unsqueeze(1)
        singlemask_logits = all_mask_logits[:, 0:1, :, :]
        singlemask_iou_scores = all_iou_scores[:, 0:1]
        stability_scores = self._get_stability_scores(singlemask_logits)
        is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
        mask_logits_out = torch.where(
            is_stable[..., None, None].expand_as(singlemask_logits),
            singlemask_logits,
            best_multimask_logits,
        )
        iou_scores_out = torch.where(
            is_stable.expand_as(singlemask_iou_scores),
            singlemask_iou_scores,
            best_multimask_iou_scores,
        )
        return mask_logits_out, iou_scores_out
