from __future__ import annotations

"""
Runtime patches for third-party libraries.

Currently we fix a couple of `torch.Tensor.view(...)` usages inside the
Transformers CLIP implementation that break autograd when we request
gradients for intermediate hidden states. The upstream code assumes a
contiguous tensor, but hidden states captured for TCAV analysis are
often non-contiguous, which leads to the infamous
`RuntimeError: view size is not compatible ...`.

We monkey-patch the relevant methods to use `.reshape(...)` (which
handles non-contiguous tensors by materialising a contiguous copy when
needed) so that backpropagating to hidden states succeeds.
"""

from typing import Optional, Callable

_CLIP_PATCHED = False


def patch_clip_for_tcav() -> None:
    """
    Idempotently patch HuggingFace CLIP modules so that intermediate
    tensors use `.reshape(...)` instead of `.view(...)`.
    """
    global _CLIP_PATCHED
    if _CLIP_PATCHED:
        return
    try:
        from transformers.models.clip import modeling_clip  # type: ignore
    except Exception:
        # Transformers not available in the current environment.
        return

    if getattr(modeling_clip, "_tcav_clip_patched", False):
        _CLIP_PATCHED = True
        return

    torch = modeling_clip.torch
    nn = modeling_clip.nn
    ALL_ATTENTION_FUNCTIONS = modeling_clip.ALL_ATTENTION_FUNCTIONS
    eager_attention_forward = modeling_clip.eager_attention_forward
    logger = modeling_clip.logger

    # ------------------------------------------------------------------
    # Patch CLIPAttention.forward: swap `.view` calls for `.reshape`
    # ------------------------------------------------------------------
    def patched_attention_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ):
        batch_size, seq_length, embed_dim = hidden_states.shape

        queries = self.q_proj(hidden_states)
        keys = self.k_proj(hidden_states)
        values = self.v_proj(hidden_states)

        queries = queries.reshape(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
        keys = keys.reshape(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
        values = values.reshape(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)

        # CLIP text model uses both `causal_attention_mask` and `attention_mask`
        if self.config._attn_implementation == "flash_attention_2":
            self.is_causal = causal_attention_mask is not None
        else:
            if attention_mask is not None and causal_attention_mask is not None:
                attention_mask = attention_mask + causal_attention_mask
            elif causal_attention_mask is not None:
                attention_mask = causal_attention_mask

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and output_attentions:
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support "
                    "`output_attentions=True`. Falling back to eager attention. This warning can be "
                    'removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            queries,
            keys,
            values,
            attention_mask,
            is_causal=self.is_causal,
            scaling=self.scale,
            dropout=0.0 if not self.training else self.dropout,
            output_attentions=output_attentions,
        )

        attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
        attn_output = self.out_proj(attn_output)

        if not output_attentions:
            attn_weights = None
        return attn_output, attn_weights

    modeling_clip.CLIPAttention.forward = patched_attention_forward  # type: ignore

    # ------------------------------------------------------------------
    # Patch CLIPVisionEmbeddings.interpolate_pos_encoding
    # ------------------------------------------------------------------
    def patched_interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        num_patches = embeddings.shape[1] - 1
        position_embedding = self.position_embedding.weight.unsqueeze(0)
        num_positions = position_embedding.shape[1] - 1

        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embedding(self.position_ids)

        class_pos_embed = position_embedding[:, :1]
        patch_pos_embed = position_embedding[:, 1:]

        dim = embeddings.shape[-1]
        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = modeling_clip.torch_int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode="bicubic",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding = patched_interpolate_pos_encoding  # type: ignore

    modeling_clip._tcav_clip_patched = True
    _CLIP_PATCHED = True
